Oleh convert (#200)

* StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* StringUtils for utf convertor more corrections to support convertors

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading

* StringUtils for utf convertor #8613 tests added some corrections to optimize build

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* StringUtils for utf convertor #8613 some corrections and code clean up

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* StringUtils for utf convertor #8613 minor corrections and bug fix before discussion

* StringUtils for utf convertor #8613 some fixes and tets

* StringUtils for utf convertor #8613 some more staff to support different unicode

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* StringUtils for utf convertor #8613 fix linking bug

* StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed

* StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing

* StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization)

* StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* StringUtils for utf convertor #8613 merge master and sync with server

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up

* StringUtils for utf convertor #8613 fixed cast and copy issues

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* StringUtils for utf convertor #8613 fixed cuda and update tests

* StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc

* - avoid ambiguity of NDArray ctrs overloading in some tests

Signed-off-by: Yurii <iuriish@yahoo.com>

* StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* StringUtils for utf convertor #8613  void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* StringUtils for utf convertor #8613 several more fixes, improvements and updates

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* StringUtils for utf convertor #8613 master merge, code clean up and optimization before review

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* StringUtils for utf convertor #8613 minor fixes string element size define

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* StringUtils for utf convertor #8613 revert last changes as mistake

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>
master
Oleh 2020-01-31 15:30:49 +02:00 committed by GitHub
parent 00cd61f32d
commit d52e67209e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 2992 additions and 464 deletions

View File

@ -195,6 +195,56 @@ namespace nd4j {
NDArray(std::shared_ptr<DataBuffer> buffer, const char order, const std::vector<Nd4jLong> &shape, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext()); NDArray(std::shared_ptr<DataBuffer> buffer, const char order, const std::vector<Nd4jLong> &shape, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
/**
* This contructors create scalar array containing string utf8
*
*/
NDArray(const char* str, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext())
: NDArray(std::string(str), dtype, context) {
}
NDArray(const std::string& string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
/**
* This contructors create scalar array containing string utf16
*
*/
NDArray(const char16_t* u16string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext())
: NDArray(std::u16string(u16string), dtype, context) {
}
NDArray(const std::u16string& u16string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
/**
* This contructors create scalar array containing string utf32
*
*/
NDArray(const char32_t* u32string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext())
: NDArray(std::u32string(u32string), dtype, context) {
}
NDArray(const std::u32string& u32string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
/**
* This contructors create array from vector of utf8 strings
*
*/
NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const char*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::string>& string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
/**
* This contructors create array from vector of utf16 strings
*
*/
NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const char16_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::u16string>& string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
/**
* This contructors create array from vector of utf32 strings
*
*/
NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const char32_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::u32string>& string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
#endif #endif
/** /**
@ -250,7 +300,6 @@ namespace nd4j {
*/ */
NDArray(void *buffer, const char order, const std::vector<Nd4jLong> &shape, nd4j::DataType dtype, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext(), const bool isBuffAlloc = false); NDArray(void *buffer, const char order, const std::vector<Nd4jLong> &shape, nd4j::DataType dtype, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext(), const bool isBuffAlloc = false);
/** /**
* This method returns new array with the same shape & data type * This method returns new array with the same shape & data type
* @return * @return
@ -1148,6 +1197,9 @@ namespace nd4j {
template <typename N> template <typename N>
NDArray asT() const; NDArray asT() const;
template <typename S>
NDArray asS() const;
NDArray asT(DataType dtype) const; NDArray asT(DataType dtype) const;

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019-2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -16,6 +17,7 @@
// //
// Created by raver119 on 2018-09-16. // Created by raver119 on 2018-09-16.
// @author Oleg Semeniv <oleg.semeniv@gmail.com>
// //
#ifndef DEV_TESTS_NDARRAYFACTORY_H #ifndef DEV_TESTS_NDARRAYFACTORY_H
@ -106,25 +108,72 @@ namespace nd4j {
template <typename T> template <typename T>
static NDArray create(char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<T>& data, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); static NDArray create(char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<T>& data, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
static NDArray string(const char *string, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); /**
* This factory create array from utf8 string
* @return NDArray default dataType UTF8
*/
static NDArray string(const char *string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
static NDArray* string_(const char *string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
static NDArray* string_(const std::string &string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
static NDArray string(const std::string& string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray* string_(const char *string, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); /**
* This factory create array from utf16 string
* @return NDArray default dataType UTF16
*/
static NDArray string(const char16_t* u16string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray* string_(const char16_t* u16string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray* string_(const std::u16string& u16string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray string(const std::u16string& u16string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray string(const std::string &string, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); /**
* This factory create array from utf32 string
* @return NDArray default dataType UTF32
*/
static NDArray string(const char32_t* u32string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray* string_(const char32_t* u32string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray* string_(const std::u32string& u32string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray string(const std::u32string& u32string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray* string_(const std::string &string, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); /**
* This factory create array from vector of utf8 strings
* @return NDArray default dataType UTF8
*/
static NDArray string( const std::vector<Nd4jLong> &shape, const std::initializer_list<const char *> &strings, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
static NDArray string( const std::vector<Nd4jLong> &shape, const std::initializer_list<std::string> &string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
static NDArray string( const std::vector<Nd4jLong> &shape, const std::vector<const char *> &strings, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
static NDArray string( const std::vector<Nd4jLong> &shape, const std::vector<std::string> &string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
static NDArray* string_( const std::vector<Nd4jLong> &shape, const std::initializer_list<const char *> &strings, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
static NDArray* string_( const std::vector<Nd4jLong> &shape, const std::initializer_list<std::string> &string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
static NDArray* string_( const std::vector<Nd4jLong> &shape, const std::vector<const char *> &strings, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
static NDArray* string_( const std::vector<Nd4jLong> &shape, const std::vector<std::string> &string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
static NDArray string(char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<const char *> &strings, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); /**
static NDArray string(char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<std::string> &string, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); * This factory create array from vector of utf16 strings
* @return NDArray default dataType UTF16
*/
static NDArray string( const std::vector<Nd4jLong>& shape, const std::initializer_list<const char16_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray string( const std::vector<Nd4jLong>& shape, const std::initializer_list<std::u16string>& string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray string( const std::vector<Nd4jLong>& shape, const std::vector<const char16_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray string( const std::vector<Nd4jLong>& shape, const std::vector<std::u16string>& string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::initializer_list<const char16_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::initializer_list<std::u16string>& string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::vector<const char16_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::vector<std::u16string>& string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray string(char order, const std::vector<Nd4jLong> &shape, const std::vector<const char *> &strings, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); /**
static NDArray string(char order, const std::vector<Nd4jLong> &shape, const std::vector<std::string> &string, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); * This factory create array from vector of utf32 strings
* @return NDArray default dataType UTF32
*/
static NDArray string( const std::vector<Nd4jLong>& shape, const std::initializer_list<const char32_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray string( const std::vector<Nd4jLong>& shape, const std::initializer_list<std::u32string>& string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray string( const std::vector<Nd4jLong>& shape, const std::vector<const char32_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray string( const std::vector<Nd4jLong>& shape, const std::vector<std::u32string>& string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::initializer_list<const char32_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::initializer_list<std::u32string>& string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::vector<const char32_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::vector<std::u32string>& string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
static NDArray* string_(char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<const char *> &strings, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
static NDArray* string_(char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<std::string> &string, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
static NDArray* string_(char order, const std::vector<Nd4jLong> &shape, const std::vector<const char *> &strings, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
static NDArray* string_(char order, const std::vector<Nd4jLong> &shape, const std::vector<std::string> &string, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
static ResultSet createSetOfArrs(const Nd4jLong numOfArrs, const void* buffer, const Nd4jLong* shapeInfo, const Nd4jLong* offsets, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); static ResultSet createSetOfArrs(const Nd4jLong numOfArrs, const void* buffer, const Nd4jLong* shapeInfo, const Nd4jLong* offsets, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019-2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -16,6 +17,7 @@
// //
// Created by GS <sgazeos@gmail.com> on 2018-12-20. // Created by GS <sgazeos@gmail.com> on 2018-12-20.
// @author Oleg Semeniv <oleg.semeniv@gmail.com>
// //
#include <NDArrayFactory.h> #include <NDArrayFactory.h>
@ -25,6 +27,9 @@
#include <ShapeUtils.h> #include <ShapeUtils.h>
#include <type_traits> #include <type_traits>
#include <StringUtils.h>
namespace nd4j { namespace nd4j {
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -85,45 +90,6 @@ namespace nd4j {
template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::vector<uint8_t>& data, nd4j::LaunchContext * context); template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::vector<uint8_t>& data, nd4j::LaunchContext * context);
template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::vector<bool>& data, nd4j::LaunchContext * context); template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::vector<bool>& data, nd4j::LaunchContext * context);
NDArray NDArrayFactory::string(const char *str, nd4j::LaunchContext * context) {
std::string s(str);
return string(s, context);
}
NDArray* NDArrayFactory::string_(const char *str, nd4j::LaunchContext * context) {
return string_(std::string(str), context);
}
NDArray NDArrayFactory::string(const std::string &str, nd4j::LaunchContext * context) {
auto headerLength = ShapeUtils::stringBufferHeaderRequirements(1);
std::shared_ptr<DataBuffer> pBuffer = std::make_shared<DataBuffer>(headerLength + str.length(), DataType::UTF8, context->getWorkspace(), true);
NDArray res(pBuffer, ShapeDescriptor::scalarDescriptor(DataType::UTF8), context);
int8_t* buffer = reinterpret_cast<int8_t*>(res.getBuffer());
auto offsets = reinterpret_cast<Nd4jLong *>(buffer);
offsets[0] = 0;
offsets[1] = str.length();
auto data = buffer + headerLength;
memcpy(data, str.c_str(), str.length());
res.tickWriteHost();
res.syncToDevice();
return res;
}
NDArray* NDArrayFactory::string_(const std::string &str, nd4j::LaunchContext * context) {
auto res = new NDArray();
*res = NDArrayFactory::string(str, context);
return res;
}
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
template<typename T> template<typename T>
NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, nd4j::LaunchContext * context) { NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, nd4j::LaunchContext * context) {
@ -551,91 +517,175 @@ template ND4J_EXPORT NDArray NDArrayFactory::create(uint8_t * buffer, const char
template ND4J_EXPORT NDArray NDArrayFactory::create(int8_t* buffer, const char order, const std::initializer_list<Nd4jLong>& shape, nd4j::LaunchContext * context); template ND4J_EXPORT NDArray NDArrayFactory::create(int8_t* buffer, const char order, const std::initializer_list<Nd4jLong>& shape, nd4j::LaunchContext * context);
template ND4J_EXPORT NDArray NDArrayFactory::create(int16_t* buffer, const char order, const std::initializer_list<Nd4jLong>& shape, nd4j::LaunchContext * context); template ND4J_EXPORT NDArray NDArrayFactory::create(int16_t* buffer, const char order, const std::initializer_list<Nd4jLong>& shape, nd4j::LaunchContext * context);
/////////////////////////////////////////////////////////////////////////////////////
NDArray NDArrayFactory::string(const char16_t* u16string, nd4j::DataType dtype, nd4j::LaunchContext* context) {
return NDArray(u16string, dtype, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray* NDArrayFactory::string_(const char16_t* u16string, nd4j::DataType dtype, nd4j::LaunchContext* context) {
return string_(std::u16string(u16string), dtype, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray* NDArrayFactory::string_(const std::u16string& u16string, nd4j::DataType dtype, nd4j::LaunchContext* context) {
auto res = new NDArray();
*res = NDArray(u16string, dtype, context);
return res;
}
/////////////////////////////////////////////////////////////////////////
NDArray NDArrayFactory::string(const std::u16string& u16string, nd4j::DataType dtype, nd4j::LaunchContext* context) {
return NDArray(u16string, dtype, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray NDArrayFactory::string(const char32_t* u32string, nd4j::DataType dtype, nd4j::LaunchContext* context) {
return NDArray(u32string, dtype, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray* NDArrayFactory::string_(const char32_t* u32string, nd4j::DataType dtype, nd4j::LaunchContext* context) {
return string_(std::u32string(u32string), dtype, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray* NDArrayFactory::string_(const std::u32string& u32string, nd4j::DataType dtype, nd4j::LaunchContext* context) {
auto res = new NDArray();
*res = NDArray(u32string, dtype, context);
return res;
}
/////////////////////////////////////////////////////////////////////////
NDArray NDArrayFactory::string(const std::u32string& u32string, nd4j::DataType dtype, nd4j::LaunchContext* context) {
return NDArray(u32string, dtype, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray NDArrayFactory::string(const char* str, nd4j::DataType dtype, nd4j::LaunchContext* context) {
return NDArray(str, dtype, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray* NDArrayFactory::string_(const char* str, nd4j::DataType dtype, nd4j::LaunchContext* context) {
return string_(std::string(str), dtype, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray* NDArrayFactory::string_(const std::string& str, nd4j::DataType dtype, nd4j::LaunchContext* context) {
auto res = new NDArray();
*res = NDArray(str, dtype, context);
return res;
}
/////////////////////////////////////////////////////////////////////////
NDArray NDArrayFactory::string(const std::string& str, nd4j::DataType dtype, nd4j::LaunchContext* context) {
return NDArray(str, dtype, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray NDArrayFactory::string(const std::vector<Nd4jLong> &shape, const std::initializer_list<const char *> &strings, nd4j::DataType dataType, nd4j::LaunchContext * context) {
return NDArray(shape, std::vector<const char*>(strings), dataType, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray NDArrayFactory::string( const std::vector<Nd4jLong> &shape, const std::vector<const char *> &strings, nd4j::DataType dataType, nd4j::LaunchContext * context) {
return NDArray( shape, strings, dataType, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray NDArrayFactory::string( const std::vector<Nd4jLong> &shape, const std::initializer_list<std::string> &string, nd4j::DataType dataType, nd4j::LaunchContext * context) {
return NDArray( shape, std::vector<std::string>(string), dataType, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray* NDArrayFactory::string_( const std::vector<Nd4jLong> &shape, const std::initializer_list<const char *> &strings, nd4j::DataType dataType, nd4j::LaunchContext * context) {
return NDArrayFactory::string_( shape, std::vector<const char*>(strings), dataType, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray* NDArrayFactory::string_( const std::vector<Nd4jLong> &shape, const std::vector<const char *> &strings, nd4j::DataType dataType, nd4j::LaunchContext * context) {
std::vector<std::string> vec(strings.size());
int cnt = 0;
for (auto s:strings)
vec[cnt++] = std::string(s);
NDArray NDArrayFactory::string(char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<const char *> &strings, nd4j::LaunchContext * context) { return NDArrayFactory::string_( shape, vec, dataType, context);
std::vector<const char*> vec(strings); }
return NDArrayFactory::string(order, shape, vec, context); /////////////////////////////////////////////////////////////////////////
} NDArray* NDArrayFactory::string_( const std::vector<Nd4jLong> &shape, const std::initializer_list<std::string> &string, nd4j::DataType dataType, nd4j::LaunchContext * context) {
return NDArrayFactory::string_( shape, std::vector<std::string>(string), dataType, context);
NDArray NDArrayFactory::string(char order, const std::vector<Nd4jLong> &shape, const std::vector<const char *> &strings, nd4j::LaunchContext * context) { }
std::vector<std::string> vec(strings.size()); /////////////////////////////////////////////////////////////////////////
int cnt = 0; NDArray NDArrayFactory::string( const std::vector<Nd4jLong> &shape, const std::vector<std::string> &string, nd4j::DataType dataType, nd4j::LaunchContext * context) {
for (auto s:strings) return NDArray(shape, string, dataType, context);
vec[cnt++] = std::string(s); }
/////////////////////////////////////////////////////////////////////////
return NDArrayFactory::string(order, shape, vec, context); NDArray* NDArrayFactory::string_(const std::vector<Nd4jLong> &shape, const std::vector<std::string> &string, nd4j::DataType dataType, nd4j::LaunchContext * context) {
} auto res = new NDArray();
*res = NDArray( shape, string, dataType, context);
return res;
NDArray NDArrayFactory::string(char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<std::string> &string, nd4j::LaunchContext * context) { }
std::vector<std::string> vec(string); /////////////////////////////////////////////////////////////////////////
return NDArrayFactory::string(order, shape, vec, context); NDArray NDArrayFactory::string(const std::vector<Nd4jLong>& shape, const std::initializer_list<const char16_t*>& strings, nd4j::DataType dataType, nd4j::LaunchContext* context) {
} return NDArray( shape, std::vector<const char16_t*>(strings), dataType, context);
}
NDArray* NDArrayFactory::string_(char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<const char *> &strings, nd4j::LaunchContext * context) { /////////////////////////////////////////////////////////////////////////
std::vector<const char*> vec(strings); NDArray NDArrayFactory::string( const std::vector<Nd4jLong>& shape, const std::vector<const char16_t*>& strings, nd4j::DataType dataType, nd4j::LaunchContext* context) {
return NDArrayFactory::string_(order, shape, vec, context); return NDArray( shape, strings, dataType, context);
} }
/////////////////////////////////////////////////////////////////////////
NDArray* NDArrayFactory::string_(char order, const std::vector<Nd4jLong> &shape, const std::vector<const char *> &strings, nd4j::LaunchContext * context) { NDArray NDArrayFactory::string( const std::vector<Nd4jLong>& shape, const std::initializer_list<std::u16string>& string, nd4j::DataType dataType, nd4j::LaunchContext* context) {
std::vector<std::string> vec(strings.size()); return NDArray( shape, std::vector<std::u16string>(string), dataType, context);
int cnt = 0; }
for (auto s:strings) /////////////////////////////////////////////////////////////////////////
vec[cnt++] = std::string(s); NDArray* NDArrayFactory::string_( const std::vector<Nd4jLong>& shape, const std::initializer_list<const char16_t*>& strings, nd4j::DataType dataType, nd4j::LaunchContext* context) {
return NDArrayFactory::string_( shape, std::vector<const char16_t*>(strings), dataType, context);
return NDArrayFactory::string_(order, shape, vec, context); }
} /////////////////////////////////////////////////////////////////////////
NDArray* NDArrayFactory::string_( const std::vector<Nd4jLong>& shape, const std::vector<const char16_t*>& strings, nd4j::DataType dataType, nd4j::LaunchContext* context) {
std::vector<std::u16string> vec(strings.size());
NDArray* NDArrayFactory::string_(char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<std::string> &string, nd4j::LaunchContext * context) { int cnt = 0;
std::vector<std::string> vec(string); for (auto s : strings)
return NDArrayFactory::string_(order, shape, vec, context); vec[cnt++] = std::u16string(s);
}
NDArray NDArrayFactory::string(char order, const std::vector<Nd4jLong> &shape, const std::vector<std::string> &string, nd4j::LaunchContext * context) {
if (context == nullptr)
context = nd4j::LaunchContext ::defaultContext();
auto headerLength = ShapeUtils::stringBufferHeaderRequirements(string.size());
std::vector<Nd4jLong> offsets(string.size() + 1);
Nd4jLong dataLength = 0;
for (int e = 0; e < string.size(); e++) {
offsets[e] = dataLength;
dataLength += string[e].length();
}
offsets[string.size()] = dataLength;
std::shared_ptr<DataBuffer> pBuffer = std::make_shared<DataBuffer>(headerLength + dataLength, DataType::UTF8, context->getWorkspace(), true);
NDArray res(pBuffer, ShapeDescriptor(DataType::UTF8, order, shape), context);
res.setAttached(context->getWorkspace() != nullptr);
if (res.lengthOf() != string.size())
throw std::invalid_argument("Number of strings should match length of array");
memcpy(res.buffer(), offsets.data(), offsets.size() * sizeof(Nd4jLong));
auto data = static_cast<int8_t*>(res.buffer()) + headerLength;
int resLen = res.lengthOf();
for (int e = 0; e < resLen; e++) {
auto length = offsets[e+1] - offsets[e];
auto cdata = data + offsets[e];
memcpy(cdata, string[e].c_str(), string[e].length());
}
res.tickWriteHost();
res.syncToDevice();
return res;
}
NDArray* NDArrayFactory::string_(char order, const std::vector<Nd4jLong> &shape, const std::vector<std::string> &string, nd4j::LaunchContext * context) {
auto res = new NDArray();
*res = NDArrayFactory::string(order, shape, string, context);
return res;
}
return NDArrayFactory::string_( shape, vec, dataType, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray* NDArrayFactory::string_( const std::vector<Nd4jLong>& shape, const std::initializer_list<std::u16string>& string, nd4j::DataType dataType, nd4j::LaunchContext* context) {
return NDArrayFactory::string_( shape, std::vector<std::u16string>(string), dataType, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray* NDArrayFactory::string_( const std::vector<Nd4jLong>& shape, const std::vector<std::u16string>& string, nd4j::DataType dataType, nd4j::LaunchContext* context) {
auto res = new NDArray();
*res = NDArray( shape, string, dataType, context);
return res;
}
/////////////////////////////////////////////////////////////////////////
NDArray NDArrayFactory::string( const std::vector<Nd4jLong>& shape, const std::vector<std::u16string>& string, nd4j::DataType dtype, nd4j::LaunchContext* context) {
return NDArray( shape, string, dtype, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray NDArrayFactory::string( const std::vector<Nd4jLong>& shape, const std::initializer_list<const char32_t*>& strings, nd4j::DataType dataType, nd4j::LaunchContext* context) {
return NDArray( shape, std::vector<const char32_t*>(strings), dataType, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray NDArrayFactory::string( const std::vector<Nd4jLong>& shape, const std::vector<const char32_t*>& strings, nd4j::DataType dataType, nd4j::LaunchContext* context) {
return NDArray( shape, strings, dataType, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray NDArrayFactory::string( const std::vector<Nd4jLong>& shape, const std::initializer_list<std::u32string>& string, nd4j::DataType dataType, nd4j::LaunchContext* context) {
return NDArray(shape, std::vector<std::u32string>(string), dataType, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray* NDArrayFactory::string_( const std::vector<Nd4jLong>& shape, const std::initializer_list<const char32_t*>& strings, nd4j::DataType dataType, nd4j::LaunchContext* context) {
return NDArrayFactory::string_( shape, std::vector<const char32_t*>(strings), dataType, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray* NDArrayFactory::string_( const std::vector<Nd4jLong>& shape, const std::vector<const char32_t*>& strings, nd4j::DataType dataType, nd4j::LaunchContext* context) {
std::vector<std::u32string> vec(strings.size());
int cnt = 0;
for (auto s : strings)
vec[cnt++] = std::u32string(s);
return NDArrayFactory::string_( shape, vec, dataType, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray* NDArrayFactory::string_( const std::vector<Nd4jLong>& shape, const std::initializer_list<std::u32string>& string, nd4j::DataType dataType, nd4j::LaunchContext* context) {
return NDArrayFactory::string_( shape, std::vector<std::u32string>(string), dataType, context);
}
/////////////////////////////////////////////////////////////////////////
NDArray* NDArrayFactory::string_( const std::vector<Nd4jLong>& shape, const std::vector<std::u32string>& string, nd4j::DataType dataType, nd4j::LaunchContext* context) {
auto res = new NDArray();
*res = NDArray( shape, string, dataType, context);
return res;
}
/////////////////////////////////////////////////////////////////////////
NDArray NDArrayFactory::string(const std::vector<Nd4jLong>& shape, const std::vector<std::u32string>& string, nd4j::DataType dtype, nd4j::LaunchContext* context) {
return NDArray( shape, string, dtype, context);
}
} }

View File

@ -122,7 +122,7 @@ namespace nd4j {
} }
FORCEINLINE bool DataTypeUtils::isS(nd4j::DataType dataType) { FORCEINLINE bool DataTypeUtils::isS(nd4j::DataType dataType) {
return dataType == nd4j::DataType::UTF8; return dataType == nd4j::DataType::UTF8 || dataType == nd4j::DataType::UTF16 || dataType == nd4j::DataType::UTF32;
} }
FORCEINLINE bool DataTypeUtils::isZ(nd4j::DataType dataType) { FORCEINLINE bool DataTypeUtils::isZ(nd4j::DataType dataType) {
@ -370,6 +370,10 @@ FORCEINLINE std::string DataTypeUtils::asString(DataType dataType) {
return std::string("UINT64"); return std::string("UINT64");
case UTF8: case UTF8:
return std::string("UTF8"); return std::string("UTF8");
case UTF16:
return std::string("UTF16");
case UTF32:
return std::string("UTF32");
default: default:
throw std::runtime_error("Unknown data type used"); throw std::runtime_error("Unknown data type used");
} }
@ -431,6 +435,8 @@ FORCEINLINE _CUDA_HD T DataTypeUtils::eps() {
case nd4j::DataType::UINT16: return (size_t) 2; case nd4j::DataType::UINT16: return (size_t) 2;
case nd4j::DataType::UTF8: case nd4j::DataType::UTF8:
case nd4j::DataType::UTF16:
case nd4j::DataType::UTF32:
case nd4j::DataType::INT32: case nd4j::DataType::INT32:
case nd4j::DataType::UINT32: case nd4j::DataType::UINT32:
case nd4j::DataType::HALF2: case nd4j::DataType::HALF2:
@ -455,6 +461,10 @@ FORCEINLINE _CUDA_HD T DataTypeUtils::eps() {
return nd4j::DataType::BOOL; return nd4j::DataType::BOOL;
} else if (std::is_same<T, std::string>::value) { } else if (std::is_same<T, std::string>::value) {
return nd4j::DataType::UTF8; return nd4j::DataType::UTF8;
} else if (std::is_same<T, std::u16string>::value) {
return nd4j::DataType::UTF16;
} else if (std::is_same<T, std::u32string>::value) {
return nd4j::DataType::UTF32;
} else if (std::is_same<T, float>::value) { } else if (std::is_same<T, float>::value) {
return nd4j::DataType::FLOAT32; return nd4j::DataType::FLOAT32;
} else if (std::is_same<T, float16>::value) { } else if (std::is_same<T, float16>::value) {

View File

@ -49,11 +49,10 @@ namespace nd4j {
delete[] newShape; delete[] newShape;
return NDArrayFactory::empty_(dtype, nullptr); return NDArrayFactory::empty_(dtype, nullptr);
} }
// TODO fix UTF16 and UTF32
if (dtype == UTF8) { if (dtype == UTF8) {
bool isBe = BitwiseUtils::isBE(); bool isBe = BitwiseUtils::isBE();
bool canKeep = (isBe && flatArray->byteOrder() == nd4j::graph::ByteOrder_BE) || (!isBe && flatArray->byteOrder() == nd4j::graph::ByteOrder_LE); bool canKeep = (isBe && flatArray->byteOrder() == nd4j::graph::ByteOrder_BE) || (!isBe && flatArray->byteOrder() == nd4j::graph::ByteOrder_LE);
auto order = shape::order(newShape);
std::vector<std::string> substrings(length); std::vector<std::string> substrings(length);
std::vector<Nd4jLong> shapeVector(rank); std::vector<Nd4jLong> shapeVector(rank);
@ -88,8 +87,8 @@ namespace nd4j {
delete[] offsets; delete[] offsets;
delete[] newShape; delete[] newShape;
// string order always 'c'
return NDArrayFactory::string_(order, shapeVector, substrings); return NDArrayFactory::string_(shapeVector, substrings);
} }

View File

@ -171,7 +171,10 @@ namespace nd4j {
* @param numStrings * @param numStrings
* @return * @return
*/ */
static Nd4jLong stringBufferHeaderRequirements(Nd4jLong numStrings); static FORCEINLINE Nd4jLong stringBufferHeaderRequirements(Nd4jLong numStrings) {
// we store +1 offset
return (numStrings + 1) * sizeof(Nd4jLong);
}
/* /*
* check whether arr1/arr2 is sub-array of arr2/arr1, * check whether arr1/arr2 is sub-array of arr2/arr1,

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019-2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -16,6 +17,7 @@
// //
// Created by raver119 on 20/04/18. // Created by raver119 on 20/04/18.
// @author Oleg Semeniv <oleg.semeniv@gmail.com>
// //
#ifndef LIBND4J_STRINGUTILS_H #ifndef LIBND4J_STRINGUTILS_H
@ -27,6 +29,7 @@
#include <sstream> #include <sstream>
#include <vector> #include <vector>
#include <NDArray.h> #include <NDArray.h>
#include <unicode.h>
namespace nd4j { namespace nd4j {
class ND4J_EXPORT StringUtils { class ND4J_EXPORT StringUtils {
@ -85,6 +88,55 @@ namespace nd4j {
* @return * @return
*/ */
static std::vector<std::string> split(const std::string &haystack, const std::string &delimiter); static std::vector<std::string> split(const std::string &haystack, const std::string &delimiter);
/**
* This method convert u8 string to u16
* @param const reference to input string
* @param reference to output u16string
* @return boolean status
*/
static bool u8StringToU16String(const std::string& u8, std::u16string& u16);
/**
* This method convert u8 string to u32
* @param const reference to input string
* @param reference to output u32string
* @return boolean status
*/
static bool u8StringToU32String(const std::string& u8, std::u32string& u32);
/**
* This method convert u16 string to u32
* @param const reference to input u16string
* @param reference to output u32string
* @return boolean status
*/
static bool u16StringToU32String(const std::u16string& u16, std::u32string& u32);
/**
* This method convert u16 string to u8 string
* @param const reference to input u16string
* @param reference to output string
* @return boolean status
*/
static bool u16StringToU8String(const std::u16string& u16, std::string& u8);
/**
* This method convert u32 string to u16 string
* @param const reference to input u32string
* @param reference to output u16string
* @return boolean status
*/
static bool u32StringToU16String(const std::u32string& u32, std::u16string& u16);
/**
* This method convert u32 string to u8 string
* @param const reference to input u32string
* @param reference to output string
* @return boolean status
*/
static bool u32StringToU8String(const std::u32string& u32, std::string& u8);
}; };
} }

View File

@ -1019,15 +1019,6 @@ std::vector<int> ShapeUtils::tadAxesForSimpleBroadcast(const NDArray& max, const
return numOfMinTads == 1 ? maxTadDims : std::vector<int>(); return numOfMinTads == 1 ? maxTadDims : std::vector<int>();
} }
Nd4jLong ShapeUtils::stringBufferHeaderRequirements(Nd4jLong numStrings) {
// we store +1 offset
auto base = numStrings + 1;
// since we return number of bytes...
return base * sizeof(Nd4jLong);
}
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
/* /*
bool ShapeUtils::isSubArrayCase(const NDArray& arr1, const NDArray& arr2, std::vector<int>& sameDims) { bool ShapeUtils::isSubArrayCase(const NDArray& arr1, const NDArray& arr2, std::vector<int>& sameDims) {

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019-2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -16,6 +17,7 @@
// //
// Created by raver119 on 20/04/18. // Created by raver119 on 20/04/18.
// @author Oleg Semeniv <oleg.semeniv@gmail.com>
// //
#include <helpers/StringUtils.h> #include <helpers/StringUtils.h>
@ -49,13 +51,8 @@ namespace nd4j {
if (!array.isS()) if (!array.isS())
throw nd4j::datatype_exception::build("StringUtils::byteLength expects one of String types;", array.dataType()); throw nd4j::datatype_exception::build("StringUtils::byteLength expects one of String types;", array.dataType());
uint64_t result = 0;
// our buffer stores offsets, and the last value is basically number of bytes used
auto buffer = array.bufferAsT<Nd4jLong>(); auto buffer = array.bufferAsT<Nd4jLong>();
result = buffer[array.lengthOf()]; return buffer[array.lengthOf()];
return result;
} }
std::vector<std::string> StringUtils::split(const std::string &haystack, const std::string &delimiter) { std::vector<std::string> StringUtils::split(const std::string &haystack, const std::string &delimiter) {
@ -73,4 +70,89 @@ namespace nd4j {
return output; return output;
} }
bool StringUtils::u8StringToU16String(const std::string& u8, std::u16string& u16) {
if (u8.empty())
return false;
u16.resize(unicode::offsetUtf8StringInUtf16(u8.data(), u8.size()) / sizeof(char16_t));
if (u8.size() == u16.size())
u16.assign(u8.begin(), u8.end());
else
return unicode::utf8to16(u8.data(), &u16[0], u8.size());
return true;
}
bool StringUtils::u8StringToU32String(const std::string& u8, std::u32string& u32) {
if (u8.empty())
return false;
u32.resize( unicode::offsetUtf8StringInUtf32(u8.data(), u8.size()) / sizeof(char32_t) );
if (u8.size() == u32.size())
u32.assign(u8.begin(), u8.end());
else
return unicode::utf8to32(u8.data(), &u32[0], u8.size());
return true;
}
bool StringUtils::u16StringToU32String(const std::u16string& u16, std::u32string& u32) {
if (u16.empty())
return false;
u32.resize(unicode::offsetUtf16StringInUtf32(u16.data(), u16.size()) / sizeof(char32_t));
if (u16.size() == u32.size())
u32.assign(u16.begin(), u16.end());
else
return unicode::utf16to32(u16.data(), &u32[0], u16.size());
return true;
}
bool StringUtils::u16StringToU8String(const std::u16string& u16, std::string& u8) {
if (u16.empty())
return false;
u8.resize(unicode::offsetUtf16StringInUtf8(u16.data(), u16.size()));
if (u16.size() == u8.size())
u8.assign(u16.begin(), u16.end());
else
return unicode::utf16to8(u16.data(), &u8[0], u16.size());
return true;
}
bool StringUtils::u32StringToU16String(const std::u32string& u32, std::u16string& u16) {
if (u32.empty())
return false;
u16.resize(unicode::offsetUtf32StringInUtf16(u32.data(), u32.size()) / sizeof(char16_t));
if (u32.size() == u16.size())
u16.assign(u32.begin(), u32.end());
else
return unicode::utf32to16(u32.data(), &u16[0], u32.size());
return true;
}
bool StringUtils::u32StringToU8String(const std::u32string& u32, std::string& u8) {
if (u32.empty())
return false;
u8.resize(unicode::offsetUtf32StringInUtf8(u32.data(), u32.size()));
if (u32.size() == u8.size())
u8.assign(u32.begin(), u32.end());
else
return unicode::utf32to8(u32.data(), &u8[0], u32.size());
return true;
}
} }

View File

@ -0,0 +1,456 @@
/*******************************************************************************
* Copyright (c) 2015-2020 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Oleg Semeniv <oleg.semeniv@gmail.com>
//
#include <unicode.h>
namespace nd4j {
namespace unicode {
constexpr uint32_t ONEBYTEBOUND = 0x00000080;
constexpr uint32_t TWOBYTEBOUND = 0x00000800;
constexpr uint32_t THREEBYTEBOUND = 0x00010000;
constexpr uint16_t HIGHBYTEMIN = 0xd800u;
constexpr uint16_t HIGHBYTEMAX = 0xdbffu;
constexpr uint16_t TRAILBYTEMIN = 0xdc00u;
constexpr uint16_t TRAILBYTEMAX = 0xdfffu;
constexpr uint16_t HIGHBYTEOFFSET = HIGHBYTEMIN - (0x10000 >> 10);
constexpr uint32_t BYTEOFFSET = 0x10000u - (HIGHBYTEMIN << 10) - TRAILBYTEMIN;
// Maximum valid value for a Unicode code point
constexpr uint32_t CODEPOINTMAX = 0x0010ffffu;
template<typename T>
FORCEINLINE uint8_t castToU8(const T cp) {
return static_cast<uint8_t>(0xff & cp);
}
template<typename T>
FORCEINLINE uint16_t castToU16(const T cp) {
return static_cast<uint16_t>(0xffff & cp);
}
template<typename T>
FORCEINLINE uint32_t castToU32(const T cp) {
return static_cast<uint32_t>(0xffffff & cp);
}
template<typename T>
FORCEINLINE bool isTrail(const T cp) {
return ((castToU8(cp) >> 6) == 0x2);
}
template <typename T>
FORCEINLINE bool isHighSurrogate(const T cp) {
return (cp & 0xfffffc00) == 0xd800;
}
template <typename T>
bool isLowSurrogate(const T cp) {
return (cp & 0xfffffc00) == 0xdc00;
}
template <typename T>
FORCEINLINE bool isLeadSurrogate(const T cp) {
return (cp >= HIGHBYTEMIN && cp <= HIGHBYTEMAX);
}
template <typename T>
FORCEINLINE bool isTrailSurrogate(const T cp) {
return (cp >= TRAILBYTEMIN && cp <= TRAILBYTEMAX);
}
template <typename T>
FORCEINLINE bool isSurrogateU8(const T cp) {
return (cp >= HIGHBYTEMIN && cp <= TRAILBYTEMAX);
}
template <typename T>
FORCEINLINE bool isSurrogateU16(const T cp) {
return ((cp - 0xd800u) < 2048u);
}
template <typename T>
FORCEINLINE bool isSymbolU8Valid(const T cp) {
return (cp <= CODEPOINTMAX && !isSurrogateU8(cp));
}
template <typename T>
FORCEINLINE bool isSymbolValid(const T cp) {
return (cp <= CODEPOINTMAX);
}
template <typename T>
FORCEINLINE uint32_t surrogateU32(const T& high, const T& low) {
return (high << 10) + low - 0x35fdc00;
}
template <typename T>
Nd4jLong symbolLength(const T* it) {
uint8_t lead = castToU8(*it);
if (lead < 0x80)
return 1;
else if ((lead >> 5) == 0x6)
return 2;
else if ((lead >> 4) == 0xe)
return 3;
else if ((lead >> 3) == 0x1e)
return 4;
else
return 0;
}
template <typename T>
Nd4jLong symbolLength32(const T* it) {
auto lead = castToU32(*it);
if (lead < ONEBYTEBOUND)
return 1;
else if (lead < TWOBYTEBOUND)
return 2;
else if (lead < THREEBYTEBOUND)
return 3;
else if (lead <= CODEPOINTMAX)
return 4;
else
return 0;
}
template <typename T>
Nd4jLong symbolLength16(const T* it) {
uint32_t lead = castToU16(*it);
if (!isLeadSurrogate(lead)) {
if (lead < ONEBYTEBOUND)
return 1;
else if (lead < TWOBYTEBOUND)
return 2;
else if (lead < THREEBYTEBOUND)
return 3;
else
return 0;
}
else {
return 4;
}
}
Nd4jLong offsetUtf8StringInUtf32(const void* start, const void* end) {
Nd4jLong count = 0;
for (auto it = static_cast<const int8_t*>(start); it != end; it++) {
auto length = symbolLength(it);
it += (length > 0) ? (length - 1) : 0;
count += 1;
}
return static_cast<Nd4jLong>(count * sizeof(char32_t));
}
Nd4jLong offsetUtf16StringInUtf32(const void* start, const void* end) {
Nd4jLong count = 0;
for (auto it = static_cast<const uint16_t*>(start); it != end;) {
auto length = symbolLength16(it);
it += (4 == length) ? 2 : 1;
count += 1;
}
return static_cast<Nd4jLong>(count*sizeof(char32_t));
}
Nd4jLong offsetUtf8StringInUtf16(const void* start, const void* end) {
Nd4jLong count = 0;
for (auto it = static_cast<const int8_t*>(start); it != end; it++) {
auto length = symbolLength(it);
auto step = ((length > 0) ? (length - 1) : 0);
it += step;
count += (4 == length) ? 2 : 1;
}
return static_cast<Nd4jLong>(count*sizeof(char16_t));
}
Nd4jLong offsetUtf16StringInUtf8(const void* start, const void* end) {
Nd4jLong count = 0;
for (auto it = static_cast<const uint16_t*>(start); it != end;) {
auto length = symbolLength16(it);
it += (4 == length) ? 2 : 1;
count += length;
}
return static_cast<Nd4jLong>(count);
}
Nd4jLong offsetUtf32StringInUtf16(const void* start, const void* end) {
Nd4jLong count = 0;
for (auto it = static_cast<const uint32_t*>(start); it != end; it++) {
auto length = symbolLength32(it);
count += (4 == length) ? 2 : 1;;
}
return static_cast<Nd4jLong>(count*sizeof(char16_t));
}
Nd4jLong offsetUtf32StringInUtf8(const void* start, const void* end) {
Nd4jLong count = 0;
for (auto it = static_cast<const uint32_t*>(start); it != end; it++) {
count += symbolLength32(it);
}
return count;
}
bool isStringValidU8(const void* start, const void* stop) {
for (auto it = static_cast<const int8_t*>(start); it != stop; it++) {
if (!isSymbolU8Valid( castToU8(*it) )) {
return false;
}
}
return true;
}
bool isStringValidU16(const void* start, const void* stop) {
for (auto it = static_cast<const uint16_t*>(start); it != stop; it++) {
if (!isSymbolValid( castToU32(*it) )) {
return false;
}
}
return true;
}
bool isStringValidU32(const void* start, const void* stop) {
for (auto it = static_cast<const uint32_t*>(start); it != stop; it++) {
if (!isSymbolValid( castToU32(*it) )) {
return false;
}
}
return true;
}
void* utf16to8Ptr(const void* start, const void* end, void* res) {
auto result = static_cast<int8_t*>(res);
// result have to be pre-allocated
for (auto it = static_cast<const uint16_t*>(start); it != end;) {
uint32_t cp = castToU16(*it++);
if (!isLeadSurrogate(cp)) {
if (cp < 0x80) { // for one byte
*(result++) = static_cast<uint8_t>(cp);
}
else if (cp < 0x800) { // for two bytes
*(result++) = static_cast<uint8_t>((cp >> 6) | 0xc0);
*(result++) = static_cast<uint8_t>((cp & 0x3f) | 0x80);
}
else{ // for three bytes
*(result++) = static_cast<uint8_t>((cp >> 12) | 0xe0);
*(result++) = static_cast<uint8_t>(((cp >> 6) & 0x3f) | 0x80);
*(result++) = static_cast<uint8_t>((cp & 0x3f) | 0x80);
}
}
else {
if (it != end) {
uint32_t trail_surrogate = castToU16(*it++);
if (isTrailSurrogate(trail_surrogate))
cp = (cp << 10) + trail_surrogate + BYTEOFFSET;
}
// for four bytes
*(result++) = static_cast<uint8_t>((cp >> 18) | 0xf0);
*(result++) = static_cast<uint8_t>(((cp >> 12) & 0x3f) | 0x80);
*(result++) = static_cast<uint8_t>(((cp >> 6) & 0x3f) | 0x80);
*(result++) = static_cast<uint8_t>((cp & 0x3f) | 0x80);
}
}
return result;
}
void* utf8to16Ptr(const void* start, const void* end, void* res) {
auto result = static_cast<uint16_t*>(res);
// result have to be pre-allocated
for (auto it = static_cast<const int8_t*>(start); it != end;) {
auto nLength = symbolLength(it);
uint32_t cp = castToU8(*it++);
if (4 != nLength) {
if (2 == nLength) {
cp = ((cp << 6) & 0x7ff) + ((*it++) & 0x3f);
}
else if (3 == nLength) {
cp = ((cp << 12) & 0xffff) + ((castToU8(*it++) << 6) & 0xfff);
cp += (*it++) & 0x3f;
}
*(result++) = static_cast<uint16_t>(cp);
}
else {
cp = ((cp << 18) & 0x1fffff) + ((castToU8(*it++) << 12) & 0x3ffff);
cp += (castToU8(*it++) << 6) & 0xfff;
cp += (*it++) & 0x3f;
//make a surrogate pair
*(result++) = static_cast<uint16_t>((cp >> 10) + HIGHBYTEOFFSET);
*(result++) = static_cast<uint16_t>((cp & 0x3ff) + TRAILBYTEMIN);
}
}
return result;
}
void* utf32to8Ptr( const void* start, const void* end, void* result) {
auto res = static_cast<uint8_t*>(result);
// result have to be pre-allocated
for (auto it = static_cast<const uint32_t*>(start); it != end; it++) {
if (*it < 0x80) // for one byte
*(res++) = static_cast<uint8_t>(*it);
else if (*it < 0x800) { // for two bytes
*(res++) = static_cast<uint8_t>((*it >> 6) | 0xc0);
*(res++) = static_cast<uint8_t>((*it & 0x3f) | 0x80);
}
else if (*it < 0x10000) { // for three bytes
*(res++) = static_cast<uint8_t>((*it >> 12) | 0xe0);
*(res++) = static_cast<uint8_t>(((*it >> 6) & 0x3f) | 0x80);
*(res++) = static_cast<uint8_t>((*it & 0x3f) | 0x80);
}
else { // for four bytes
*(res++) = static_cast<uint8_t>((*it >> 18) | 0xf0);
*(res++) = static_cast<uint8_t>(((*it >> 12) & 0x3f) | 0x80);
*(res++) = static_cast<uint8_t>(((*it >> 6) & 0x3f) | 0x80);
*(res++) = static_cast<uint8_t>((*it & 0x3f) | 0x80);
}
}
return result;
}
void* utf8to32Ptr(const void* start, const void* end, void* res) {
auto result = static_cast<uint32_t*>(res);
// result have to be pre-allocated
for (auto it = static_cast<const int8_t*>(start); it != end;) {
auto nLength = symbolLength(it);
uint32_t cp = castToU8(*it++);
if (2 == nLength) {
cp = ((cp << 6) & 0x7ff) + ((*it++) & 0x3f);
}
else if (3 == nLength) {
cp = ((cp << 12) & 0xffff) + ((castToU8(*it++) << 6) & 0xfff);
cp += (*it++) & 0x3f;
}
else if (4 == nLength) {
cp = ((cp << 18) & 0x1fffff) + ((castToU8(*it++) << 12) & 0x3ffff);
cp += (castToU8(*it++) << 6) & 0xfff;
cp += (*it++) & 0x3f;
}
(*result++) = cp;
}
return result;
}
void* utf16to32Ptr(const void* start, const void* end, void* res) {
auto result = static_cast<uint32_t*>(res);
// result have to be pre-allocated
for (auto it = static_cast<const uint16_t*>(start); it != end; it++) {
uint32_t cpHigh = castToU32(*it);
if (!isSurrogateU16(cpHigh)) {
*result++ = cpHigh;
}
else {
it++;
uint32_t cpLow = castToU32(*it);
if (isHighSurrogate(cpHigh) && it != end && isLowSurrogate(cpLow)) {
*result++ = surrogateU32(cpHigh, cpLow);
}
}
}
return result;
}
void* utf32to16Ptr(const void* start, const void* end, void* res) {
auto result = static_cast<uint16_t*>(res);
// result have to be pre-allocate
for (auto it = static_cast<const uint32_t*>(start); it != end; it++) {
uint32_t cpHigh = castToU32(*it);
// todo check do we need this as we have pre-validation, if yes find out how to check u16
if (cpHigh < 0 || cpHigh > 0x10FFFF || (cpHigh >= 0xD800 && cpHigh <= 0xDFFF)) {
// Invalid code point. Replace with sentinel, per Unicode standard:
*result++ = u'\uFFFD';
}
else if (cpHigh < 0x10000UL) { // In the BMP.
*result++ = static_cast<char16_t>(cpHigh);
}
else {
*result++ = static_cast<char16_t>(((cpHigh - 0x10000UL) / 0x400U) + 0xD800U);
*result++ = static_cast<char16_t>(((cpHigh - 0x10000UL) % 0x400U) + 0xDC00U);
}
}
return result;
}
Nd4jLong offsetUtf8StringInUtf32(const void* input, uint32_t nInputSize) {
return offsetUtf8StringInUtf32(input, static_cast<const int8_t*>(input) + nInputSize);
}
Nd4jLong offsetUtf16StringInUtf32(const void* input, uint32_t nInputSize) {
return offsetUtf16StringInUtf32(input, static_cast<const uint16_t*>(input) + nInputSize);
}
Nd4jLong offsetUtf8StringInUtf16(const void* input, uint32_t nInputSize) {
return offsetUtf8StringInUtf16(input, static_cast<const int8_t*>(input) + nInputSize);
}
Nd4jLong offsetUtf16StringInUtf8(const void* input, uint32_t nInputSize) {
return offsetUtf16StringInUtf8(input, static_cast<const uint16_t*>(input) + nInputSize);
}
Nd4jLong offsetUtf32StringInUtf8(const void* input, uint32_t nInputSize) {
return offsetUtf32StringInUtf8(input, static_cast<const uint32_t*>(input) + nInputSize);
}
Nd4jLong offsetUtf32StringInUtf16(const void* input, const uint32_t nInputSize) {
return offsetUtf32StringInUtf16(input, static_cast<const uint32_t*>(input) + nInputSize);
}
bool utf8to16(const void* input, void* output, uint32_t nInputSize) {
return utf8to16Ptr(input, static_cast<const int8_t*>(input) + nInputSize, output);
}
bool utf8to32(const void* input, void* output, uint32_t nInputSize) {
return utf8to32Ptr(input, static_cast<const int8_t*>(input) + nInputSize, output);
}
bool utf16to32(const void* input, void* output, uint32_t nInputSize) {
return utf16to32Ptr(input, static_cast<const uint16_t*>(input) + nInputSize, output);
}
bool utf16to8(const void* input, void* output, uint32_t nInputSize) {
return utf16to8Ptr(input, static_cast<const uint16_t*>(input) + nInputSize, output);
}
bool utf32to16(const void* input, void* output, uint32_t nInputSize) {
return utf32to16Ptr(input, static_cast<const uint32_t*>(input) + nInputSize, output);
}
bool utf32to8(const void* input, void* output, const Nd4jLong nInputSize) {
return utf32to8Ptr(input, static_cast<const uint32_t*>(input) + nInputSize, output);
}
}
}

View File

@ -0,0 +1,189 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Oleg Semeniv <oleg.semeniv@gmail.com>
//
#ifndef LIBND4J_UNICODE_H
#define LIBND4J_UNICODE_H
#include <NDArray.h>
namespace nd4j {
namespace unicode {
/**
* This method calculate u16 offset based on utf8
* @param const pointer to the utf8 string start point
* @param size of the string
* @return offset of utf16
*/
Nd4jLong offsetUtf8StringInUtf16(const void* start, const void* end);
/**
* This method calculate u8 offset based on utf16
* @param const pointer to the utf16 string start point
* @param size of the string
* @return offset of utf8
*/
Nd4jLong offsetUtf16StringInUtf8(const void* start, const void* end);
/**
* This method calculate u32 offset based on utf16
* @param const pointer to the utf16 string start point
* @param size of the string
* @return offset of utf32
*/
Nd4jLong offsetUtf32StringInUtf16(const void* start, const void* end);
/**
* This method calculate u32 offset based on utf8
* @param const pointer to the utf16 string start point
* @param size of the string
* @return offset of utf8
*/
Nd4jLong offsetUtf32StringInUtf8(const void* start, const void* end);
/*
* This function check is valid charecter in u8 string
*/
bool isStringValidU8(const void* start, const void* stop);
/*
* This function check is valid charecter in u16 string
*/
bool isStringValidU16(const void* start, const void* stop);
/*
* This function check is valid u32 charecter in string
*/
bool isStringValidU32(const void* start, const void* stop);
/**
* This method count offset for utf8 string in utf32
* @param const pointer to the utf8 string start point
* @param size of the string
* @return offset
*/
Nd4jLong offsetUtf8StringInUtf32(const void* input, uint32_t nInputSize);
/**
* This method count offset for utf8 string in utf32
* @param const pointer to the utf8 string start point
* @param const end pointer to the utf8 string
* @return offset
*/
Nd4jLong offsetUtf8StringInUtf32(const void* input, const void* stop);
/**
* This method count offset for utf32 based on utf16 string
* @param const pointer to the utf16 string start point
* @param size of the string
* @return offset
*/
Nd4jLong offsetUtf16StringInUtf32(const void* input, uint32_t nInputSize);
/**
* This method calculate offset of u16 based on utf8
* @param const pointer to the utf8 string start point
* @param size of the string
* @return offset of utf16
*/
Nd4jLong offsetUtf8StringInUtf16(const void* input, uint32_t nInputSize);
/**
* This method calculate offset of u8 based on utf16
* @param const pointer to the utf16 string start point
* @param size of the string
* @return offset of utf8
*/
Nd4jLong offsetUtf16StringInUtf8(const void* input, uint32_t nInputSize);
/**
* This method calculate offset of u32 based on utf8
* @param const pointer to the utf16 string start point
* @param size of the string
* @return offset of utf32
*/
Nd4jLong offsetUtf32StringInUtf8(const void* input, uint32_t nInputSize);
/**
* This method calculate offset of u32 based on utf16
* @param const pointer to the utf16 string start point
* @param size of the string
* @return offset of utf32
*/
Nd4jLong offsetUtf32StringInUtf16(const void* input, const uint32_t nInputSize);
/**
* This method convert utf8 string to utf16 string
* @param const pointer to the utf8 string start point
* @param reference to start point to utf16
* @param size of input utf8 string
* @return status of convertion
*/
bool utf8to16(const void* input, void* output, uint32_t nInputSize);
/**
* This method convert utf8 string to utf32 string
* @param const pointer to the utf8 string start point
* @param reference to start point to utf32
* @param size of input utf8 string
* @return status of convertion
*/
bool utf8to32(const void* input, void* output, uint32_t nInputSize);
/**
* This method convert utf16 string to utf32 string
* @param const pointer to the utf16 string start point
* @param reference to start point to utf32
* @param size of input utf16 string
* @return status of convertion
*/
bool utf16to32(const void* input, void* output, uint32_t nInputSize);
/**
* This method convert utf16 string to utf8 string
* @param const pointer to the utf16 string start point
* @param reference to start point to utf8
* @param size of input utf16 string
* @return status of convertion
*/
bool utf16to8(const void* input, void* output, uint32_t nInputSize);
/**
* This method convert utf32 string to utf16 string
* @param const pointer to the utf32 string start point
* @param reference to start point to utf16
* @param size of input utf32 string
* @return status of convertion
*/
bool utf32to16(const void* input, void* output, uint32_t nInputSize);
/**
* This method convert utf32 string to utf8 string
* @param const pointer to the utf32 string start point
* @param reference to start point to utf8
* @param size of input utf32 string
* @return status of convertion
*/
bool utf32to8(const void* input, void* output, const Nd4jLong nInputSize);
}
}
#endif //LIBND4J_UNICODE_H

View File

@ -118,7 +118,7 @@ namespace ops {
DECLARE_TYPES(Pow_bp) { DECLARE_TYPES(Pow_bp) {
getOpDescriptor() getOpDescriptor()
->setAllowedInputTypes({ ALL_FLOATS, ALL_INTS }) ->setAllowedInputTypes({ ALL_FLOATS, ALL_INTS })
->setAllowedOutputTypes({ ALL_FLOATS }); // TODO maybe wourth to add ALL_INTS ->setAllowedOutputTypes({ ALL_FLOATS });
} }
} }

View File

@ -81,7 +81,7 @@ namespace nd4j {
} }
// now once we have all strings in single vector time to fill // now once we have all strings in single vector time to fill
auto tmp = NDArrayFactory::string('c', {(Nd4jLong) strings.size()}, strings); auto tmp = NDArrayFactory::string({(Nd4jLong) strings.size()}, strings);
auto blen = StringUtils::byteLength(tmp) + ShapeUtils::stringBufferHeaderRequirements(strings.size()); auto blen = StringUtils::byteLength(tmp) + ShapeUtils::stringBufferHeaderRequirements(strings.size());
// for CUDA mostly // for CUDA mostly

View File

@ -33,6 +33,11 @@
#include <type_boilerplate.h> #include <type_boilerplate.h>
#define LIBND4J_STRINGTYPES \
(nd4j::DataType::UTF8, std::string),\
(nd4j::DataType::UTF16, std::u16string), \
(nd4j::DataType::UTF32, std::u32string)
#define LIBND4J_TYPES \ #define LIBND4J_TYPES \
(nd4j::DataType::BFLOAT16, bfloat16),\ (nd4j::DataType::BFLOAT16, bfloat16),\
(nd4j::DataType::HALF, float16), \ (nd4j::DataType::HALF, float16), \

View File

@ -599,7 +599,7 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_2) {
TEST_F(BroadcastableOpsTests, broadcast_empty_3) { TEST_F(BroadcastableOpsTests, broadcast_empty_3) {
NDArray x = NDArrayFactory::create<float>('c', {1, 0, 2}); NDArray x = NDArrayFactory::create<float>('c', {1, 0, 2});
NDArray y('c', {}, {0.1}, nd4j::DataType::FLOAT32); NDArray y('c', {}, std::vector<double>{0.1}, nd4j::DataType::FLOAT32);
NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2});; NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2});;
nd4j::ops::maximum op; nd4j::ops::maximum op;

View File

@ -626,7 +626,7 @@ TEST_F(ConvolutionTests2, deconv3d_bp_test1) {
NDArray expGradI('c', {bS, oD, oH, oW, oC}, {62., 67.6, 68.4, 74.8, 81.2, 89.2, 87.6, 96.4, 119.6, 132.4, 126., 139.6, 138.8, 154., 145.2, 161.2}, nd4j::DataType::FLOAT32); NDArray expGradI('c', {bS, oD, oH, oW, oC}, {62., 67.6, 68.4, 74.8, 81.2, 89.2, 87.6, 96.4, 119.6, 132.4, 126., 139.6, 138.8, 154., 145.2, 161.2}, nd4j::DataType::FLOAT32);
NDArray expGradW('c', {kD, kH, kW, iC, oC}, {28., 28., 32., 32., 40., 40., 44., 44., 64, 64., 68., 68., 76., 76., 80., 80.}, nd4j::DataType::FLOAT32); NDArray expGradW('c', {kD, kH, kW, iC, oC}, {28., 28., 32., 32., 40., 40., 44., 44., 64, 64., 68., 68., 76., 76., 80., 80.}, nd4j::DataType::FLOAT32);
NDArray expGradB('c', {iC}, {364.5}, nd4j::DataType::FLOAT32); NDArray expGradB('c', {iC}, std::vector<double>{364.5}, nd4j::DataType::FLOAT32);
input = 0.5; input = 0.5;
weights.linspace(0.1, 0.1); weights.linspace(0.1, 0.1);

View File

@ -132,11 +132,11 @@ TEST_F(CudaBasicsTests1, execIndexReduceScalar_1) {
NDArray x2('c', {2,2}, {0.5, 1.5, -4.5, 3.5}, nd4j::DataType::BFLOAT16); NDArray x2('c', {2,2}, {0.5, 1.5, -4.5, 3.5}, nd4j::DataType::BFLOAT16);
NDArray x3('c', {2,2}, {0, -1, 0, 1}, nd4j::DataType::BOOL); NDArray x3('c', {2,2}, {0, -1, 0, 1}, nd4j::DataType::BOOL);
NDArray scalar('c', {}, {0}, nd4j::DataType::INT64); NDArray scalar('c', {}, std::vector<double>{0}, nd4j::DataType::INT64);
NDArray exp1('c', {}, {3}, nd4j::DataType::INT64); NDArray exp1('c', {}, std::vector<double>{3}, nd4j::DataType::INT64);
NDArray exp2('c', {}, {2}, nd4j::DataType::INT64); NDArray exp2('c', {}, std::vector<double>{2}, nd4j::DataType::INT64);
NDArray exp3('c', {}, {1}, nd4j::DataType::INT64); NDArray exp3('c', {}, std::vector<double>{1}, nd4j::DataType::INT64);
void *dX1, *dX2, *dX3, *dZ; void *dX1, *dX2, *dX3, *dZ;
Nd4jLong *dX1ShapeInfo, *dX2ShapeInfo, *dX3ShapeInfo, *dZShapeInfo; Nd4jLong *dX1ShapeInfo, *dX2ShapeInfo, *dX3ShapeInfo, *dZShapeInfo;
@ -262,11 +262,11 @@ TEST_F(CudaBasicsTests1, execReduce3Scalar_1) {
NDArray x3('c', {2,2}, {1.5,1.5,1.5,1.5}, nd4j::DataType::DOUBLE); NDArray x3('c', {2,2}, {1.5,1.5,1.5,1.5}, nd4j::DataType::DOUBLE);
NDArray x4('c', {2,2}, {1,2,3,4}, nd4j::DataType::DOUBLE); NDArray x4('c', {2,2}, {1,2,3,4}, nd4j::DataType::DOUBLE);
NDArray exp1('c', {}, {-30.f}, nd4j::DataType::FLOAT32); NDArray exp1('c', {}, std::vector<double>{-30.f}, nd4j::DataType::FLOAT32);
NDArray exp2('c', {}, {15.}, nd4j::DataType::DOUBLE); NDArray exp2('c', {}, std::vector<double>{15.}, nd4j::DataType::DOUBLE);
NDArray scalar1('c', {}, {100.f}, nd4j::DataType::FLOAT32); NDArray scalar1('c', {}, std::vector<double>{100.f}, nd4j::DataType::FLOAT32);
NDArray scalar2('c', {}, {100.}, nd4j::DataType::DOUBLE); NDArray scalar2('c', {}, std::vector<double>{100.}, nd4j::DataType::DOUBLE);
void *dX1, *dX2, *dX3, *dX4, *dZ1, *dZ2; void *dX1, *dX2, *dX3, *dX4, *dZ1, *dZ2;
Nd4jLong *dX1ShapeInfo, *dX3ShapeInfo, *dZ1ShapeInfo, *dZ2ShapeInfo; Nd4jLong *dX1ShapeInfo, *dX3ShapeInfo, *dZ1ShapeInfo, *dZ2ShapeInfo;
@ -363,8 +363,8 @@ TEST_F(CudaBasicsTests1, execReduce3_1) {
NDArray x('c', {2,2}, {1,2,3,4}, nd4j::DataType::INT32); NDArray x('c', {2,2}, {1,2,3,4}, nd4j::DataType::INT32);
NDArray y('c', {2,2}, {-1,-2,-3,-4}, nd4j::DataType::INT32); NDArray y('c', {2,2}, {-1,-2,-3,-4}, nd4j::DataType::INT32);
NDArray exp('c', {}, {-30.f}, nd4j::DataType::FLOAT32); NDArray exp('c', {}, std::vector<double>{-30.f}, nd4j::DataType::FLOAT32);
NDArray z('c', {}, {100.f}, nd4j::DataType::FLOAT32); NDArray z('c', {}, std::vector<double>{100.f}, nd4j::DataType::FLOAT32);
std::vector<int> dimensions = {0, 1}; std::vector<int> dimensions = {0, 1};
@ -415,8 +415,8 @@ TEST_F(CudaBasicsTests1, execReduce3_2) {
NDArray x('c', {2,2}, {1.5,1.5,1.5,1.5}, nd4j::DataType::DOUBLE); NDArray x('c', {2,2}, {1.5,1.5,1.5,1.5}, nd4j::DataType::DOUBLE);
NDArray y('c', {2,2}, {1,2,3,4}, nd4j::DataType::DOUBLE); NDArray y('c', {2,2}, {1,2,3,4}, nd4j::DataType::DOUBLE);
NDArray exp('c', {}, {15.}, nd4j::DataType::DOUBLE); NDArray exp('c', {}, std::vector<double>{15.}, nd4j::DataType::DOUBLE);
NDArray z('c', {}, {100.}, nd4j::DataType::DOUBLE); NDArray z('c', {}, std::vector<double>{100.}, nd4j::DataType::DOUBLE);
std::vector<int> dimensions = {0, 1}; std::vector<int> dimensions = {0, 1};
@ -975,7 +975,7 @@ TEST_F(CudaBasicsTests1, execScalar_1) {
NDArray x('c', {2,3}, {0,1,2,3,4,5}, nd4j::DataType::INT64); NDArray x('c', {2,3}, {0,1,2,3,4,5}, nd4j::DataType::INT64);
NDArray exp('c',{2,3}, {0,0,1,1,2,2}, nd4j::DataType::INT64); NDArray exp('c',{2,3}, {0,0,1,1,2,2}, nd4j::DataType::INT64);
NDArray scalar('c',{}, {2.f}, nd4j::DataType::FLOAT32); NDArray scalar('c',{}, std::vector<double>{2.f}, nd4j::DataType::FLOAT32);
NDArray z('c', {2,3}, {100,100,100,100,100,100}, nd4j::DataType::INT64); NDArray z('c', {2,3}, {100,100,100,100,100,100}, nd4j::DataType::INT64);
// create cuda stream and LaunchContext // create cuda stream and LaunchContext
@ -1010,7 +1010,7 @@ TEST_F(CudaBasicsTests1, execScalar_2) {
NDArray x('c', {2,3}, {-1,-2,-3,-4,-5,-6}, nd4j::DataType::INT64); NDArray x('c', {2,3}, {-1,-2,-3,-4,-5,-6}, nd4j::DataType::INT64);
NDArray exp('c',{2,3}, {10,10,10,10,10,10}, nd4j::DataType::FLOAT32); NDArray exp('c',{2,3}, {10,10,10,10,10,10}, nd4j::DataType::FLOAT32);
NDArray scalar('c',{}, {10.f}, nd4j::DataType::FLOAT32); NDArray scalar('c',{}, std::vector<double>{10.f}, nd4j::DataType::FLOAT32);
NDArray z('c', {2,3}, {100,100,100,100,100,100}, nd4j::DataType::FLOAT32); NDArray z('c', {2,3}, {100,100,100,100,100,100}, nd4j::DataType::FLOAT32);
// create cuda stream and LaunchContext // create cuda stream and LaunchContext
@ -1103,7 +1103,7 @@ TEST_F(CudaBasicsTests1, execScalar_3) {
TEST_F(CudaBasicsTests1, execScalarBool_1) { TEST_F(CudaBasicsTests1, execScalarBool_1) {
NDArray x('c', {2,3}, {-1,-2,0,1,2,3}, nd4j::DataType::BFLOAT16); NDArray x('c', {2,3}, {-1,-2,0,1,2,3}, nd4j::DataType::BFLOAT16);
NDArray scalar('c',{}, {0}, nd4j::DataType::BFLOAT16); NDArray scalar('c',{}, std::vector<double>{0}, nd4j::DataType::BFLOAT16);
NDArray exp('c',{2,3}, {0,0,0,1,1,1}, nd4j::DataType::BOOL); NDArray exp('c',{2,3}, {0,0,0,1,1,1}, nd4j::DataType::BOOL);
NDArray z('c', {2,3}, {100,100,100,100,100,100,}, nd4j::DataType::BOOL); NDArray z('c', {2,3}, {100,100,100,100,100,100,}, nd4j::DataType::BOOL);
@ -2245,8 +2245,8 @@ TEST_F(CudaBasicsTests1, execReduceLong_2) {
TEST_F(CudaBasicsTests1, execReduceFloatScalar_1) { TEST_F(CudaBasicsTests1, execReduceFloatScalar_1) {
NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, nd4j::DataType::INT32); NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, nd4j::DataType::INT32);
NDArray z('c', {}, {100}, nd4j::DataType::FLOAT32); NDArray z('c', {}, std::vector<double>{100}, nd4j::DataType::FLOAT32);
NDArray exp('c', {}, {6.5}, nd4j::DataType::FLOAT32); NDArray exp('c', {}, std::vector<double>{6.5}, nd4j::DataType::FLOAT32);
x.permutei({2,1,0}); x.permutei({2,1,0});
// create cuda stream and LaunchContext // create cuda stream and LaunchContext
@ -2282,8 +2282,8 @@ TEST_F(CudaBasicsTests1, execReduceFloatScalar_1) {
TEST_F(CudaBasicsTests1, execReduceFloatScalar_2) { TEST_F(CudaBasicsTests1, execReduceFloatScalar_2) {
NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, nd4j::DataType::INT32); NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, nd4j::DataType::INT32);
NDArray z('c', {}, {100}, nd4j::DataType::DOUBLE); NDArray z('c', {}, std::vector<double>{100}, nd4j::DataType::DOUBLE);
NDArray exp('c', {}, {6.5}, nd4j::DataType::DOUBLE); NDArray exp('c', {}, std::vector<double>{6.5}, nd4j::DataType::DOUBLE);
// create cuda stream and LaunchContext // create cuda stream and LaunchContext
cudaError_t cudaResult; cudaError_t cudaResult;
@ -2318,8 +2318,8 @@ TEST_F(CudaBasicsTests1, execReduceFloatScalar_2) {
TEST_F(CudaBasicsTests1, execReduceSameScalar_1) { TEST_F(CudaBasicsTests1, execReduceSameScalar_1) {
NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, nd4j::DataType::INT32); NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, nd4j::DataType::INT32);
NDArray z('c', {}, {100}, nd4j::DataType::INT32); NDArray z('c', {}, std::vector<double>{100}, nd4j::DataType::INT32);
NDArray exp('c', {}, {156}, nd4j::DataType::INT32); NDArray exp('c', {}, std::vector<double>{156}, nd4j::DataType::INT32);
x.permutei({2,1,0}); x.permutei({2,1,0});
// create cuda stream and LaunchContext // create cuda stream and LaunchContext
@ -2355,8 +2355,8 @@ TEST_F(CudaBasicsTests1, execReduceSameScalar_1) {
TEST_F(CudaBasicsTests1, execReduceSameScalar_2) { TEST_F(CudaBasicsTests1, execReduceSameScalar_2) {
NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, nd4j::DataType::DOUBLE); NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, nd4j::DataType::DOUBLE);
NDArray z('c', {}, {100}, nd4j::DataType::DOUBLE); NDArray z('c', {}, std::vector<double>{100}, nd4j::DataType::DOUBLE);
NDArray exp('c', {}, {156}, nd4j::DataType::DOUBLE); NDArray exp('c', {}, std::vector<double>{156}, nd4j::DataType::DOUBLE);
// create cuda stream and LaunchContext // create cuda stream and LaunchContext
cudaError_t cudaResult; cudaError_t cudaResult;
@ -2391,8 +2391,8 @@ TEST_F(CudaBasicsTests1, execReduceSameScalar_2) {
TEST_F(CudaBasicsTests1, execReduceBoolScalar_1) { TEST_F(CudaBasicsTests1, execReduceBoolScalar_1) {
NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,-7,-8,-9,-10,-11,-12,-13,-14,-15,-16,-17,-18}, nd4j::DataType::INT32); NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,-7,-8,-9,-10,-11,-12,-13,-14,-15,-16,-17,-18}, nd4j::DataType::INT32);
NDArray z('c', {}, {100}, nd4j::DataType::BOOL); NDArray z('c', {}, std::vector<double>{100}, nd4j::DataType::BOOL);
NDArray exp('c', {}, {1}, nd4j::DataType::BOOL); NDArray exp('c', {}, std::vector<double>{1}, nd4j::DataType::BOOL);
x.permutei({2,1,0}); x.permutei({2,1,0});
x.syncShape(); x.syncShape();
@ -2429,8 +2429,8 @@ TEST_F(CudaBasicsTests1, execReduceBoolScalar_1) {
TEST_F(CudaBasicsTests1, execReduceBoolScalar_2) { TEST_F(CudaBasicsTests1, execReduceBoolScalar_2) {
NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,-7,-8,-9,-10,-11,-12,-13,-14,-15,-16,-17,-18}, nd4j::DataType::DOUBLE); NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,-7,-8,-9,-10,-11,-12,-13,-14,-15,-16,-17,-18}, nd4j::DataType::DOUBLE);
NDArray z('c', {}, {100}, nd4j::DataType::BOOL); NDArray z('c', {}, std::vector<double>{100}, nd4j::DataType::BOOL);
NDArray exp('c', {}, {1}, nd4j::DataType::BOOL); NDArray exp('c', {}, std::vector<double>{1}, nd4j::DataType::BOOL);
// create cuda stream and LaunchContext // create cuda stream and LaunchContext
cudaError_t cudaResult; cudaError_t cudaResult;
@ -2465,8 +2465,8 @@ TEST_F(CudaBasicsTests1, execReduceBoolScalar_2) {
TEST_F(CudaBasicsTests1, execReduceLongScalar_1) { TEST_F(CudaBasicsTests1, execReduceLongScalar_1) {
NDArray x('c', {2,3,4}, {-5,0,-3,0,-1,0,1,2,3,4,5,6,7,0,9,10,11,0,13,14,0,16,0,18}, nd4j::DataType::INT32); NDArray x('c', {2,3,4}, {-5,0,-3,0,-1,0,1,2,3,4,5,6,7,0,9,10,11,0,13,14,0,16,0,18}, nd4j::DataType::INT32);
NDArray z('c', {}, {100}, nd4j::DataType::INT64); NDArray z('c', {}, std::vector<double>{100}, nd4j::DataType::INT64);
NDArray exp('c', {}, {17}, nd4j::DataType::INT64); NDArray exp('c', {}, std::vector<double>{17}, nd4j::DataType::INT64);
x.permutei({2,1,0}); x.permutei({2,1,0});
x.syncShape(); x.syncShape();
@ -2503,8 +2503,8 @@ TEST_F(CudaBasicsTests1, execReduceLongScalar_1) {
TEST_F(CudaBasicsTests1, execReduceLongScalar_2) { TEST_F(CudaBasicsTests1, execReduceLongScalar_2) {
NDArray x('c', {2,3,4}, {-5,0,-3,0,-1,0,1,2,3,4,5,6,7,0,9,10,11,0,13,14,0,16,0,18}, nd4j::DataType::DOUBLE); NDArray x('c', {2,3,4}, {-5,0,-3,0,-1,0,1,2,3,4,5,6,7,0,9,10,11,0,13,14,0,16,0,18}, nd4j::DataType::DOUBLE);
NDArray z('c', {}, {100}, nd4j::DataType::INT64); NDArray z('c', {}, std::vector<double>{100}, nd4j::DataType::INT64);
NDArray exp('c', {}, {17}, nd4j::DataType::INT64); NDArray exp('c', {}, std::vector<double>{17}, nd4j::DataType::INT64);
// create cuda stream and LaunchContext // create cuda stream and LaunchContext
cudaError_t cudaResult; cudaError_t cudaResult;
@ -2685,8 +2685,8 @@ TEST_F(CudaBasicsTests1, execReduce3TAD_4) {
NDArray x('c', {2,2,3}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6}, nd4j::DataType::DOUBLE); NDArray x('c', {2,2,3}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6}, nd4j::DataType::DOUBLE);
NDArray y('c', {2,2,3}, {10,20,30,40,50,60,70,80,90,100,110,120}, nd4j::DataType::DOUBLE); NDArray y('c', {2,2,3}, {10,20,30,40,50,60,70,80,90,100,110,120}, nd4j::DataType::DOUBLE);
NDArray exp('c', {}, {1820}, nd4j::DataType::FLOAT32); NDArray exp('c', {}, std::vector<double>{1820}, nd4j::DataType::FLOAT32);
NDArray z('c', {}, {100}, nd4j::DataType::FLOAT32); NDArray z('c', {}, std::vector<double>{100}, nd4j::DataType::FLOAT32);
std::vector<int> dimensions = {0,1,2}; std::vector<int> dimensions = {0,1,2};
@ -2739,8 +2739,8 @@ TEST_F(CudaBasicsTests1, execReduce3TAD_4) {
TEST_F(CudaBasicsTests1, execSummaryStats_1) { TEST_F(CudaBasicsTests1, execSummaryStats_1) {
NDArray x('c', {2,2,3}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6}, nd4j::DataType::INT64); NDArray x('c', {2,2,3}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6}, nd4j::DataType::INT64);
NDArray exp('c', {}, {3.605551}, nd4j::DataType::FLOAT32); NDArray exp('c', {}, std::vector<double>{3.605551}, nd4j::DataType::FLOAT32);
NDArray z('c', {}, {100}, nd4j::DataType::FLOAT32); NDArray z('c', {}, std::vector<double>{100}, nd4j::DataType::FLOAT32);
// create cuda stream and LaunchContext // create cuda stream and LaunchContext
cudaError_t cudaResult; cudaError_t cudaResult;
@ -2881,8 +2881,8 @@ TEST_F(CudaBasicsTests1, execSummaryStats_3) {
TEST_F(CudaBasicsTests1, execSummaryStatsScalar_1) { TEST_F(CudaBasicsTests1, execSummaryStatsScalar_1) {
NDArray x('c', {2,2,3}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6}, nd4j::DataType::INT64); NDArray x('c', {2,2,3}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6}, nd4j::DataType::INT64);
NDArray exp('c', {}, {3.605551}, nd4j::DataType::FLOAT32); NDArray exp('c', {}, std::vector<double>{3.605551}, nd4j::DataType::FLOAT32);
NDArray z('c', {}, {100}, nd4j::DataType::FLOAT32); NDArray z('c', {}, std::vector<double>{100}, nd4j::DataType::FLOAT32);
// create cuda stream and LaunchContext // create cuda stream and LaunchContext
cudaError_t cudaResult; cudaError_t cudaResult;

View File

@ -775,7 +775,7 @@ TEST_F(DeclarableOpsTests10, sparse_softmax_cross_entropy_loss_with_logits_test2
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, sparse_softmax_cross_entropy_loss_with_logits_test3) { TEST_F(DeclarableOpsTests10, sparse_softmax_cross_entropy_loss_with_logits_test3) {
NDArray labels('c', {1}, {0}, nd4j::DataType::INT32); NDArray labels('c', {1}, std::vector<double>{0}, nd4j::DataType::INT32);
auto logits = NDArrayFactory::create<double>('c', {1,3}); auto logits = NDArrayFactory::create<double>('c', {1,3});
auto expected = NDArrayFactory::create<double>('c', {1}, {1.20194}); auto expected = NDArrayFactory::create<double>('c', {1}, {1.20194});
@ -2735,7 +2735,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_3) {
NDArray images ('c', {1,2,2,1}, {1,2,3,4}, nd4j::DataType::FLOAT32); NDArray images ('c', {1,2,2,1}, {1,2,3,4}, nd4j::DataType::FLOAT32);
NDArray boxes('c', {1,4}, {0,0,1,1}, nd4j::DataType::FLOAT32); NDArray boxes('c', {1,4}, {0,0,1,1}, nd4j::DataType::FLOAT32);
NDArray boxI('c', {1}, {0}, nd4j::DataType::INT64); NDArray boxI('c', {1}, std::vector<double>{0}, nd4j::DataType::INT64);
NDArray cropSize = NDArrayFactory::create<Nd4jLong>({3, 3}); NDArray cropSize = NDArrayFactory::create<Nd4jLong>({3, 3});
//NDArray<float> ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); //NDArray<float> ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f});
@ -2759,7 +2759,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_4) {
NDArray images('c', {1,2,2,1}, {1, 2, 3, 4}, nd4j::DataType::FLOAT32); NDArray images('c', {1,2,2,1}, {1, 2, 3, 4}, nd4j::DataType::FLOAT32);
NDArray boxes('c', {1,4}, {0,0,1,1}, nd4j::DataType::FLOAT32); NDArray boxes('c', {1,4}, {0,0,1,1}, nd4j::DataType::FLOAT32);
NDArray boxI('c', {1}, {0}, nd4j::DataType::INT32); NDArray boxI('c', {1}, std::vector<double>({0.}), nd4j::DataType::INT32);
NDArray cropSize = NDArrayFactory::create<int>({3, 3}); NDArray cropSize = NDArrayFactory::create<int>({3, 3});
//NDArray<float> ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); //NDArray<float> ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f});
@ -2933,8 +2933,8 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_1) {
NDArray x('c', {2,3}, {-63.80f, -63.75f, -63.70f, -63.5f, 0.0f, 0.1f}, nd4j::DataType::FLOAT32); NDArray x('c', {2,3}, {-63.80f, -63.75f, -63.70f, -63.5f, 0.0f, 0.1f}, nd4j::DataType::FLOAT32);
NDArray exp('c', {2,3}, {-63.75f, -63.75f, -63.75f, -63.5f, 0.f, 0.f}, nd4j::DataType::FLOAT32); NDArray exp('c', {2,3}, {-63.75f, -63.75f, -63.75f, -63.5f, 0.f, 0.f}, nd4j::DataType::FLOAT32);
NDArray min('c', {}, {-63.65f}, nd4j::DataType::FLOAT32); NDArray min('c', {}, std::vector<double>{-63.65f}, nd4j::DataType::FLOAT32);
NDArray max('c', {}, {0.1f}, nd4j::DataType::FLOAT32); NDArray max('c', {}, std::vector<double>{0.1f}, nd4j::DataType::FLOAT32);
nd4j::ops::fake_quant_with_min_max_vars op; nd4j::ops::fake_quant_with_min_max_vars op;
auto results = op.evaluate({&x, &min, &max}, {}, {}); auto results = op.evaluate({&x, &min, &max}, {}, {});

View File

@ -121,7 +121,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test3) {
NDArray dLdpExp('c', {2,3,4}, {-12.49997,-13.04346, -13.63635, -14.28571,-14.99999,-15.78947, -16.66666, -17.64705,-18.75 ,-20. , -21.42857, -23.07692, NDArray dLdpExp('c', {2,3,4}, {-12.49997,-13.04346, -13.63635, -14.28571,-14.99999,-15.78947, -16.66666, -17.64705,-18.75 ,-20. , -21.42857, -23.07692,
-24.99999,-27.27272, -29.99999, -33.33332,-37.49999,-42.85713, -49.99998, -59.99998,-74.99995,-99.99992,-149.99986,-299.99911}); -24.99999,-27.27272, -29.99999, -33.33332,-37.49999,-42.85713, -49.99998, -59.99998,-74.99995,-99.99992,-149.99986,-299.99911});
NDArray dLdwExp('c', {}, {-227.77286}); NDArray dLdwExp('c', {}, std::vector<double>{-227.77286});
NDArray dLdlExp('c', {2,3,4}, {1.58903, 1.22117, 0.99621, 0.82911, 0.69315, 0.57634, 0.47223, 0.37689, 0.28768, 0.20273, 0.12058, 0.04002, NDArray dLdlExp('c', {2,3,4}, {1.58903, 1.22117, 0.99621, 0.82911, 0.69315, 0.57634, 0.47223, 0.37689, 0.28768, 0.20273, 0.12058, 0.04002,
-0.04002,-0.12058,-0.20273,-0.28768,-0.37689,-0.47223,-0.57634,-0.69315,-0.82911,-0.99621,-1.22117,-1.58903}); -0.04002,-0.12058,-0.20273,-0.28768,-0.37689,-0.47223,-0.57634,-0.69315,-0.82911,-0.99621,-1.22117,-1.58903});
@ -246,7 +246,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test7) {
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE); NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights(nd4j::DataType::DOUBLE); NDArray weights(nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {}, {0.}); NDArray dLdwExp('c', {}, std::vector<double>{0.});
predictions.linspace(0.04, 0.04); predictions.linspace(0.04, 0.04);
labels.linspace(1); labels.linspace(1);
@ -350,7 +350,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test10) {
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE); NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,1}, nd4j::DataType::DOUBLE); NDArray weights('c', {1,1}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {1,1}, {-9.49054}); NDArray dLdwExp('c', {1,1}, std::vector<double>{-9.49054});
predictions.linspace(0.04, 0.04); predictions.linspace(0.04, 0.04);
labels.linspace(1); labels.linspace(1);
@ -1611,7 +1611,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test3) {
NDArray dLdpExp('c', {2,3,4}, {-0.96, -1.92, -2.88, -3.84, -4.8 , -5.76, -6.72, -7.68, -8.64, -9.6 ,-10.56,-11.52, NDArray dLdpExp('c', {2,3,4}, {-0.96, -1.92, -2.88, -3.84, -4.8 , -5.76, -6.72, -7.68, -8.64, -9.6 ,-10.56,-11.52,
-12.48,-13.44,-14.4 ,-15.36,-16.32,-17.28,-18.24,-19.2 ,-20.16,-21.12,-22.08,-23.04}); -12.48,-13.44,-14.4 ,-15.36,-16.32,-17.28,-18.24,-19.2 ,-20.16,-21.12,-22.08,-23.04});
NDArray dLdwExp('c', {}, {4515.84}); NDArray dLdwExp('c', {}, std::vector<double>{4515.84});
predictions.linspace(0.04, 0.04); predictions.linspace(0.04, 0.04);
labels.linspace(1); labels.linspace(1);
@ -1730,7 +1730,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test7) {
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE); NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights(nd4j::DataType::DOUBLE); NDArray weights(nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {}, {0.}); NDArray dLdwExp('c', {}, std::vector<double>{0.});
predictions.linspace(0.04, 0.04); predictions.linspace(0.04, 0.04);
labels.linspace(1); labels.linspace(1);
@ -1830,7 +1830,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test10) {
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE); NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,1}, nd4j::DataType::DOUBLE); NDArray weights('c', {1,1}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {1,1}, {188.16}); NDArray dLdwExp('c', {1,1}, std::vector<double>{188.16});
predictions.linspace(0.04, 0.04); predictions.linspace(0.04, 0.04);
labels.linspace(1); labels.linspace(1);
@ -2056,7 +2056,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test3) {
NDArray dLdpExp('c', {2,3,4}, {-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5, NDArray dLdpExp('c', {2,3,4}, {-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,
-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5}); -0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5});
NDArray dLdwExp('c', {}, {288.}); NDArray dLdwExp('c', {}, std::vector<double>{288.});
predictions.linspace(0.04, 0.04); predictions.linspace(0.04, 0.04);
labels.linspace(1); labels.linspace(1);
@ -2175,7 +2175,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test7) {
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE); NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights(nd4j::DataType::DOUBLE); NDArray weights(nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {}, {0.}); NDArray dLdwExp('c', {}, std::vector<double>{0.});
predictions.linspace(0.04, 0.04); predictions.linspace(0.04, 0.04);
labels.linspace(1); labels.linspace(1);
@ -2275,7 +2275,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test10) {
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE); NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,1}, nd4j::DataType::DOUBLE); NDArray weights('c', {1,1}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {1,1}, {12.}); NDArray dLdwExp('c', {1,1}, std::vector<double>{12.});
predictions.linspace(0.04, 0.04); predictions.linspace(0.04, 0.04);
labels.linspace(1); labels.linspace(1);
@ -2541,7 +2541,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test3) {
NDArray dLdpExp('c', {2,3,4}, {-0.18499,-0.53 ,-0.875 ,-1.22 ,-1.56501,-1.91002,-2.25504,-2.60008,-2.94514,-3.29023,-3.63534,-3.98048, NDArray dLdpExp('c', {2,3,4}, {-0.18499,-0.53 ,-0.875 ,-1.22 ,-1.56501,-1.91002,-2.25504,-2.60008,-2.94514,-3.29023,-3.63534,-3.98048,
-4.32566,-4.67087,-5.01613,-5.36143,-5.70677,-6.05217,-6.39762,-6.74313,-7.0887 ,-7.43432,-7.78001,-8.12577}); -4.32566,-4.67087,-5.01613,-5.36143,-5.70677,-6.05217,-6.39762,-6.74313,-7.0887 ,-7.43432,-7.78001,-8.12577});
NDArray dLdwExp('c', {}, {-91.52109}); NDArray dLdwExp('c', {}, std::vector<double>{-91.52109});
NDArray dLdlExp('c', {2,3,4}, {0.028, 0.014, -0., -0.014,-0.028, -0.042, -0.056, -0.07 ,-0.084, -0.098, -0.112, -0.126, NDArray dLdlExp('c', {2,3,4}, {0.028, 0.014, -0., -0.014,-0.028, -0.042, -0.056, -0.07 ,-0.084, -0.098, -0.112, -0.126,
-0.14 , -0.154, -0.168, -0.182,-0.196, -0.21 , -0.224, -0.238,-0.252, -0.266, -0.28 , -0.294}); -0.14 , -0.154, -0.168, -0.182,-0.196, -0.21 , -0.224, -0.238,-0.252, -0.266, -0.28 , -0.294});
@ -2664,7 +2664,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test7) {
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE); NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights(nd4j::DataType::DOUBLE); NDArray weights(nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {}, {0.}); NDArray dLdwExp('c', {}, std::vector<double>{0.});
logits.linspace(-0.08, 0.04); logits.linspace(-0.08, 0.04);
labels.linspace(1); labels.linspace(1);
@ -2766,7 +2766,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test10) {
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE); NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {1,1}, nd4j::DataType::DOUBLE); NDArray weights('c', {1,1}, nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {1,1}, {-3.81338}); NDArray dLdwExp('c', {1,1}, std::vector<double>{-3.81338});
logits.linspace(-0.08, 0.04); logits.linspace(-0.08, 0.04);
labels.linspace(1); labels.linspace(1);
@ -2992,7 +2992,7 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test2) {
NDArray weights('c', {1}, nd4j::DataType::DOUBLE); NDArray weights('c', {1}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {4}, {0.125, 0.125, -0.375, 0.125}); NDArray dLdpExp('c', {4}, {0.125, 0.125, -0.375, 0.125});
NDArray dLdwExp('c', {1}, {1.38629}); NDArray dLdwExp('c', {1}, std::vector<double>{1.38629});
logits = 2.; logits = 2.;
weights.assign(0.5); weights.assign(0.5);
@ -3020,10 +3020,10 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test3) {
NDArray labels('c', {4}, {0,0,1,0}, nd4j::DataType::INT32); NDArray labels('c', {4}, {0,0,1,0}, nd4j::DataType::INT32);
NDArray logits('c', {4}, nd4j::DataType::DOUBLE); NDArray logits('c', {4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {}, {0}, nd4j::DataType::DOUBLE); NDArray weights('c', {}, std::vector<double>{0}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {4}, {0.125, 0.125, -0.375, 0.125}); NDArray dLdpExp('c', {4}, {0.125, 0.125, -0.375, 0.125});
NDArray dLdwExp('c', {}, {1.38629}); NDArray dLdwExp('c', {}, std::vector<double>{1.38629});
logits = 2.; logits = 2.;
weights.assign(0.5); weights.assign(0.5);
@ -3051,10 +3051,10 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test4) {
NDArray labels('c', {4}, {0,0,1,0}, nd4j::DataType::INT32); NDArray labels('c', {4}, {0,0,1,0}, nd4j::DataType::INT32);
NDArray logits('c', {4}, nd4j::DataType::DOUBLE); NDArray logits('c', {4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {}, {0}, nd4j::DataType::DOUBLE); NDArray weights('c', {}, std::vector<double>{0}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {4}, {0.23521, 0.2448 , -0.7452 , 0.26519}); NDArray dLdpExp('c', {4}, {0.23521, 0.2448 , -0.7452 , 0.26519});
NDArray dLdwExp('c', {}, {0.}); NDArray dLdwExp('c', {}, std::vector<double>{0.});
logits.linspace(-0.08, 0.04); logits.linspace(-0.08, 0.04);
weights = 0.5; weights = 0.5;
@ -3085,7 +3085,7 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test5) {
NDArray weights('c', {1}, nd4j::DataType::DOUBLE); NDArray weights('c', {1}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {4}, {0.1176, 0.1224, -0.3726, 0.1326}); NDArray dLdpExp('c', {4}, {0.1176, 0.1224, -0.3726, 0.1326});
NDArray dLdwExp('c', {1}, {1.36729}); NDArray dLdwExp('c', {1}, std::vector<double>{1.36729});
logits.linspace(-0.08, 0.04); logits.linspace(-0.08, 0.04);
weights = 0.5; weights = 0.5;
@ -3321,7 +3321,7 @@ TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test4) {
///////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test5) { TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test5) {
NDArray labels('c', {2,1}, {1,0}); NDArray labels('c', {2,1}, std::vector<double>{1,0});
NDArray logits('c', {2,1}, {-0.04, 0.04}); NDArray logits('c', {2,1}, {-0.04, 0.04});
NDArray dLdpExp('c', {2,1}, {-0.51999, 0.51999}); NDArray dLdpExp('c', {2,1}, {-0.51999, 0.51999});
@ -3343,10 +3343,10 @@ TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test5) {
///////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test6) { TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test6) {
NDArray labels('c', {1,2}, {1,1}); NDArray labels('c', {1,2}, {1,1.});
NDArray logits('c', {1,2}, {-0.04, 0.04}); NDArray logits('c', {1,2}, {-0.04, 0.04});
NDArray dLdpExp('c', {1,2}, {0, 0}); NDArray dLdpExp('c', {1,2}, {0, 0.});
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op; nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
@ -3387,10 +3387,10 @@ TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test7) {
///////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test8) { TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test8) {
NDArray labels('c', {1}, {1}); NDArray labels('c', {1}, std::vector<double>{1});
NDArray logits('c', {1}, {0.04}); NDArray logits('c', {1}, std::vector<double>{0.04});
NDArray dLdpExp('c', {1}, {0}); NDArray dLdpExp('c', {1}, std::vector<double>{0});
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op; nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
@ -3483,7 +3483,7 @@ TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test2) {
///////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test3) { TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test3) {
NDArray labels('c', {}, {1}, nd4j::DataType::INT64); NDArray labels('c', {}, std::vector<double>{1}, nd4j::DataType::INT64);
NDArray logits('c', {2}, {-0.2, 0.3}); NDArray logits('c', {2}, {-0.2, 0.3});
NDArray dLdpExp('c', {2}, {0.37754, -0.37754}); NDArray dLdpExp('c', {2}, {0.37754, -0.37754});
@ -3529,7 +3529,7 @@ TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test4) {
///////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test5) { TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test5) {
NDArray labels('c', {1,1}, {0}, nd4j::DataType::INT64); NDArray labels('c', {1,1}, std::vector<double>({0}), nd4j::DataType::INT64);
NDArray logits('c', {1,1,2}, {-0.3,0.2}); NDArray logits('c', {1,1,2}, {-0.3,0.2});
NDArray dLdpExp('c', {1,1,2}, {-0.62246, 0.62246}); NDArray dLdpExp('c', {1,1,2}, {-0.62246, 0.62246});

View File

@ -127,7 +127,7 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test3) {
NDArray weights('c', {1}, nd4j::DataType::DOUBLE); NDArray weights('c', {1}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {4}, {0.05, -0.15, -1., 0.7}); NDArray dLdpExp('c', {4}, {0.05, -0.15, -1., 0.7});
NDArray dLdwExp('c', {1}, {1.3}); NDArray dLdwExp('c', {1}, std::vector<double>{1.3});
NDArray dLdlExp('c', {4}, {0.2, 0.1, -0. , -0.1}); NDArray dLdlExp('c', {4}, {0.2, 0.1, -0. , -0.1});
predictions.linspace(-0.4, 0.2); predictions.linspace(-0.4, 0.2);
@ -158,10 +158,10 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test4) {
NDArray labels('c', {1,4}, {-0.1, 0.3, 2, -1.4}); NDArray labels('c', {1,4}, {-0.1, 0.3, 2, -1.4});
NDArray predictions('c', {1,4}, nd4j::DataType::DOUBLE); NDArray predictions('c', {1,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {}, {0.}, nd4j::DataType::DOUBLE); NDArray weights('c', {}, std::vector<double>{0.}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {1,4}, {0.05, -0.15, -1., 0.7}); NDArray dLdpExp('c', {1,4}, {0.05, -0.15, -1., 0.7});
NDArray dLdwExp('c', {}, {1.3}); NDArray dLdwExp('c', {}, std::vector<double>{1.3});
NDArray dLdlExp('c', {1,4}, {0.2, 0.1, -0. , -0.1}); NDArray dLdlExp('c', {1,4}, {0.2, 0.1, -0. , -0.1});
predictions.linspace(-0.4, 0.2); predictions.linspace(-0.4, 0.2);
@ -196,7 +196,7 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test5) {
NDArray weights('c', {1,1}, nd4j::DataType::DOUBLE); NDArray weights('c', {1,1}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {4}, {0.1, -0.3, -2. , 1.4}); NDArray dLdpExp('c', {4}, {0.1, -0.3, -2. , 1.4});
NDArray dLdwExp('c', {1,1}, {0.}); NDArray dLdwExp('c', {1,1}, std::vector<double>{0.});
NDArray dLdlExp('c', {4}, {0.4, 0.2, -0. , -0.2}); NDArray dLdlExp('c', {4}, {0.4, 0.2, -0. , -0.2});
predictions.linspace(-0.4, 0.2); predictions.linspace(-0.4, 0.2);
@ -369,10 +369,10 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test9) {
TEST_F(DeclarableOpsTests12, hinge_loss_14) { TEST_F(DeclarableOpsTests12, hinge_loss_14) {
NDArray logits('c', {3,4}, nd4j::DataType::DOUBLE); NDArray logits('c', {3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {}, {1.}); NDArray weights('c', {}, std::vector<double>{1.});
NDArray labels('c', {3,4}, {0,1,1,0,1,0,1,0,1,0,1,0}); NDArray labels('c', {3,4}, {0,1,1,0,1,0,1,0,1,0,1,0});
NDArray output('c', {}, {0.}, nd4j::DataType::DOUBLE); NDArray output('c', {}, std::vector<double>{0.}, nd4j::DataType::DOUBLE);
logits.linspace(1.); logits.linspace(1.);
weights.assign(1.); weights.assign(1.);
@ -576,7 +576,7 @@ TEST_F(DeclarableOpsTests12, TestMinimumBP_1) {
TEST_F(DeclarableOpsTests12, reverse_test15) { TEST_F(DeclarableOpsTests12, reverse_test15) {
NDArray x('c', {5}, {1,2,3,4,5}, nd4j::DataType::DOUBLE); NDArray x('c', {5}, {1,2,3,4,5}, nd4j::DataType::DOUBLE);
NDArray axis('c', {}, {0}, nd4j::DataType::INT32); NDArray axis('c', {}, std::vector<double>{0}, nd4j::DataType::INT32);
NDArray z('c', {5}, nd4j::DataType::DOUBLE); NDArray z('c', {5}, nd4j::DataType::DOUBLE);
NDArray exp('c', {5}, {5,4,3,2,1}, nd4j::DataType::DOUBLE); NDArray exp('c', {5}, {5,4,3,2,1}, nd4j::DataType::DOUBLE);
@ -711,7 +711,7 @@ TEST_F(DeclarableOpsTests12, multiUnique_2) {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests12, tensormmul_6) { TEST_F(DeclarableOpsTests12, tensormmul_6) {
NDArray x('c', {1}, {2}, nd4j::DataType::FLOAT32); NDArray x('c', {1}, std::vector<double>{2}, nd4j::DataType::FLOAT32);
NDArray y('c', {2,1,2}, {1,2,3,4}, nd4j::DataType::FLOAT32); NDArray y('c', {2,1,2}, {1,2,3,4}, nd4j::DataType::FLOAT32);
NDArray exp('c', {2,2}, {2,4,6,8}, nd4j::DataType::FLOAT32); NDArray exp('c', {2,2}, {2,4,6,8}, nd4j::DataType::FLOAT32);
@ -1140,9 +1140,9 @@ TEST_F(DeclarableOpsTests12, lrn_bp_9) {
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests12, lrn_bp_10) { TEST_F(DeclarableOpsTests12, lrn_bp_10) {
NDArray input('c', {1,1,1,1}, {1}); NDArray input('c', {1,1,1,1}, std::vector<double>{1});
NDArray gradO('c', {1,1,1,1}, {1}); NDArray gradO('c', {1,1,1,1}, std::vector<double>{1});
NDArray exp('c', {1,1,1,1}, {0.19245008}); NDArray exp('c', {1,1,1,1}, std::vector<double>{0.19245008});
nd4j::ops::lrn_bp op; nd4j::ops::lrn_bp op;
@ -1193,8 +1193,8 @@ TEST_F(DeclarableOpsTests12, lrn_2) {
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests12, lrn_3) { TEST_F(DeclarableOpsTests12, lrn_3) {
NDArray input('c', {1,1,1,1}, {1.}); NDArray input('c', {1,1,1,1}, std::vector<double>{1.});
NDArray exp('c', {1,1,1,1}, {0.69006556}); NDArray exp('c', {1,1,1,1}, std::vector<double>{0.69006556});
nd4j::ops::lrn op; nd4j::ops::lrn op;
@ -1208,8 +1208,8 @@ TEST_F(DeclarableOpsTests12, lrn_3) {
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests12, lrn_4) { TEST_F(DeclarableOpsTests12, lrn_4) {
NDArray input('c', {1,1,1,1}, {1.}); NDArray input('c', {1,1,1,1}, std::vector<double>{1.});
NDArray exp('c', {1,1,1,1}, {0.69006556}); NDArray exp('c', {1,1,1,1}, std::vector<double>{0.69006556});
nd4j::ops::lrn op; nd4j::ops::lrn op;
@ -1239,10 +1239,10 @@ TEST_F(DeclarableOpsTests12, lrn_5) {
TEST_F(DeclarableOpsTests12, inTopK_1) { TEST_F(DeclarableOpsTests12, inTopK_1) {
NDArray x('c', {4, 5}, {11.0, 14.0, 6.0, 9.0, 3.5, 7.0, 21.0, 3.0, 15.0, 6.0, 9.0, 3.5, 7.0, 11.0, 13.0, 5.0, 16.0, 9.0, 13.5, 7.0}); NDArray x('c', {4, 5}, {11.0, 14.0, 6.0, 9.0, 3.5, 7.0, 21.0, 3.0, 15.0, 6.0, 9.0, 3.5, 7.0, 11.0, 13.0, 5.0, 16.0, 9.0, 13.5, 7.0});
NDArray y('c', {4}, {0, 0, 0, 0}, nd4j::DataType::INT64); NDArray y('c', {4}, {0., 0, 0, 0}, nd4j::DataType::INT64);
NDArray z('c', {4}, {1, 1, 1, 1}, nd4j::DataType::BOOL); NDArray z('c', {4}, {1., 1, 1, 1}, nd4j::DataType::BOOL);
NDArray expV('c', {4}, {1, 0, 0, 0}, nd4j::DataType::BOOL); NDArray expV('c', {4}, {1., 0, 0, 0}, nd4j::DataType::BOOL);
nd4j::ops::in_top_k op; nd4j::ops::in_top_k op;
Nd4jStatus status = op.execute({&x, &y, }, {&z}, {}, {2}, {}); Nd4jStatus status = op.execute({&x, &y, }, {&z}, {}, {2}, {});

View File

@ -809,7 +809,7 @@ TEST_F(DeclarableOpsTests13, space_to_batch_nd_1) {
NDArray x('c', {1, 2, 2, 2, 3}, nd4j::DataType::FLOAT32); NDArray x('c', {1, 2, 2, 2, 3}, nd4j::DataType::FLOAT32);
NDArray blockShape('c', {3}, {2, 2, 2} , nd4j::DataType::INT32); // three spatial dimensions NDArray blockShape('c', {3}, {2, 2, 2} , nd4j::DataType::INT32); // three spatial dimensions
NDArray paddings('c', {3, 2}, {0, 0, 0, 0, 0, 0} , nd4j::DataType::INT32); NDArray paddings('c', {3, 2}, std::vector<double>{0, 0, 0, 0, 0, 0} , nd4j::DataType::INT32);
NDArray exp('c', {8, 1, 1, 1, 3}, nd4j::DataType::FLOAT32); NDArray exp('c', {8, 1, 1, 1, 3}, nd4j::DataType::FLOAT32);
@ -892,8 +892,8 @@ TEST_F(DeclarableOpsTests13, batch_to_space_nd_1) {
NDArray x('c', {8, 1, 1, 1, 3}, nd4j::DataType::FLOAT32); NDArray x('c', {8, 1, 1, 1, 3}, nd4j::DataType::FLOAT32);
NDArray blockShape('c', {3}, {2, 2, 2} , nd4j::DataType::INT32); // three spatial dimensions NDArray blockShape('c', {3}, {2., 2, 2} , nd4j::DataType::INT32); // three spatial dimensions
NDArray crop('c', {3, 2}, {0, 0, 0, 0, 0, 0} , nd4j::DataType::INT32); NDArray crop('c', {3, 2}, {0., 0, 0, 0, 0, 0} , nd4j::DataType::INT32);
NDArray exp('c', {1, 2, 2, 2, 3}, nd4j::DataType::FLOAT32); NDArray exp('c', {1, 2, 2, 2, 3}, nd4j::DataType::FLOAT32);
@ -990,7 +990,7 @@ TEST_F(DeclarableOpsTests13, mergemax_1) {
TEST_F(DeclarableOpsTests13, mergemax_2) { TEST_F(DeclarableOpsTests13, mergemax_2) {
NDArray x1('c', {1, 3}, {0., 1, 2}, nd4j::DataType::FLOAT32); NDArray x1('c', {1, 3}, {0., 1, 2}, nd4j::DataType::FLOAT32);
NDArray x2('c', {1, 1}, {1.}, nd4j::DataType::FLOAT32); NDArray x2('c', {1, 1}, std::vector<double>{1.}, nd4j::DataType::FLOAT32);
NDArray out('c', {1, 3}, {-1., -1, -1}, nd4j::DataType::FLOAT32); NDArray out('c', {1, 3}, {-1., -1, -1}, nd4j::DataType::FLOAT32);
nd4j::ops::mergemax op; nd4j::ops::mergemax op;
@ -2143,10 +2143,10 @@ TEST_F(DeclarableOpsTests13, batchnorm_test7) {
NDArray input2('c', {3,15,15,3}, nd4j::DataType::FLOAT32); NDArray input2('c', {3,15,15,3}, nd4j::DataType::FLOAT32);
input2.permutei({0,3,1,2}); input2.permutei({0,3,1,2});
NDArray mean ('c', {3}, {0, 0, 0}, nd4j::DataType::FLOAT32); NDArray mean ('c', {3}, {0., 0, 0}, nd4j::DataType::FLOAT32);
NDArray variance('c', {3}, {1, 1, 1}, nd4j::DataType::FLOAT32); NDArray variance('c', {3}, {1., 1, 1}, nd4j::DataType::FLOAT32);
NDArray gamma ('c', {3}, {1, 1, 1}, nd4j::DataType::FLOAT32); NDArray gamma ('c', {3}, {1., 1, 1}, nd4j::DataType::FLOAT32);
NDArray beta ('c', {3}, {0, 0, 0}, nd4j::DataType::FLOAT32); NDArray beta ('c', {3}, {0., 0, 0}, nd4j::DataType::FLOAT32);
NDArray out1('c', {3,3,15,15}, nd4j::DataType::FLOAT32); NDArray out1('c', {3,3,15,15}, nd4j::DataType::FLOAT32);
NDArray out2('c', {3,3,15,15}, nd4j::DataType::FLOAT32); NDArray out2('c', {3,3,15,15}, nd4j::DataType::FLOAT32);

View File

@ -858,7 +858,7 @@ TEST_F(DeclarableOpsTests15, test_empty_decreasing_1) {
TEST_F(DeclarableOpsTests15, test_rgb_to_grs_1) { TEST_F(DeclarableOpsTests15, test_rgb_to_grs_1) {
// rank 1 // rank 1
NDArray rgbs('c', { 3 }, { 10, 50, 200 }, nd4j::DataType::INT32); NDArray rgbs('c', { 3 }, { 10, 50, 200 }, nd4j::DataType::INT32);
NDArray expected('c', { 1 }, { 55 }, nd4j::DataType::INT32); NDArray expected('c', { 1 }, std::vector<double>{ 55 }, nd4j::DataType::INT32);
nd4j::ops::rgb_to_grs op; nd4j::ops::rgb_to_grs op;
auto result = op.evaluate({&rgbs}, {}, {}); auto result = op.evaluate({&rgbs}, {}, {});
auto output = result->at(0); auto output = result->at(0);
@ -1395,7 +1395,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test6) {
y.assign(4.0); y.assign(4.0);
dLdzC.linspace(0.1, 0.1); dLdzC.linspace(0.1, 0.1);
NDArray dLdxExpXC('c', { 1 }, { 115.2 }, nd4j::DataType::FLOAT32); NDArray dLdxExpXC('c', { 1 }, std::vector<double>{ 115.2 }, nd4j::DataType::FLOAT32);
NDArray dLdyExpXC('c', { 2, 2, 2 }, { 1.10904, 2.21807, 3.32711, 4.43614, 5.54518, 6.65421, 7.76325, 8.87228 }, nd4j::DataType::FLOAT32); NDArray dLdyExpXC('c', { 2, 2, 2 }, { 1.10904, 2.21807, 3.32711, 4.43614, 5.54518, 6.65421, 7.76325, 8.87228 }, nd4j::DataType::FLOAT32);
nd4j::ops::Pow_bp op; nd4j::ops::Pow_bp op;

View File

@ -55,11 +55,11 @@ TEST_F(DeclarableOpsTests17, test_sparse_to_dense_1) {
} }
TEST_F(DeclarableOpsTests17, test_sparse_to_dense_2) { TEST_F(DeclarableOpsTests17, test_sparse_to_dense_2) {
auto values = NDArrayFactory::string('c', {3}, {"alpha", "beta", "gamma"}); auto values = NDArrayFactory::string({3}, {"alpha", "beta", "gamma"});
auto shape = NDArrayFactory::create<Nd4jLong>({3, 3}); auto shape = NDArrayFactory::create<Nd4jLong>({3, 3});
auto ranges = NDArrayFactory::create<Nd4jLong>({0,0, 1,1, 2,2}); auto ranges = NDArrayFactory::create<Nd4jLong>({0,0, 1,1, 2,2});
auto def = NDArrayFactory::string("d"); auto def = NDArrayFactory::string("d");
auto exp = NDArrayFactory::string('c', {3, 3}, {"alpha","d","d", "d","beta","d", "d","d","gamma"}); auto exp = NDArrayFactory::string( {3, 3}, {"alpha","d","d", "d","beta","d", "d","d","gamma"});
nd4j::ops::compat_sparse_to_dense op; nd4j::ops::compat_sparse_to_dense op;
@ -70,11 +70,11 @@ TEST_F(DeclarableOpsTests17, test_sparse_to_dense_2) {
} }
TEST_F(DeclarableOpsTests17, test_compat_string_split_1) { TEST_F(DeclarableOpsTests17, test_compat_string_split_1) {
auto x = NDArrayFactory::string('c', {2}, {"first string", "second"}); auto x = NDArrayFactory::string( {2}, {"first string", "second"});
auto delimiter = NDArrayFactory::string(" "); auto delimiter = NDArrayFactory::string(" ");
auto exp0 = NDArrayFactory::create<Nd4jLong>({0,0, 0,1, 1,0}); auto exp0 = NDArrayFactory::create<Nd4jLong>({0,0, 0,1, 1,0});
auto exp1 = NDArrayFactory::string('c', {3}, {"first", "string", "second"}); auto exp1 = NDArrayFactory::string( {3}, {"first", "string", "second"});
nd4j::ops::compat_string_split op; nd4j::ops::compat_string_split op;
auto result = op.evaluate({&x, &delimiter}); auto result = op.evaluate({&x, &delimiter});

View File

@ -79,7 +79,7 @@ TEST_F(DeclarableOpsTests2, gather_2) {
TEST_F(DeclarableOpsTests2, gather_3) { TEST_F(DeclarableOpsTests2, gather_3) {
NDArray input ('c', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}); NDArray input ('c', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24});
NDArray indices ('c', {1,1}, {2}, nd4j::DataType::INT32); NDArray indices ('c', {1,1}, std::vector<double>{2}, nd4j::DataType::INT32);
NDArray expected('c', {2,1,1,4}, {9,10,11,12,21,22,23,24}); NDArray expected('c', {2,1,1,4}, {9,10,11,12,21,22,23,24});
nd4j::ops::gather op; nd4j::ops::gather op;
@ -186,7 +186,7 @@ TEST_F(DeclarableOpsTests2, gather_7) {
TEST_F(DeclarableOpsTests2, gather_8) { TEST_F(DeclarableOpsTests2, gather_8) {
NDArray input('c', {3,5}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, nd4j::DataType::FLOAT32); NDArray input('c', {3,5}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, nd4j::DataType::FLOAT32);
NDArray indices('c', {1}, {2}, nd4j::DataType::INT32); NDArray indices('c', {1}, std::vector<double>{2}, nd4j::DataType::INT32);
NDArray expected('c', {1,5}, {11, 12, 13, 14, 15.}, nd4j::DataType::FLOAT32); NDArray expected('c', {1,5}, {11, 12, 13, 14, 15.}, nd4j::DataType::FLOAT32);
nd4j::ops::gather op; nd4j::ops::gather op;
@ -206,7 +206,7 @@ TEST_F(DeclarableOpsTests2, gather_8) {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, gather_9) { TEST_F(DeclarableOpsTests2, gather_9) {
NDArray x('c', {2, 4, 3, 2}, nd4j::DataType::FLOAT32); NDArray x('c', {2, 4, 3, 2}, nd4j::DataType::FLOAT32);
NDArray indices('c', {2}, {1, 0}, nd4j::DataType::INT32); NDArray indices('c', {2}, std::vector<double>{1, 0}, nd4j::DataType::INT32);
nd4j::ops::gather op; nd4j::ops::gather op;
auto result = op.evaluate({&x, &indices}, {}, {-2}); auto result = op.evaluate({&x, &indices}, {}, {-2});
@ -238,7 +238,7 @@ TEST_F(DeclarableOpsTests2, gather_10) {
TEST_F(DeclarableOpsTests2, gather_11) { TEST_F(DeclarableOpsTests2, gather_11) {
NDArray x('c', {2, 2}, {1, 2, 3, 4}); NDArray x('c', {2, 2}, {1, 2, 3, 4});
NDArray indices('c', {2}, {1, 0}, nd4j::DataType::INT64); NDArray indices('c', {2}, std::vector<double>{1, 0}, nd4j::DataType::INT64);
NDArray e('c', {2, 2}, {3, 4, 1, 2}); NDArray e('c', {2, 2}, {3, 4, 1, 2});
nd4j::ops::gather op; nd4j::ops::gather op;

View File

@ -243,7 +243,7 @@ TEST_F(DeclarableOpsTests5, Test_SetSeed_1) {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, scatterMul_test1) { TEST_F(DeclarableOpsTests5, scatterMul_test1) {
auto matrix = NDArrayFactory::create<float>('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); auto matrix = NDArrayFactory::create<float>('c', {2, 2}, {1.f, 2.f, 3.f, 4.f});
NDArray idc('c', {1}, {0LL}, nd4j::DataType::INT64); NDArray idc('c', {1}, std::vector<double>({0LL}), nd4j::DataType::INT64);
auto updates = NDArrayFactory::create<float>('c', {1, 2}, {10.f, 1.f}); auto updates = NDArrayFactory::create<float>('c', {1, 2}, {10.f, 1.f});
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {10.f, 2.f, 3.f, 4.f}); auto exp = NDArrayFactory::create<float>('c', {2, 2}, {10.f, 2.f, 3.f, 4.f});
@ -261,7 +261,7 @@ TEST_F(DeclarableOpsTests5, scatterMul_test1) {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, scatterDiv_test1) { TEST_F(DeclarableOpsTests5, scatterDiv_test1) {
auto matrix = NDArrayFactory::create<float>('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); auto matrix = NDArrayFactory::create<float>('c', {2, 2}, {1.f, 2.f, 3.f, 4.f});
NDArray idc('c', {1}, {0LL}, nd4j::DataType::INT64); NDArray idc('c', {1}, std::vector<double>({0LL}), nd4j::DataType::INT64);
auto updates = NDArrayFactory::create<float>('c', {1, 2}, {10.f, 1.f}); auto updates = NDArrayFactory::create<float>('c', {1, 2}, {10.f, 1.f});
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {0.10f, 2.f, 3.f, 4.f}); auto exp = NDArrayFactory::create<float>('c', {2, 2}, {0.10f, 2.f, 3.f, 4.f});
@ -279,7 +279,7 @@ TEST_F(DeclarableOpsTests5, scatterDiv_test1) {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, scatterSub_test1) { TEST_F(DeclarableOpsTests5, scatterSub_test1) {
auto matrix = NDArrayFactory::create<float>('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); auto matrix = NDArrayFactory::create<float>('c', {2, 2}, {1.f, 2.f, 3.f, 4.f});
NDArray idc('c', {1}, {0LL}, nd4j::DataType::INT64); NDArray idc('c', {1}, std::vector<double>({0LL}), nd4j::DataType::INT64);
auto updates = NDArrayFactory::create<float>('c', {1, 2}, {10.f, 1.f}); auto updates = NDArrayFactory::create<float>('c', {1, 2}, {10.f, 1.f});
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {-9.f, 1.f, 3.f, 4.f}); auto exp = NDArrayFactory::create<float>('c', {2, 2}, {-9.f, 1.f, 3.f, 4.f});

View File

@ -1411,7 +1411,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_2) {
TEST_F(DeclarableOpsTests6, MatrixDeterminant_3) { TEST_F(DeclarableOpsTests6, MatrixDeterminant_3) {
auto x = NDArrayFactory::create<double>('c', {1, 3, 3}, {3.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 3.0}); auto x = NDArrayFactory::create<double>('c', {1, 3, 3}, {3.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 3.0});
NDArray exp('c', {1}, {-54.0}); NDArray exp('c', {1}, std::vector<double>{-54.0});
nd4j::ops::matrix_determinant op; nd4j::ops::matrix_determinant op;
auto result = op.evaluate({&x}, {}, {}); auto result = op.evaluate({&x}, {}, {});
@ -1453,7 +1453,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_4) {
TEST_F(DeclarableOpsTests6, MatrixDeterminant_5) { TEST_F(DeclarableOpsTests6, MatrixDeterminant_5) {
auto x = NDArrayFactory::create<double>('c', {1, 4, 4}); auto x = NDArrayFactory::create<double>('c', {1, 4, 4});
NDArray exp('c', {1}, {-16.0}); NDArray exp('c', {1}, std::vector<double>{-16.0});
x.linspace(1); x.linspace(1);
x.p(5, 4.0); x.p(5, 4.0);
x.p(12, 12.0); x.p(12, 12.0);

View File

@ -83,7 +83,7 @@ TEST_F(FlatUtilsTests, flat_bool_serde_1) {
} }
TEST_F(FlatUtilsTests, flat_string_serde_1) { TEST_F(FlatUtilsTests, flat_string_serde_1) {
auto array = NDArrayFactory::string('c', {3}, {"alpha", "beta", "gamma"}); auto array = NDArrayFactory::string( {3}, {"alpha", "beta", "gamma"});
flatbuffers::FlatBufferBuilder builder(1024); flatbuffers::FlatBufferBuilder builder(1024);
auto flatArray = FlatUtils::toFlatArray(builder, array); auto flatArray = FlatUtils::toFlatArray(builder, array);

View File

@ -1277,14 +1277,14 @@ TEST_F(JavaInteropTests, test_size_dtype_1) {
} }
TEST_F(JavaInteropTests, test_expandable_array_op_1) { TEST_F(JavaInteropTests, test_expandable_array_op_1) {
auto x = NDArrayFactory::string('c', {2}, {"first string", "second"}); auto x = NDArrayFactory::string( {2}, {"first string", "second"});
auto d = NDArrayFactory::string(" "); auto d = NDArrayFactory::string(" ", nd4j::DataType::UTF8);
auto z0 = NDArrayFactory::create<Nd4jLong>('c', {6}); auto z0 = NDArrayFactory::create<Nd4jLong>('c', {6});
auto z1 = NDArrayFactory::string('c', {3}, {"", "", ""}); auto z1 = NDArrayFactory::string( {3}, {"", "", ""});
auto exp0 = NDArrayFactory::create<Nd4jLong>({0,0, 0,1, 1,0}); auto exp0 = NDArrayFactory::create<Nd4jLong>({0,0, 0,1, 1,0});
auto exp1 = NDArrayFactory::string('c', {3}, {"first", "string", "second"}); auto exp1 = NDArrayFactory::string( {3}, {"first", "string", "second"});
InteropDataBuffer iz0(z0.dataBuffer()); InteropDataBuffer iz0(z0.dataBuffer());
InteropDataBuffer iz1(z1.dataBuffer()); InteropDataBuffer iz1(z1.dataBuffer());

View File

@ -204,7 +204,7 @@ TEST_F(MultiDataTypeTests, ndarray_repeat_test1) {
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
TEST_F(MultiDataTypeTests, ndarray_bufferAsT_test1) { TEST_F(MultiDataTypeTests, ndarray_bufferAsT_test1) {
NDArray x('f', {2}, {1.5, 3.5}, nd4j::DataType::FLOAT32); NDArray x('f', {2}, {1.5, 3.5}, nd4j::DataType::FLOAT32);
NDArray y('c', {}, {1.5}, nd4j::DataType::FLOAT32); NDArray y('c', {}, std::vector<double>{1.5}, nd4j::DataType::FLOAT32);
const int* buffX = x.bufferAsT<int>(); const int* buffX = x.bufferAsT<int>();
const int* buffY = y.bufferAsT<int>(); const int* buffY = y.bufferAsT<int>();
@ -217,8 +217,8 @@ TEST_F(MultiDataTypeTests, ndarray_assign_test1) {
NDArray x('c', {2,2}, {0, 1, 2, 3}, nd4j::DataType::UINT8); NDArray x('c', {2,2}, {0, 1, 2, 3}, nd4j::DataType::UINT8);
NDArray exp('c', {2,2}, {10, 10, 20, 20}, nd4j::DataType::UINT8); NDArray exp('c', {2,2}, {10, 10, 20, 20}, nd4j::DataType::UINT8);
NDArray scalar1('c', {}, {10.5}, nd4j::DataType::FLOAT32); NDArray scalar1('c', {}, std::vector<double>{10.5}, nd4j::DataType::FLOAT32);
NDArray scalar2('c', {}, {20.8}, nd4j::DataType::DOUBLE); NDArray scalar2('c', {}, std::vector<double>{20.8}, nd4j::DataType::DOUBLE);
x(0,{0}).assign(scalar1); x(0,{0}).assign(scalar1);
x(1,{0}).assign(scalar2); x(1,{0}).assign(scalar2);
@ -233,9 +233,9 @@ TEST_F(MultiDataTypeTests, ndarray_assign_test1) {
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test1) { TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test1) {
NDArray x('f', {2,2}, {0, 1.5, 2.5, 3.5}, nd4j::DataType::HALF); NDArray x('f', {2,2}, {0, 1.5, 2.5, 3.5}, nd4j::DataType::HALF);
NDArray exp1('c', {}, {3}, nd4j::DataType::INT64); NDArray exp1('c', {}, std::vector<double>{3}, nd4j::DataType::INT64);
NDArray exp2('c', {1,1}, {1}, nd4j::DataType::INT64); NDArray exp2('c', {1,1}, std::vector<double>{1}, nd4j::DataType::INT64);
NDArray exp3('c', {2}, {1,2}, nd4j::DataType::INT64); NDArray exp3('c', {2}, std::vector<double>{1,2}, nd4j::DataType::INT64);
auto scalar1 = x.reduceAlongDimension(nd4j::reduce::CountNonZero, {}/*whole range*/); auto scalar1 = x.reduceAlongDimension(nd4j::reduce::CountNonZero, {}/*whole range*/);
ASSERT_EQ(scalar1, exp1); ASSERT_EQ(scalar1, exp1);
@ -250,7 +250,7 @@ TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test1) {
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test2) { TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test2) {
NDArray x('c', {2, 2}, {0, 1, 2, 3}, nd4j::DataType::INT32); NDArray x('c', {2, 2}, {0, 1, 2, 3}, nd4j::DataType::INT32);
NDArray exp1('c', {}, {1.5}, nd4j::DataType::FLOAT32); NDArray exp1('c', {}, std::vector<double>{1.5}, nd4j::DataType::FLOAT32);
NDArray exp2('c', {2}, {0.5,2.5}, nd4j::DataType::FLOAT32); NDArray exp2('c', {2}, {0.5,2.5}, nd4j::DataType::FLOAT32);
auto scalar1 = x.reduceAlongDimension(nd4j::reduce::Mean, {}/*whole range*/); auto scalar1 = x.reduceAlongDimension(nd4j::reduce::Mean, {}/*whole range*/);
@ -265,7 +265,7 @@ TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test2) {
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test3) { TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test3) {
NDArray x('c', {2, 2}, {0.5, 1.5, 2.5, 3.5}, nd4j::DataType::HALF); NDArray x('c', {2, 2}, {0.5, 1.5, 2.5, 3.5}, nd4j::DataType::HALF);
NDArray exp1('c', {}, {8.}, nd4j::DataType::HALF); NDArray exp1('c', {}, std::vector<double>{8.}, nd4j::DataType::HALF);
NDArray exp2('c', {2}, {2.,6.}, nd4j::DataType::HALF); NDArray exp2('c', {2}, {2.,6.}, nd4j::DataType::HALF);
auto scalar1 = x.reduceAlongDimension(nd4j::reduce::Sum, {}/*whole range*/); auto scalar1 = x.reduceAlongDimension(nd4j::reduce::Sum, {}/*whole range*/);
@ -278,8 +278,8 @@ TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test3) {
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test4) { TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test4) {
NDArray x('c', {2, 2}, {10.5, 1.5, -2.5, -3.5}, nd4j::DataType::HALF); NDArray x('c', {2, 2}, {10.5, 1.5, -2.5, -3.5}, nd4j::DataType::HALF);
NDArray exp1('c', {}, {1}, nd4j::DataType::BOOL); NDArray exp1('c', {}, std::vector<double>{1}, nd4j::DataType::BOOL);
NDArray exp2('c', {2}, {1,0}, nd4j::DataType::BOOL); NDArray exp2('c', {2}, std::vector<double>{1, 0}, nd4j::DataType::BOOL);
auto scalar1 = x.reduceAlongDimension(nd4j::reduce::IsPositive, {}/*whole range*/); auto scalar1 = x.reduceAlongDimension(nd4j::reduce::IsPositive, {}/*whole range*/);
ASSERT_EQ(scalar1, exp1); ASSERT_EQ(scalar1, exp1);
@ -291,8 +291,8 @@ TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test4) {
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
TEST_F(MultiDataTypeTests, ndarray_varianceNumber_test1) { TEST_F(MultiDataTypeTests, ndarray_varianceNumber_test1) {
NDArray x('f', {2, 2}, {0, 1, 2, 3}, nd4j::DataType::INT64); NDArray x('f', {2, 2}, {0, 1, 2, 3}, nd4j::DataType::INT64);
NDArray exp1('c', {}, {1.666666667}, nd4j::DataType::FLOAT32); NDArray exp1('c', {}, std::vector<double>{1.666666667}, nd4j::DataType::FLOAT32);
NDArray exp2('c', {}, {1.118033989}, nd4j::DataType::FLOAT32); NDArray exp2('c', {}, std::vector<double>{1.118033989}, nd4j::DataType::FLOAT32);
auto scalar1 = x.varianceNumber(variance::SummaryStatsVariance); auto scalar1 = x.varianceNumber(variance::SummaryStatsVariance);
ASSERT_EQ(scalar1, exp1); ASSERT_EQ(scalar1, exp1);
@ -475,8 +475,8 @@ TEST_F(MultiDataTypeTests, ndarray_operatorPlusEqual_test1) {
if (!Environment::getInstance()->isExperimentalBuild()) if (!Environment::getInstance()->isExperimentalBuild())
return; return;
NDArray scalar1('c', {0}, {4}, nd4j::DataType::INT32); NDArray scalar1('c', {0}, std::vector<double>{4}, nd4j::DataType::INT32);
NDArray scalar2('c', {0}, {1.5}, nd4j::DataType::HALF); NDArray scalar2('c', {0}, std::vector<double>{1.5}, nd4j::DataType::HALF);
NDArray x1('c', {2,3}, {1.5, 2.5, 3.5, 4.5, 5.5, 6.5}, nd4j::DataType::FLOAT32); NDArray x1('c', {2,3}, {1.5, 2.5, 3.5, 4.5, 5.5, 6.5}, nd4j::DataType::FLOAT32);
NDArray x2('c', {3,2}, {10, 20, 30, 40, 50, 60}, nd4j::DataType::INT64); NDArray x2('c', {3,2}, {10, 20, 30, 40, 50, 60}, nd4j::DataType::INT64);
@ -485,8 +485,8 @@ TEST_F(MultiDataTypeTests, ndarray_operatorPlusEqual_test1) {
NDArray x5('c', {2,2}, {0, 1, 2, 3}, nd4j::DataType::HALF); NDArray x5('c', {2,2}, {0, 1, 2, 3}, nd4j::DataType::HALF);
NDArray x6('c', {2}, {0.4, 0.5}, nd4j::DataType::FLOAT32); NDArray x6('c', {2}, {0.4, 0.5}, nd4j::DataType::FLOAT32);
NDArray exp1('c', {0}, {5}, nd4j::DataType::INT32); NDArray exp1('c', {0}, std::vector<double>{5}, nd4j::DataType::INT32);
NDArray exp2('c', {0}, {6.5}, nd4j::DataType::HALF); NDArray exp2('c', {0}, std::vector<double>{6.5}, nd4j::DataType::HALF);
NDArray exp3('c', {3,2}, {11, 22, 33, 44, 55, 66}, nd4j::DataType::INT64); NDArray exp3('c', {3,2}, {11, 22, 33, 44, 55, 66}, nd4j::DataType::INT64);
NDArray exp4('c', {2,3}, {12.5, 24.5, 36.5, 48.5, 60.5, 72.5}, nd4j::DataType::FLOAT32); NDArray exp4('c', {2,3}, {12.5, 24.5, 36.5, 48.5, 60.5, 72.5}, nd4j::DataType::FLOAT32);
NDArray exp5('c', {2,2}, {0.4, 1.5, 2.4, 3.5}, nd4j::DataType::HALF); NDArray exp5('c', {2,2}, {0.4, 1.5, 2.4, 3.5}, nd4j::DataType::HALF);
@ -553,8 +553,8 @@ TEST_F(MultiDataTypeTests, ndarray_operatorMinusEqual_test1) {
if (!Environment::getInstance()->isExperimentalBuild()) if (!Environment::getInstance()->isExperimentalBuild())
return; return;
NDArray scalar1('c', {0}, {4}, nd4j::DataType::INT32); NDArray scalar1('c', {0}, std::vector<double>{4}, nd4j::DataType::INT32);
NDArray scalar2('c', {0}, {1.5}, nd4j::DataType::HALF); NDArray scalar2('c', {0}, std::vector<double>{1.5}, nd4j::DataType::HALF);
NDArray x1('c', {2,3}, {1.5, 2.5, 3.5, 4.5, 5.5, 6.5}, nd4j::DataType::FLOAT32); NDArray x1('c', {2,3}, {1.5, 2.5, 3.5, 4.5, 5.5, 6.5}, nd4j::DataType::FLOAT32);
NDArray x2('c', {3,2}, {10, 20, 30, 40, 50, 60}, nd4j::DataType::INT64); NDArray x2('c', {3,2}, {10, 20, 30, 40, 50, 60}, nd4j::DataType::INT64);
@ -563,8 +563,8 @@ TEST_F(MultiDataTypeTests, ndarray_operatorMinusEqual_test1) {
NDArray x5('c', {2,2}, {0, 1, 2, 3}, nd4j::DataType::HALF); NDArray x5('c', {2,2}, {0, 1, 2, 3}, nd4j::DataType::HALF);
NDArray x6('c', {2}, {0.4, 0.5}, nd4j::DataType::FLOAT32); NDArray x6('c', {2}, {0.4, 0.5}, nd4j::DataType::FLOAT32);
NDArray exp1('c', {0}, {2}, nd4j::DataType::INT32); NDArray exp1('c', {0}, std::vector<double>{2}, nd4j::DataType::INT32);
NDArray exp2('c', {0}, {-0.5}, nd4j::DataType::HALF); NDArray exp2('c', {0}, std::vector<double>{-0.5}, nd4j::DataType::HALF);
NDArray exp3('c', {3,2}, {8, 17, 26, 35, 44, 53}, nd4j::DataType::INT64); NDArray exp3('c', {3,2}, {8, 17, 26, 35, 44, 53}, nd4j::DataType::INT64);
NDArray exp4('c', {2,3}, {-6.5, -14.5, -22.5, -30.5, -38.5, -46.5}, nd4j::DataType::FLOAT32); NDArray exp4('c', {2,3}, {-6.5, -14.5, -22.5, -30.5, -38.5, -46.5}, nd4j::DataType::FLOAT32);
NDArray exp5('c', {2,2}, {0.4, -0.5, -1.6, -2.5}, nd4j::DataType::HALF); NDArray exp5('c', {2,2}, {0.4, -0.5, -1.6, -2.5}, nd4j::DataType::HALF);
@ -631,8 +631,8 @@ TEST_F(MultiDataTypeTests, ndarray_operatorMultiplyEqual_test1) {
if (!Environment::getInstance()->isExperimentalBuild()) if (!Environment::getInstance()->isExperimentalBuild())
return; return;
NDArray scalar1('c', {0}, {3}, nd4j::DataType::INT32); NDArray scalar1('c', {0}, std::vector<double>{3}, nd4j::DataType::INT32);
NDArray scalar2('c', {0}, {2.5}, nd4j::DataType::HALF); NDArray scalar2('c', {0}, std::vector<double>{2.5}, nd4j::DataType::HALF);
NDArray x1('c', {2,3}, {1.5, 2.5, 3.5, 4.5, 5.5, 6.5}, nd4j::DataType::FLOAT32); NDArray x1('c', {2,3}, {1.5, 2.5, 3.5, 4.5, 5.5, 6.5}, nd4j::DataType::FLOAT32);
NDArray x2('c', {3,2}, {1, 2, 3, 4, 5, 6}, nd4j::DataType::INT64); NDArray x2('c', {3,2}, {1, 2, 3, 4, 5, 6}, nd4j::DataType::INT64);
@ -641,8 +641,8 @@ TEST_F(MultiDataTypeTests, ndarray_operatorMultiplyEqual_test1) {
NDArray x5('c', {2,2}, {0, 1, 2, 3}, nd4j::DataType::HALF); NDArray x5('c', {2,2}, {0, 1, 2, 3}, nd4j::DataType::HALF);
NDArray x6('c', {2}, {0.4, 0.5}, nd4j::DataType::FLOAT32); NDArray x6('c', {2}, {0.4, 0.5}, nd4j::DataType::FLOAT32);
NDArray exp1('c', {0}, {7}, nd4j::DataType::INT32); NDArray exp1('c', {0}, std::vector<double>{7}, nd4j::DataType::INT32);
NDArray exp2('c', {0}, {17.5}, nd4j::DataType::HALF); NDArray exp2('c', {0}, std::vector<double>{17.5}, nd4j::DataType::HALF);
NDArray exp3('c', {3,2}, {1, 5, 10, 18, 27, 39}, nd4j::DataType::INT64); NDArray exp3('c', {3,2}, {1, 5, 10, 18, 27, 39}, nd4j::DataType::INT64);
NDArray exp4('c', {2,3}, {1.5, 12.5, 35, 81, 148.5, 253.5}, nd4j::DataType::FLOAT32); NDArray exp4('c', {2,3}, {1.5, 12.5, 35, 81, 148.5, 253.5}, nd4j::DataType::FLOAT32);
NDArray exp5('c', {2,2}, {0., 0.5, 0.8, 1.5}, nd4j::DataType::HALF); NDArray exp5('c', {2,2}, {0., 0.5, 0.8, 1.5}, nd4j::DataType::HALF);
@ -709,8 +709,8 @@ TEST_F(MultiDataTypeTests, ndarray_operatorDivideEqual_test1) {
if (!Environment::getInstance()->isExperimentalBuild()) if (!Environment::getInstance()->isExperimentalBuild())
return; return;
NDArray scalar1('c', {0}, {3}, nd4j::DataType::INT32); NDArray scalar1('c', {0}, std::vector<double>{3}, nd4j::DataType::INT32);
NDArray scalar2('c', {0}, {2.5}, nd4j::DataType::HALF); NDArray scalar2('c', {0}, std::vector<double>{2.5}, nd4j::DataType::HALF);
NDArray x1('c', {2,3}, {1.5, 2.5, 3.5, 4.5, 5.5, 6.5}, nd4j::DataType::FLOAT32); NDArray x1('c', {2,3}, {1.5, 2.5, 3.5, 4.5, 5.5, 6.5}, nd4j::DataType::FLOAT32);
NDArray x2('c', {3,2}, {10, 20, 30, 40, 50, 60}, nd4j::DataType::INT64); NDArray x2('c', {3,2}, {10, 20, 30, 40, 50, 60}, nd4j::DataType::INT64);
@ -719,8 +719,8 @@ TEST_F(MultiDataTypeTests, ndarray_operatorDivideEqual_test1) {
NDArray x5('c', {2,2}, {1, 2, 3, 4}, nd4j::DataType::HALF); NDArray x5('c', {2,2}, {1, 2, 3, 4}, nd4j::DataType::HALF);
NDArray x6('c', {2}, {0.4, 0.5}, nd4j::DataType::FLOAT32); NDArray x6('c', {2}, {0.4, 0.5}, nd4j::DataType::FLOAT32);
NDArray exp1('c', {0}, {1}, nd4j::DataType::INT32); NDArray exp1('c', {0}, std::vector<double>{1}, nd4j::DataType::INT32);
NDArray exp2('c', {0}, {2.5}, nd4j::DataType::HALF); NDArray exp2('c', {0}, std::vector<double>{2.5}, nd4j::DataType::HALF);
NDArray exp3('c', {3,2}, {6, 8, 8, 8, 9, 9}, nd4j::DataType::INT64); NDArray exp3('c', {3,2}, {6, 8, 8, 8, 9, 9}, nd4j::DataType::INT64);
NDArray exp4('c', {2,3}, {0.25, 0.3125, 0.4375, 0.5625, 0.611111111, 0.722222222}, nd4j::DataType::FLOAT32); NDArray exp4('c', {2,3}, {0.25, 0.3125, 0.4375, 0.5625, 0.611111111, 0.722222222}, nd4j::DataType::FLOAT32);
NDArray exp5('c', {2,2}, {0.4, 0.25, 0.1333333, 0.125}, nd4j::DataType::HALF); NDArray exp5('c', {2,2}, {0.4, 0.25, 0.1333333, 0.125}, nd4j::DataType::HALF);
@ -792,10 +792,10 @@ TEST_F(MultiDataTypeTests, ndarray_reduceNumberFloat_test1) {
NDArray x3('c', {2,2}, {0.5, 1.5, 2.5, 3.5}, nd4j::DataType::DOUBLE); NDArray x3('c', {2,2}, {0.5, 1.5, 2.5, 3.5}, nd4j::DataType::DOUBLE);
NDArray x4('c', {2,2}, {0, 1, 0, 1}, nd4j::DataType::BOOL); NDArray x4('c', {2,2}, {0, 1, 0, 1}, nd4j::DataType::BOOL);
NDArray exp1('c', {0}, {1.5}, nd4j::DataType::FLOAT32); NDArray exp1('c', {0}, std::vector<double>{1.5}, nd4j::DataType::FLOAT32);
NDArray exp2('c', {0}, {2}, nd4j::DataType::HALF); NDArray exp2('c', {0}, std::vector<double>{2}, nd4j::DataType::HALF);
NDArray exp3('c', {0}, {2}, nd4j::DataType::DOUBLE); NDArray exp3('c', {0}, std::vector<double>{2}, nd4j::DataType::DOUBLE);
NDArray exp4('c', {0}, {0.25},nd4j::DataType::FLOAT32); NDArray exp4('c', {0}, std::vector<double>{0.25},nd4j::DataType::FLOAT32);
NDArray scalar = x1.reduceNumber(reduce::Mean); NDArray scalar = x1.reduceNumber(reduce::Mean);
@ -829,10 +829,10 @@ TEST_F(MultiDataTypeTests, ndarray_reduceNumberSame_test1) {
NDArray x3('c', {2,2}, {0.5, 1.5, 2.5, 3.5}, nd4j::DataType::DOUBLE); NDArray x3('c', {2,2}, {0.5, 1.5, 2.5, 3.5}, nd4j::DataType::DOUBLE);
NDArray x4('c', {2,2}, {0, 1, 0, 1}, nd4j::DataType::BOOL); NDArray x4('c', {2,2}, {0, 1, 0, 1}, nd4j::DataType::BOOL);
NDArray exp1('c', {0}, {6}, nd4j::DataType::INT64); NDArray exp1('c', {0}, std::vector<double>{6}, nd4j::DataType::INT64);
NDArray exp2('c', {0}, {8}, nd4j::DataType::HALF); NDArray exp2('c', {0}, std::vector<double>{8}, nd4j::DataType::HALF);
NDArray exp3('c', {0}, {8}, nd4j::DataType::DOUBLE); NDArray exp3('c', {0}, std::vector<double>{8}, nd4j::DataType::DOUBLE);
NDArray exp4('c', {0}, {1}, nd4j::DataType::BOOL); NDArray exp4('c', {0}, std::vector<double>{1}, nd4j::DataType::BOOL);
NDArray scalar = x1.reduceNumber(reduce::Sum); NDArray scalar = x1.reduceNumber(reduce::Sum);
@ -866,7 +866,7 @@ TEST_F(MultiDataTypeTests, ndarray_reduceNumberBool_test1) {
NDArray x3('c', {2,2}, {0.5, 1.5, 2.5, 3.5}, nd4j::DataType::DOUBLE); NDArray x3('c', {2,2}, {0.5, 1.5, 2.5, 3.5}, nd4j::DataType::DOUBLE);
NDArray x4('c', {2,2}, {-2, -1, 0, 1}, nd4j::DataType::BOOL); NDArray x4('c', {2,2}, {-2, -1, 0, 1}, nd4j::DataType::BOOL);
NDArray exp1('c', {0}, {1}, nd4j::DataType::BOOL); NDArray exp1('c', {0}, std::vector<double>{1}, nd4j::DataType::BOOL);
NDArray scalar = x1.reduceNumber(reduce::IsFinite); NDArray scalar = x1.reduceNumber(reduce::IsFinite);
ASSERT_EQ(scalar, exp1); ASSERT_EQ(scalar, exp1);
@ -899,10 +899,10 @@ TEST_F(MultiDataTypeTests, ndarray_reduceNumberLong_test1) {
NDArray x3('c', {2,2}, {0.5, -1.5, 0, 3.5}, nd4j::DataType::DOUBLE); NDArray x3('c', {2,2}, {0.5, -1.5, 0, 3.5}, nd4j::DataType::DOUBLE);
NDArray x4('c', {2,2}, {0, 1, 0, 1}, nd4j::DataType::BOOL); NDArray x4('c', {2,2}, {0, 1, 0, 1}, nd4j::DataType::BOOL);
NDArray exp1('c', {0}, {3}, nd4j::DataType::INT64); NDArray exp1('c', {0}, std::vector<double>{3}, nd4j::DataType::INT64);
NDArray exp2('c', {0}, {4}, nd4j::DataType::INT64); NDArray exp2('c', {0}, std::vector<double>{4}, nd4j::DataType::INT64);
NDArray exp3('c', {0}, {3}, nd4j::DataType::INT64); NDArray exp3('c', {0}, std::vector<double>{3}, nd4j::DataType::INT64);
NDArray exp4('c', {0}, {2}, nd4j::DataType::INT64); NDArray exp4('c', {0}, std::vector<double>{2}, nd4j::DataType::INT64);
NDArray scalar = x1.reduceNumber(reduce::CountNonZero); NDArray scalar = x1.reduceNumber(reduce::CountNonZero);
ASSERT_EQ(scalar, exp1); ASSERT_EQ(scalar, exp1);
@ -934,9 +934,9 @@ TEST_F(MultiDataTypeTests, ndarray_indexReduceNumber_test1) {
NDArray x2('c', {2,2}, {0.5, 1.5, -4.5, 3.5}, nd4j::DataType::HALF); NDArray x2('c', {2,2}, {0.5, 1.5, -4.5, 3.5}, nd4j::DataType::HALF);
NDArray x3('c', {2,2}, {0, -1, 0, 1}, nd4j::DataType::BOOL); NDArray x3('c', {2,2}, {0, -1, 0, 1}, nd4j::DataType::BOOL);
NDArray exp1('c', {0}, {3}, nd4j::DataType::INT64); NDArray exp1('c', {0}, std::vector<double>{3}, nd4j::DataType::INT64);
NDArray exp2('c', {0}, {2}, nd4j::DataType::INT64); NDArray exp2('c', {0}, std::vector<double>{2}, nd4j::DataType::INT64);
NDArray exp3('c', {0}, {1}, nd4j::DataType::INT64); NDArray exp3('c', {0}, std::vector<double>{1}, nd4j::DataType::INT64);
NDArray scalar = x1.indexReduceNumber(nd4j::indexreduce::IndexAbsoluteMax); NDArray scalar = x1.indexReduceNumber(nd4j::indexreduce::IndexAbsoluteMax);
ASSERT_EQ(scalar, exp1); ASSERT_EQ(scalar, exp1);
@ -1238,15 +1238,15 @@ TEST_F(MultiDataTypeTests, ndarray_applyTrueBroadcast_test1) {
NDArray x7('c', {2}, {1, 2}, nd4j::DataType::INT64); NDArray x7('c', {2}, {1, 2}, nd4j::DataType::INT64);
NDArray x8('c', {2,2}, nd4j::DataType::BOOL); NDArray x8('c', {2,2}, nd4j::DataType::BOOL);
NDArray x13('c', {0}, {3}, nd4j::DataType::INT64); NDArray x13('c', {0}, std::vector<double>{3}, nd4j::DataType::INT64);
NDArray x14('c', {0}, {1.5}, nd4j::DataType::DOUBLE); NDArray x14('c', {0}, std::vector<double>{1.5}, nd4j::DataType::DOUBLE);
NDArray x15(nd4j::DataType::DOUBLE); NDArray x15(nd4j::DataType::DOUBLE);
NDArray x16('c', {2,2}, nd4j::DataType::DOUBLE); NDArray x16('c', {2,2}, nd4j::DataType::DOUBLE);
NDArray exp1('c', {2,2}, {11, 22, 31, 42}, nd4j::DataType::HALF); NDArray exp1('c', {2,2}, {11, 22, 31, 42}, nd4j::DataType::HALF);
NDArray exp2('c', {2,2}, {11, 22, 31, 42}, nd4j::DataType::INT32); NDArray exp2('c', {2,2}, {11, 22, 31, 42}, nd4j::DataType::INT32);
NDArray exp3('c', {2,2}, {1, 1, 1, 1}, nd4j::DataType::BOOL); NDArray exp3('c', {2,2}, {1, 1, 1, 1}, nd4j::DataType::BOOL);
NDArray exp4('c', {0}, {4.5}, nd4j::DataType::DOUBLE); NDArray exp4('c', {0}, std::vector<double>{4.5}, nd4j::DataType::DOUBLE);
NDArray exp5('c', {2,2}, {11.5, 21.5, 31.5, 41.5}, nd4j::DataType::DOUBLE); NDArray exp5('c', {2,2}, {11.5, 21.5, 31.5, 41.5}, nd4j::DataType::DOUBLE);
x1.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), x2, x3); x1.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), x2, x3);
@ -1289,13 +1289,13 @@ TEST_F(MultiDataTypeTests, ndarray_applyTrueBroadcast_test2) {
NDArray x1('c', {2,2}, {10, 20, 30, 40}, nd4j::DataType::HALF); NDArray x1('c', {2,2}, {10, 20, 30, 40}, nd4j::DataType::HALF);
NDArray x2('c', {2}, {10, 40}, nd4j::DataType::HALF); NDArray x2('c', {2}, {10, 40}, nd4j::DataType::HALF);
NDArray x3('c', {2,2}, nd4j::DataType::BOOL); NDArray x3('c', {2,2}, nd4j::DataType::BOOL);
NDArray x4('c', {0}, {10}, nd4j::DataType::HALF); NDArray x4('c', {0}, std::vector<double>{10}, nd4j::DataType::HALF);
NDArray x5('c', {0}, {20}, nd4j::DataType::HALF); NDArray x5('c', {0}, std::vector<double>{20}, nd4j::DataType::HALF);
NDArray x6(nd4j::DataType::BOOL); NDArray x6(nd4j::DataType::BOOL);
NDArray exp1('c', {2,2}, {1, 0, 0, 1}, nd4j::DataType::BOOL); NDArray exp1('c', {2,2}, {1, 0, 0, 1}, nd4j::DataType::BOOL);
NDArray exp2('c', {2,2}, {1, 0, 0, 0}, nd4j::DataType::BOOL); NDArray exp2('c', {2,2}, {1, 0, 0, 0}, nd4j::DataType::BOOL);
NDArray exp3('c', {0}, {0}, nd4j::DataType::BOOL); NDArray exp3('c', {0}, std::vector<double>{0}, nd4j::DataType::BOOL);
x1.applyTrueBroadcast(BroadcastBoolOpsTuple(nd4j::scalar::EqualTo, nd4j::pairwise::EqualTo, nd4j::broadcast::EqualTo), x2, x3); x1.applyTrueBroadcast(BroadcastBoolOpsTuple(nd4j::scalar::EqualTo, nd4j::pairwise::EqualTo, nd4j::broadcast::EqualTo), x2, x3);
ASSERT_EQ(x3, exp1); ASSERT_EQ(x3, exp1);
@ -1459,16 +1459,16 @@ TEST_F(MultiDataTypeTests, ndarray_applyIndexedLambda_test1) {
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
TEST_F(MultiDataTypeTests, ndarray_applyPairwiseLambda_test1) { TEST_F(MultiDataTypeTests, ndarray_applyPairwiseLambda_test1) {
NDArray x1('c', {2,2}, {0, 1, 2, 3}, nd4j::DataType::DOUBLE); NDArray x1('c', {2,2}, {0., 1, 2, 3}, nd4j::DataType::DOUBLE);
NDArray x2('c', {2,2}, {0, 1, 2, 3}, nd4j::DataType::INT64); NDArray x2('c', {2,2}, {0., 1, 2, 3}, nd4j::DataType::INT64);
NDArray x3('c', {2,2}, {0, 1.5, 2.5, 3.5}, nd4j::DataType::FLOAT32); NDArray x3('c', {2,2}, {0., 1.5, 2.5, 3.5}, nd4j::DataType::FLOAT32);
NDArray x4('c', {2,2}, nd4j::DataType::DOUBLE); NDArray x4('c', {2,2}, nd4j::DataType::DOUBLE);
NDArray x5('c', {2,2}, {0, 1.5, 2.5, 3.5}, nd4j::DataType::FLOAT32); NDArray x5('c', {2,2}, {0, 1.5, 2.5, 3.5}, nd4j::DataType::FLOAT32);
NDArray x6('c', {2,2}, {0.1, -1, -1, 0.1}, nd4j::DataType::BOOL); NDArray x6('c', {2,2}, {0.1, -1, -1, 0.1}, nd4j::DataType::BOOL);
NDArray x7('c', {2,2}, nd4j::DataType::BOOL); NDArray x7('c', {2,2}, nd4j::DataType::BOOL);
NDArray other1('c', {2,2}, {0.1, 0.1, 0.1, 0.1}, nd4j::DataType::FLOAT32); NDArray other1('c', {2,2}, {0.1, 0.1, 0.1, 0.1}, nd4j::DataType::FLOAT32);
NDArray other2('c', {2,2}, {0.1, 0.1, 0.1, 0.1}, nd4j::DataType::DOUBLE); NDArray other2('c', {2,2}, {0.1, 0.1, 0.1, 0.1}, nd4j::DataType::DOUBLE);
NDArray other3('c', {2,2}, {0, -1, -2, -3}, nd4j::DataType::INT64); NDArray other3('c', {2,2}, {0., -1, -2, -3}, nd4j::DataType::INT64);
NDArray other4('c', {2,2}, {1, 0, 0.1, 0}, nd4j::DataType::BOOL); NDArray other4('c', {2,2}, {1, 0, 0.1, 0}, nd4j::DataType::BOOL);
auto func1 = [](float elem1, float elem2) { return elem1 + elem2; }; auto func1 = [](float elem1, float elem2) { return elem1 + elem2; };
@ -1478,10 +1478,10 @@ TEST_F(MultiDataTypeTests, ndarray_applyPairwiseLambda_test1) {
auto func5 = [](float elem1, int elem2) { return elem1 - elem2; }; auto func5 = [](float elem1, int elem2) { return elem1 - elem2; };
NDArray exp1('c', {2,2}, {0.1, 1.1, 2.1, 3.1}, nd4j::DataType::DOUBLE); NDArray exp1('c', {2,2}, {0.1, 1.1, 2.1, 3.1}, nd4j::DataType::DOUBLE);
NDArray exp2('c', {2,2}, {0, 0, 0, 0}, nd4j::DataType::INT64); NDArray exp2('c', {2,2}, {0., 0, 0, 0}, nd4j::DataType::INT64);
NDArray exp3('c', {2,2}, {0.1, 1.1, 2.1, 3.1}, nd4j::DataType::FLOAT32); NDArray exp3('c', {2,2}, {0.1, 1.1, 2.1, 3.1}, nd4j::DataType::FLOAT32);
NDArray exp4('c', {2,2}, {0.1, 1.6, 2.6, 3.6}, nd4j::DataType::FLOAT32); NDArray exp4('c', {2,2}, {0.1, 1.6, 2.6, 3.6}, nd4j::DataType::FLOAT32);
NDArray exp5('c', {2,2}, {0, 1, 0, 1}, nd4j::DataType::BOOL); NDArray exp5('c', {2,2}, {0., 1, 0, 1}, nd4j::DataType::BOOL);
x1.applyPairwiseLambda<double>(other2, func1, x4); x1.applyPairwiseLambda<double>(other2, func1, x4);
ASSERT_EQ(x4, exp1); ASSERT_EQ(x4, exp1);
@ -1505,16 +1505,16 @@ TEST_F(MultiDataTypeTests, ndarray_applyPairwiseLambda_test1) {
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
TEST_F(MultiDataTypeTests, ndarray_applyIndexedPairwiseLambda_test1) { TEST_F(MultiDataTypeTests, ndarray_applyIndexedPairwiseLambda_test1) {
NDArray x1('c', {2,2}, {0, 1, 2, 3}, nd4j::DataType::DOUBLE); NDArray x1('c', {2,2}, {0., 1, 2, 3}, nd4j::DataType::DOUBLE);
NDArray x2('c', {2,2}, {0, 1, 2, 3}, nd4j::DataType::INT64); NDArray x2('c', {2,2}, {0., 1, 2, 3}, nd4j::DataType::INT64);
NDArray x3('c', {2,2}, {0, 1.5, 2.5, 3.5}, nd4j::DataType::FLOAT32); NDArray x3('c', {2,2}, {0., 1.5, 2.5, 3.5}, nd4j::DataType::FLOAT32);
NDArray x4('c', {2,2}, nd4j::DataType::DOUBLE); NDArray x4('c', {2,2}, nd4j::DataType::DOUBLE);
NDArray x5('c', {2,2}, {0, 1.5, 2.5, 3.5}, nd4j::DataType::FLOAT32); NDArray x5('c', {2,2}, {0, 1.5, 2.5, 3.5}, nd4j::DataType::FLOAT32);
NDArray x6('c', {2,2}, {0.1, -1, -1, 0.1}, nd4j::DataType::BOOL); NDArray x6('c', {2,2}, {0.1, -1, -1, 0.1}, nd4j::DataType::BOOL);
NDArray x7('c', {2,2}, nd4j::DataType::BOOL); NDArray x7('c', {2,2}, nd4j::DataType::BOOL);
NDArray other1('c', {2,2}, {0.1, 0.1, 0.1, 0.1}, nd4j::DataType::FLOAT32); NDArray other1('c', {2,2}, {0.1, 0.1, 0.1, 0.1}, nd4j::DataType::FLOAT32);
NDArray other2('c', {2,2}, {0.1, 0.1, 0.1, 0.1}, nd4j::DataType::DOUBLE); NDArray other2('c', {2,2}, {0.1, 0.1, 0.1, 0.1}, nd4j::DataType::DOUBLE);
NDArray other3('c', {2,2}, {0, -1, -2, -3}, nd4j::DataType::INT64); NDArray other3('c', {2,2}, {0., -1, -2, -3}, nd4j::DataType::INT64);
NDArray other4('c', {2,2}, {1, 0, 0.1, 0}, nd4j::DataType::BOOL); NDArray other4('c', {2,2}, {1, 0, 0.1, 0}, nd4j::DataType::BOOL);
auto func1 = [](Nd4jLong idx, float elem1, float elem2) { return elem1 + elem2 + idx; }; auto func1 = [](Nd4jLong idx, float elem1, float elem2) { return elem1 + elem2 + idx; };
@ -1524,10 +1524,10 @@ TEST_F(MultiDataTypeTests, ndarray_applyIndexedPairwiseLambda_test1) {
auto func5 = [](Nd4jLong idx, float elem1, int elem2) { return elem1 - elem2 + idx; }; auto func5 = [](Nd4jLong idx, float elem1, int elem2) { return elem1 - elem2 + idx; };
NDArray exp1('c', {2,2}, {0.1, 2.1, 4.1, 6.1}, nd4j::DataType::DOUBLE); NDArray exp1('c', {2,2}, {0.1, 2.1, 4.1, 6.1}, nd4j::DataType::DOUBLE);
NDArray exp2('c', {2,2}, {0, 1, 2, 3}, nd4j::DataType::INT64); NDArray exp2('c', {2,2}, {0., 1, 2, 3}, nd4j::DataType::INT64);
NDArray exp3('c', {2,2}, {0.1, 2.1, 4.1, 6.1}, nd4j::DataType::FLOAT32); NDArray exp3('c', {2,2}, {0.1, 2.1, 4.1, 6.1}, nd4j::DataType::FLOAT32);
NDArray exp4('c', {2,2}, {0.1, 2.6, 4.6, 6.6}, nd4j::DataType::FLOAT32); NDArray exp4('c', {2,2}, {0.1, 2.6, 4.6, 6.6}, nd4j::DataType::FLOAT32);
NDArray exp5('c', {2,2}, {0, 1, 1, 1}, nd4j::DataType::BOOL); NDArray exp5('c', {2,2}, {0., 1, 1, 1}, nd4j::DataType::BOOL);
x1.applyIndexedPairwiseLambda<double>(other2, func1, x4); x1.applyIndexedPairwiseLambda<double>(other2, func1, x4);
ASSERT_EQ(x4, exp1); ASSERT_EQ(x4, exp1);
@ -1551,25 +1551,25 @@ TEST_F(MultiDataTypeTests, ndarray_applyIndexedPairwiseLambda_test1) {
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
TEST_F(MultiDataTypeTests, ndarray_applyTriplewiseLambda_test1) { TEST_F(MultiDataTypeTests, ndarray_applyTriplewiseLambda_test1) {
NDArray x1('c', {2,2}, {0, 1, 2, 3}, nd4j::DataType::DOUBLE); NDArray x1('c', {2,2}, {0., 1, 2, 3}, nd4j::DataType::DOUBLE);
NDArray x2('c', {2,2}, {0, -1, -2, -3}, nd4j::DataType::DOUBLE); NDArray x2('c', {2,2}, {0., -1, -2, -3}, nd4j::DataType::DOUBLE);
NDArray x3('c', {2,2}, {0, -1.5, -2.5, -3.5}, nd4j::DataType::DOUBLE); NDArray x3('c', {2,2}, {0, -1.5, -2.5, -3.5}, nd4j::DataType::DOUBLE);
NDArray x4('c', {2,2}, nd4j::DataType::DOUBLE); NDArray x4('c', {2,2}, nd4j::DataType::DOUBLE);
NDArray x5('c', {2,2}, {0, 1, 2, 3}, nd4j::DataType::INT32); NDArray x5('c', {2,2}, {0., 1, 2, 3}, nd4j::DataType::INT32);
NDArray x6('c', {2,2}, {0, -1, -2, -3}, nd4j::DataType::INT32); NDArray x6('c', {2,2}, {0., -1, -2, -3}, nd4j::DataType::INT32);
NDArray x7('c', {2,2}, {0, 10, 20, 30}, nd4j::DataType::INT32); NDArray x7('c', {2,2}, {0., 10, 20, 30}, nd4j::DataType::INT32);
NDArray x8('c', {2,2}, {0, 1, 0, 1}, nd4j::DataType::BOOL); NDArray x8('c', {2,2}, {0., 1, 0, 1}, nd4j::DataType::BOOL);
NDArray x9('c', {2,2}, {1, 1, 0, 1}, nd4j::DataType::BOOL); NDArray x9('c', {2,2}, {1., 1, 0, 1}, nd4j::DataType::BOOL);
NDArray x10('c', {2,2}, {0, 0, 0, 0}, nd4j::DataType::BOOL); NDArray x10('c', {2,2}, {0., 0, 0, 0}, nd4j::DataType::BOOL);
auto func1 = [](double elem1, float elem2, int elem3) { return elem1 + elem2 + elem3; }; auto func1 = [](double elem1, float elem2, int elem3) { return elem1 + elem2 + elem3; };
auto func2 = [](float elem1, float elem2, float elem3) { return elem1 + elem2 + elem3; }; auto func2 = [](float elem1, float elem2, float elem3) { return elem1 + elem2 + elem3; };
auto func3 = [](int elem1, int elem2, int elem3) { return elem1 + elem2 + elem3; }; auto func3 = [](int elem1, int elem2, int elem3) { return elem1 + elem2 + elem3; };
auto func4 = [](bool elem1, bool elem2, bool elem3) { return elem1 + elem2 + elem3; }; auto func4 = [](bool elem1, bool elem2, bool elem3) { return elem1 + elem2 + elem3; };
NDArray exp('c', {2,2}, {1, 1, 0, 1}, nd4j::DataType::BOOL); NDArray exp('c', {2,2}, {1., 1, 0, 1}, nd4j::DataType::BOOL);
x1.applyTriplewiseLambda<double>(x2, x3, func1, x4); x1.applyTriplewiseLambda<double>(x2, x3, func1, x4);
ASSERT_EQ(x4, x2); ASSERT_EQ(x4, x2);
@ -1590,7 +1590,7 @@ TEST_F(MultiDataTypeTests, ndarray_applyTriplewiseLambda_test1) {
TEST_F(MultiDataTypeTests, ndarray_applyIndexReduce_test1) { TEST_F(MultiDataTypeTests, ndarray_applyIndexReduce_test1) {
NDArray x1('c', {2,3}, {0, 1, 2, 3, 4, 5}, nd4j::DataType::DOUBLE); NDArray x1('c', {2,3}, {0, 1, 2, 3, 4, 5}, nd4j::DataType::DOUBLE);
NDArray exp1('c', {}, {5}, nd4j::DataType::INT64); NDArray exp1('c', {}, std::vector<double>{5}, nd4j::DataType::INT64);
NDArray exp2('c', {2}, {2,2}, nd4j::DataType::INT64); NDArray exp2('c', {2}, {2,2}, nd4j::DataType::INT64);
NDArray exp3('c', {3}, {1,1,1}, nd4j::DataType::INT64); NDArray exp3('c', {3}, {1,1,1}, nd4j::DataType::INT64);
@ -1608,10 +1608,10 @@ TEST_F(MultiDataTypeTests, ndarray_applyIndexReduce_test1) {
TEST_F(MultiDataTypeTests, ndarray_applyIndexReduce_test2) { TEST_F(MultiDataTypeTests, ndarray_applyIndexReduce_test2) {
NDArray x1('c', {2,3}, {0, 1, 2, 3, 4, 5}, nd4j::DataType::DOUBLE); NDArray x1('c', {2,3}, {0, 1, 2, 3, 4, 5}, nd4j::DataType::DOUBLE);
NDArray scalar('c', {}, {5}, nd4j::DataType::INT64); NDArray scalar('c', {}, std::vector<double>{5}, nd4j::DataType::INT64);
NDArray vec1('c', {2}, {2,2}, nd4j::DataType::INT64); NDArray vec1('c', {2}, {2,2}, nd4j::DataType::INT64);
NDArray vec2('c', {3}, {1,1,1}, nd4j::DataType::INT64); NDArray vec2('c', {3}, {1,1,1}, nd4j::DataType::INT64);
NDArray exp1('c', {}, {5}, nd4j::DataType::INT64); NDArray exp1('c', {}, std::vector<double>{5}, nd4j::DataType::INT64);
NDArray exp2('c', {2}, {2,2}, nd4j::DataType::INT64); NDArray exp2('c', {2}, {2,2}, nd4j::DataType::INT64);
NDArray exp3('c', {3}, {1,1,1}, nd4j::DataType::INT64); NDArray exp3('c', {3}, {1,1,1}, nd4j::DataType::INT64);
@ -1632,8 +1632,8 @@ TEST_F(MultiDataTypeTests, applyReduce3_test1) {
NDArray x2('c', {2,2}, {-1,-2,-3,-4}, nd4j::DataType::INT32); NDArray x2('c', {2,2}, {-1,-2,-3,-4}, nd4j::DataType::INT32);
NDArray x3('c', {2,2}, {1.5,1.5,1.5,1.5}, nd4j::DataType::DOUBLE); NDArray x3('c', {2,2}, {1.5,1.5,1.5,1.5}, nd4j::DataType::DOUBLE);
NDArray x4('c', {2,2}, {1,2,3,4}, nd4j::DataType::DOUBLE); NDArray x4('c', {2,2}, {1,2,3,4}, nd4j::DataType::DOUBLE);
NDArray exp1('c', {}, {-30}, nd4j::DataType::FLOAT32); NDArray exp1('c', {}, std::vector<double>{-30}, nd4j::DataType::FLOAT32);
NDArray exp2('c', {}, {15}, nd4j::DataType::DOUBLE); NDArray exp2('c', {}, std::vector<double>{15}, nd4j::DataType::DOUBLE);
auto result = x1.applyReduce3(reduce3::Dot, x2); auto result = x1.applyReduce3(reduce3::Dot, x2);
ASSERT_EQ(result, exp1); ASSERT_EQ(result, exp1);
@ -1654,8 +1654,8 @@ TEST_F(MultiDataTypeTests, applyReduce3_test2) {
NDArray x7('c', {2,3}, {1.5,1.5,1.5,1.5,1.5,1.5}, nd4j::DataType::DOUBLE); NDArray x7('c', {2,3}, {1.5,1.5,1.5,1.5,1.5,1.5}, nd4j::DataType::DOUBLE);
NDArray x8('c', {2,3}, {1,2,3,4,5,6}, nd4j::DataType::DOUBLE); NDArray x8('c', {2,3}, {1,2,3,4,5,6}, nd4j::DataType::DOUBLE);
NDArray exp1('c', {}, {-30}, nd4j::DataType::FLOAT32); NDArray exp1('c', {}, std::vector<double>{-30}, nd4j::DataType::FLOAT32);
NDArray exp2('c', {}, {15}, nd4j::DataType::DOUBLE); NDArray exp2('c', {}, std::vector<double>{15}, nd4j::DataType::DOUBLE);
NDArray exp3('c', {3}, {-18,-20,-18}, nd4j::DataType::FLOAT32); NDArray exp3('c', {3}, {-18,-20,-18}, nd4j::DataType::FLOAT32);
NDArray exp4('c', {2}, {-28,-28}, nd4j::DataType::FLOAT32); NDArray exp4('c', {2}, {-28,-28}, nd4j::DataType::FLOAT32);
NDArray exp5('c', {3}, {7.5,10.5,13.5}, nd4j::DataType::DOUBLE); NDArray exp5('c', {3}, {7.5,10.5,13.5}, nd4j::DataType::DOUBLE);

View File

@ -184,7 +184,7 @@ TEST_F(NDArrayConstructorsTests, test_linspace_1) {
TEST_F(NDArrayConstructorsTests, test_constructor_10) { TEST_F(NDArrayConstructorsTests, test_constructor_10) {
NDArray scalar1(nd4j::DataType::DOUBLE); // scalar1 = 0 NDArray scalar1(nd4j::DataType::DOUBLE); // scalar1 = 0
NDArray scalar2('c', {}, {0}); NDArray scalar2('c', {}, std::vector<double>{0});
ASSERT_TRUE(scalar1.isActualOnDeviceSide()); ASSERT_TRUE(scalar1.isActualOnDeviceSide());
ASSERT_TRUE(!scalar1.isActualOnHostSide()); ASSERT_TRUE(!scalar1.isActualOnHostSide());

View File

@ -1226,8 +1226,8 @@ TEST_F(NDArrayCudaBasicsTests, applyReduce3_3) {
NDArray x3('c', {3,2}, {1.5,1.5,1.5,1.5,1.5,1.5}, nd4j::DataType::DOUBLE); NDArray x3('c', {3,2}, {1.5,1.5,1.5,1.5,1.5,1.5}, nd4j::DataType::DOUBLE);
NDArray x4('c', {3,2}, {1,2,3,4,5,6}, nd4j::DataType::DOUBLE); NDArray x4('c', {3,2}, {1,2,3,4,5,6}, nd4j::DataType::DOUBLE);
NDArray exp1('c', {}, {-204}, nd4j::DataType::FLOAT32); NDArray exp1('c', {}, std::vector<double>{-204}, nd4j::DataType::FLOAT32);
NDArray exp2('c', {}, {31.5}, nd4j::DataType::DOUBLE); NDArray exp2('c', {}, std::vector<double>{31.5}, nd4j::DataType::DOUBLE);
auto z = x1.applyReduce3(reduce3::Dot, x2); auto z = x1.applyReduce3(reduce3::Dot, x2);
@ -1260,7 +1260,7 @@ TEST_F(NDArrayCudaBasicsTests, applyAllReduce3_1) {
NDArray exp2('c', {6,4}, {-36.f, -44.f, -52.f, -60.f,-42.f, -52.f, -62.f, -72.f, 2.f, 0.f, -2.f, NDArray exp2('c', {6,4}, {-36.f, -44.f, -52.f, -60.f,-42.f, -52.f, -62.f, -72.f, 2.f, 0.f, -2.f,
-4.f, 6.f, 4.f, 2.f, 0.f, 10.f, 8.f, 6.f, 4.f, 14.f, 12.f, 10.f, 8.f}, -4.f, 6.f, 4.f, 2.f, 0.f, 10.f, 8.f, 6.f, 4.f, 14.f, 12.f, 10.f, 8.f},
nd4j::DataType::FLOAT32); nd4j::DataType::FLOAT32);
NDArray exp3('c', {1,1}, {31.5}, nd4j::DataType::DOUBLE); NDArray exp3('c', {1,1}, std::vector<double>{31.5}, nd4j::DataType::DOUBLE);
NDArray exp4('c', {3,3}, {4.5, 10.5, 16.5,4.5, 10.5, 16.5,4.5, 10.5, 16.5}, nd4j::DataType::DOUBLE); NDArray exp4('c', {3,3}, {4.5, 10.5, 16.5,4.5, 10.5, 16.5,4.5, 10.5, 16.5}, nd4j::DataType::DOUBLE);
auto z = x1.applyAllReduce3(reduce3::Dot, x2, {0,2}); auto z = x1.applyAllReduce3(reduce3::Dot, x2, {0,2});
@ -1292,15 +1292,15 @@ TEST_F(NDArrayCudaBasicsTests, applyIndexReduce_test1) {
NDArray x('c', {2,3}, {0, 10, 1, 2, 2.5,-4}, nd4j::DataType::DOUBLE); NDArray x('c', {2,3}, {0, 10, 1, 2, 2.5,-4}, nd4j::DataType::DOUBLE);
NDArray scalar('c', {}, {100}, nd4j::DataType::INT64); NDArray scalar('c', {}, std::vector<double>{100}, nd4j::DataType::INT64);
NDArray vec1('c', {2}, {100,100}, nd4j::DataType::INT64); NDArray vec1('c', {2}, {100,100}, nd4j::DataType::INT64);
NDArray vec2('c', {3}, {100,100,100}, nd4j::DataType::INT64); NDArray vec2('c', {3}, {100,100,100}, nd4j::DataType::INT64);
NDArray exp1('c', {}, {1}, nd4j::DataType::INT64); NDArray exp1('c', {}, std::vector<double>{1}, nd4j::DataType::INT64);
NDArray exp2('c', {2}, {1,1}, nd4j::DataType::INT64); NDArray exp2('c', {2}, {1,1}, nd4j::DataType::INT64);
NDArray exp3('c', {3}, {1,0,0}, nd4j::DataType::INT64); NDArray exp3('c', {3}, {1,0,0}, nd4j::DataType::INT64);
NDArray exp4('c', {}, {2}, nd4j::DataType::INT64); NDArray exp4('c', {}, std::vector<double>{2}, nd4j::DataType::INT64);
NDArray exp5('c', {2}, {1,1}, nd4j::DataType::INT64); NDArray exp5('c', {2}, {1,1}, nd4j::DataType::INT64);
NDArray exp6('c', {3}, {1,0,0}, nd4j::DataType::INT64); NDArray exp6('c', {3}, {1,0,0}, nd4j::DataType::INT64);
@ -1331,11 +1331,11 @@ TEST_F(NDArrayCudaBasicsTests, applyIndexReduce_test2) {
NDArray x('c', {2,3}, {0, 10, 1, 2, 2.5,-4}, nd4j::DataType::DOUBLE); NDArray x('c', {2,3}, {0, 10, 1, 2, 2.5,-4}, nd4j::DataType::DOUBLE);
NDArray exp1('c', {}, {1}, nd4j::DataType::INT64); NDArray exp1('c', {}, std::vector<double>{1}, nd4j::DataType::INT64);
NDArray exp2('c', {2}, {1,1}, nd4j::DataType::INT64); NDArray exp2('c', {2}, {1,1}, nd4j::DataType::INT64);
NDArray exp3('c', {3}, {1,0,0}, nd4j::DataType::INT64); NDArray exp3('c', {3}, {1,0,0}, nd4j::DataType::INT64);
NDArray exp4('c', {}, {2}, nd4j::DataType::INT64); NDArray exp4('c', {}, std::vector<double>{2}, nd4j::DataType::INT64);
NDArray exp5('c', {2}, {1,1}, nd4j::DataType::INT64); NDArray exp5('c', {2}, {1,1}, nd4j::DataType::INT64);
NDArray exp6('c', {3}, {1,0,0}, nd4j::DataType::INT64); NDArray exp6('c', {3}, {1,0,0}, nd4j::DataType::INT64);
@ -1365,13 +1365,13 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_float_test1) {
NDArray x('c', {2,3,2}, {1,2,3,4,5,6,7,8,-1,-2,-3,-4,}, nd4j::DataType::INT32); NDArray x('c', {2,3,2}, {1,2,3,4,5,6,7,8,-1,-2,-3,-4,}, nd4j::DataType::INT32);
NDArray z1('c', {}, {100}, nd4j::DataType::DOUBLE); NDArray z1('c', {}, std::vector<double>{100}, nd4j::DataType::DOUBLE);
NDArray z2('c', {2,2}, {100,100,100,100}, nd4j::DataType::FLOAT32); NDArray z2('c', {2,2}, {100,100,100,100}, nd4j::DataType::FLOAT32);
NDArray z3('c', {3}, {100,100,100}, nd4j::DataType::DOUBLE); NDArray z3('c', {3}, {100,100,100}, nd4j::DataType::DOUBLE);
NDArray z4('c', {3,2}, {100,100,100,100,100,100}, nd4j::DataType::FLOAT32); NDArray z4('c', {3,2}, {100,100,100,100,100,100}, nd4j::DataType::FLOAT32);
NDArray z5('c', {2}, {100,100}, nd4j::DataType::FLOAT32); NDArray z5('c', {2}, {100,100}, nd4j::DataType::FLOAT32);
NDArray exp1('c', {}, {2.166667}, nd4j::DataType::DOUBLE); NDArray exp1('c', {}, std::vector<double>{2.166667}, nd4j::DataType::DOUBLE);
NDArray exp2('c', {2,2}, {3.f,4.f,1.f,0.666667f}, nd4j::DataType::FLOAT32); NDArray exp2('c', {2,2}, {3.f,4.f,1.f,0.666667f}, nd4j::DataType::FLOAT32);
NDArray exp3('c', {3}, {4.5,1,1}, nd4j::DataType::DOUBLE); NDArray exp3('c', {3}, {4.5,1,1}, nd4j::DataType::DOUBLE);
NDArray exp4('c', {3,2}, {4,5,1,1,1,1}, nd4j::DataType::FLOAT32); NDArray exp4('c', {3,2}, {4,5,1,1,1,1}, nd4j::DataType::FLOAT32);
@ -1403,7 +1403,7 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_float_test2) {
NDArray x('c', {2,3,2}, {1,2,3,4,5,6,7,8,-1,-2,-3,-4,}, nd4j::DataType::DOUBLE); NDArray x('c', {2,3,2}, {1,2,3,4,5,6,7,8,-1,-2,-3,-4,}, nd4j::DataType::DOUBLE);
NDArray exp1('c', {}, {2.166667}, nd4j::DataType::DOUBLE); NDArray exp1('c', {}, std::vector<double>{2.166667}, nd4j::DataType::DOUBLE);
NDArray exp2('c', {2,2}, {3,4,1,0.666667}, nd4j::DataType::DOUBLE); NDArray exp2('c', {2,2}, {3,4,1,0.666667}, nd4j::DataType::DOUBLE);
NDArray exp3('c', {3}, {4.5,1,1}, nd4j::DataType::DOUBLE); NDArray exp3('c', {3}, {4.5,1,1}, nd4j::DataType::DOUBLE);
NDArray exp4('c', {3,2}, {4,5,1,1,1,1}, nd4j::DataType::DOUBLE); NDArray exp4('c', {3,2}, {4,5,1,1,1,1}, nd4j::DataType::DOUBLE);
@ -1477,13 +1477,13 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_same_test1) {
NDArray x('c', {2,3,2}, {1.5f,2.f,3.f,4.f,5.f,6.f,7.5f,8.f,-1.f,-2.f,-3.5f,-4.f}, nd4j::DataType::FLOAT32); NDArray x('c', {2,3,2}, {1.5f,2.f,3.f,4.f,5.f,6.f,7.5f,8.f,-1.f,-2.f,-3.5f,-4.f}, nd4j::DataType::FLOAT32);
NDArray z1('c', {}, {100}, nd4j::DataType::FLOAT32); NDArray z1('c', {}, std::vector<double>{100}, nd4j::DataType::FLOAT32);
NDArray z2('c', {2,2}, {100,100,100,100}, nd4j::DataType::FLOAT32); NDArray z2('c', {2,2}, {100,100,100,100}, nd4j::DataType::FLOAT32);
NDArray z3('c', {3}, {100,100,100}, nd4j::DataType::FLOAT32); NDArray z3('c', {3}, {100,100,100}, nd4j::DataType::FLOAT32);
NDArray z4('c', {3,2}, {100,100,100,100,100,100}, nd4j::DataType::FLOAT32); NDArray z4('c', {3,2}, {100,100,100,100,100,100}, nd4j::DataType::FLOAT32);
NDArray z5('c', {2}, {100,100}, nd4j::DataType::FLOAT32); NDArray z5('c', {2}, {100,100}, nd4j::DataType::FLOAT32);
NDArray exp1('c', {}, {26.5f}, nd4j::DataType::FLOAT32); NDArray exp1('c', {}, std::vector<double>{26.5f}, nd4j::DataType::FLOAT32);
NDArray exp2('c', {2,2}, {9.5f,12.f,3.f,2.f}, nd4j::DataType::FLOAT32); NDArray exp2('c', {2,2}, {9.5f,12.f,3.f,2.f}, nd4j::DataType::FLOAT32);
NDArray exp3('c', {3}, {19.f,4.f,3.5f}, nd4j::DataType::FLOAT32); NDArray exp3('c', {3}, {19.f,4.f,3.5f}, nd4j::DataType::FLOAT32);
NDArray exp4('c', {3,2}, {9.f,10.f,2.f,2.f,1.5f,2.f}, nd4j::DataType::FLOAT32); NDArray exp4('c', {3,2}, {9.f,10.f,2.f,2.f,1.5f,2.f}, nd4j::DataType::FLOAT32);
@ -1515,7 +1515,7 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_same_test2) {
NDArray x('c', {2,3,2}, {1.5,2,3,4,5,6,7.5,8,-1,-2,-3.5,-4,}, nd4j::DataType::INT64); NDArray x('c', {2,3,2}, {1.5,2,3,4,5,6,7.5,8,-1,-2,-3.5,-4,}, nd4j::DataType::INT64);
NDArray exp1('c', {}, {26}, nd4j::DataType::INT64); NDArray exp1('c', {}, std::vector<double>{26}, nd4j::DataType::INT64);
NDArray exp2('c', {2,2}, {9,12,3,2}, nd4j::DataType::INT64); NDArray exp2('c', {2,2}, {9,12,3,2}, nd4j::DataType::INT64);
NDArray exp3('c', {3}, {18,4,4}, nd4j::DataType::INT64); NDArray exp3('c', {3}, {18,4,4}, nd4j::DataType::INT64);
NDArray exp4('c', {3,2}, {8,10,2,2,2,2}, nd4j::DataType::INT64); NDArray exp4('c', {3,2}, {8,10,2,2,2,2}, nd4j::DataType::INT64);
@ -1547,13 +1547,13 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_bool_test1) {
NDArray x('c', {2,3,2}, {0.5,2,3,-4,5,6,-7.5,8,-1,-0.5,-3.5,4}, nd4j::DataType::DOUBLE); NDArray x('c', {2,3,2}, {0.5,2,3,-4,5,6,-7.5,8,-1,-0.5,-3.5,4}, nd4j::DataType::DOUBLE);
NDArray z1('c', {}, {true}, nd4j::DataType::BOOL); NDArray z1('c', {}, std::vector<double>{true}, nd4j::DataType::BOOL);
NDArray z2('c', {2,2}, {true,true,true,true}, nd4j::DataType::BOOL); NDArray z2('c', {2,2}, {true,true,true,true}, nd4j::DataType::BOOL);
NDArray z3('c', {3}, {true,true,true}, nd4j::DataType::BOOL); NDArray z3('c', {3}, {true,true,true}, nd4j::DataType::BOOL);
NDArray z4('c', {3,2}, {true,true,true,true,true,true}, nd4j::DataType::BOOL); NDArray z4('c', {3,2}, {true,true,true,true,true,true}, nd4j::DataType::BOOL);
NDArray z5('c', {2}, {true,true}, nd4j::DataType::BOOL); NDArray z5('c', {2}, {true,true}, nd4j::DataType::BOOL);
NDArray exp1('c', {}, {true}, nd4j::DataType::BOOL); NDArray exp1('c', {}, std::vector<double>{true}, nd4j::DataType::BOOL);
NDArray exp2('c', {2,2}, {true,true,false,true}, nd4j::DataType::BOOL); NDArray exp2('c', {2,2}, {true,true,false,true}, nd4j::DataType::BOOL);
NDArray exp3('c', {3}, {true,true,true}, nd4j::DataType::BOOL); NDArray exp3('c', {3}, {true,true,true}, nd4j::DataType::BOOL);
NDArray exp4('c', {3,2}, {true,true,true,false,true,true}, nd4j::DataType::BOOL); NDArray exp4('c', {3,2}, {true,true,true,false,true,true}, nd4j::DataType::BOOL);
@ -1585,7 +1585,7 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_bool_test2) {
NDArray x('c', {2,3,2}, {0.5,2,3,-4,5,6,-7.5,8,-1,-0.5,-3.5,4}, nd4j::DataType::INT32); NDArray x('c', {2,3,2}, {0.5,2,3,-4,5,6,-7.5,8,-1,-0.5,-3.5,4}, nd4j::DataType::INT32);
NDArray exp1('c', {}, {1}, nd4j::DataType::BOOL); NDArray exp1('c', {}, std::vector<double>{1}, nd4j::DataType::BOOL);
NDArray exp2('c', {2,2}, {1,1,0,1}, nd4j::DataType::BOOL); NDArray exp2('c', {2,2}, {1,1,0,1}, nd4j::DataType::BOOL);
NDArray exp3('c', {3}, {1,1,1}, nd4j::DataType::BOOL); NDArray exp3('c', {3}, {1,1,1}, nd4j::DataType::BOOL);
NDArray exp4('c', {3,2}, {0,1,1,0,1,1}, nd4j::DataType::BOOL); NDArray exp4('c', {3,2}, {0,1,1,0,1,1}, nd4j::DataType::BOOL);
@ -1617,13 +1617,13 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_long_test1) {
NDArray x('c', {2,3,2}, {0.5f,2.f,3.f,-0.f,5.f,6.f,-7.5f,0.f,-1.f,-0.5f,-3.5f,4.f}, nd4j::DataType::FLOAT32); NDArray x('c', {2,3,2}, {0.5f,2.f,3.f,-0.f,5.f,6.f,-7.5f,0.f,-1.f,-0.5f,-3.5f,4.f}, nd4j::DataType::FLOAT32);
NDArray z1('c', {}, {100}, nd4j::DataType::INT64); NDArray z1('c', {}, std::vector<double>{100}, nd4j::DataType::INT64);
NDArray z2('c', {2,2}, {100,100,100,100}, nd4j::DataType::INT64); NDArray z2('c', {2,2}, {100,100,100,100}, nd4j::DataType::INT64);
NDArray z3('c', {3}, {100,100,100}, nd4j::DataType::INT64); NDArray z3('c', {3}, {100,100,100}, nd4j::DataType::INT64);
NDArray z4('c', {3,2}, {100,100,100,100,100,100}, nd4j::DataType::INT64); NDArray z4('c', {3,2}, {100,100,100,100,100,100}, nd4j::DataType::INT64);
NDArray z5('c', {2}, {100,100}, nd4j::DataType::INT64); NDArray z5('c', {2}, {100,100}, nd4j::DataType::INT64);
NDArray exp1('c', {}, {2}, nd4j::DataType::INT64); NDArray exp1('c', {}, std::vector<double>{2}, nd4j::DataType::INT64);
NDArray exp2('c', {2,2}, {0,1,0,1}, nd4j::DataType::INT64); NDArray exp2('c', {2,2}, {0,1,0,1}, nd4j::DataType::INT64);
NDArray exp3('c', {3}, {1,1,0}, nd4j::DataType::INT64); NDArray exp3('c', {3}, {1,1,0}, nd4j::DataType::INT64);
NDArray exp4('c', {3,2}, {0,1,0,1,0,0}, nd4j::DataType::INT64); NDArray exp4('c', {3,2}, {0,1,0,1,0,0}, nd4j::DataType::INT64);
@ -1655,7 +1655,7 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_long_test2) {
NDArray x('c', {2,3,2}, {0.5,2,3,-0,5,6,-7.5,0,-1,-0.5,-3.5,4}, nd4j::DataType::INT32); NDArray x('c', {2,3,2}, {0.5,2,3,-0,5,6,-7.5,0,-1,-0.5,-3.5,4}, nd4j::DataType::INT32);
NDArray exp1('c', {}, {4}, nd4j::DataType::INT64); NDArray exp1('c', {}, std::vector<double>{4}, nd4j::DataType::INT64);
NDArray exp2('c', {2,2}, {1,1,0,2}, nd4j::DataType::INT64); NDArray exp2('c', {2,2}, {1,1,0,2}, nd4j::DataType::INT64);
NDArray exp3('c', {3}, {2,2,0}, nd4j::DataType::INT64); NDArray exp3('c', {3}, {2,2,0}, nd4j::DataType::INT64);
NDArray exp4('c', {3,2}, {1,1,0,2,0,0}, nd4j::DataType::INT64); NDArray exp4('c', {3,2}, {1,1,0,2,0,0}, nd4j::DataType::INT64);

View File

@ -692,7 +692,7 @@ TEST_F(ParityOpsTests, Test_Bias_Add_1) {
TEST_F(ParityOpsTests, Test_Scatter_Add_1) { TEST_F(ParityOpsTests, Test_Scatter_Add_1) {
auto matrix = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4}); auto matrix = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
NDArray idc('c', {1}, {0}, nd4j::DataType::INT64); NDArray idc('c', {1}, std::vector<double>({0}), nd4j::DataType::INT64);
auto updates = NDArrayFactory::create<float>('c', {1, 2}, {1, 1}); auto updates = NDArrayFactory::create<float>('c', {1, 2}, {1, 1});
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {2, 3, 3, 4}); auto exp = NDArrayFactory::create<float>('c', {2, 2}, {2, 3, 3, 4});
@ -710,7 +710,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_1) {
TEST_F(ParityOpsTests, Test_Scatter_Add_2) { TEST_F(ParityOpsTests, Test_Scatter_Add_2) {
auto vec = NDArrayFactory::create<float>('c', {4}, {1, 2, 3, 4}); auto vec = NDArrayFactory::create<float>('c', {4}, {1, 2, 3, 4});
NDArray idc('c', {1, 4}, {0, 1, 2, 3}, nd4j::DataType::INT64); NDArray idc('c', {1, 4}, {0., 1, 2, 3}, nd4j::DataType::INT64);
auto updates = NDArrayFactory::create<float>('c', {1, 4}, {1, 1, 1, 1}); auto updates = NDArrayFactory::create<float>('c', {1, 4}, {1, 1, 1, 1});
auto exp = NDArrayFactory::create<float>('c', {1, 4}, {2, 3, 4, 5}); auto exp = NDArrayFactory::create<float>('c', {1, 4}, {2, 3, 4, 5});
@ -727,7 +727,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_2) {
TEST_F(ParityOpsTests, Test_Scatter_Add_3) { TEST_F(ParityOpsTests, Test_Scatter_Add_3) {
auto matrix = NDArrayFactory::create<float>('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); auto matrix = NDArrayFactory::create<float>('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8});
NDArray idc('c', {1}, {0}, nd4j::DataType::INT64); NDArray idc('c', {1}, std::vector<double>({0}), nd4j::DataType::INT64);
auto updates = NDArrayFactory::create<float>('c', {1, 2, 2}, {1, 1, 1, 1}); auto updates = NDArrayFactory::create<float>('c', {1, 2, 2}, {1, 1, 1, 1});
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {2, 3, 4, 5, 5, 6, 7, 8}); auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {2, 3, 4, 5, 5, 6, 7, 8});
@ -744,7 +744,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_3) {
TEST_F(ParityOpsTests, Test_Scatter_Add_4) { TEST_F(ParityOpsTests, Test_Scatter_Add_4) {
auto matrix = NDArrayFactory::create<float>('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); auto matrix = NDArrayFactory::create<float>('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8});
NDArray idc('c', {1, 2}, {0, 0}, nd4j::DataType::INT64); NDArray idc('c', {1, 2}, std::vector<double>{0, 0}, nd4j::DataType::INT64);
auto updates = NDArrayFactory::create<float>('c', {1, 2, 2, 2}, {1, 1, 1, 1, 1, 1, 1, 1}); auto updates = NDArrayFactory::create<float>('c', {1, 2, 2, 2}, {1, 1, 1, 1, 1, 1, 1, 1});
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {3, 4, 5, 6, 5, 6, 7, 8}); auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {3, 4, 5, 6, 5, 6, 7, 8});
@ -761,7 +761,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_4) {
TEST_F(ParityOpsTests, Test_Scatter_Add_5) { TEST_F(ParityOpsTests, Test_Scatter_Add_5) {
auto matrix = NDArrayFactory::create<float>('c', {2, 2, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); auto matrix = NDArrayFactory::create<float>('c', {2, 2, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
NDArray idc('c', {2, 2}, {1, 1, 0, 0}, nd4j::DataType::INT64); NDArray idc('c', {2, 2}, {1., 1, 0, 0}, nd4j::DataType::INT64);
auto updates = NDArrayFactory::create<float>('c', {2, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); auto updates = NDArrayFactory::create<float>('c', {2, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
auto exp = NDArrayFactory::create<float>('c', {2, 2, 3}, {9., 11., 13.,15., 17., 19., 9., 11., 13.,15., 17., 19.}); auto exp = NDArrayFactory::create<float>('c', {2, 2, 3}, {9., 11., 13.,15., 17., 19., 9., 11., 13.,15., 17., 19.});
@ -796,7 +796,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_6) {
TEST_F(ParityOpsTests, Test_Scatter_Add_7) { TEST_F(ParityOpsTests, Test_Scatter_Add_7) {
auto matrix = NDArrayFactory::create<float>('c', {10, 3}, {1.f,2.f,3.f,4.f,5.f,6.f,7.f,8.f,9.f,10.f,11.f,12.f,13.f,14.f,15.f,16.f,17.f,18.f,19.f,20.f,21.f,22.f,23.f,24.f,25.f,26.f,27.f,28.f,29.f,30.f}); auto matrix = NDArrayFactory::create<float>('c', {10, 3}, {1.f,2.f,3.f,4.f,5.f,6.f,7.f,8.f,9.f,10.f,11.f,12.f,13.f,14.f,15.f,16.f,17.f,18.f,19.f,20.f,21.f,22.f,23.f,24.f,25.f,26.f,27.f,28.f,29.f,30.f});
NDArray idc('c', {}, {5}, nd4j::DataType::INT64); NDArray idc('c', {}, std::vector<double>{5}, nd4j::DataType::INT64);
auto updates = NDArrayFactory::create<float>('c', {3}, {10.f, 20.f, 30.f}); auto updates = NDArrayFactory::create<float>('c', {3}, {10.f, 20.f, 30.f});
auto exp = NDArrayFactory::create<float>('c', {10, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f,11.f,12.f, 13.f,14.f,15.f, 26.f,37.f,48.f, 19.f,20.f,21.f, 22.f,23.f,24.f, 25.f,26.f,27.f, 28.f,29.f,30.f}); auto exp = NDArrayFactory::create<float>('c', {10, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f,11.f,12.f, 13.f,14.f,15.f, 26.f,37.f,48.f, 19.f,20.f,21.f, 22.f,23.f,24.f, 25.f,26.f,27.f, 28.f,29.f,30.f});
@ -845,7 +845,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_9) {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(ParityOpsTests, scatterMax_test1) { TEST_F(ParityOpsTests, scatterMax_test1) {
auto matrix = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4}); auto matrix = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
NDArray idc('c', {1}, {0.}, nd4j::DataType::INT64); NDArray idc('c', {1}, std::vector<double>{0.}, nd4j::DataType::INT64);
auto updates = NDArrayFactory::create<float>('c', {1, 2}, {10, 1}); auto updates = NDArrayFactory::create<float>('c', {1, 2}, {10, 1});
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {10, 2, 3, 4}); auto exp = NDArrayFactory::create<float>('c', {2, 2}, {10, 2, 3, 4});
@ -879,7 +879,7 @@ TEST_F(ParityOpsTests, scatterMax_test2) {
TEST_F(ParityOpsTests, scatterMax_test3) { TEST_F(ParityOpsTests, scatterMax_test3) {
auto matrix = NDArrayFactory::create<float>('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); auto matrix = NDArrayFactory::create<float>('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8});
NDArray idc('c', {1}, {0}, nd4j::DataType::INT64); NDArray idc('c', {1}, std::vector<double>({0}), nd4j::DataType::INT64);
auto updates = NDArrayFactory::create<float>('c', {1, 2, 2}, {10, 1, 30, 1}); auto updates = NDArrayFactory::create<float>('c', {1, 2, 2}, {10, 1, 30, 1});
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {10, 2, 30, 4, 5, 6, 7, 8}); auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {10, 2, 30, 4, 5, 6, 7, 8});
@ -896,7 +896,7 @@ TEST_F(ParityOpsTests, scatterMax_test3) {
TEST_F(ParityOpsTests, scatterMax_test4) { TEST_F(ParityOpsTests, scatterMax_test4) {
auto matrix = NDArrayFactory::create<float>('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); auto matrix = NDArrayFactory::create<float>('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8});
NDArray idc('c', {1,2}, {0,0}, nd4j::DataType::INT32); NDArray idc('c', {1,2}, std::vector<double>{0.,0}, nd4j::DataType::INT32);
auto updates = NDArrayFactory::create<float>('c', {1, 2, 2, 2}, {1,10,1,10, 1,1,10,1.}); auto updates = NDArrayFactory::create<float>('c', {1, 2, 2, 2}, {1,10,1,10, 1,1,10,1.});
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {1, 10, 10, 10, 5, 6, 7, 8}); auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {1, 10, 10, 10, 5, 6, 7, 8});
@ -948,7 +948,7 @@ TEST_F(ParityOpsTests, scatterMax_test6) {
TEST_F(ParityOpsTests, scatterMin_test1) { TEST_F(ParityOpsTests, scatterMin_test1) {
auto matrix = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4}); auto matrix = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
NDArray idc('c', {1}, {0}, nd4j::DataType::INT32); NDArray idc('c', {1}, std::vector<double>({0}), nd4j::DataType::INT32);
auto updates = NDArrayFactory::create<float>('c', {1, 2}, {-1, 1}); auto updates = NDArrayFactory::create<float>('c', {1, 2}, {-1, 1});
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {-1, 1, 3, 4}); auto exp = NDArrayFactory::create<float>('c', {2, 2}, {-1, 1, 3, 4});
@ -982,7 +982,7 @@ TEST_F(ParityOpsTests, scatterMin_test2) {
TEST_F(ParityOpsTests, scatterMin_test3) { TEST_F(ParityOpsTests, scatterMin_test3) {
auto matrix = NDArrayFactory::create<float>('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); auto matrix = NDArrayFactory::create<float>('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8});
NDArray idc('c', {1}, {0}, nd4j::DataType::INT32); NDArray idc('c', {1}, std::vector<double>({0}), nd4j::DataType::INT32);
auto updates = NDArrayFactory::create<float>('c', {1, 2, 2}, {10, 1, 30, 2}); auto updates = NDArrayFactory::create<float>('c', {1, 2, 2}, {10, 1, 30, 2});
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {1, 1, 3, 2, 5, 6, 7, 8}); auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {1, 1, 3, 2, 5, 6, 7, 8});
@ -999,7 +999,7 @@ TEST_F(ParityOpsTests, scatterMin_test3) {
TEST_F(ParityOpsTests, scatterMin_test4) { TEST_F(ParityOpsTests, scatterMin_test4) {
auto matrix = NDArrayFactory::create<float>('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); auto matrix = NDArrayFactory::create<float>('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8});
NDArray idc('c', {1,2}, {0,0}, nd4j::DataType::INT32); NDArray idc('c', {1,2}, std::vector<double>{0.,0}, nd4j::DataType::INT32);
auto updates = NDArrayFactory::create<float>('c', {1, 2, 2, 2}, {1,10,1,10, 1,1,10,1.}); auto updates = NDArrayFactory::create<float>('c', {1, 2, 2, 2}, {1,10,1,10, 1,1,10,1.});
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {1, 1, 1, 1, 5, 6, 7, 8}); auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {1, 1, 1, 1, 5, 6, 7, 8});

View File

@ -1007,9 +1007,9 @@ TEST_F(RNGTests, test_uniform_119) {
TEST_F(RNGTests, test_multinomial_1) { TEST_F(RNGTests, test_multinomial_1) {
NDArray probs('f', { 3, 3 }, { 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3 }, nd4j::DataType::FLOAT32); NDArray probs('f', { 3, 3 }, { 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3 }, nd4j::DataType::FLOAT32);
NDArray expected('f', { 3, 3 }, { 0, 1, 2, 2, 0, 0, 1, 2, 1 }, nd4j::DataType::INT64); NDArray expected('f', { 3, 3 }, { 0., 1, 2, 2, 0, 0, 1, 2, 1 }, nd4j::DataType::INT64);
NDArray output('f', { 3, 3 }, nd4j::DataType::INT64); NDArray output('f', { 3, 3 }, nd4j::DataType::INT64);
NDArray samples('f', { 1 }, { 3 }, nd4j::DataType::INT32); NDArray samples('f', { 1 }, std::vector<double>({3}), nd4j::DataType::INT32);
nd4j::ops::random_multinomial op; nd4j::ops::random_multinomial op;
RandomGenerator rng(1234, 1234); RandomGenerator rng(1234, 1234);
@ -1018,7 +1018,7 @@ TEST_F(RNGTests, test_multinomial_1) {
ASSERT_TRUE(expected.equalsTo(output)); ASSERT_TRUE(expected.equalsTo(output));
NDArray probsZ('c', { 1, 3 }, { 0.3, 0.3, 0.3 }, nd4j::DataType::FLOAT32); NDArray probsZ('c', { 1, 3 }, { 0.3, 0.3, 0.3 }, nd4j::DataType::FLOAT32);
NDArray expectedZ('c', { 3, 3 }, { 0, 0, 0, 0, 0, 0, 0, 0, 0 }, nd4j::DataType::INT64); NDArray expectedZ('c', { 3, 3 }, { 0., 0, 0, 0, 0, 0, 0, 0, 0 }, nd4j::DataType::INT64);
auto result = op.evaluate({ &probsZ, &samples }, { }, { 1, INT64 }); auto result = op.evaluate({ &probsZ, &samples }, { }, { 1, INT64 });
auto outputZ = result->at(0); auto outputZ = result->at(0);
@ -1031,7 +1031,7 @@ TEST_F(RNGTests, test_multinomial_1) {
TEST_F(RNGTests, test_multinomial_2) { TEST_F(RNGTests, test_multinomial_2) {
NDArray samples('c', { 1 }, { 20 }, nd4j::DataType::INT32); NDArray samples('c', { 1 }, std::vector<double>{ 20 }, nd4j::DataType::INT32);
NDArray probs('c', { 3, 5 }, { 0.2, 0.3, 0.5, 0.3, 0.5, 0.2, 0.5, 0.2, 0.3, 0.35, 0.25, 0.3, 0.25, 0.25, 0.5 }, nd4j::DataType::FLOAT32); NDArray probs('c', { 3, 5 }, { 0.2, 0.3, 0.5, 0.3, 0.5, 0.2, 0.5, 0.2, 0.3, 0.35, 0.25, 0.3, 0.25, 0.25, 0.5 }, nd4j::DataType::FLOAT32);
NDArray expected('c', { 3, 20 }, { 0, 2, 0, 2, 0, 4, 2, 0, 1, 2, 0, 2, 3, 0, 0, 2, 4, 4, 1, 0, 2, 3, 2, 3, 0, 1, 3, 1, 1, 1, 2, 4, 3, 3, 1, 4, 4, 2, 0, 0, 3, 3, 3, 0, 0, 2, 2, 3, 3, 0, 0, 2, 3, 4, 2, 2, 3, 2, 1, 2 }, nd4j::DataType::INT64); NDArray expected('c', { 3, 20 }, { 0, 2, 0, 2, 0, 4, 2, 0, 1, 2, 0, 2, 3, 0, 0, 2, 4, 4, 1, 0, 2, 3, 2, 3, 0, 1, 3, 1, 1, 1, 2, 4, 3, 3, 1, 4, 4, 2, 0, 0, 3, 3, 3, 0, 0, 2, 2, 3, 3, 0, 0, 2, 3, 4, 2, 2, 3, 2, 1, 2 }, nd4j::DataType::INT64);
NDArray output('c', { 3, 20 }, nd4j::DataType::INT64); NDArray output('c', { 3, 20 }, nd4j::DataType::INT64);
@ -1057,10 +1057,11 @@ TEST_F(RNGTests, test_multinomial_3) {
NDArray probs('c', { 4, 3 }, { 0.3, 0.3, 0.4, 0.3, 0.4, 0.3, 0.3, 0.3, 0.4, 0.4, 0.3, 0.3 }, nd4j::DataType::FLOAT32); NDArray probs('c', { 4, 3 }, { 0.3, 0.3, 0.4, 0.3, 0.4, 0.3, 0.3, 0.3, 0.4, 0.4, 0.3, 0.3 }, nd4j::DataType::FLOAT32);
NDArray expected('c', { 4, 5 }, nd4j::DataType::INT64); NDArray expected('c', { 4, 5 }, nd4j::DataType::INT64);
NDArray output('c', { 4, 5 }, nd4j::DataType::INT64); NDArray output('c', { 4, 5 }, nd4j::DataType::INT64);
NDArray samples('c', { 1 }, { 5 }, nd4j::DataType::INT32); NDArray samples('c', { 1 }, std::vector<double>{ 5 }, nd4j::DataType::INT32);
RandomGenerator rng(1234, 1234); RandomGenerator rng(1234, 1234);
nd4j::ops::random_multinomial op; nd4j::ops::random_multinomial op;
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &expected }, {}, { 0, INT64 }, {}, {}, false)); ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &expected }, {}, { 0, INT64 }, {}, {}, false));
rng.setStates(1234, 1234); rng.setStates(1234, 1234);
@ -1074,7 +1075,7 @@ TEST_F(RNGTests, test_multinomial_4) {
NDArray probs('c', { 3, 4 }, { 0.3, 0.3, 0.4, 0.3, 0.4, 0.3, 0.3, 0.3, 0.4, 0.4, 0.3, 0.3 }, nd4j::DataType::FLOAT32); NDArray probs('c', { 3, 4 }, { 0.3, 0.3, 0.4, 0.3, 0.4, 0.3, 0.3, 0.3, 0.4, 0.4, 0.3, 0.3 }, nd4j::DataType::FLOAT32);
NDArray expected('c', { 5, 4 }, nd4j::DataType::INT64); NDArray expected('c', { 5, 4 }, nd4j::DataType::INT64);
NDArray output('c', { 5, 4 }, nd4j::DataType::INT64); NDArray output('c', { 5, 4 }, nd4j::DataType::INT64);
NDArray samples('c', { 1 }, { 5 }, nd4j::DataType::INT32); NDArray samples('c', { 1 }, std::vector<double>{ 5 }, nd4j::DataType::INT32);
RandomGenerator rng(1234, 1234); RandomGenerator rng(1234, 1234);
nd4j::ops::random_multinomial op; nd4j::ops::random_multinomial op;
@ -1092,7 +1093,7 @@ TEST_F(RNGTests, test_multinomial_5) {
int ClassValue = 2; int ClassValue = 2;
int Samples = 100000; int Samples = 100000;
NDArray samples('c', { 1 }, { 1.*Samples }, nd4j::DataType::INT32); NDArray samples('c', { 1 }, std::vector<double>{ 1.*Samples }, nd4j::DataType::INT32);
NDArray probs('c', { ClassValue, batchValue }, { 1.0, 1.0 }, nd4j::DataType::FLOAT32); NDArray probs('c', { ClassValue, batchValue }, { 1.0, 1.0 }, nd4j::DataType::FLOAT32);
@ -1140,7 +1141,7 @@ TEST_F(RNGTests, test_multinomial_6) {
int ClassValue = 5; int ClassValue = 5;
int Samples = 100000; int Samples = 100000;
NDArray samples('c', { 1 }, { 1. * Samples }, nd4j::DataType::INT32); NDArray samples('c', { 1 }, std::vector<double>{ 1. * Samples }, nd4j::DataType::INT32);
nd4j::ops::random_multinomial op; nd4j::ops::random_multinomial op;
NDArray probExpect('c', { ClassValue }, { 0.058, 0.096, 0.1576, 0.2598, 0.4287 }, nd4j::DataType::DOUBLE); NDArray probExpect('c', { ClassValue }, { 0.058, 0.096, 0.1576, 0.2598, 0.4287 }, nd4j::DataType::DOUBLE);
@ -1152,7 +1153,7 @@ TEST_F(RNGTests, test_multinomial_6) {
auto outputR = resultR->at(0); auto outputR = resultR->at(0);
ASSERT_EQ(Status::OK(), resultR->status()); ASSERT_EQ(Status::OK(), resultR->status());
NDArray countsR('c', { ClassValue }, { 0, 0, 0, 0, 0 }, nd4j::DataType::DOUBLE); NDArray countsR('c', { ClassValue }, { 0., 0, 0, 0, 0 }, nd4j::DataType::DOUBLE);
for (int i = 0; i < outputR->lengthOf(); i++) { for (int i = 0; i < outputR->lengthOf(); i++) {
auto value = outputR->e<Nd4jLong>(i); auto value = outputR->e<Nd4jLong>(i);
@ -1182,7 +1183,7 @@ TEST_F(RNGTests, test_multinomial_6) {
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, {}, false)); ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, {}, false));
NDArray counts('c', { ClassValue }, { 0, 0, 0, 0, 0 }, nd4j::DataType::DOUBLE); NDArray counts('c', { ClassValue }, { 0., 0, 0, 0, 0 }, nd4j::DataType::DOUBLE);
for (int i = 0; i < output.lengthOf(); i++) { for (int i = 0; i < output.lengthOf(); i++) {
auto value = output.e<Nd4jLong>(i); auto value = output.e<Nd4jLong>(i);

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019-2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -16,6 +17,7 @@
// //
// @author raver119@gmail.com // @author raver119@gmail.com
// @author Oleg Semeniv <oleg.semeniv@gmail.com>
// //
@ -30,7 +32,7 @@ class StringTests : public testing::Test {
public: public:
}; };
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_Test_1) { TEST_F(StringTests, Basic_Test_1) {
std::string f("alpha"); std::string f("alpha");
auto array = NDArrayFactory::string(f); auto array = NDArrayFactory::string(f);
@ -43,7 +45,7 @@ TEST_F(StringTests, Basic_Test_1) {
ASSERT_EQ(f, z); ASSERT_EQ(f, z);
} }
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_Test_2) { TEST_F(StringTests, Basic_Test_2) {
std::string f("alpha"); std::string f("alpha");
auto array = NDArrayFactory::string(f.c_str()); auto array = NDArrayFactory::string(f.c_str());
@ -56,23 +58,213 @@ TEST_F(StringTests, Basic_Test_2) {
ASSERT_EQ(f, z); ASSERT_EQ(f, z);
} }
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_Test_3) { TEST_F(StringTests, Basic_Test_3) {
auto array = NDArrayFactory::string('c', {3, 2}, {"alpha", "beta", "gamma", "phi", "theta", "omega"});
auto array = NDArrayFactory::string({3, 2}, {"alpha", "beta", "gamma", "phi", "theta", "omega"});
ASSERT_EQ(6, array.lengthOf()); ASSERT_EQ(6, array.lengthOf());
ASSERT_EQ(2, array.rankOf()); ASSERT_EQ(2, array.rankOf());
array.printIndexedBuffer("String array"); array.printIndexedBuffer("String array");
} }
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_Test_4) {
NDArray array( { 3, 2 }, std::vector<const char32_t*>{ U"alpha", U"beta", U"gamma€한", U"pÿqwe", U"ß水𝄋", U"omega" });
ASSERT_EQ(6, array.lengthOf());
ASSERT_EQ(2, array.rankOf());
array.printIndexedBuffer("String array");
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_Test_5) {
NDArray array( { 3, 2 }, std::vector<const char16_t*>{ u"alpha", u"beta", u"gamma€한", u"pÿqwe", u"ß水𝄋", u"omega" });
ASSERT_EQ(6, array.lengthOf());
ASSERT_EQ(2, array.rankOf());
array.printIndexedBuffer("String array");
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_Test_6) {
NDArray array( { 3, 2 }, std::vector<const char*>{ "alpha", "beta", "gamma€한", "pÿqwe", "ß水𝄋", "omega" });
ASSERT_EQ(6, array.lengthOf());
ASSERT_EQ(2, array.rankOf());
array.printIndexedBuffer("String array");
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_Test_7) {
NDArray array( { 3, 2 }, std::vector<std::u32string>{ U"alpha", U"beta", U"gamma€한", U"pÿqwe", U"ß水𝄋", U"omega" });
ASSERT_EQ(6, array.lengthOf());
ASSERT_EQ(2, array.rankOf());
array.printIndexedBuffer("String array");
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_Test_8) {
NDArray array( { 3, 2 }, std::vector<std::u16string>{ u"alpha", u"beta", u"gamma€한", u"pÿqwe", u"ß水𝄋", u"omega" });
ASSERT_EQ(6, array.lengthOf());
ASSERT_EQ(2, array.rankOf());
array.printIndexedBuffer("String array");
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_Test_9) {
NDArray array( { 3, 2 }, std::vector<std::string>{ "alpha", "beta", "gamma€한", "pÿqwe", "ß水𝄋", "omega" });
ASSERT_EQ(6, array.lengthOf());
ASSERT_EQ(2, array.rankOf());
array.printIndexedBuffer("String array");
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_Test_10) {
NDArray array(std::u32string(U"gamma€한"));
ASSERT_EQ(1, array.lengthOf());
ASSERT_EQ(0, array.rankOf());
array.printIndexedBuffer("String array");
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_Test_11) {
NDArray array(U"gamma€한");
ASSERT_EQ(1, array.lengthOf());
ASSERT_EQ(0, array.rankOf());
array.printIndexedBuffer("String array");
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_Test_12) {
NDArray array(std::u16string(u"gamma€한"));
ASSERT_EQ(1, array.lengthOf());
ASSERT_EQ(0, array.rankOf());
array.printIndexedBuffer("String array");
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_Test_13) {
NDArray array(u"gamma€한");
ASSERT_EQ(1, array.lengthOf());
ASSERT_EQ(0, array.rankOf());
array.printIndexedBuffer("String array");
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_Test_14) {
NDArray array(std::string("gamma€한"));
ASSERT_EQ(1, array.lengthOf());
ASSERT_EQ(0, array.rankOf());
array.printIndexedBuffer("String array");
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_Test_15) {
NDArray array("gamma€한");
ASSERT_EQ(1, array.lengthOf());
ASSERT_EQ(0, array.rankOf());
array.printIndexedBuffer("String array");
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_Test_16) {
auto array = NDArrayFactory::string( { 3, 2 }, std::vector<std::string>{ "alpha", "beta", "gamma", "phi", "theta", "omega" });
ASSERT_EQ(6, array.lengthOf());
ASSERT_EQ(2, array.rankOf());
array.printIndexedBuffer("String array");
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_Test_17) {
auto array = NDArrayFactory::string({ 3, 2 }, std::vector<const char*>{ "alpha", "beta", "gamma", "phi", "theta", "omega" });
ASSERT_EQ(6, array.lengthOf());
ASSERT_EQ(2, array.rankOf());
array.printIndexedBuffer("String array");
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_Test_18) {
auto array = NDArrayFactory::string({ 3, 2 }, std::vector<std::u16string>{ u"alpha", u"beta", u"gamma", u"phi", u"theta", u"omega" });
ASSERT_EQ(6, array.lengthOf());
ASSERT_EQ(2, array.rankOf());
array.printIndexedBuffer("String array");
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_Test_19) {
auto array = NDArrayFactory::string( { 3, 2 }, std::vector<const char16_t*>{ u"alpha", u"beta", u"gamma", u"phi", u"theta", u"omega" });
ASSERT_EQ(6, array.lengthOf());
ASSERT_EQ(2, array.rankOf());
array.printIndexedBuffer("String array");
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_Test_20) {
auto array = NDArrayFactory::string( { 3, 2 }, std::vector<std::u32string>{ U"alpha", U"beta", U"gamma", U"phi", U"theta", U"omega" });
ASSERT_EQ(6, array.lengthOf());
ASSERT_EQ(2, array.rankOf());
array.printIndexedBuffer("String array");
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_Test_21) {
auto array = NDArrayFactory::string( { 3, 2 }, std::vector<const char32_t*>{ U"alpha", U"òèçùà12345¤z", U"ß水𝄋ÿ€한𐍈®кею90ощъ]ї", U"phi", U"theta", U"omega" });
ASSERT_EQ(6, array.lengthOf());
ASSERT_EQ(2, array.rankOf());
array.printIndexedBuffer("String array");
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_Test_22) {
std::u16string f(u"ß水𝄋ÿ€한𐍈®кею90ощъ]ї");
auto array = NDArrayFactory::string(f.c_str());
ASSERT_EQ(nd4j::DataType::UTF16, array.dataType());
ASSERT_EQ(1, array.lengthOf());
ASSERT_EQ(0, array.rankOf());
auto z = array.e<std::u16string>(0);
ASSERT_EQ(f, z);
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_Test_23) {
std::u32string f(U"ß水𝄋ÿ€한𐍈®кею90ощъ]ї");
auto array = NDArrayFactory::string(f.c_str());
ASSERT_EQ(nd4j::DataType::UTF32, array.dataType());
ASSERT_EQ(1, array.lengthOf());
ASSERT_EQ(0, array.rankOf());
auto z = array.e<std::u32string>(0);
ASSERT_EQ(f, z);
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Export_Test_1) { TEST_F(StringTests, Export_Test_1) {
auto array = NDArrayFactory::string('c', {3}, {"alpha", "beta", "gamma"}); auto array = NDArrayFactory::string( {3}, {"alpha", "beta", "gamma"});
auto vector = array.asByteVector(); auto vector = array.asByteVector();
} }
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_dup_1) { TEST_F(StringTests, Basic_dup_1) {
std::string f("alpha"); std::string f("alpha");
auto array = NDArrayFactory::string(f); auto array = NDArrayFactory::string(f);
@ -91,20 +283,20 @@ TEST_F(StringTests, Basic_dup_1) {
delete dup; delete dup;
} }
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, byte_length_test_1) { TEST_F(StringTests, byte_length_test_1) {
std::string f("alpha"); std::string f("alpha");
auto array = NDArrayFactory::string(f); auto array = NDArrayFactory::string(f);
ASSERT_EQ(f.length(), StringUtils::byteLength(array)); ASSERT_EQ(f.length(), StringUtils::byteLength(array));
} }
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, byte_length_test_2) { TEST_F(StringTests, byte_length_test_2) {
auto array = NDArrayFactory::string('c', {2}, {"alpha", "beta"}); auto array = NDArrayFactory::string( {2}, {"alpha", "beta"});
ASSERT_EQ(9, StringUtils::byteLength(array)); ASSERT_EQ(9, StringUtils::byteLength(array));
} }
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, test_split_1) { TEST_F(StringTests, test_split_1) {
auto split = StringUtils::split("alpha beta gamma", " "); auto split = StringUtils::split("alpha beta gamma", " ");
@ -113,3 +305,561 @@ TEST_F(StringTests, test_split_1) {
ASSERT_EQ(std::string("beta"), split[1]); ASSERT_EQ(std::string("beta"), split[1]);
ASSERT_EQ(std::string("gamma"), split[2]); ASSERT_EQ(std::string("gamma"), split[2]);
} }
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, test_unicode_utf8_utf16) {
std::string utf8 = u8"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~";
std::u16string utf16Exp = u"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~";
std::u16string utf16Res;
ASSERT_TRUE(StringUtils::u8StringToU16String(utf8, utf16Res));
ASSERT_EQ(utf16Res.size(), utf16Exp.size());
for (auto i = 0; i < utf16Exp.size(); i++) {
ASSERT_EQ(utf16Exp[i], utf16Res[i]);
}
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, test_unicode_utf8_utf32) {
std::string utf8 = u8"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~";
std::u32string utf32Exp = U"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~";
std::u32string utf32Res;
ASSERT_TRUE(StringUtils::u8StringToU32String(utf8, utf32Res));
ASSERT_EQ(utf32Res.size(), utf32Exp.size());
for (auto i = 0; i < utf32Exp.size(); i++) {
ASSERT_EQ(utf32Exp[i], utf32Res[i]);
}
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, test_unicode_utf16_utf8) {
std::string utf8Exp = u8"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~";
std::u16string utf16 = u"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~";
std::string utf8Res;
ASSERT_TRUE(StringUtils::u16StringToU8String(utf16, utf8Res));
ASSERT_EQ(utf8Res.size(), utf8Exp.size());
for (auto i = 0; i < utf8Exp.size(); i++) {
ASSERT_EQ(utf8Exp[i], utf8Res[i]);
}
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, test_unicode_utf32_utf8) {
std::string utf8Exp = u8"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею 90ощъ]їїщkk1q\n\t\rop~";
std::u32string utf32 = U"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею 90ощъ]їїщkk1q\n\t\rop~";
std::string utf8Res;
ASSERT_TRUE(StringUtils::u32StringToU8String(utf32, utf8Res));
ASSERT_EQ(utf8Res.size(), utf8Exp.size());
for (auto i = 0; i < utf8Exp.size(); i++) {
ASSERT_EQ(utf8Exp[i], utf8Res[i]);
}
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, test_unicode_utf16_utf32) {
std::u32string utf32Exp = U"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~";
std::u16string utf16 = u"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~";
std::u32string utf32Res;
ASSERT_TRUE(StringUtils::u16StringToU32String(utf16, utf32Res));
ASSERT_EQ(utf32Res.size(), utf32Exp.size());
for (auto i = 0; i < utf32Exp.size(); i++) {
ASSERT_EQ(utf32Exp[i], utf32Res[i]);
}
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, test_unicode_utf32_utf16) {
std::u16string utf16Exp = u"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~";
std::u32string utf32 = U"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~";
std::u16string utf16Res;
ASSERT_TRUE(StringUtils::u32StringToU16String(utf32, utf16Res));
ASSERT_EQ(utf16Res.size(), utf16Exp.size());
for (auto i = 0; i < utf16Exp.size(); i++) {
ASSERT_EQ(utf16Exp[i], utf16Res[i]);
}
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, byte_length_test_Default) {
std::string f("alpha");
auto array = NDArrayFactory::string(f);
ASSERT_EQ(f.length(), StringUtils::byteLength(array));
std::u16string f16(u"alpha");
auto array16 = NDArrayFactory::string(f16);
ASSERT_EQ(sizeof(char16_t)*f16.length(), StringUtils::byteLength(array16));
std::u32string f32(U"alpha");
auto array32 = NDArrayFactory::string(f32);
ASSERT_EQ(sizeof(char32_t) * f32.length(), StringUtils::byteLength(array32));
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, byte_length_test_UTF16) {
std::string f(u8"alpha");
auto array = NDArrayFactory::string(f, nd4j::DataType::UTF16);
ASSERT_EQ(sizeof(char16_t) * f.length(), StringUtils::byteLength(array));
std::u16string f16(u"alpha");
auto array16 = NDArrayFactory::string(f16, nd4j::DataType::UTF16);
ASSERT_EQ(sizeof(char16_t) * f16.length(), StringUtils::byteLength(array16));
std::u32string f32(U"alpha");
auto array32 = NDArrayFactory::string(f32, nd4j::DataType::UTF16);
ASSERT_EQ(sizeof(char16_t) * f32.length(), StringUtils::byteLength(array32));
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_Test_UTF16toU8) {
std::u16string f16(u"alpha水𝄋ÿ€한𐍈®кею");
auto array = NDArrayFactory::string(f16, nd4j::DataType::UTF8);
ASSERT_EQ(nd4j::DataType::UTF8, array.dataType());
ASSERT_EQ(1, array.lengthOf());
ASSERT_EQ(0, array.rankOf());
auto z = array.e<std::string>(0);
std::string f(u8"alpha水𝄋ÿ€한𐍈®кею");
ASSERT_EQ(f, z);
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_Test_UTF32toU8) {
std::u32string f32(U"alpha水𝄋ÿ€한𐍈®кею");
auto array = NDArrayFactory::string(f32.c_str(), nd4j::DataType::UTF8);
ASSERT_EQ(nd4j::DataType::UTF8, array.dataType());
ASSERT_EQ(1, array.lengthOf());
ASSERT_EQ(0, array.rankOf());
auto z = array.e<std::string>(0);
std::string f(u8"alpha水𝄋ÿ€한𐍈®кею");
ASSERT_EQ(f, z);
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_Test_UTF16toU16) {
std::u16string f16(u"€alpha水𝄋ÿ€한𐍈®кею");
auto array = NDArrayFactory::string(f16, nd4j::DataType::UTF16);
ASSERT_EQ(nd4j::DataType::UTF16, array.dataType());
ASSERT_EQ(1, array.lengthOf());
ASSERT_EQ(0, array.rankOf());
auto z = array.e<std::u16string>(0);
ASSERT_EQ(z, f16);
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_Test_UTF32toU16) {
std::u32string f32(U"€alpha水𝄋ÿ€한𐍈®кею");
auto array = NDArrayFactory::string(f32, nd4j::DataType::UTF16);
ASSERT_EQ(nd4j::DataType::UTF16, array.dataType());
ASSERT_EQ(1, array.lengthOf());
ASSERT_EQ(0, array.rankOf());
auto z = array.e<std::u16string>(0);
std::u16string f16(u"€alpha水𝄋ÿ€한𐍈®кею");
ASSERT_EQ(z, f16);
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_Test_UTF16toU32) {
std::u16string f16(u"€alpha水𝄋ÿ€한𐍈®кею");
auto array = NDArrayFactory::string(f16, nd4j::DataType::UTF32);
ASSERT_EQ(nd4j::DataType::UTF32, array.dataType());
ASSERT_EQ(1, array.lengthOf());
ASSERT_EQ(0, array.rankOf());
auto z = array.e<std::u32string>(0);
std::u32string fres(U"€alpha水𝄋ÿ€한𐍈®кею");
ASSERT_EQ(z, fres);
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_Test_UTF32toU32) {
std::u32string f32(U"€alpha水𝄋ÿ€한𐍈®кею");
auto array = NDArrayFactory::string(f32);
ASSERT_EQ(nd4j::DataType::UTF32, array.dataType());
ASSERT_EQ(1, array.lengthOf());
ASSERT_EQ(0, array.rankOf());
auto z = array.e<std::u32string>(0);
ASSERT_EQ(f32, z);
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_Test_UTF8toU32) {
std::string f(u8"€alpha水𝄋ÿ€한𐍈®кею");
auto array = NDArrayFactory::string(f, nd4j::DataType::UTF32);
ASSERT_EQ(nd4j::DataType::UTF32, array.dataType());
ASSERT_EQ(1, array.lengthOf());
ASSERT_EQ(0, array.rankOf());
std::u32string f32(U"€alpha水𝄋ÿ€한𐍈®кею");
auto z = array.e<std::u32string>(0);
ASSERT_EQ(f32, z);
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_Test_StringVecU8toUTF16) {
auto array = NDArrayFactory::string({ 3, 2 }, { "alpha€", "beta", "gamma水", "phi", "theta", "omega水" }, nd4j::DataType::UTF16);
ASSERT_EQ(6, array.lengthOf());
ASSERT_EQ(2, array.rankOf());
array.printIndexedBuffer("String array");
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_Test_StringVecU8toUTF32) {
auto array = NDArrayFactory::string( { 3, 2 }, { "alpha€", "beta水", "gamma", "phi", "theta", "omega" }, nd4j::DataType::UTF32);
ASSERT_EQ(6, array.lengthOf());
ASSERT_EQ(2, array.rankOf());
array.printIndexedBuffer("String array");
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Export_Test_U8toUTF16) {
auto array = NDArrayFactory::string({ 3 }, { "alpha", "beta", "gamma" }, nd4j::DataType::UTF16);
auto vector = array.asByteVector();
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Export_Test_U8toUTF32) {
auto array = NDArrayFactory::string({ 3 }, { "alpha", "beta", "gamma" }, nd4j::DataType::UTF32);
auto vector = array.asByteVector();
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_Test_StringVecU16toUTF16) {
auto array = NDArrayFactory::string({ 3, 2 }, { u"alpha水", u"beta", u"gamma", u"phi", u"theta水", u"omega" }, nd4j::DataType::UTF16);
ASSERT_EQ(6, array.lengthOf());
ASSERT_EQ(2, array.rankOf());
array.printIndexedBuffer("String array");
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_Test_StringVecU16toUTF32) {
auto array = NDArrayFactory::string( { 3, 2 }, { u"alpha水", u"beta", u"gamma水", u"phi", u"theta", u"omega" }, nd4j::DataType::UTF32);
ASSERT_EQ(6, array.lengthOf());
ASSERT_EQ(2, array.rankOf());
array.printIndexedBuffer("String array");
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_Test_StringVecU16toUTF8) {
auto array = NDArrayFactory::string( { 3, 2 }, { u"alpha€", u"beta水", u"gamma", u"phi水", u"theta", u"omega" }, nd4j::DataType::UTF8);
ASSERT_EQ(6, array.lengthOf());
ASSERT_EQ(2, array.rankOf());
array.printIndexedBuffer("String array");
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Export_Test_U16toUTF8) {
auto array = NDArrayFactory::string( { 3 }, { u"alpha", u"beta", u"gamma" }, nd4j::DataType::UTF8);
auto vector = array.asByteVector();
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Export_Test_U16toUTF16) {
auto array = NDArrayFactory::string( { 3 }, { u"alpha", u"beta", u"gamma" }, nd4j::DataType::UTF16);
auto vector = array.asByteVector();
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Export_Test_U16toUTF32) {
auto array = NDArrayFactory::string( { 3 }, { u"alpha水", u"beta", u"gamma水" }, nd4j::DataType::UTF32);
auto vector = array.asByteVector();
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_Test_StringVecU32toUTF32) {
auto array = NDArrayFactory::string( { 3, 2 }, { U"alpha€", U"beta水", U"gamma", U"phi", U"theta", U"omega水" }, nd4j::DataType::UTF32);
ASSERT_EQ(6, array.lengthOf());
ASSERT_EQ(2, array.rankOf());
array.printIndexedBuffer("String array");
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_Test_StringVecU32toUTF16) {
auto array = NDArrayFactory::string({ 3, 2 }, { U"alpha水", U"水beta", U"gamma", U"phi水", U"theta", U"omega" }, nd4j::DataType::UTF16);
ASSERT_EQ(6, array.lengthOf());
ASSERT_EQ(2, array.rankOf());
array.printIndexedBuffer("String array");
printf("Array elements size: \n");
for (int e = 0; e < array.lengthOf(); e++) {
printf("Element %d size: %d\n", e, static_cast<int>(array.e<std::u16string>(e).size()));
}
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_Test_StringVecU32toUTF8) {
auto array = NDArrayFactory::string( { 3, 2 }, { U"alpha水", U"beta", U"gamma水", U"phi", U"theta", U"omega" }, nd4j::DataType::UTF8);
ASSERT_EQ(6, array.lengthOf());
ASSERT_EQ(2, array.rankOf());
array.printIndexedBuffer("String array");
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Export_Test_U32toUTF32) {
auto array = NDArrayFactory::string( { 3 }, { U"alpha", U"beta", U"gamma" }, nd4j::DataType::UTF32);
auto vector = array.asByteVector();
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Export_Test_U32toUTF16) {
auto array = NDArrayFactory::string( { 3 }, { U"alpha", U"beta水", U"gamma水" }, nd4j::DataType::UTF16);
auto vector = array.asByteVector();
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Export_Test_U32toUTF8) {
auto array = NDArrayFactory::string( { 3 }, { U"alpha", U"beta", U"gamma水" }, nd4j::DataType::UTF8);
auto vector = array.asByteVector();
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_dup_UTF16) {
std::u16string f(u"€alpha水𝄋ÿ€한𐍈®кею");
auto array = NDArrayFactory::string(f);
ASSERT_EQ(nd4j::DataType::UTF16, array.dataType());
ASSERT_EQ(1, array.lengthOf());
ASSERT_EQ(0, array.rankOf());
auto dup = new NDArray(array.dup());
auto z0 = array.e<std::u16string>(0);
auto z1 = dup->e<std::u16string>(0);
ASSERT_EQ(f, z0);
ASSERT_EQ(f, z1);
delete dup;
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_dup_UTF32) {
std::u32string f(U"€alpha水𝄋ÿ€한𐍈®кею");
auto array = NDArrayFactory::string(f);
ASSERT_EQ(nd4j::DataType::UTF32, array.dataType());
ASSERT_EQ(1, array.lengthOf());
ASSERT_EQ(0, array.rankOf());
auto dup = new NDArray(array.dup());
auto z0 = array.e<std::u32string>(0);
auto z1 = dup->e<std::u32string>(0);
ASSERT_EQ(f, z0);
ASSERT_EQ(f, z1);
delete dup;
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_cast_UTF32toUTF8) {
std::u32string u32(U"€alpha水𝄋ÿ€한𐍈®кею");
std::string u8(u8"€alpha水𝄋ÿ€한𐍈®кею");
auto array = NDArrayFactory::string(u32);
ASSERT_EQ(nd4j::DataType::UTF32, array.dataType());
ASSERT_EQ(1, array.lengthOf());
ASSERT_EQ(0, array.rankOf());
auto aCast = array.cast(nd4j::DataType::UTF8);
auto z0 = array.e<std::u32string>(0);
auto z1 = aCast.e<std::string>(0);
ASSERT_EQ(u32, z0);
ASSERT_EQ(u8, z1);
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_cast_UTF32toUTF16) {
std::u32string u32(U"€alpha水𝄋ÿ€한𐍈®кею");
std::u16string u16(u"€alpha水𝄋ÿ€한𐍈®кею");
auto array = NDArrayFactory::string(u32);
ASSERT_EQ(nd4j::DataType::UTF32, array.dataType());
ASSERT_EQ(1, array.lengthOf());
ASSERT_EQ(0, array.rankOf());
auto aCast = array.cast(nd4j::DataType::UTF16);
auto z0 = array.e<std::u32string>(0);
auto z1 = aCast.e<std::u16string>(0);
ASSERT_EQ(u32, z0);
ASSERT_EQ(u16, z1);
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_cast_UTF32toUTF32) {
std::u32string u32(U"€alpha水𝄋ÿ€한𐍈®кею");
auto array = NDArrayFactory::string(u32);
ASSERT_EQ(nd4j::DataType::UTF32, array.dataType());
ASSERT_EQ(1, array.lengthOf());
ASSERT_EQ(0, array.rankOf());
auto aCast = array.cast(nd4j::DataType::UTF32);
auto z0 = array.e<std::u32string>(0);
auto z1 = aCast.e<std::u32string>(0);
ASSERT_EQ(u32, z0);
ASSERT_EQ(u32, z1);
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_cast_UTF16toUTF16) {
std::u16string u16(u"€alpha水𝄋ÿ€한𐍈®кею");
auto array = NDArrayFactory::string(u16);
ASSERT_EQ(nd4j::DataType::UTF16, array.dataType());
ASSERT_EQ(1, array.lengthOf());
ASSERT_EQ(0, array.rankOf());
auto aCast = array.cast(nd4j::DataType::UTF16);
auto z0 = array.e<std::u16string>(0);
auto z1 = aCast.e<std::u16string>(0);
ASSERT_EQ(u16, z0);
ASSERT_EQ(u16, z1);
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_cast_UTF16toUTF32) {
std::u32string u32(U"€alpha水𝄋ÿ€한𐍈®кею");
std::u16string u16(u"€alpha水𝄋ÿ€한𐍈®кею");
auto array = NDArrayFactory::string(u16);
ASSERT_EQ(nd4j::DataType::UTF16, array.dataType());
ASSERT_EQ(1, array.lengthOf());
ASSERT_EQ(0, array.rankOf());
auto aCast = array.cast(nd4j::DataType::UTF32);
auto z0 = array.e<std::u16string>(0);
auto z1 = aCast.e<std::u32string>(0);
ASSERT_EQ(u32, z1);
ASSERT_EQ(u16, z0);
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_cast_UTF16toUTF8) {
std::string u8(u8"€alpha水𝄋ÿ€한𐍈®кею");
std::u16string u16(u"€alpha水𝄋ÿ€한𐍈®кею");
auto array = NDArrayFactory::string(u16);
ASSERT_EQ(nd4j::DataType::UTF16, array.dataType());
ASSERT_EQ(1, array.lengthOf());
ASSERT_EQ(0, array.rankOf());
auto aCast = array.cast(nd4j::DataType::UTF8);
auto z0 = array.e<std::u16string>(0);
auto z1 = aCast.e<std::string>(0);
ASSERT_EQ(u8, z1);
ASSERT_EQ(u16, z0);
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_cast_UTF8toUTF8) {
std::string u8("€alpha水𝄋ÿ€한𐍈®кею");
auto array = NDArrayFactory::string(u8);
ASSERT_EQ(nd4j::DataType::UTF8, array.dataType());
ASSERT_EQ(1, array.lengthOf());
ASSERT_EQ(0, array.rankOf());
auto aCast = array.cast(nd4j::DataType::UTF8);
auto z0 = array.e<std::string>(0);
auto z1 = aCast.e<std::string>(0);
ASSERT_EQ(u8, z1);
ASSERT_EQ(u8, z0);
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_cast_UTF8toUTF16) {
std::string u8(u8"€alpha水𝄋ÿ€한𐍈®кею");
std::u16string u16(u"€alpha水𝄋ÿ€한𐍈®кею");
auto array = NDArrayFactory::string(u8);
ASSERT_EQ(nd4j::DataType::UTF8, array.dataType());
ASSERT_EQ(1, array.lengthOf());
ASSERT_EQ(0, array.rankOf());
auto aCast = array.cast(nd4j::DataType::UTF16);
auto z0 = array.e<std::string>(0);
auto z1 = aCast.e<std::u16string>(0);
ASSERT_EQ(u8, z0);
ASSERT_EQ(u16, z1);
}
/////////////////////////////////////////////////////////////////////////
TEST_F(StringTests, Basic_cast_UTF8toUTF32) {
std::string u8(u8"€alpha水𝄋ÿ€한𐍈®кею");
std::u32string u32(U"€alpha水𝄋ÿ€한𐍈®кею");
auto array = NDArrayFactory::string(u8);
ASSERT_EQ(nd4j::DataType::UTF8, array.dataType());
ASSERT_EQ(1, array.lengthOf());
ASSERT_EQ(0, array.rankOf());
auto aCast = array.cast(nd4j::DataType::UTF32);
auto z0 = array.e<std::string>(0);
auto z1 = aCast.e<std::u32string>(0);
ASSERT_EQ(u8, z0);
ASSERT_EQ(u32, z1);
}