Oleh convert (#200)
* StringUtils for utf convertor raw implementation of all possible combinations, need to be add counter of bytes per symbol for any type and add api to call convertors and store data Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor more corrections to support convertors Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor some corrections and bug fixes, need review to discuss how to add multi-threading Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections to move to multi-threading, add one test need discussion data inputs/outputs array presentation, need discussion the way of multi-threading * StringUtils for utf convertor #8613 tests added some corrections to optimize build Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some corrections and code clean up Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 code clean up and optimize usage, need update ndarray factory before replace std usage Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some staff to integrate converters into NDArrayFactory, update tests and add some functionality Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor corrections and bug fix before discussion * StringUtils for utf convertor #8613 some fixes and tets * StringUtils for utf convertor #8613 some more staff to support different unicode Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fix linking bug * StringUtils for utf convertor #8613 corrected several tests as defaults for string ndarray changed * StringUtils for utf convertor #8613 replace some incorrect implementation, revert some test changes, need sync before testing * StringUtils for utf convertor #8613 fixed several thing that were badly implemented yesterday, need optimization, testing (before testing have to be add support of u32 and u16 buffer visualization) * StringUtils for utf convertor #8613 fixed to support u16 and u32, and convertor in ndarray, fix buffer print, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master and sync with server Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 some correction for string cast, need print check only asci support Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 merge master, remove copies and add cast, need test, refactoring according review and clean up * StringUtils for utf convertor #8613 fixed cast and copy issues Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda and update tests * StringUtils for utf convertor #8613 integration into NdArray, fix several tests for build pass, refactoring, etc * - avoid ambiguity of NDArray ctrs overloading in some tests Signed-off-by: Yurii <iuriish@yahoo.com> * StringUtils for utf convertor #8613 NDArray string constructors added, updated NDArrayFactory, refactoring unicode and tests, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed cuda build and test, refactoring and void* added to some functions Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 void* integration, removed copy operation, refactoring, added tests for NDArray string constructors, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 several more fixes, improvements and updates Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 master merge, code clean up and optimization before review Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 minor fixes string element size define Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 revert last changes as mistake Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 fixed NDArray constructor build problem, remove order from string factory, fixed order use for factory via project, added catch of incorrect sync in cast of arrays to data types, fixed e method for strings, etc Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 added javacpp hack, added multi-threading, minor corrections in license agreement Signed-off-by: Oleg <oleg.semeniv@gmail.com> * StringUtils for utf convertor #8613 windows builds fix, as "sting" is not treated as utf8 Signed-off-by: Oleg <oleg.semeniv@gmail.com> Co-authored-by: Yurii Shyrma <iuriish@yahoo.com>master
parent
00cd61f32d
commit
d52e67209e
|
@ -195,6 +195,56 @@ namespace nd4j {
|
|||
|
||||
NDArray(std::shared_ptr<DataBuffer> buffer, const char order, const std::vector<Nd4jLong> &shape, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||
|
||||
/**
|
||||
* This contructors create scalar array containing string utf8
|
||||
*
|
||||
*/
|
||||
NDArray(const char* str, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext())
|
||||
: NDArray(std::string(str), dtype, context) {
|
||||
}
|
||||
NDArray(const std::string& string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||
|
||||
/**
|
||||
* This contructors create scalar array containing string utf16
|
||||
*
|
||||
*/
|
||||
NDArray(const char16_t* u16string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext())
|
||||
: NDArray(std::u16string(u16string), dtype, context) {
|
||||
}
|
||||
|
||||
NDArray(const std::u16string& u16string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||
|
||||
/**
|
||||
* This contructors create scalar array containing string utf32
|
||||
*
|
||||
*/
|
||||
NDArray(const char32_t* u32string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext())
|
||||
: NDArray(std::u32string(u32string), dtype, context) {
|
||||
}
|
||||
|
||||
NDArray(const std::u32string& u32string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||
|
||||
/**
|
||||
* This contructors create array from vector of utf8 strings
|
||||
*
|
||||
*/
|
||||
NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const char*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||
NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::string>& string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||
|
||||
/**
|
||||
* This contructors create array from vector of utf16 strings
|
||||
*
|
||||
*/
|
||||
NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const char16_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||
NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::u16string>& string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||
|
||||
/**
|
||||
* This contructors create array from vector of utf32 strings
|
||||
*
|
||||
*/
|
||||
NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const char32_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||
NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::u32string>& string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||
|
||||
#endif
|
||||
|
||||
/**
|
||||
|
@ -250,7 +300,6 @@ namespace nd4j {
|
|||
*/
|
||||
NDArray(void *buffer, const char order, const std::vector<Nd4jLong> &shape, nd4j::DataType dtype, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext(), const bool isBuffAlloc = false);
|
||||
|
||||
|
||||
/**
|
||||
* This method returns new array with the same shape & data type
|
||||
* @return
|
||||
|
@ -1148,6 +1197,9 @@ namespace nd4j {
|
|||
template <typename N>
|
||||
NDArray asT() const;
|
||||
|
||||
template <typename S>
|
||||
NDArray asS() const;
|
||||
|
||||
NDArray asT(DataType dtype) const;
|
||||
|
||||
|
||||
|
@ -1441,7 +1493,7 @@ namespace nd4j {
|
|||
* @return
|
||||
*/
|
||||
bool isS() const;
|
||||
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> asVectorT();
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -16,6 +17,7 @@
|
|||
|
||||
//
|
||||
// Created by raver119 on 2018-09-16.
|
||||
// @author Oleg Semeniv <oleg.semeniv@gmail.com>
|
||||
//
|
||||
|
||||
#ifndef DEV_TESTS_NDARRAYFACTORY_H
|
||||
|
@ -106,25 +108,72 @@ namespace nd4j {
|
|||
template <typename T>
|
||||
static NDArray create(char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<T>& data, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||
|
||||
static NDArray string(const char *string, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||
/**
|
||||
* This factory create array from utf8 string
|
||||
* @return NDArray default dataType UTF8
|
||||
*/
|
||||
static NDArray string(const char *string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||
static NDArray* string_(const char *string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||
static NDArray* string_(const std::string &string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||
static NDArray string(const std::string& string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||
|
||||
static NDArray* string_(const char *string, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||
/**
|
||||
* This factory create array from utf16 string
|
||||
* @return NDArray default dataType UTF16
|
||||
*/
|
||||
static NDArray string(const char16_t* u16string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||
static NDArray* string_(const char16_t* u16string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||
static NDArray* string_(const std::u16string& u16string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||
static NDArray string(const std::u16string& u16string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||
|
||||
/**
|
||||
* This factory create array from utf32 string
|
||||
* @return NDArray default dataType UTF32
|
||||
*/
|
||||
static NDArray string(const char32_t* u32string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||
static NDArray* string_(const char32_t* u32string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||
static NDArray* string_(const std::u32string& u32string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||
static NDArray string(const std::u32string& u32string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||
|
||||
static NDArray string(const std::string &string, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||
/**
|
||||
* This factory create array from vector of utf8 strings
|
||||
* @return NDArray default dataType UTF8
|
||||
*/
|
||||
static NDArray string( const std::vector<Nd4jLong> &shape, const std::initializer_list<const char *> &strings, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||
static NDArray string( const std::vector<Nd4jLong> &shape, const std::initializer_list<std::string> &string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||
static NDArray string( const std::vector<Nd4jLong> &shape, const std::vector<const char *> &strings, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||
static NDArray string( const std::vector<Nd4jLong> &shape, const std::vector<std::string> &string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||
static NDArray* string_( const std::vector<Nd4jLong> &shape, const std::initializer_list<const char *> &strings, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||
static NDArray* string_( const std::vector<Nd4jLong> &shape, const std::initializer_list<std::string> &string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||
static NDArray* string_( const std::vector<Nd4jLong> &shape, const std::vector<const char *> &strings, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||
static NDArray* string_( const std::vector<Nd4jLong> &shape, const std::vector<std::string> &string, nd4j::DataType dtype = nd4j::DataType::UTF8, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||
|
||||
static NDArray* string_(const std::string &string, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||
/**
|
||||
* This factory create array from vector of utf16 strings
|
||||
* @return NDArray default dataType UTF16
|
||||
*/
|
||||
static NDArray string( const std::vector<Nd4jLong>& shape, const std::initializer_list<const char16_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||
static NDArray string( const std::vector<Nd4jLong>& shape, const std::initializer_list<std::u16string>& string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||
static NDArray string( const std::vector<Nd4jLong>& shape, const std::vector<const char16_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||
static NDArray string( const std::vector<Nd4jLong>& shape, const std::vector<std::u16string>& string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::initializer_list<const char16_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::initializer_list<std::u16string>& string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::vector<const char16_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::vector<std::u16string>& string, nd4j::DataType dtype = nd4j::DataType::UTF16, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||
|
||||
static NDArray string(char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<const char *> &strings, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||
static NDArray string(char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<std::string> &string, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||
/**
|
||||
* This factory create array from vector of utf32 strings
|
||||
* @return NDArray default dataType UTF32
|
||||
*/
|
||||
static NDArray string( const std::vector<Nd4jLong>& shape, const std::initializer_list<const char32_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||
static NDArray string( const std::vector<Nd4jLong>& shape, const std::initializer_list<std::u32string>& string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||
static NDArray string( const std::vector<Nd4jLong>& shape, const std::vector<const char32_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||
static NDArray string( const std::vector<Nd4jLong>& shape, const std::vector<std::u32string>& string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::initializer_list<const char32_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::initializer_list<std::u32string>& string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::vector<const char32_t*>& strings, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||
static NDArray* string_( const std::vector<Nd4jLong>& shape, const std::vector<std::u32string>& string, nd4j::DataType dtype = nd4j::DataType::UTF32, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||
|
||||
static NDArray string(char order, const std::vector<Nd4jLong> &shape, const std::vector<const char *> &strings, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||
static NDArray string(char order, const std::vector<Nd4jLong> &shape, const std::vector<std::string> &string, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||
|
||||
static NDArray* string_(char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<const char *> &strings, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||
static NDArray* string_(char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<std::string> &string, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||
|
||||
static NDArray* string_(char order, const std::vector<Nd4jLong> &shape, const std::vector<const char *> &strings, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||
static NDArray* string_(char order, const std::vector<Nd4jLong> &shape, const std::vector<std::string> &string, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||
|
||||
static ResultSet createSetOfArrs(const Nd4jLong numOfArrs, const void* buffer, const Nd4jLong* shapeInfo, const Nd4jLong* offsets, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -16,6 +17,7 @@
|
|||
|
||||
//
|
||||
// Created by GS <sgazeos@gmail.com> on 2018-12-20.
|
||||
// @author Oleg Semeniv <oleg.semeniv@gmail.com>
|
||||
//
|
||||
|
||||
#include <NDArrayFactory.h>
|
||||
|
@ -25,6 +27,9 @@
|
|||
#include <ShapeUtils.h>
|
||||
#include <type_traits>
|
||||
|
||||
|
||||
#include <StringUtils.h>
|
||||
|
||||
namespace nd4j {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
@ -85,45 +90,6 @@ namespace nd4j {
|
|||
template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::vector<uint8_t>& data, nd4j::LaunchContext * context);
|
||||
template ND4J_EXPORT NDArray NDArrayFactory::create(const char order, const std::vector<Nd4jLong> &shape, const std::vector<bool>& data, nd4j::LaunchContext * context);
|
||||
|
||||
NDArray NDArrayFactory::string(const char *str, nd4j::LaunchContext * context) {
|
||||
std::string s(str);
|
||||
return string(s, context);
|
||||
}
|
||||
|
||||
NDArray* NDArrayFactory::string_(const char *str, nd4j::LaunchContext * context) {
|
||||
return string_(std::string(str), context);
|
||||
}
|
||||
|
||||
NDArray NDArrayFactory::string(const std::string &str, nd4j::LaunchContext * context) {
|
||||
|
||||
auto headerLength = ShapeUtils::stringBufferHeaderRequirements(1);
|
||||
|
||||
std::shared_ptr<DataBuffer> pBuffer = std::make_shared<DataBuffer>(headerLength + str.length(), DataType::UTF8, context->getWorkspace(), true);
|
||||
|
||||
NDArray res(pBuffer, ShapeDescriptor::scalarDescriptor(DataType::UTF8), context);
|
||||
|
||||
int8_t* buffer = reinterpret_cast<int8_t*>(res.getBuffer());
|
||||
|
||||
auto offsets = reinterpret_cast<Nd4jLong *>(buffer);
|
||||
offsets[0] = 0;
|
||||
offsets[1] = str.length();
|
||||
|
||||
auto data = buffer + headerLength;
|
||||
|
||||
memcpy(data, str.c_str(), str.length());
|
||||
|
||||
res.tickWriteHost();
|
||||
res.syncToDevice();
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
NDArray* NDArrayFactory::string_(const std::string &str, nd4j::LaunchContext * context) {
|
||||
auto res = new NDArray();
|
||||
*res = NDArrayFactory::string(str, context);
|
||||
return res;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
NDArray* NDArrayFactory::create_(const char order, const std::vector<Nd4jLong> &shape, nd4j::LaunchContext * context) {
|
||||
|
@ -551,91 +517,175 @@ template ND4J_EXPORT NDArray NDArrayFactory::create(uint8_t * buffer, const char
|
|||
template ND4J_EXPORT NDArray NDArrayFactory::create(int8_t* buffer, const char order, const std::initializer_list<Nd4jLong>& shape, nd4j::LaunchContext * context);
|
||||
template ND4J_EXPORT NDArray NDArrayFactory::create(int16_t* buffer, const char order, const std::initializer_list<Nd4jLong>& shape, nd4j::LaunchContext * context);
|
||||
|
||||
|
||||
NDArray NDArrayFactory::string(char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<const char *> &strings, nd4j::LaunchContext * context) {
|
||||
std::vector<const char*> vec(strings);
|
||||
return NDArrayFactory::string(order, shape, vec, context);
|
||||
}
|
||||
|
||||
NDArray NDArrayFactory::string(char order, const std::vector<Nd4jLong> &shape, const std::vector<const char *> &strings, nd4j::LaunchContext * context) {
|
||||
std::vector<std::string> vec(strings.size());
|
||||
int cnt = 0;
|
||||
for (auto s:strings)
|
||||
vec[cnt++] = std::string(s);
|
||||
|
||||
return NDArrayFactory::string(order, shape, vec, context);
|
||||
}
|
||||
|
||||
|
||||
NDArray NDArrayFactory::string(char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<std::string> &string, nd4j::LaunchContext * context) {
|
||||
std::vector<std::string> vec(string);
|
||||
return NDArrayFactory::string(order, shape, vec, context);
|
||||
}
|
||||
|
||||
NDArray* NDArrayFactory::string_(char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<const char *> &strings, nd4j::LaunchContext * context) {
|
||||
std::vector<const char*> vec(strings);
|
||||
return NDArrayFactory::string_(order, shape, vec, context);
|
||||
}
|
||||
|
||||
NDArray* NDArrayFactory::string_(char order, const std::vector<Nd4jLong> &shape, const std::vector<const char *> &strings, nd4j::LaunchContext * context) {
|
||||
std::vector<std::string> vec(strings.size());
|
||||
int cnt = 0;
|
||||
for (auto s:strings)
|
||||
vec[cnt++] = std::string(s);
|
||||
|
||||
return NDArrayFactory::string_(order, shape, vec, context);
|
||||
}
|
||||
|
||||
|
||||
NDArray* NDArrayFactory::string_(char order, const std::vector<Nd4jLong> &shape, const std::initializer_list<std::string> &string, nd4j::LaunchContext * context) {
|
||||
std::vector<std::string> vec(string);
|
||||
return NDArrayFactory::string_(order, shape, vec, context);
|
||||
}
|
||||
|
||||
NDArray NDArrayFactory::string(char order, const std::vector<Nd4jLong> &shape, const std::vector<std::string> &string, nd4j::LaunchContext * context) {
|
||||
|
||||
if (context == nullptr)
|
||||
context = nd4j::LaunchContext ::defaultContext();
|
||||
|
||||
auto headerLength = ShapeUtils::stringBufferHeaderRequirements(string.size());
|
||||
|
||||
std::vector<Nd4jLong> offsets(string.size() + 1);
|
||||
Nd4jLong dataLength = 0;
|
||||
for (int e = 0; e < string.size(); e++) {
|
||||
offsets[e] = dataLength;
|
||||
dataLength += string[e].length();
|
||||
}
|
||||
offsets[string.size()] = dataLength;
|
||||
|
||||
std::shared_ptr<DataBuffer> pBuffer = std::make_shared<DataBuffer>(headerLength + dataLength, DataType::UTF8, context->getWorkspace(), true);
|
||||
|
||||
NDArray res(pBuffer, ShapeDescriptor(DataType::UTF8, order, shape), context);
|
||||
res.setAttached(context->getWorkspace() != nullptr);
|
||||
|
||||
if (res.lengthOf() != string.size())
|
||||
throw std::invalid_argument("Number of strings should match length of array");
|
||||
|
||||
memcpy(res.buffer(), offsets.data(), offsets.size() * sizeof(Nd4jLong));
|
||||
|
||||
auto data = static_cast<int8_t*>(res.buffer()) + headerLength;
|
||||
int resLen = res.lengthOf();
|
||||
for (int e = 0; e < resLen; e++) {
|
||||
auto length = offsets[e+1] - offsets[e];
|
||||
auto cdata = data + offsets[e];
|
||||
memcpy(cdata, string[e].c_str(), string[e].length());
|
||||
}
|
||||
|
||||
res.tickWriteHost();
|
||||
res.syncToDevice();
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
NDArray* NDArrayFactory::string_(char order, const std::vector<Nd4jLong> &shape, const std::vector<std::string> &string, nd4j::LaunchContext * context) {
|
||||
auto res = new NDArray();
|
||||
*res = NDArrayFactory::string(order, shape, string, context);
|
||||
return res;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////
|
||||
NDArray NDArrayFactory::string(const char16_t* u16string, nd4j::DataType dtype, nd4j::LaunchContext* context) {
|
||||
return NDArray(u16string, dtype, context);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
NDArray* NDArrayFactory::string_(const char16_t* u16string, nd4j::DataType dtype, nd4j::LaunchContext* context) {
|
||||
return string_(std::u16string(u16string), dtype, context);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
NDArray* NDArrayFactory::string_(const std::u16string& u16string, nd4j::DataType dtype, nd4j::LaunchContext* context) {
|
||||
auto res = new NDArray();
|
||||
*res = NDArray(u16string, dtype, context);
|
||||
return res;
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
NDArray NDArrayFactory::string(const std::u16string& u16string, nd4j::DataType dtype, nd4j::LaunchContext* context) {
|
||||
return NDArray(u16string, dtype, context);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
NDArray NDArrayFactory::string(const char32_t* u32string, nd4j::DataType dtype, nd4j::LaunchContext* context) {
|
||||
return NDArray(u32string, dtype, context);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
NDArray* NDArrayFactory::string_(const char32_t* u32string, nd4j::DataType dtype, nd4j::LaunchContext* context) {
|
||||
return string_(std::u32string(u32string), dtype, context);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
NDArray* NDArrayFactory::string_(const std::u32string& u32string, nd4j::DataType dtype, nd4j::LaunchContext* context) {
|
||||
auto res = new NDArray();
|
||||
*res = NDArray(u32string, dtype, context);
|
||||
return res;
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
NDArray NDArrayFactory::string(const std::u32string& u32string, nd4j::DataType dtype, nd4j::LaunchContext* context) {
|
||||
return NDArray(u32string, dtype, context);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
NDArray NDArrayFactory::string(const char* str, nd4j::DataType dtype, nd4j::LaunchContext* context) {
|
||||
return NDArray(str, dtype, context);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
NDArray* NDArrayFactory::string_(const char* str, nd4j::DataType dtype, nd4j::LaunchContext* context) {
|
||||
return string_(std::string(str), dtype, context);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
NDArray* NDArrayFactory::string_(const std::string& str, nd4j::DataType dtype, nd4j::LaunchContext* context) {
|
||||
auto res = new NDArray();
|
||||
*res = NDArray(str, dtype, context);
|
||||
return res;
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
NDArray NDArrayFactory::string(const std::string& str, nd4j::DataType dtype, nd4j::LaunchContext* context) {
|
||||
return NDArray(str, dtype, context);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
NDArray NDArrayFactory::string(const std::vector<Nd4jLong> &shape, const std::initializer_list<const char *> &strings, nd4j::DataType dataType, nd4j::LaunchContext * context) {
|
||||
return NDArray(shape, std::vector<const char*>(strings), dataType, context);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
NDArray NDArrayFactory::string( const std::vector<Nd4jLong> &shape, const std::vector<const char *> &strings, nd4j::DataType dataType, nd4j::LaunchContext * context) {
|
||||
return NDArray( shape, strings, dataType, context);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
NDArray NDArrayFactory::string( const std::vector<Nd4jLong> &shape, const std::initializer_list<std::string> &string, nd4j::DataType dataType, nd4j::LaunchContext * context) {
|
||||
return NDArray( shape, std::vector<std::string>(string), dataType, context);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
NDArray* NDArrayFactory::string_( const std::vector<Nd4jLong> &shape, const std::initializer_list<const char *> &strings, nd4j::DataType dataType, nd4j::LaunchContext * context) {
|
||||
return NDArrayFactory::string_( shape, std::vector<const char*>(strings), dataType, context);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
NDArray* NDArrayFactory::string_( const std::vector<Nd4jLong> &shape, const std::vector<const char *> &strings, nd4j::DataType dataType, nd4j::LaunchContext * context) {
|
||||
std::vector<std::string> vec(strings.size());
|
||||
int cnt = 0;
|
||||
for (auto s:strings)
|
||||
vec[cnt++] = std::string(s);
|
||||
|
||||
return NDArrayFactory::string_( shape, vec, dataType, context);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
NDArray* NDArrayFactory::string_( const std::vector<Nd4jLong> &shape, const std::initializer_list<std::string> &string, nd4j::DataType dataType, nd4j::LaunchContext * context) {
|
||||
return NDArrayFactory::string_( shape, std::vector<std::string>(string), dataType, context);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
NDArray NDArrayFactory::string( const std::vector<Nd4jLong> &shape, const std::vector<std::string> &string, nd4j::DataType dataType, nd4j::LaunchContext * context) {
|
||||
return NDArray(shape, string, dataType, context);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
NDArray* NDArrayFactory::string_(const std::vector<Nd4jLong> &shape, const std::vector<std::string> &string, nd4j::DataType dataType, nd4j::LaunchContext * context) {
|
||||
auto res = new NDArray();
|
||||
*res = NDArray( shape, string, dataType, context);
|
||||
return res;
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
NDArray NDArrayFactory::string(const std::vector<Nd4jLong>& shape, const std::initializer_list<const char16_t*>& strings, nd4j::DataType dataType, nd4j::LaunchContext* context) {
|
||||
return NDArray( shape, std::vector<const char16_t*>(strings), dataType, context);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
NDArray NDArrayFactory::string( const std::vector<Nd4jLong>& shape, const std::vector<const char16_t*>& strings, nd4j::DataType dataType, nd4j::LaunchContext* context) {
|
||||
return NDArray( shape, strings, dataType, context);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
NDArray NDArrayFactory::string( const std::vector<Nd4jLong>& shape, const std::initializer_list<std::u16string>& string, nd4j::DataType dataType, nd4j::LaunchContext* context) {
|
||||
return NDArray( shape, std::vector<std::u16string>(string), dataType, context);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
NDArray* NDArrayFactory::string_( const std::vector<Nd4jLong>& shape, const std::initializer_list<const char16_t*>& strings, nd4j::DataType dataType, nd4j::LaunchContext* context) {
|
||||
return NDArrayFactory::string_( shape, std::vector<const char16_t*>(strings), dataType, context);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
NDArray* NDArrayFactory::string_( const std::vector<Nd4jLong>& shape, const std::vector<const char16_t*>& strings, nd4j::DataType dataType, nd4j::LaunchContext* context) {
|
||||
std::vector<std::u16string> vec(strings.size());
|
||||
int cnt = 0;
|
||||
for (auto s : strings)
|
||||
vec[cnt++] = std::u16string(s);
|
||||
|
||||
return NDArrayFactory::string_( shape, vec, dataType, context);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
NDArray* NDArrayFactory::string_( const std::vector<Nd4jLong>& shape, const std::initializer_list<std::u16string>& string, nd4j::DataType dataType, nd4j::LaunchContext* context) {
|
||||
return NDArrayFactory::string_( shape, std::vector<std::u16string>(string), dataType, context);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
NDArray* NDArrayFactory::string_( const std::vector<Nd4jLong>& shape, const std::vector<std::u16string>& string, nd4j::DataType dataType, nd4j::LaunchContext* context) {
|
||||
auto res = new NDArray();
|
||||
*res = NDArray( shape, string, dataType, context);
|
||||
return res;
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
NDArray NDArrayFactory::string( const std::vector<Nd4jLong>& shape, const std::vector<std::u16string>& string, nd4j::DataType dtype, nd4j::LaunchContext* context) {
|
||||
return NDArray( shape, string, dtype, context);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
NDArray NDArrayFactory::string( const std::vector<Nd4jLong>& shape, const std::initializer_list<const char32_t*>& strings, nd4j::DataType dataType, nd4j::LaunchContext* context) {
|
||||
return NDArray( shape, std::vector<const char32_t*>(strings), dataType, context);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
NDArray NDArrayFactory::string( const std::vector<Nd4jLong>& shape, const std::vector<const char32_t*>& strings, nd4j::DataType dataType, nd4j::LaunchContext* context) {
|
||||
return NDArray( shape, strings, dataType, context);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
NDArray NDArrayFactory::string( const std::vector<Nd4jLong>& shape, const std::initializer_list<std::u32string>& string, nd4j::DataType dataType, nd4j::LaunchContext* context) {
|
||||
return NDArray(shape, std::vector<std::u32string>(string), dataType, context);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
NDArray* NDArrayFactory::string_( const std::vector<Nd4jLong>& shape, const std::initializer_list<const char32_t*>& strings, nd4j::DataType dataType, nd4j::LaunchContext* context) {
|
||||
return NDArrayFactory::string_( shape, std::vector<const char32_t*>(strings), dataType, context);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
NDArray* NDArrayFactory::string_( const std::vector<Nd4jLong>& shape, const std::vector<const char32_t*>& strings, nd4j::DataType dataType, nd4j::LaunchContext* context) {
|
||||
std::vector<std::u32string> vec(strings.size());
|
||||
int cnt = 0;
|
||||
for (auto s : strings)
|
||||
vec[cnt++] = std::u32string(s);
|
||||
return NDArrayFactory::string_( shape, vec, dataType, context);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
NDArray* NDArrayFactory::string_( const std::vector<Nd4jLong>& shape, const std::initializer_list<std::u32string>& string, nd4j::DataType dataType, nd4j::LaunchContext* context) {
|
||||
return NDArrayFactory::string_( shape, std::vector<std::u32string>(string), dataType, context);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
NDArray* NDArrayFactory::string_( const std::vector<Nd4jLong>& shape, const std::vector<std::u32string>& string, nd4j::DataType dataType, nd4j::LaunchContext* context) {
|
||||
auto res = new NDArray();
|
||||
*res = NDArray( shape, string, dataType, context);
|
||||
return res;
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
NDArray NDArrayFactory::string(const std::vector<Nd4jLong>& shape, const std::vector<std::u32string>& string, nd4j::DataType dtype, nd4j::LaunchContext* context) {
|
||||
return NDArray( shape, string, dtype, context);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -122,7 +122,7 @@ namespace nd4j {
|
|||
}
|
||||
|
||||
FORCEINLINE bool DataTypeUtils::isS(nd4j::DataType dataType) {
|
||||
return dataType == nd4j::DataType::UTF8;
|
||||
return dataType == nd4j::DataType::UTF8 || dataType == nd4j::DataType::UTF16 || dataType == nd4j::DataType::UTF32;
|
||||
}
|
||||
|
||||
FORCEINLINE bool DataTypeUtils::isZ(nd4j::DataType dataType) {
|
||||
|
@ -370,6 +370,10 @@ FORCEINLINE std::string DataTypeUtils::asString(DataType dataType) {
|
|||
return std::string("UINT64");
|
||||
case UTF8:
|
||||
return std::string("UTF8");
|
||||
case UTF16:
|
||||
return std::string("UTF16");
|
||||
case UTF32:
|
||||
return std::string("UTF32");
|
||||
default:
|
||||
throw std::runtime_error("Unknown data type used");
|
||||
}
|
||||
|
@ -431,6 +435,8 @@ FORCEINLINE _CUDA_HD T DataTypeUtils::eps() {
|
|||
case nd4j::DataType::UINT16: return (size_t) 2;
|
||||
|
||||
case nd4j::DataType::UTF8:
|
||||
case nd4j::DataType::UTF16:
|
||||
case nd4j::DataType::UTF32:
|
||||
case nd4j::DataType::INT32:
|
||||
case nd4j::DataType::UINT32:
|
||||
case nd4j::DataType::HALF2:
|
||||
|
@ -455,6 +461,10 @@ FORCEINLINE _CUDA_HD T DataTypeUtils::eps() {
|
|||
return nd4j::DataType::BOOL;
|
||||
} else if (std::is_same<T, std::string>::value) {
|
||||
return nd4j::DataType::UTF8;
|
||||
} else if (std::is_same<T, std::u16string>::value) {
|
||||
return nd4j::DataType::UTF16;
|
||||
} else if (std::is_same<T, std::u32string>::value) {
|
||||
return nd4j::DataType::UTF32;
|
||||
} else if (std::is_same<T, float>::value) {
|
||||
return nd4j::DataType::FLOAT32;
|
||||
} else if (std::is_same<T, float16>::value) {
|
||||
|
|
|
@ -49,12 +49,11 @@ namespace nd4j {
|
|||
delete[] newShape;
|
||||
return NDArrayFactory::empty_(dtype, nullptr);
|
||||
}
|
||||
|
||||
// TODO fix UTF16 and UTF32
|
||||
if (dtype == UTF8) {
|
||||
bool isBe = BitwiseUtils::isBE();
|
||||
bool canKeep = (isBe && flatArray->byteOrder() == nd4j::graph::ByteOrder_BE) || (!isBe && flatArray->byteOrder() == nd4j::graph::ByteOrder_LE);
|
||||
auto order = shape::order(newShape);
|
||||
|
||||
|
||||
std::vector<std::string> substrings(length);
|
||||
std::vector<Nd4jLong> shapeVector(rank);
|
||||
for (int e = 0; e < rank; e++)
|
||||
|
@ -88,8 +87,8 @@ namespace nd4j {
|
|||
|
||||
delete[] offsets;
|
||||
delete[] newShape;
|
||||
|
||||
return NDArrayFactory::string_(order, shapeVector, substrings);
|
||||
// string order always 'c'
|
||||
return NDArrayFactory::string_(shapeVector, substrings);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -171,7 +171,10 @@ namespace nd4j {
|
|||
* @param numStrings
|
||||
* @return
|
||||
*/
|
||||
static Nd4jLong stringBufferHeaderRequirements(Nd4jLong numStrings);
|
||||
static FORCEINLINE Nd4jLong stringBufferHeaderRequirements(Nd4jLong numStrings) {
|
||||
// we store +1 offset
|
||||
return (numStrings + 1) * sizeof(Nd4jLong);
|
||||
}
|
||||
|
||||
/*
|
||||
* check whether arr1/arr2 is sub-array of arr2/arr1,
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -16,6 +17,7 @@
|
|||
|
||||
//
|
||||
// Created by raver119 on 20/04/18.
|
||||
// @author Oleg Semeniv <oleg.semeniv@gmail.com>
|
||||
//
|
||||
|
||||
#ifndef LIBND4J_STRINGUTILS_H
|
||||
|
@ -27,6 +29,7 @@
|
|||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <NDArray.h>
|
||||
#include <unicode.h>
|
||||
|
||||
namespace nd4j {
|
||||
class ND4J_EXPORT StringUtils {
|
||||
|
@ -85,6 +88,55 @@ namespace nd4j {
|
|||
* @return
|
||||
*/
|
||||
static std::vector<std::string> split(const std::string &haystack, const std::string &delimiter);
|
||||
|
||||
|
||||
/**
|
||||
* This method convert u8 string to u16
|
||||
* @param const reference to input string
|
||||
* @param reference to output u16string
|
||||
* @return boolean status
|
||||
*/
|
||||
static bool u8StringToU16String(const std::string& u8, std::u16string& u16);
|
||||
|
||||
/**
|
||||
* This method convert u8 string to u32
|
||||
* @param const reference to input string
|
||||
* @param reference to output u32string
|
||||
* @return boolean status
|
||||
*/
|
||||
static bool u8StringToU32String(const std::string& u8, std::u32string& u32);
|
||||
|
||||
/**
|
||||
* This method convert u16 string to u32
|
||||
* @param const reference to input u16string
|
||||
* @param reference to output u32string
|
||||
* @return boolean status
|
||||
*/
|
||||
static bool u16StringToU32String(const std::u16string& u16, std::u32string& u32);
|
||||
|
||||
/**
|
||||
* This method convert u16 string to u8 string
|
||||
* @param const reference to input u16string
|
||||
* @param reference to output string
|
||||
* @return boolean status
|
||||
*/
|
||||
static bool u16StringToU8String(const std::u16string& u16, std::string& u8);
|
||||
|
||||
/**
|
||||
* This method convert u32 string to u16 string
|
||||
* @param const reference to input u32string
|
||||
* @param reference to output u16string
|
||||
* @return boolean status
|
||||
*/
|
||||
static bool u32StringToU16String(const std::u32string& u32, std::u16string& u16);
|
||||
|
||||
/**
|
||||
* This method convert u32 string to u8 string
|
||||
* @param const reference to input u32string
|
||||
* @param reference to output string
|
||||
* @return boolean status
|
||||
*/
|
||||
static bool u32StringToU8String(const std::u32string& u32, std::string& u8);
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
@ -1019,15 +1019,6 @@ std::vector<int> ShapeUtils::tadAxesForSimpleBroadcast(const NDArray& max, const
|
|||
return numOfMinTads == 1 ? maxTadDims : std::vector<int>();
|
||||
}
|
||||
|
||||
|
||||
Nd4jLong ShapeUtils::stringBufferHeaderRequirements(Nd4jLong numStrings) {
|
||||
// we store +1 offset
|
||||
auto base = numStrings + 1;
|
||||
|
||||
// since we return number of bytes...
|
||||
return base * sizeof(Nd4jLong);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/*
|
||||
bool ShapeUtils::isSubArrayCase(const NDArray& arr1, const NDArray& arr2, std::vector<int>& sameDims) {
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -16,6 +17,7 @@
|
|||
|
||||
//
|
||||
// Created by raver119 on 20/04/18.
|
||||
// @author Oleg Semeniv <oleg.semeniv@gmail.com>
|
||||
//
|
||||
|
||||
#include <helpers/StringUtils.h>
|
||||
|
@ -49,13 +51,8 @@ namespace nd4j {
|
|||
if (!array.isS())
|
||||
throw nd4j::datatype_exception::build("StringUtils::byteLength expects one of String types;", array.dataType());
|
||||
|
||||
uint64_t result = 0;
|
||||
|
||||
// our buffer stores offsets, and the last value is basically number of bytes used
|
||||
auto buffer = array.bufferAsT<Nd4jLong>();
|
||||
result = buffer[array.lengthOf()];
|
||||
|
||||
return result;
|
||||
return buffer[array.lengthOf()];
|
||||
}
|
||||
|
||||
std::vector<std::string> StringUtils::split(const std::string &haystack, const std::string &delimiter) {
|
||||
|
@ -73,4 +70,89 @@ namespace nd4j {
|
|||
|
||||
return output;
|
||||
}
|
||||
|
||||
bool StringUtils::u8StringToU16String(const std::string& u8, std::u16string& u16) {
|
||||
|
||||
if (u8.empty())
|
||||
return false;
|
||||
|
||||
u16.resize(unicode::offsetUtf8StringInUtf16(u8.data(), u8.size()) / sizeof(char16_t));
|
||||
if (u8.size() == u16.size())
|
||||
u16.assign(u8.begin(), u8.end());
|
||||
else
|
||||
return unicode::utf8to16(u8.data(), &u16[0], u8.size());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool StringUtils::u8StringToU32String(const std::string& u8, std::u32string& u32) {
|
||||
|
||||
if (u8.empty())
|
||||
return false;
|
||||
|
||||
u32.resize( unicode::offsetUtf8StringInUtf32(u8.data(), u8.size()) / sizeof(char32_t) );
|
||||
if (u8.size() == u32.size())
|
||||
u32.assign(u8.begin(), u8.end());
|
||||
else
|
||||
return unicode::utf8to32(u8.data(), &u32[0], u8.size());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool StringUtils::u16StringToU32String(const std::u16string& u16, std::u32string& u32) {
|
||||
|
||||
if (u16.empty())
|
||||
return false;
|
||||
|
||||
u32.resize(unicode::offsetUtf16StringInUtf32(u16.data(), u16.size()) / sizeof(char32_t));
|
||||
if (u16.size() == u32.size())
|
||||
u32.assign(u16.begin(), u16.end());
|
||||
else
|
||||
return unicode::utf16to32(u16.data(), &u32[0], u16.size());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool StringUtils::u16StringToU8String(const std::u16string& u16, std::string& u8) {
|
||||
|
||||
if (u16.empty())
|
||||
return false;
|
||||
|
||||
u8.resize(unicode::offsetUtf16StringInUtf8(u16.data(), u16.size()));
|
||||
if (u16.size() == u8.size())
|
||||
u8.assign(u16.begin(), u16.end());
|
||||
else
|
||||
return unicode::utf16to8(u16.data(), &u8[0], u16.size());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool StringUtils::u32StringToU16String(const std::u32string& u32, std::u16string& u16) {
|
||||
|
||||
if (u32.empty())
|
||||
return false;
|
||||
|
||||
u16.resize(unicode::offsetUtf32StringInUtf16(u32.data(), u32.size()) / sizeof(char16_t));
|
||||
if (u32.size() == u16.size())
|
||||
u16.assign(u32.begin(), u32.end());
|
||||
else
|
||||
return unicode::utf32to16(u32.data(), &u16[0], u32.size());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool StringUtils::u32StringToU8String(const std::u32string& u32, std::string& u8) {
|
||||
|
||||
if (u32.empty())
|
||||
return false;
|
||||
|
||||
u8.resize(unicode::offsetUtf32StringInUtf8(u32.data(), u32.size()));
|
||||
if (u32.size() == u8.size())
|
||||
u8.assign(u32.begin(), u32.end());
|
||||
else
|
||||
return unicode::utf32to8(u32.data(), &u8[0], u32.size());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -0,0 +1,456 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2020 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author Oleg Semeniv <oleg.semeniv@gmail.com>
|
||||
//
|
||||
|
||||
#include <unicode.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace unicode {
|
||||
|
||||
constexpr uint32_t ONEBYTEBOUND = 0x00000080;
|
||||
constexpr uint32_t TWOBYTEBOUND = 0x00000800;
|
||||
constexpr uint32_t THREEBYTEBOUND = 0x00010000;
|
||||
constexpr uint16_t HIGHBYTEMIN = 0xd800u;
|
||||
constexpr uint16_t HIGHBYTEMAX = 0xdbffu;
|
||||
constexpr uint16_t TRAILBYTEMIN = 0xdc00u;
|
||||
constexpr uint16_t TRAILBYTEMAX = 0xdfffu;
|
||||
constexpr uint16_t HIGHBYTEOFFSET = HIGHBYTEMIN - (0x10000 >> 10);
|
||||
constexpr uint32_t BYTEOFFSET = 0x10000u - (HIGHBYTEMIN << 10) - TRAILBYTEMIN;
|
||||
// Maximum valid value for a Unicode code point
|
||||
constexpr uint32_t CODEPOINTMAX = 0x0010ffffu;
|
||||
|
||||
template<typename T>
|
||||
FORCEINLINE uint8_t castToU8(const T cp) {
|
||||
return static_cast<uint8_t>(0xff & cp);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
FORCEINLINE uint16_t castToU16(const T cp) {
|
||||
return static_cast<uint16_t>(0xffff & cp);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
FORCEINLINE uint32_t castToU32(const T cp) {
|
||||
return static_cast<uint32_t>(0xffffff & cp);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
FORCEINLINE bool isTrail(const T cp) {
|
||||
return ((castToU8(cp) >> 6) == 0x2);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
FORCEINLINE bool isHighSurrogate(const T cp) {
|
||||
return (cp & 0xfffffc00) == 0xd800;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool isLowSurrogate(const T cp) {
|
||||
return (cp & 0xfffffc00) == 0xdc00;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
FORCEINLINE bool isLeadSurrogate(const T cp) {
|
||||
return (cp >= HIGHBYTEMIN && cp <= HIGHBYTEMAX);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
FORCEINLINE bool isTrailSurrogate(const T cp) {
|
||||
return (cp >= TRAILBYTEMIN && cp <= TRAILBYTEMAX);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
FORCEINLINE bool isSurrogateU8(const T cp) {
|
||||
return (cp >= HIGHBYTEMIN && cp <= TRAILBYTEMAX);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
FORCEINLINE bool isSurrogateU16(const T cp) {
|
||||
return ((cp - 0xd800u) < 2048u);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
FORCEINLINE bool isSymbolU8Valid(const T cp) {
|
||||
return (cp <= CODEPOINTMAX && !isSurrogateU8(cp));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
FORCEINLINE bool isSymbolValid(const T cp) {
|
||||
return (cp <= CODEPOINTMAX);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
FORCEINLINE uint32_t surrogateU32(const T& high, const T& low) {
|
||||
return (high << 10) + low - 0x35fdc00;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Nd4jLong symbolLength(const T* it) {
|
||||
uint8_t lead = castToU8(*it);
|
||||
if (lead < 0x80)
|
||||
return 1;
|
||||
else if ((lead >> 5) == 0x6)
|
||||
return 2;
|
||||
else if ((lead >> 4) == 0xe)
|
||||
return 3;
|
||||
else if ((lead >> 3) == 0x1e)
|
||||
return 4;
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Nd4jLong symbolLength32(const T* it) {
|
||||
auto lead = castToU32(*it);
|
||||
if (lead < ONEBYTEBOUND)
|
||||
return 1;
|
||||
else if (lead < TWOBYTEBOUND)
|
||||
return 2;
|
||||
else if (lead < THREEBYTEBOUND)
|
||||
return 3;
|
||||
else if (lead <= CODEPOINTMAX)
|
||||
return 4;
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Nd4jLong symbolLength16(const T* it) {
|
||||
|
||||
uint32_t lead = castToU16(*it);
|
||||
if (!isLeadSurrogate(lead)) {
|
||||
if (lead < ONEBYTEBOUND)
|
||||
return 1;
|
||||
else if (lead < TWOBYTEBOUND)
|
||||
return 2;
|
||||
else if (lead < THREEBYTEBOUND)
|
||||
return 3;
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
else {
|
||||
return 4;
|
||||
}
|
||||
}
|
||||
|
||||
Nd4jLong offsetUtf8StringInUtf32(const void* start, const void* end) {
|
||||
|
||||
Nd4jLong count = 0;
|
||||
for (auto it = static_cast<const int8_t*>(start); it != end; it++) {
|
||||
auto length = symbolLength(it);
|
||||
it += (length > 0) ? (length - 1) : 0;
|
||||
count += 1;
|
||||
}
|
||||
return static_cast<Nd4jLong>(count * sizeof(char32_t));
|
||||
}
|
||||
|
||||
Nd4jLong offsetUtf16StringInUtf32(const void* start, const void* end) {
|
||||
|
||||
Nd4jLong count = 0;
|
||||
for (auto it = static_cast<const uint16_t*>(start); it != end;) {
|
||||
auto length = symbolLength16(it);
|
||||
it += (4 == length) ? 2 : 1;
|
||||
count += 1;
|
||||
}
|
||||
return static_cast<Nd4jLong>(count*sizeof(char32_t));
|
||||
}
|
||||
|
||||
Nd4jLong offsetUtf8StringInUtf16(const void* start, const void* end) {
|
||||
|
||||
Nd4jLong count = 0;
|
||||
for (auto it = static_cast<const int8_t*>(start); it != end; it++) {
|
||||
auto length = symbolLength(it);
|
||||
auto step = ((length > 0) ? (length - 1) : 0);
|
||||
it += step;
|
||||
count += (4 == length) ? 2 : 1;
|
||||
}
|
||||
return static_cast<Nd4jLong>(count*sizeof(char16_t));
|
||||
}
|
||||
|
||||
Nd4jLong offsetUtf16StringInUtf8(const void* start, const void* end) {
|
||||
|
||||
Nd4jLong count = 0;
|
||||
for (auto it = static_cast<const uint16_t*>(start); it != end;) {
|
||||
auto length = symbolLength16(it);
|
||||
it += (4 == length) ? 2 : 1;
|
||||
count += length;
|
||||
}
|
||||
return static_cast<Nd4jLong>(count);
|
||||
}
|
||||
|
||||
Nd4jLong offsetUtf32StringInUtf16(const void* start, const void* end) {
|
||||
|
||||
Nd4jLong count = 0;
|
||||
for (auto it = static_cast<const uint32_t*>(start); it != end; it++) {
|
||||
auto length = symbolLength32(it);
|
||||
count += (4 == length) ? 2 : 1;;
|
||||
}
|
||||
return static_cast<Nd4jLong>(count*sizeof(char16_t));
|
||||
}
|
||||
|
||||
Nd4jLong offsetUtf32StringInUtf8(const void* start, const void* end) {
|
||||
|
||||
Nd4jLong count = 0;
|
||||
for (auto it = static_cast<const uint32_t*>(start); it != end; it++) {
|
||||
count += symbolLength32(it);
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
bool isStringValidU8(const void* start, const void* stop) {
|
||||
for (auto it = static_cast<const int8_t*>(start); it != stop; it++) {
|
||||
if (!isSymbolU8Valid( castToU8(*it) )) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool isStringValidU16(const void* start, const void* stop) {
|
||||
for (auto it = static_cast<const uint16_t*>(start); it != stop; it++) {
|
||||
if (!isSymbolValid( castToU32(*it) )) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool isStringValidU32(const void* start, const void* stop) {
|
||||
for (auto it = static_cast<const uint32_t*>(start); it != stop; it++) {
|
||||
if (!isSymbolValid( castToU32(*it) )) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void* utf16to8Ptr(const void* start, const void* end, void* res) {
|
||||
|
||||
auto result = static_cast<int8_t*>(res);
|
||||
// result have to be pre-allocated
|
||||
for (auto it = static_cast<const uint16_t*>(start); it != end;) {
|
||||
uint32_t cp = castToU16(*it++);
|
||||
if (!isLeadSurrogate(cp)) {
|
||||
if (cp < 0x80) { // for one byte
|
||||
*(result++) = static_cast<uint8_t>(cp);
|
||||
}
|
||||
else if (cp < 0x800) { // for two bytes
|
||||
*(result++) = static_cast<uint8_t>((cp >> 6) | 0xc0);
|
||||
*(result++) = static_cast<uint8_t>((cp & 0x3f) | 0x80);
|
||||
}
|
||||
else{ // for three bytes
|
||||
*(result++) = static_cast<uint8_t>((cp >> 12) | 0xe0);
|
||||
*(result++) = static_cast<uint8_t>(((cp >> 6) & 0x3f) | 0x80);
|
||||
*(result++) = static_cast<uint8_t>((cp & 0x3f) | 0x80);
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (it != end) {
|
||||
uint32_t trail_surrogate = castToU16(*it++);
|
||||
if (isTrailSurrogate(trail_surrogate))
|
||||
cp = (cp << 10) + trail_surrogate + BYTEOFFSET;
|
||||
}
|
||||
// for four bytes
|
||||
*(result++) = static_cast<uint8_t>((cp >> 18) | 0xf0);
|
||||
*(result++) = static_cast<uint8_t>(((cp >> 12) & 0x3f) | 0x80);
|
||||
*(result++) = static_cast<uint8_t>(((cp >> 6) & 0x3f) | 0x80);
|
||||
*(result++) = static_cast<uint8_t>((cp & 0x3f) | 0x80);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void* utf8to16Ptr(const void* start, const void* end, void* res) {
|
||||
|
||||
auto result = static_cast<uint16_t*>(res);
|
||||
// result have to be pre-allocated
|
||||
for (auto it = static_cast<const int8_t*>(start); it != end;) {
|
||||
|
||||
auto nLength = symbolLength(it);
|
||||
uint32_t cp = castToU8(*it++);
|
||||
if (4 != nLength) {
|
||||
if (2 == nLength) {
|
||||
cp = ((cp << 6) & 0x7ff) + ((*it++) & 0x3f);
|
||||
}
|
||||
else if (3 == nLength) {
|
||||
cp = ((cp << 12) & 0xffff) + ((castToU8(*it++) << 6) & 0xfff);
|
||||
cp += (*it++) & 0x3f;
|
||||
}
|
||||
*(result++) = static_cast<uint16_t>(cp);
|
||||
}
|
||||
else {
|
||||
cp = ((cp << 18) & 0x1fffff) + ((castToU8(*it++) << 12) & 0x3ffff);
|
||||
cp += (castToU8(*it++) << 6) & 0xfff;
|
||||
cp += (*it++) & 0x3f;
|
||||
//make a surrogate pair
|
||||
*(result++) = static_cast<uint16_t>((cp >> 10) + HIGHBYTEOFFSET);
|
||||
*(result++) = static_cast<uint16_t>((cp & 0x3ff) + TRAILBYTEMIN);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void* utf32to8Ptr( const void* start, const void* end, void* result) {
|
||||
|
||||
auto res = static_cast<uint8_t*>(result);
|
||||
// result have to be pre-allocated
|
||||
for (auto it = static_cast<const uint32_t*>(start); it != end; it++) {
|
||||
|
||||
if (*it < 0x80) // for one byte
|
||||
*(res++) = static_cast<uint8_t>(*it);
|
||||
else if (*it < 0x800) { // for two bytes
|
||||
*(res++) = static_cast<uint8_t>((*it >> 6) | 0xc0);
|
||||
*(res++) = static_cast<uint8_t>((*it & 0x3f) | 0x80);
|
||||
}
|
||||
else if (*it < 0x10000) { // for three bytes
|
||||
*(res++) = static_cast<uint8_t>((*it >> 12) | 0xe0);
|
||||
*(res++) = static_cast<uint8_t>(((*it >> 6) & 0x3f) | 0x80);
|
||||
*(res++) = static_cast<uint8_t>((*it & 0x3f) | 0x80);
|
||||
}
|
||||
else { // for four bytes
|
||||
*(res++) = static_cast<uint8_t>((*it >> 18) | 0xf0);
|
||||
*(res++) = static_cast<uint8_t>(((*it >> 12) & 0x3f) | 0x80);
|
||||
*(res++) = static_cast<uint8_t>(((*it >> 6) & 0x3f) | 0x80);
|
||||
*(res++) = static_cast<uint8_t>((*it & 0x3f) | 0x80);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void* utf8to32Ptr(const void* start, const void* end, void* res) {
|
||||
|
||||
auto result = static_cast<uint32_t*>(res);
|
||||
// result have to be pre-allocated
|
||||
for (auto it = static_cast<const int8_t*>(start); it != end;) {
|
||||
|
||||
auto nLength = symbolLength(it);
|
||||
uint32_t cp = castToU8(*it++);
|
||||
if (2 == nLength) {
|
||||
cp = ((cp << 6) & 0x7ff) + ((*it++) & 0x3f);
|
||||
}
|
||||
else if (3 == nLength) {
|
||||
cp = ((cp << 12) & 0xffff) + ((castToU8(*it++) << 6) & 0xfff);
|
||||
cp += (*it++) & 0x3f;
|
||||
}
|
||||
else if (4 == nLength) {
|
||||
cp = ((cp << 18) & 0x1fffff) + ((castToU8(*it++) << 12) & 0x3ffff);
|
||||
cp += (castToU8(*it++) << 6) & 0xfff;
|
||||
cp += (*it++) & 0x3f;
|
||||
}
|
||||
(*result++) = cp;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void* utf16to32Ptr(const void* start, const void* end, void* res) {
|
||||
|
||||
auto result = static_cast<uint32_t*>(res);
|
||||
// result have to be pre-allocated
|
||||
for (auto it = static_cast<const uint16_t*>(start); it != end; it++) {
|
||||
|
||||
uint32_t cpHigh = castToU32(*it);
|
||||
if (!isSurrogateU16(cpHigh)) {
|
||||
*result++ = cpHigh;
|
||||
}
|
||||
else {
|
||||
it++;
|
||||
uint32_t cpLow = castToU32(*it);
|
||||
if (isHighSurrogate(cpHigh) && it != end && isLowSurrogate(cpLow)) {
|
||||
*result++ = surrogateU32(cpHigh, cpLow);
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void* utf32to16Ptr(const void* start, const void* end, void* res) {
|
||||
|
||||
auto result = static_cast<uint16_t*>(res);
|
||||
// result have to be pre-allocate
|
||||
for (auto it = static_cast<const uint32_t*>(start); it != end; it++) {
|
||||
|
||||
uint32_t cpHigh = castToU32(*it);
|
||||
// todo check do we need this as we have pre-validation, if yes find out how to check u16
|
||||
if (cpHigh < 0 || cpHigh > 0x10FFFF || (cpHigh >= 0xD800 && cpHigh <= 0xDFFF)) {
|
||||
// Invalid code point. Replace with sentinel, per Unicode standard:
|
||||
*result++ = u'\uFFFD';
|
||||
}
|
||||
else if (cpHigh < 0x10000UL) { // In the BMP.
|
||||
*result++ = static_cast<char16_t>(cpHigh);
|
||||
}
|
||||
else {
|
||||
*result++ = static_cast<char16_t>(((cpHigh - 0x10000UL) / 0x400U) + 0xD800U);
|
||||
*result++ = static_cast<char16_t>(((cpHigh - 0x10000UL) % 0x400U) + 0xDC00U);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
Nd4jLong offsetUtf8StringInUtf32(const void* input, uint32_t nInputSize) {
|
||||
return offsetUtf8StringInUtf32(input, static_cast<const int8_t*>(input) + nInputSize);
|
||||
}
|
||||
|
||||
Nd4jLong offsetUtf16StringInUtf32(const void* input, uint32_t nInputSize) {
|
||||
return offsetUtf16StringInUtf32(input, static_cast<const uint16_t*>(input) + nInputSize);
|
||||
}
|
||||
|
||||
Nd4jLong offsetUtf8StringInUtf16(const void* input, uint32_t nInputSize) {
|
||||
return offsetUtf8StringInUtf16(input, static_cast<const int8_t*>(input) + nInputSize);
|
||||
}
|
||||
|
||||
Nd4jLong offsetUtf16StringInUtf8(const void* input, uint32_t nInputSize) {
|
||||
return offsetUtf16StringInUtf8(input, static_cast<const uint16_t*>(input) + nInputSize);
|
||||
}
|
||||
|
||||
Nd4jLong offsetUtf32StringInUtf8(const void* input, uint32_t nInputSize) {
|
||||
return offsetUtf32StringInUtf8(input, static_cast<const uint32_t*>(input) + nInputSize);
|
||||
}
|
||||
|
||||
Nd4jLong offsetUtf32StringInUtf16(const void* input, const uint32_t nInputSize) {
|
||||
return offsetUtf32StringInUtf16(input, static_cast<const uint32_t*>(input) + nInputSize);
|
||||
}
|
||||
|
||||
bool utf8to16(const void* input, void* output, uint32_t nInputSize) {
|
||||
return utf8to16Ptr(input, static_cast<const int8_t*>(input) + nInputSize, output);
|
||||
}
|
||||
|
||||
bool utf8to32(const void* input, void* output, uint32_t nInputSize) {
|
||||
return utf8to32Ptr(input, static_cast<const int8_t*>(input) + nInputSize, output);
|
||||
}
|
||||
|
||||
bool utf16to32(const void* input, void* output, uint32_t nInputSize) {
|
||||
return utf16to32Ptr(input, static_cast<const uint16_t*>(input) + nInputSize, output);
|
||||
}
|
||||
|
||||
bool utf16to8(const void* input, void* output, uint32_t nInputSize) {
|
||||
return utf16to8Ptr(input, static_cast<const uint16_t*>(input) + nInputSize, output);
|
||||
}
|
||||
|
||||
bool utf32to16(const void* input, void* output, uint32_t nInputSize) {
|
||||
return utf32to16Ptr(input, static_cast<const uint32_t*>(input) + nInputSize, output);
|
||||
}
|
||||
|
||||
bool utf32to8(const void* input, void* output, const Nd4jLong nInputSize) {
|
||||
return utf32to8Ptr(input, static_cast<const uint32_t*>(input) + nInputSize, output);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -0,0 +1,189 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author Oleg Semeniv <oleg.semeniv@gmail.com>
|
||||
//
|
||||
|
||||
#ifndef LIBND4J_UNICODE_H
|
||||
#define LIBND4J_UNICODE_H
|
||||
|
||||
#include <NDArray.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace unicode {
|
||||
|
||||
/**
|
||||
* This method calculate u16 offset based on utf8
|
||||
* @param const pointer to the utf8 string start point
|
||||
* @param size of the string
|
||||
* @return offset of utf16
|
||||
*/
|
||||
Nd4jLong offsetUtf8StringInUtf16(const void* start, const void* end);
|
||||
|
||||
/**
|
||||
* This method calculate u8 offset based on utf16
|
||||
* @param const pointer to the utf16 string start point
|
||||
* @param size of the string
|
||||
* @return offset of utf8
|
||||
*/
|
||||
Nd4jLong offsetUtf16StringInUtf8(const void* start, const void* end);
|
||||
|
||||
/**
|
||||
* This method calculate u32 offset based on utf16
|
||||
* @param const pointer to the utf16 string start point
|
||||
* @param size of the string
|
||||
* @return offset of utf32
|
||||
*/
|
||||
Nd4jLong offsetUtf32StringInUtf16(const void* start, const void* end);
|
||||
|
||||
/**
|
||||
* This method calculate u32 offset based on utf8
|
||||
* @param const pointer to the utf16 string start point
|
||||
* @param size of the string
|
||||
* @return offset of utf8
|
||||
*/
|
||||
Nd4jLong offsetUtf32StringInUtf8(const void* start, const void* end);
|
||||
|
||||
/*
|
||||
* This function check is valid charecter in u8 string
|
||||
*/
|
||||
bool isStringValidU8(const void* start, const void* stop);
|
||||
|
||||
/*
|
||||
* This function check is valid charecter in u16 string
|
||||
*/
|
||||
bool isStringValidU16(const void* start, const void* stop);
|
||||
|
||||
/*
|
||||
* This function check is valid u32 charecter in string
|
||||
*/
|
||||
bool isStringValidU32(const void* start, const void* stop);
|
||||
|
||||
/**
|
||||
* This method count offset for utf8 string in utf32
|
||||
* @param const pointer to the utf8 string start point
|
||||
* @param size of the string
|
||||
* @return offset
|
||||
*/
|
||||
Nd4jLong offsetUtf8StringInUtf32(const void* input, uint32_t nInputSize);
|
||||
|
||||
/**
|
||||
* This method count offset for utf8 string in utf32
|
||||
* @param const pointer to the utf8 string start point
|
||||
* @param const end pointer to the utf8 string
|
||||
* @return offset
|
||||
*/
|
||||
Nd4jLong offsetUtf8StringInUtf32(const void* input, const void* stop);
|
||||
|
||||
/**
|
||||
* This method count offset for utf32 based on utf16 string
|
||||
* @param const pointer to the utf16 string start point
|
||||
* @param size of the string
|
||||
* @return offset
|
||||
*/
|
||||
Nd4jLong offsetUtf16StringInUtf32(const void* input, uint32_t nInputSize);
|
||||
|
||||
/**
|
||||
* This method calculate offset of u16 based on utf8
|
||||
* @param const pointer to the utf8 string start point
|
||||
* @param size of the string
|
||||
* @return offset of utf16
|
||||
*/
|
||||
Nd4jLong offsetUtf8StringInUtf16(const void* input, uint32_t nInputSize);
|
||||
|
||||
/**
|
||||
* This method calculate offset of u8 based on utf16
|
||||
* @param const pointer to the utf16 string start point
|
||||
* @param size of the string
|
||||
* @return offset of utf8
|
||||
*/
|
||||
Nd4jLong offsetUtf16StringInUtf8(const void* input, uint32_t nInputSize);
|
||||
|
||||
/**
|
||||
* This method calculate offset of u32 based on utf8
|
||||
* @param const pointer to the utf16 string start point
|
||||
* @param size of the string
|
||||
* @return offset of utf32
|
||||
*/
|
||||
Nd4jLong offsetUtf32StringInUtf8(const void* input, uint32_t nInputSize);
|
||||
|
||||
/**
|
||||
* This method calculate offset of u32 based on utf16
|
||||
* @param const pointer to the utf16 string start point
|
||||
* @param size of the string
|
||||
* @return offset of utf32
|
||||
*/
|
||||
Nd4jLong offsetUtf32StringInUtf16(const void* input, const uint32_t nInputSize);
|
||||
|
||||
/**
|
||||
* This method convert utf8 string to utf16 string
|
||||
* @param const pointer to the utf8 string start point
|
||||
* @param reference to start point to utf16
|
||||
* @param size of input utf8 string
|
||||
* @return status of convertion
|
||||
*/
|
||||
bool utf8to16(const void* input, void* output, uint32_t nInputSize);
|
||||
|
||||
/**
|
||||
* This method convert utf8 string to utf32 string
|
||||
* @param const pointer to the utf8 string start point
|
||||
* @param reference to start point to utf32
|
||||
* @param size of input utf8 string
|
||||
* @return status of convertion
|
||||
*/
|
||||
bool utf8to32(const void* input, void* output, uint32_t nInputSize);
|
||||
|
||||
/**
|
||||
* This method convert utf16 string to utf32 string
|
||||
* @param const pointer to the utf16 string start point
|
||||
* @param reference to start point to utf32
|
||||
* @param size of input utf16 string
|
||||
* @return status of convertion
|
||||
*/
|
||||
bool utf16to32(const void* input, void* output, uint32_t nInputSize);
|
||||
|
||||
/**
|
||||
* This method convert utf16 string to utf8 string
|
||||
* @param const pointer to the utf16 string start point
|
||||
* @param reference to start point to utf8
|
||||
* @param size of input utf16 string
|
||||
* @return status of convertion
|
||||
*/
|
||||
bool utf16to8(const void* input, void* output, uint32_t nInputSize);
|
||||
|
||||
/**
|
||||
* This method convert utf32 string to utf16 string
|
||||
* @param const pointer to the utf32 string start point
|
||||
* @param reference to start point to utf16
|
||||
* @param size of input utf32 string
|
||||
* @return status of convertion
|
||||
*/
|
||||
bool utf32to16(const void* input, void* output, uint32_t nInputSize);
|
||||
|
||||
/**
|
||||
* This method convert utf32 string to utf8 string
|
||||
* @param const pointer to the utf32 string start point
|
||||
* @param reference to start point to utf8
|
||||
* @param size of input utf32 string
|
||||
* @return status of convertion
|
||||
*/
|
||||
bool utf32to8(const void* input, void* output, const Nd4jLong nInputSize);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#endif //LIBND4J_UNICODE_H
|
|
@ -118,7 +118,7 @@ namespace ops {
|
|||
DECLARE_TYPES(Pow_bp) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes({ ALL_FLOATS, ALL_INTS })
|
||||
->setAllowedOutputTypes({ ALL_FLOATS }); // TODO maybe wourth to add ALL_INTS
|
||||
->setAllowedOutputTypes({ ALL_FLOATS });
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -81,7 +81,7 @@ namespace nd4j {
|
|||
}
|
||||
|
||||
// now once we have all strings in single vector time to fill
|
||||
auto tmp = NDArrayFactory::string('c', {(Nd4jLong) strings.size()}, strings);
|
||||
auto tmp = NDArrayFactory::string({(Nd4jLong) strings.size()}, strings);
|
||||
auto blen = StringUtils::byteLength(tmp) + ShapeUtils::stringBufferHeaderRequirements(strings.size());
|
||||
|
||||
// for CUDA mostly
|
||||
|
|
|
@ -33,6 +33,11 @@
|
|||
#include <type_boilerplate.h>
|
||||
|
||||
|
||||
#define LIBND4J_STRINGTYPES \
|
||||
(nd4j::DataType::UTF8, std::string),\
|
||||
(nd4j::DataType::UTF16, std::u16string), \
|
||||
(nd4j::DataType::UTF32, std::u32string)
|
||||
|
||||
#define LIBND4J_TYPES \
|
||||
(nd4j::DataType::BFLOAT16, bfloat16),\
|
||||
(nd4j::DataType::HALF, float16), \
|
||||
|
|
|
@ -599,7 +599,7 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_2) {
|
|||
TEST_F(BroadcastableOpsTests, broadcast_empty_3) {
|
||||
|
||||
NDArray x = NDArrayFactory::create<float>('c', {1, 0, 2});
|
||||
NDArray y('c', {}, {0.1}, nd4j::DataType::FLOAT32);
|
||||
NDArray y('c', {}, std::vector<double>{0.1}, nd4j::DataType::FLOAT32);
|
||||
NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2});;
|
||||
|
||||
nd4j::ops::maximum op;
|
||||
|
|
|
@ -626,7 +626,7 @@ TEST_F(ConvolutionTests2, deconv3d_bp_test1) {
|
|||
|
||||
NDArray expGradI('c', {bS, oD, oH, oW, oC}, {62., 67.6, 68.4, 74.8, 81.2, 89.2, 87.6, 96.4, 119.6, 132.4, 126., 139.6, 138.8, 154., 145.2, 161.2}, nd4j::DataType::FLOAT32);
|
||||
NDArray expGradW('c', {kD, kH, kW, iC, oC}, {28., 28., 32., 32., 40., 40., 44., 44., 64, 64., 68., 68., 76., 76., 80., 80.}, nd4j::DataType::FLOAT32);
|
||||
NDArray expGradB('c', {iC}, {364.5}, nd4j::DataType::FLOAT32);
|
||||
NDArray expGradB('c', {iC}, std::vector<double>{364.5}, nd4j::DataType::FLOAT32);
|
||||
|
||||
input = 0.5;
|
||||
weights.linspace(0.1, 0.1);
|
||||
|
|
|
@ -132,11 +132,11 @@ TEST_F(CudaBasicsTests1, execIndexReduceScalar_1) {
|
|||
NDArray x2('c', {2,2}, {0.5, 1.5, -4.5, 3.5}, nd4j::DataType::BFLOAT16);
|
||||
NDArray x3('c', {2,2}, {0, -1, 0, 1}, nd4j::DataType::BOOL);
|
||||
|
||||
NDArray scalar('c', {}, {0}, nd4j::DataType::INT64);
|
||||
NDArray scalar('c', {}, std::vector<double>{0}, nd4j::DataType::INT64);
|
||||
|
||||
NDArray exp1('c', {}, {3}, nd4j::DataType::INT64);
|
||||
NDArray exp2('c', {}, {2}, nd4j::DataType::INT64);
|
||||
NDArray exp3('c', {}, {1}, nd4j::DataType::INT64);
|
||||
NDArray exp1('c', {}, std::vector<double>{3}, nd4j::DataType::INT64);
|
||||
NDArray exp2('c', {}, std::vector<double>{2}, nd4j::DataType::INT64);
|
||||
NDArray exp3('c', {}, std::vector<double>{1}, nd4j::DataType::INT64);
|
||||
|
||||
void *dX1, *dX2, *dX3, *dZ;
|
||||
Nd4jLong *dX1ShapeInfo, *dX2ShapeInfo, *dX3ShapeInfo, *dZShapeInfo;
|
||||
|
@ -262,11 +262,11 @@ TEST_F(CudaBasicsTests1, execReduce3Scalar_1) {
|
|||
NDArray x3('c', {2,2}, {1.5,1.5,1.5,1.5}, nd4j::DataType::DOUBLE);
|
||||
NDArray x4('c', {2,2}, {1,2,3,4}, nd4j::DataType::DOUBLE);
|
||||
|
||||
NDArray exp1('c', {}, {-30.f}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp2('c', {}, {15.}, nd4j::DataType::DOUBLE);
|
||||
NDArray exp1('c', {}, std::vector<double>{-30.f}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp2('c', {}, std::vector<double>{15.}, nd4j::DataType::DOUBLE);
|
||||
|
||||
NDArray scalar1('c', {}, {100.f}, nd4j::DataType::FLOAT32);
|
||||
NDArray scalar2('c', {}, {100.}, nd4j::DataType::DOUBLE);
|
||||
NDArray scalar1('c', {}, std::vector<double>{100.f}, nd4j::DataType::FLOAT32);
|
||||
NDArray scalar2('c', {}, std::vector<double>{100.}, nd4j::DataType::DOUBLE);
|
||||
|
||||
void *dX1, *dX2, *dX3, *dX4, *dZ1, *dZ2;
|
||||
Nd4jLong *dX1ShapeInfo, *dX3ShapeInfo, *dZ1ShapeInfo, *dZ2ShapeInfo;
|
||||
|
@ -363,8 +363,8 @@ TEST_F(CudaBasicsTests1, execReduce3_1) {
|
|||
NDArray x('c', {2,2}, {1,2,3,4}, nd4j::DataType::INT32);
|
||||
NDArray y('c', {2,2}, {-1,-2,-3,-4}, nd4j::DataType::INT32);
|
||||
|
||||
NDArray exp('c', {}, {-30.f}, nd4j::DataType::FLOAT32);
|
||||
NDArray z('c', {}, {100.f}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp('c', {}, std::vector<double>{-30.f}, nd4j::DataType::FLOAT32);
|
||||
NDArray z('c', {}, std::vector<double>{100.f}, nd4j::DataType::FLOAT32);
|
||||
|
||||
std::vector<int> dimensions = {0, 1};
|
||||
|
||||
|
@ -415,8 +415,8 @@ TEST_F(CudaBasicsTests1, execReduce3_2) {
|
|||
NDArray x('c', {2,2}, {1.5,1.5,1.5,1.5}, nd4j::DataType::DOUBLE);
|
||||
NDArray y('c', {2,2}, {1,2,3,4}, nd4j::DataType::DOUBLE);
|
||||
|
||||
NDArray exp('c', {}, {15.}, nd4j::DataType::DOUBLE);
|
||||
NDArray z('c', {}, {100.}, nd4j::DataType::DOUBLE);
|
||||
NDArray exp('c', {}, std::vector<double>{15.}, nd4j::DataType::DOUBLE);
|
||||
NDArray z('c', {}, std::vector<double>{100.}, nd4j::DataType::DOUBLE);
|
||||
|
||||
std::vector<int> dimensions = {0, 1};
|
||||
|
||||
|
@ -975,7 +975,7 @@ TEST_F(CudaBasicsTests1, execScalar_1) {
|
|||
|
||||
NDArray x('c', {2,3}, {0,1,2,3,4,5}, nd4j::DataType::INT64);
|
||||
NDArray exp('c',{2,3}, {0,0,1,1,2,2}, nd4j::DataType::INT64);
|
||||
NDArray scalar('c',{}, {2.f}, nd4j::DataType::FLOAT32);
|
||||
NDArray scalar('c',{}, std::vector<double>{2.f}, nd4j::DataType::FLOAT32);
|
||||
NDArray z('c', {2,3}, {100,100,100,100,100,100}, nd4j::DataType::INT64);
|
||||
|
||||
// create cuda stream and LaunchContext
|
||||
|
@ -1010,7 +1010,7 @@ TEST_F(CudaBasicsTests1, execScalar_2) {
|
|||
|
||||
NDArray x('c', {2,3}, {-1,-2,-3,-4,-5,-6}, nd4j::DataType::INT64);
|
||||
NDArray exp('c',{2,3}, {10,10,10,10,10,10}, nd4j::DataType::FLOAT32);
|
||||
NDArray scalar('c',{}, {10.f}, nd4j::DataType::FLOAT32);
|
||||
NDArray scalar('c',{}, std::vector<double>{10.f}, nd4j::DataType::FLOAT32);
|
||||
NDArray z('c', {2,3}, {100,100,100,100,100,100}, nd4j::DataType::FLOAT32);
|
||||
|
||||
// create cuda stream and LaunchContext
|
||||
|
@ -1103,7 +1103,7 @@ TEST_F(CudaBasicsTests1, execScalar_3) {
|
|||
TEST_F(CudaBasicsTests1, execScalarBool_1) {
|
||||
|
||||
NDArray x('c', {2,3}, {-1,-2,0,1,2,3}, nd4j::DataType::BFLOAT16);
|
||||
NDArray scalar('c',{}, {0}, nd4j::DataType::BFLOAT16);
|
||||
NDArray scalar('c',{}, std::vector<double>{0}, nd4j::DataType::BFLOAT16);
|
||||
NDArray exp('c',{2,3}, {0,0,0,1,1,1}, nd4j::DataType::BOOL);
|
||||
NDArray z('c', {2,3}, {100,100,100,100,100,100,}, nd4j::DataType::BOOL);
|
||||
|
||||
|
@ -2245,8 +2245,8 @@ TEST_F(CudaBasicsTests1, execReduceLong_2) {
|
|||
TEST_F(CudaBasicsTests1, execReduceFloatScalar_1) {
|
||||
|
||||
NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, nd4j::DataType::INT32);
|
||||
NDArray z('c', {}, {100}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp('c', {}, {6.5}, nd4j::DataType::FLOAT32);
|
||||
NDArray z('c', {}, std::vector<double>{100}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp('c', {}, std::vector<double>{6.5}, nd4j::DataType::FLOAT32);
|
||||
x.permutei({2,1,0});
|
||||
|
||||
// create cuda stream and LaunchContext
|
||||
|
@ -2282,8 +2282,8 @@ TEST_F(CudaBasicsTests1, execReduceFloatScalar_1) {
|
|||
TEST_F(CudaBasicsTests1, execReduceFloatScalar_2) {
|
||||
|
||||
NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, nd4j::DataType::INT32);
|
||||
NDArray z('c', {}, {100}, nd4j::DataType::DOUBLE);
|
||||
NDArray exp('c', {}, {6.5}, nd4j::DataType::DOUBLE);
|
||||
NDArray z('c', {}, std::vector<double>{100}, nd4j::DataType::DOUBLE);
|
||||
NDArray exp('c', {}, std::vector<double>{6.5}, nd4j::DataType::DOUBLE);
|
||||
|
||||
// create cuda stream and LaunchContext
|
||||
cudaError_t cudaResult;
|
||||
|
@ -2318,8 +2318,8 @@ TEST_F(CudaBasicsTests1, execReduceFloatScalar_2) {
|
|||
TEST_F(CudaBasicsTests1, execReduceSameScalar_1) {
|
||||
|
||||
NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, nd4j::DataType::INT32);
|
||||
NDArray z('c', {}, {100}, nd4j::DataType::INT32);
|
||||
NDArray exp('c', {}, {156}, nd4j::DataType::INT32);
|
||||
NDArray z('c', {}, std::vector<double>{100}, nd4j::DataType::INT32);
|
||||
NDArray exp('c', {}, std::vector<double>{156}, nd4j::DataType::INT32);
|
||||
x.permutei({2,1,0});
|
||||
|
||||
// create cuda stream and LaunchContext
|
||||
|
@ -2355,8 +2355,8 @@ TEST_F(CudaBasicsTests1, execReduceSameScalar_1) {
|
|||
TEST_F(CudaBasicsTests1, execReduceSameScalar_2) {
|
||||
|
||||
NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, nd4j::DataType::DOUBLE);
|
||||
NDArray z('c', {}, {100}, nd4j::DataType::DOUBLE);
|
||||
NDArray exp('c', {}, {156}, nd4j::DataType::DOUBLE);
|
||||
NDArray z('c', {}, std::vector<double>{100}, nd4j::DataType::DOUBLE);
|
||||
NDArray exp('c', {}, std::vector<double>{156}, nd4j::DataType::DOUBLE);
|
||||
|
||||
// create cuda stream and LaunchContext
|
||||
cudaError_t cudaResult;
|
||||
|
@ -2391,8 +2391,8 @@ TEST_F(CudaBasicsTests1, execReduceSameScalar_2) {
|
|||
TEST_F(CudaBasicsTests1, execReduceBoolScalar_1) {
|
||||
|
||||
NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,-7,-8,-9,-10,-11,-12,-13,-14,-15,-16,-17,-18}, nd4j::DataType::INT32);
|
||||
NDArray z('c', {}, {100}, nd4j::DataType::BOOL);
|
||||
NDArray exp('c', {}, {1}, nd4j::DataType::BOOL);
|
||||
NDArray z('c', {}, std::vector<double>{100}, nd4j::DataType::BOOL);
|
||||
NDArray exp('c', {}, std::vector<double>{1}, nd4j::DataType::BOOL);
|
||||
x.permutei({2,1,0});
|
||||
x.syncShape();
|
||||
|
||||
|
@ -2429,8 +2429,8 @@ TEST_F(CudaBasicsTests1, execReduceBoolScalar_1) {
|
|||
TEST_F(CudaBasicsTests1, execReduceBoolScalar_2) {
|
||||
|
||||
NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,-7,-8,-9,-10,-11,-12,-13,-14,-15,-16,-17,-18}, nd4j::DataType::DOUBLE);
|
||||
NDArray z('c', {}, {100}, nd4j::DataType::BOOL);
|
||||
NDArray exp('c', {}, {1}, nd4j::DataType::BOOL);
|
||||
NDArray z('c', {}, std::vector<double>{100}, nd4j::DataType::BOOL);
|
||||
NDArray exp('c', {}, std::vector<double>{1}, nd4j::DataType::BOOL);
|
||||
|
||||
// create cuda stream and LaunchContext
|
||||
cudaError_t cudaResult;
|
||||
|
@ -2465,8 +2465,8 @@ TEST_F(CudaBasicsTests1, execReduceBoolScalar_2) {
|
|||
TEST_F(CudaBasicsTests1, execReduceLongScalar_1) {
|
||||
|
||||
NDArray x('c', {2,3,4}, {-5,0,-3,0,-1,0,1,2,3,4,5,6,7,0,9,10,11,0,13,14,0,16,0,18}, nd4j::DataType::INT32);
|
||||
NDArray z('c', {}, {100}, nd4j::DataType::INT64);
|
||||
NDArray exp('c', {}, {17}, nd4j::DataType::INT64);
|
||||
NDArray z('c', {}, std::vector<double>{100}, nd4j::DataType::INT64);
|
||||
NDArray exp('c', {}, std::vector<double>{17}, nd4j::DataType::INT64);
|
||||
x.permutei({2,1,0});
|
||||
x.syncShape();
|
||||
|
||||
|
@ -2503,8 +2503,8 @@ TEST_F(CudaBasicsTests1, execReduceLongScalar_1) {
|
|||
TEST_F(CudaBasicsTests1, execReduceLongScalar_2) {
|
||||
|
||||
NDArray x('c', {2,3,4}, {-5,0,-3,0,-1,0,1,2,3,4,5,6,7,0,9,10,11,0,13,14,0,16,0,18}, nd4j::DataType::DOUBLE);
|
||||
NDArray z('c', {}, {100}, nd4j::DataType::INT64);
|
||||
NDArray exp('c', {}, {17}, nd4j::DataType::INT64);
|
||||
NDArray z('c', {}, std::vector<double>{100}, nd4j::DataType::INT64);
|
||||
NDArray exp('c', {}, std::vector<double>{17}, nd4j::DataType::INT64);
|
||||
|
||||
// create cuda stream and LaunchContext
|
||||
cudaError_t cudaResult;
|
||||
|
@ -2685,8 +2685,8 @@ TEST_F(CudaBasicsTests1, execReduce3TAD_4) {
|
|||
|
||||
NDArray x('c', {2,2,3}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6}, nd4j::DataType::DOUBLE);
|
||||
NDArray y('c', {2,2,3}, {10,20,30,40,50,60,70,80,90,100,110,120}, nd4j::DataType::DOUBLE);
|
||||
NDArray exp('c', {}, {1820}, nd4j::DataType::FLOAT32);
|
||||
NDArray z('c', {}, {100}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp('c', {}, std::vector<double>{1820}, nd4j::DataType::FLOAT32);
|
||||
NDArray z('c', {}, std::vector<double>{100}, nd4j::DataType::FLOAT32);
|
||||
|
||||
std::vector<int> dimensions = {0,1,2};
|
||||
|
||||
|
@ -2739,8 +2739,8 @@ TEST_F(CudaBasicsTests1, execReduce3TAD_4) {
|
|||
TEST_F(CudaBasicsTests1, execSummaryStats_1) {
|
||||
|
||||
NDArray x('c', {2,2,3}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6}, nd4j::DataType::INT64);
|
||||
NDArray exp('c', {}, {3.605551}, nd4j::DataType::FLOAT32);
|
||||
NDArray z('c', {}, {100}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp('c', {}, std::vector<double>{3.605551}, nd4j::DataType::FLOAT32);
|
||||
NDArray z('c', {}, std::vector<double>{100}, nd4j::DataType::FLOAT32);
|
||||
|
||||
// create cuda stream and LaunchContext
|
||||
cudaError_t cudaResult;
|
||||
|
@ -2881,8 +2881,8 @@ TEST_F(CudaBasicsTests1, execSummaryStats_3) {
|
|||
TEST_F(CudaBasicsTests1, execSummaryStatsScalar_1) {
|
||||
|
||||
NDArray x('c', {2,2,3}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6}, nd4j::DataType::INT64);
|
||||
NDArray exp('c', {}, {3.605551}, nd4j::DataType::FLOAT32);
|
||||
NDArray z('c', {}, {100}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp('c', {}, std::vector<double>{3.605551}, nd4j::DataType::FLOAT32);
|
||||
NDArray z('c', {}, std::vector<double>{100}, nd4j::DataType::FLOAT32);
|
||||
|
||||
// create cuda stream and LaunchContext
|
||||
cudaError_t cudaResult;
|
||||
|
|
|
@ -775,7 +775,7 @@ TEST_F(DeclarableOpsTests10, sparse_softmax_cross_entropy_loss_with_logits_test2
|
|||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests10, sparse_softmax_cross_entropy_loss_with_logits_test3) {
|
||||
|
||||
NDArray labels('c', {1}, {0}, nd4j::DataType::INT32);
|
||||
NDArray labels('c', {1}, std::vector<double>{0}, nd4j::DataType::INT32);
|
||||
auto logits = NDArrayFactory::create<double>('c', {1,3});
|
||||
auto expected = NDArrayFactory::create<double>('c', {1}, {1.20194});
|
||||
|
||||
|
@ -2735,7 +2735,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_3) {
|
|||
|
||||
NDArray images ('c', {1,2,2,1}, {1,2,3,4}, nd4j::DataType::FLOAT32);
|
||||
NDArray boxes('c', {1,4}, {0,0,1,1}, nd4j::DataType::FLOAT32);
|
||||
NDArray boxI('c', {1}, {0}, nd4j::DataType::INT64);
|
||||
NDArray boxI('c', {1}, std::vector<double>{0}, nd4j::DataType::INT64);
|
||||
NDArray cropSize = NDArrayFactory::create<Nd4jLong>({3, 3});
|
||||
|
||||
//NDArray<float> ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f});
|
||||
|
@ -2759,7 +2759,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_4) {
|
|||
|
||||
NDArray images('c', {1,2,2,1}, {1, 2, 3, 4}, nd4j::DataType::FLOAT32);
|
||||
NDArray boxes('c', {1,4}, {0,0,1,1}, nd4j::DataType::FLOAT32);
|
||||
NDArray boxI('c', {1}, {0}, nd4j::DataType::INT32);
|
||||
NDArray boxI('c', {1}, std::vector<double>({0.}), nd4j::DataType::INT32);
|
||||
NDArray cropSize = NDArrayFactory::create<int>({3, 3});
|
||||
|
||||
//NDArray<float> ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f});
|
||||
|
@ -2933,8 +2933,8 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_1) {
|
|||
|
||||
NDArray x('c', {2,3}, {-63.80f, -63.75f, -63.70f, -63.5f, 0.0f, 0.1f}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp('c', {2,3}, {-63.75f, -63.75f, -63.75f, -63.5f, 0.f, 0.f}, nd4j::DataType::FLOAT32);
|
||||
NDArray min('c', {}, {-63.65f}, nd4j::DataType::FLOAT32);
|
||||
NDArray max('c', {}, {0.1f}, nd4j::DataType::FLOAT32);
|
||||
NDArray min('c', {}, std::vector<double>{-63.65f}, nd4j::DataType::FLOAT32);
|
||||
NDArray max('c', {}, std::vector<double>{0.1f}, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::fake_quant_with_min_max_vars op;
|
||||
auto results = op.evaluate({&x, &min, &max}, {}, {});
|
||||
|
|
|
@ -121,7 +121,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test3) {
|
|||
|
||||
NDArray dLdpExp('c', {2,3,4}, {-12.49997,-13.04346, -13.63635, -14.28571,-14.99999,-15.78947, -16.66666, -17.64705,-18.75 ,-20. , -21.42857, -23.07692,
|
||||
-24.99999,-27.27272, -29.99999, -33.33332,-37.49999,-42.85713, -49.99998, -59.99998,-74.99995,-99.99992,-149.99986,-299.99911});
|
||||
NDArray dLdwExp('c', {}, {-227.77286});
|
||||
NDArray dLdwExp('c', {}, std::vector<double>{-227.77286});
|
||||
NDArray dLdlExp('c', {2,3,4}, {1.58903, 1.22117, 0.99621, 0.82911, 0.69315, 0.57634, 0.47223, 0.37689, 0.28768, 0.20273, 0.12058, 0.04002,
|
||||
-0.04002,-0.12058,-0.20273,-0.28768,-0.37689,-0.47223,-0.57634,-0.69315,-0.82911,-0.99621,-1.22117,-1.58903});
|
||||
|
||||
|
@ -246,7 +246,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test7) {
|
|||
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
||||
NDArray weights(nd4j::DataType::DOUBLE);
|
||||
|
||||
NDArray dLdwExp('c', {}, {0.});
|
||||
NDArray dLdwExp('c', {}, std::vector<double>{0.});
|
||||
|
||||
predictions.linspace(0.04, 0.04);
|
||||
labels.linspace(1);
|
||||
|
@ -350,7 +350,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test10) {
|
|||
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
||||
NDArray weights('c', {1,1}, nd4j::DataType::DOUBLE);
|
||||
|
||||
NDArray dLdwExp('c', {1,1}, {-9.49054});
|
||||
NDArray dLdwExp('c', {1,1}, std::vector<double>{-9.49054});
|
||||
|
||||
predictions.linspace(0.04, 0.04);
|
||||
labels.linspace(1);
|
||||
|
@ -1611,7 +1611,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test3) {
|
|||
|
||||
NDArray dLdpExp('c', {2,3,4}, {-0.96, -1.92, -2.88, -3.84, -4.8 , -5.76, -6.72, -7.68, -8.64, -9.6 ,-10.56,-11.52,
|
||||
-12.48,-13.44,-14.4 ,-15.36,-16.32,-17.28,-18.24,-19.2 ,-20.16,-21.12,-22.08,-23.04});
|
||||
NDArray dLdwExp('c', {}, {4515.84});
|
||||
NDArray dLdwExp('c', {}, std::vector<double>{4515.84});
|
||||
|
||||
predictions.linspace(0.04, 0.04);
|
||||
labels.linspace(1);
|
||||
|
@ -1730,7 +1730,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test7) {
|
|||
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
||||
NDArray weights(nd4j::DataType::DOUBLE);
|
||||
|
||||
NDArray dLdwExp('c', {}, {0.});
|
||||
NDArray dLdwExp('c', {}, std::vector<double>{0.});
|
||||
|
||||
predictions.linspace(0.04, 0.04);
|
||||
labels.linspace(1);
|
||||
|
@ -1830,7 +1830,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test10) {
|
|||
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
||||
NDArray weights('c', {1,1}, nd4j::DataType::DOUBLE);
|
||||
|
||||
NDArray dLdwExp('c', {1,1}, {188.16});
|
||||
NDArray dLdwExp('c', {1,1}, std::vector<double>{188.16});
|
||||
|
||||
predictions.linspace(0.04, 0.04);
|
||||
labels.linspace(1);
|
||||
|
@ -2056,7 +2056,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test3) {
|
|||
|
||||
NDArray dLdpExp('c', {2,3,4}, {-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,
|
||||
-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5});
|
||||
NDArray dLdwExp('c', {}, {288.});
|
||||
NDArray dLdwExp('c', {}, std::vector<double>{288.});
|
||||
|
||||
predictions.linspace(0.04, 0.04);
|
||||
labels.linspace(1);
|
||||
|
@ -2175,7 +2175,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test7) {
|
|||
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
||||
NDArray weights(nd4j::DataType::DOUBLE);
|
||||
|
||||
NDArray dLdwExp('c', {}, {0.});
|
||||
NDArray dLdwExp('c', {}, std::vector<double>{0.});
|
||||
|
||||
predictions.linspace(0.04, 0.04);
|
||||
labels.linspace(1);
|
||||
|
@ -2275,7 +2275,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test10) {
|
|||
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
||||
NDArray weights('c', {1,1}, nd4j::DataType::DOUBLE);
|
||||
|
||||
NDArray dLdwExp('c', {1,1}, {12.});
|
||||
NDArray dLdwExp('c', {1,1}, std::vector<double>{12.});
|
||||
|
||||
predictions.linspace(0.04, 0.04);
|
||||
labels.linspace(1);
|
||||
|
@ -2541,7 +2541,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test3) {
|
|||
|
||||
NDArray dLdpExp('c', {2,3,4}, {-0.18499,-0.53 ,-0.875 ,-1.22 ,-1.56501,-1.91002,-2.25504,-2.60008,-2.94514,-3.29023,-3.63534,-3.98048,
|
||||
-4.32566,-4.67087,-5.01613,-5.36143,-5.70677,-6.05217,-6.39762,-6.74313,-7.0887 ,-7.43432,-7.78001,-8.12577});
|
||||
NDArray dLdwExp('c', {}, {-91.52109});
|
||||
NDArray dLdwExp('c', {}, std::vector<double>{-91.52109});
|
||||
NDArray dLdlExp('c', {2,3,4}, {0.028, 0.014, -0., -0.014,-0.028, -0.042, -0.056, -0.07 ,-0.084, -0.098, -0.112, -0.126,
|
||||
-0.14 , -0.154, -0.168, -0.182,-0.196, -0.21 , -0.224, -0.238,-0.252, -0.266, -0.28 , -0.294});
|
||||
|
||||
|
@ -2664,7 +2664,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test7) {
|
|||
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
||||
NDArray weights(nd4j::DataType::DOUBLE);
|
||||
|
||||
NDArray dLdwExp('c', {}, {0.});
|
||||
NDArray dLdwExp('c', {}, std::vector<double>{0.});
|
||||
|
||||
logits.linspace(-0.08, 0.04);
|
||||
labels.linspace(1);
|
||||
|
@ -2766,7 +2766,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test10) {
|
|||
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
||||
NDArray weights('c', {1,1}, nd4j::DataType::DOUBLE);
|
||||
|
||||
NDArray dLdwExp('c', {1,1}, {-3.81338});
|
||||
NDArray dLdwExp('c', {1,1}, std::vector<double>{-3.81338});
|
||||
|
||||
logits.linspace(-0.08, 0.04);
|
||||
labels.linspace(1);
|
||||
|
@ -2992,7 +2992,7 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test2) {
|
|||
NDArray weights('c', {1}, nd4j::DataType::DOUBLE);
|
||||
|
||||
NDArray dLdpExp('c', {4}, {0.125, 0.125, -0.375, 0.125});
|
||||
NDArray dLdwExp('c', {1}, {1.38629});
|
||||
NDArray dLdwExp('c', {1}, std::vector<double>{1.38629});
|
||||
|
||||
logits = 2.;
|
||||
weights.assign(0.5);
|
||||
|
@ -3020,10 +3020,10 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test3) {
|
|||
|
||||
NDArray labels('c', {4}, {0,0,1,0}, nd4j::DataType::INT32);
|
||||
NDArray logits('c', {4}, nd4j::DataType::DOUBLE);
|
||||
NDArray weights('c', {}, {0}, nd4j::DataType::DOUBLE);
|
||||
NDArray weights('c', {}, std::vector<double>{0}, nd4j::DataType::DOUBLE);
|
||||
|
||||
NDArray dLdpExp('c', {4}, {0.125, 0.125, -0.375, 0.125});
|
||||
NDArray dLdwExp('c', {}, {1.38629});
|
||||
NDArray dLdwExp('c', {}, std::vector<double>{1.38629});
|
||||
|
||||
logits = 2.;
|
||||
weights.assign(0.5);
|
||||
|
@ -3051,10 +3051,10 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test4) {
|
|||
|
||||
NDArray labels('c', {4}, {0,0,1,0}, nd4j::DataType::INT32);
|
||||
NDArray logits('c', {4}, nd4j::DataType::DOUBLE);
|
||||
NDArray weights('c', {}, {0}, nd4j::DataType::DOUBLE);
|
||||
NDArray weights('c', {}, std::vector<double>{0}, nd4j::DataType::DOUBLE);
|
||||
|
||||
NDArray dLdpExp('c', {4}, {0.23521, 0.2448 , -0.7452 , 0.26519});
|
||||
NDArray dLdwExp('c', {}, {0.});
|
||||
NDArray dLdwExp('c', {}, std::vector<double>{0.});
|
||||
|
||||
logits.linspace(-0.08, 0.04);
|
||||
weights = 0.5;
|
||||
|
@ -3085,7 +3085,7 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test5) {
|
|||
NDArray weights('c', {1}, nd4j::DataType::DOUBLE);
|
||||
|
||||
NDArray dLdpExp('c', {4}, {0.1176, 0.1224, -0.3726, 0.1326});
|
||||
NDArray dLdwExp('c', {1}, {1.36729});
|
||||
NDArray dLdwExp('c', {1}, std::vector<double>{1.36729});
|
||||
|
||||
logits.linspace(-0.08, 0.04);
|
||||
weights = 0.5;
|
||||
|
@ -3321,7 +3321,7 @@ TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test4) {
|
|||
/////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test5) {
|
||||
|
||||
NDArray labels('c', {2,1}, {1,0});
|
||||
NDArray labels('c', {2,1}, std::vector<double>{1,0});
|
||||
NDArray logits('c', {2,1}, {-0.04, 0.04});
|
||||
|
||||
NDArray dLdpExp('c', {2,1}, {-0.51999, 0.51999});
|
||||
|
@ -3343,10 +3343,10 @@ TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test5) {
|
|||
/////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test6) {
|
||||
|
||||
NDArray labels('c', {1,2}, {1,1});
|
||||
NDArray labels('c', {1,2}, {1,1.});
|
||||
NDArray logits('c', {1,2}, {-0.04, 0.04});
|
||||
|
||||
NDArray dLdpExp('c', {1,2}, {0, 0});
|
||||
NDArray dLdpExp('c', {1,2}, {0, 0.});
|
||||
|
||||
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
|
||||
|
||||
|
@ -3387,10 +3387,10 @@ TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test7) {
|
|||
/////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test8) {
|
||||
|
||||
NDArray labels('c', {1}, {1});
|
||||
NDArray logits('c', {1}, {0.04});
|
||||
NDArray labels('c', {1}, std::vector<double>{1});
|
||||
NDArray logits('c', {1}, std::vector<double>{0.04});
|
||||
|
||||
NDArray dLdpExp('c', {1}, {0});
|
||||
NDArray dLdpExp('c', {1}, std::vector<double>{0});
|
||||
|
||||
nd4j::ops::softmax_cross_entropy_loss_with_logits_grad op;
|
||||
|
||||
|
@ -3483,7 +3483,7 @@ TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test2) {
|
|||
/////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test3) {
|
||||
|
||||
NDArray labels('c', {}, {1}, nd4j::DataType::INT64);
|
||||
NDArray labels('c', {}, std::vector<double>{1}, nd4j::DataType::INT64);
|
||||
NDArray logits('c', {2}, {-0.2, 0.3});
|
||||
|
||||
NDArray dLdpExp('c', {2}, {0.37754, -0.37754});
|
||||
|
@ -3529,7 +3529,7 @@ TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test4) {
|
|||
/////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test5) {
|
||||
|
||||
NDArray labels('c', {1,1}, {0}, nd4j::DataType::INT64);
|
||||
NDArray labels('c', {1,1}, std::vector<double>({0}), nd4j::DataType::INT64);
|
||||
NDArray logits('c', {1,1,2}, {-0.3,0.2});
|
||||
|
||||
NDArray dLdpExp('c', {1,1,2}, {-0.62246, 0.62246});
|
||||
|
|
|
@ -127,7 +127,7 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test3) {
|
|||
NDArray weights('c', {1}, nd4j::DataType::DOUBLE);
|
||||
|
||||
NDArray dLdpExp('c', {4}, {0.05, -0.15, -1., 0.7});
|
||||
NDArray dLdwExp('c', {1}, {1.3});
|
||||
NDArray dLdwExp('c', {1}, std::vector<double>{1.3});
|
||||
NDArray dLdlExp('c', {4}, {0.2, 0.1, -0. , -0.1});
|
||||
|
||||
predictions.linspace(-0.4, 0.2);
|
||||
|
@ -158,10 +158,10 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test4) {
|
|||
|
||||
NDArray labels('c', {1,4}, {-0.1, 0.3, 2, -1.4});
|
||||
NDArray predictions('c', {1,4}, nd4j::DataType::DOUBLE);
|
||||
NDArray weights('c', {}, {0.}, nd4j::DataType::DOUBLE);
|
||||
NDArray weights('c', {}, std::vector<double>{0.}, nd4j::DataType::DOUBLE);
|
||||
|
||||
NDArray dLdpExp('c', {1,4}, {0.05, -0.15, -1., 0.7});
|
||||
NDArray dLdwExp('c', {}, {1.3});
|
||||
NDArray dLdwExp('c', {}, std::vector<double>{1.3});
|
||||
NDArray dLdlExp('c', {1,4}, {0.2, 0.1, -0. , -0.1});
|
||||
|
||||
predictions.linspace(-0.4, 0.2);
|
||||
|
@ -196,7 +196,7 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test5) {
|
|||
NDArray weights('c', {1,1}, nd4j::DataType::DOUBLE);
|
||||
|
||||
NDArray dLdpExp('c', {4}, {0.1, -0.3, -2. , 1.4});
|
||||
NDArray dLdwExp('c', {1,1}, {0.});
|
||||
NDArray dLdwExp('c', {1,1}, std::vector<double>{0.});
|
||||
NDArray dLdlExp('c', {4}, {0.4, 0.2, -0. , -0.2});
|
||||
|
||||
predictions.linspace(-0.4, 0.2);
|
||||
|
@ -369,10 +369,10 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test9) {
|
|||
TEST_F(DeclarableOpsTests12, hinge_loss_14) {
|
||||
|
||||
NDArray logits('c', {3,4}, nd4j::DataType::DOUBLE);
|
||||
NDArray weights('c', {}, {1.});
|
||||
NDArray weights('c', {}, std::vector<double>{1.});
|
||||
NDArray labels('c', {3,4}, {0,1,1,0,1,0,1,0,1,0,1,0});
|
||||
|
||||
NDArray output('c', {}, {0.}, nd4j::DataType::DOUBLE);
|
||||
NDArray output('c', {}, std::vector<double>{0.}, nd4j::DataType::DOUBLE);
|
||||
|
||||
logits.linspace(1.);
|
||||
weights.assign(1.);
|
||||
|
@ -576,7 +576,7 @@ TEST_F(DeclarableOpsTests12, TestMinimumBP_1) {
|
|||
TEST_F(DeclarableOpsTests12, reverse_test15) {
|
||||
|
||||
NDArray x('c', {5}, {1,2,3,4,5}, nd4j::DataType::DOUBLE);
|
||||
NDArray axis('c', {}, {0}, nd4j::DataType::INT32);
|
||||
NDArray axis('c', {}, std::vector<double>{0}, nd4j::DataType::INT32);
|
||||
NDArray z('c', {5}, nd4j::DataType::DOUBLE);
|
||||
NDArray exp('c', {5}, {5,4,3,2,1}, nd4j::DataType::DOUBLE);
|
||||
|
||||
|
@ -711,7 +711,7 @@ TEST_F(DeclarableOpsTests12, multiUnique_2) {
|
|||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests12, tensormmul_6) {
|
||||
|
||||
NDArray x('c', {1}, {2}, nd4j::DataType::FLOAT32);
|
||||
NDArray x('c', {1}, std::vector<double>{2}, nd4j::DataType::FLOAT32);
|
||||
NDArray y('c', {2,1,2}, {1,2,3,4}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp('c', {2,2}, {2,4,6,8}, nd4j::DataType::FLOAT32);
|
||||
|
||||
|
@ -1140,9 +1140,9 @@ TEST_F(DeclarableOpsTests12, lrn_bp_9) {
|
|||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests12, lrn_bp_10) {
|
||||
|
||||
NDArray input('c', {1,1,1,1}, {1});
|
||||
NDArray gradO('c', {1,1,1,1}, {1});
|
||||
NDArray exp('c', {1,1,1,1}, {0.19245008});
|
||||
NDArray input('c', {1,1,1,1}, std::vector<double>{1});
|
||||
NDArray gradO('c', {1,1,1,1}, std::vector<double>{1});
|
||||
NDArray exp('c', {1,1,1,1}, std::vector<double>{0.19245008});
|
||||
|
||||
nd4j::ops::lrn_bp op;
|
||||
|
||||
|
@ -1193,8 +1193,8 @@ TEST_F(DeclarableOpsTests12, lrn_2) {
|
|||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests12, lrn_3) {
|
||||
|
||||
NDArray input('c', {1,1,1,1}, {1.});
|
||||
NDArray exp('c', {1,1,1,1}, {0.69006556});
|
||||
NDArray input('c', {1,1,1,1}, std::vector<double>{1.});
|
||||
NDArray exp('c', {1,1,1,1}, std::vector<double>{0.69006556});
|
||||
|
||||
nd4j::ops::lrn op;
|
||||
|
||||
|
@ -1208,8 +1208,8 @@ TEST_F(DeclarableOpsTests12, lrn_3) {
|
|||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests12, lrn_4) {
|
||||
|
||||
NDArray input('c', {1,1,1,1}, {1.});
|
||||
NDArray exp('c', {1,1,1,1}, {0.69006556});
|
||||
NDArray input('c', {1,1,1,1}, std::vector<double>{1.});
|
||||
NDArray exp('c', {1,1,1,1}, std::vector<double>{0.69006556});
|
||||
|
||||
nd4j::ops::lrn op;
|
||||
|
||||
|
@ -1239,10 +1239,10 @@ TEST_F(DeclarableOpsTests12, lrn_5) {
|
|||
TEST_F(DeclarableOpsTests12, inTopK_1) {
|
||||
|
||||
NDArray x('c', {4, 5}, {11.0, 14.0, 6.0, 9.0, 3.5, 7.0, 21.0, 3.0, 15.0, 6.0, 9.0, 3.5, 7.0, 11.0, 13.0, 5.0, 16.0, 9.0, 13.5, 7.0});
|
||||
NDArray y('c', {4}, {0, 0, 0, 0}, nd4j::DataType::INT64);
|
||||
NDArray z('c', {4}, {1, 1, 1, 1}, nd4j::DataType::BOOL);
|
||||
NDArray y('c', {4}, {0., 0, 0, 0}, nd4j::DataType::INT64);
|
||||
NDArray z('c', {4}, {1., 1, 1, 1}, nd4j::DataType::BOOL);
|
||||
|
||||
NDArray expV('c', {4}, {1, 0, 0, 0}, nd4j::DataType::BOOL);
|
||||
NDArray expV('c', {4}, {1., 0, 0, 0}, nd4j::DataType::BOOL);
|
||||
|
||||
nd4j::ops::in_top_k op;
|
||||
Nd4jStatus status = op.execute({&x, &y, }, {&z}, {}, {2}, {});
|
||||
|
|
|
@ -809,7 +809,7 @@ TEST_F(DeclarableOpsTests13, space_to_batch_nd_1) {
|
|||
|
||||
NDArray x('c', {1, 2, 2, 2, 3}, nd4j::DataType::FLOAT32);
|
||||
NDArray blockShape('c', {3}, {2, 2, 2} , nd4j::DataType::INT32); // three spatial dimensions
|
||||
NDArray paddings('c', {3, 2}, {0, 0, 0, 0, 0, 0} , nd4j::DataType::INT32);
|
||||
NDArray paddings('c', {3, 2}, std::vector<double>{0, 0, 0, 0, 0, 0} , nd4j::DataType::INT32);
|
||||
|
||||
NDArray exp('c', {8, 1, 1, 1, 3}, nd4j::DataType::FLOAT32);
|
||||
|
||||
|
@ -892,8 +892,8 @@ TEST_F(DeclarableOpsTests13, batch_to_space_nd_1) {
|
|||
|
||||
NDArray x('c', {8, 1, 1, 1, 3}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray blockShape('c', {3}, {2, 2, 2} , nd4j::DataType::INT32); // three spatial dimensions
|
||||
NDArray crop('c', {3, 2}, {0, 0, 0, 0, 0, 0} , nd4j::DataType::INT32);
|
||||
NDArray blockShape('c', {3}, {2., 2, 2} , nd4j::DataType::INT32); // three spatial dimensions
|
||||
NDArray crop('c', {3, 2}, {0., 0, 0, 0, 0, 0} , nd4j::DataType::INT32);
|
||||
|
||||
NDArray exp('c', {1, 2, 2, 2, 3}, nd4j::DataType::FLOAT32);
|
||||
|
||||
|
@ -990,7 +990,7 @@ TEST_F(DeclarableOpsTests13, mergemax_1) {
|
|||
TEST_F(DeclarableOpsTests13, mergemax_2) {
|
||||
|
||||
NDArray x1('c', {1, 3}, {0., 1, 2}, nd4j::DataType::FLOAT32);
|
||||
NDArray x2('c', {1, 1}, {1.}, nd4j::DataType::FLOAT32);
|
||||
NDArray x2('c', {1, 1}, std::vector<double>{1.}, nd4j::DataType::FLOAT32);
|
||||
NDArray out('c', {1, 3}, {-1., -1, -1}, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::mergemax op;
|
||||
|
@ -2143,10 +2143,10 @@ TEST_F(DeclarableOpsTests13, batchnorm_test7) {
|
|||
NDArray input2('c', {3,15,15,3}, nd4j::DataType::FLOAT32);
|
||||
input2.permutei({0,3,1,2});
|
||||
|
||||
NDArray mean ('c', {3}, {0, 0, 0}, nd4j::DataType::FLOAT32);
|
||||
NDArray variance('c', {3}, {1, 1, 1}, nd4j::DataType::FLOAT32);
|
||||
NDArray gamma ('c', {3}, {1, 1, 1}, nd4j::DataType::FLOAT32);
|
||||
NDArray beta ('c', {3}, {0, 0, 0}, nd4j::DataType::FLOAT32);
|
||||
NDArray mean ('c', {3}, {0., 0, 0}, nd4j::DataType::FLOAT32);
|
||||
NDArray variance('c', {3}, {1., 1, 1}, nd4j::DataType::FLOAT32);
|
||||
NDArray gamma ('c', {3}, {1., 1, 1}, nd4j::DataType::FLOAT32);
|
||||
NDArray beta ('c', {3}, {0., 0, 0}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray out1('c', {3,3,15,15}, nd4j::DataType::FLOAT32);
|
||||
NDArray out2('c', {3,3,15,15}, nd4j::DataType::FLOAT32);
|
||||
|
|
|
@ -858,7 +858,7 @@ TEST_F(DeclarableOpsTests15, test_empty_decreasing_1) {
|
|||
TEST_F(DeclarableOpsTests15, test_rgb_to_grs_1) {
|
||||
// rank 1
|
||||
NDArray rgbs('c', { 3 }, { 10, 50, 200 }, nd4j::DataType::INT32);
|
||||
NDArray expected('c', { 1 }, { 55 }, nd4j::DataType::INT32);
|
||||
NDArray expected('c', { 1 }, std::vector<double>{ 55 }, nd4j::DataType::INT32);
|
||||
nd4j::ops::rgb_to_grs op;
|
||||
auto result = op.evaluate({&rgbs}, {}, {});
|
||||
auto output = result->at(0);
|
||||
|
@ -1395,7 +1395,7 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test6) {
|
|||
y.assign(4.0);
|
||||
dLdzC.linspace(0.1, 0.1);
|
||||
|
||||
NDArray dLdxExpXC('c', { 1 }, { 115.2 }, nd4j::DataType::FLOAT32);
|
||||
NDArray dLdxExpXC('c', { 1 }, std::vector<double>{ 115.2 }, nd4j::DataType::FLOAT32);
|
||||
NDArray dLdyExpXC('c', { 2, 2, 2 }, { 1.10904, 2.21807, 3.32711, 4.43614, 5.54518, 6.65421, 7.76325, 8.87228 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::Pow_bp op;
|
||||
|
|
|
@ -55,11 +55,11 @@ TEST_F(DeclarableOpsTests17, test_sparse_to_dense_1) {
|
|||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests17, test_sparse_to_dense_2) {
|
||||
auto values = NDArrayFactory::string('c', {3}, {"alpha", "beta", "gamma"});
|
||||
auto values = NDArrayFactory::string({3}, {"alpha", "beta", "gamma"});
|
||||
auto shape = NDArrayFactory::create<Nd4jLong>({3, 3});
|
||||
auto ranges = NDArrayFactory::create<Nd4jLong>({0,0, 1,1, 2,2});
|
||||
auto def = NDArrayFactory::string("d");
|
||||
auto exp = NDArrayFactory::string('c', {3, 3}, {"alpha","d","d", "d","beta","d", "d","d","gamma"});
|
||||
auto exp = NDArrayFactory::string( {3, 3}, {"alpha","d","d", "d","beta","d", "d","d","gamma"});
|
||||
|
||||
|
||||
nd4j::ops::compat_sparse_to_dense op;
|
||||
|
@ -70,11 +70,11 @@ TEST_F(DeclarableOpsTests17, test_sparse_to_dense_2) {
|
|||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests17, test_compat_string_split_1) {
|
||||
auto x = NDArrayFactory::string('c', {2}, {"first string", "second"});
|
||||
auto x = NDArrayFactory::string( {2}, {"first string", "second"});
|
||||
auto delimiter = NDArrayFactory::string(" ");
|
||||
|
||||
auto exp0 = NDArrayFactory::create<Nd4jLong>({0,0, 0,1, 1,0});
|
||||
auto exp1 = NDArrayFactory::string('c', {3}, {"first", "string", "second"});
|
||||
auto exp1 = NDArrayFactory::string( {3}, {"first", "string", "second"});
|
||||
|
||||
nd4j::ops::compat_string_split op;
|
||||
auto result = op.evaluate({&x, &delimiter});
|
||||
|
|
|
@ -79,7 +79,7 @@ TEST_F(DeclarableOpsTests2, gather_2) {
|
|||
TEST_F(DeclarableOpsTests2, gather_3) {
|
||||
|
||||
NDArray input ('c', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24});
|
||||
NDArray indices ('c', {1,1}, {2}, nd4j::DataType::INT32);
|
||||
NDArray indices ('c', {1,1}, std::vector<double>{2}, nd4j::DataType::INT32);
|
||||
NDArray expected('c', {2,1,1,4}, {9,10,11,12,21,22,23,24});
|
||||
|
||||
nd4j::ops::gather op;
|
||||
|
@ -186,7 +186,7 @@ TEST_F(DeclarableOpsTests2, gather_7) {
|
|||
TEST_F(DeclarableOpsTests2, gather_8) {
|
||||
|
||||
NDArray input('c', {3,5}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, nd4j::DataType::FLOAT32);
|
||||
NDArray indices('c', {1}, {2}, nd4j::DataType::INT32);
|
||||
NDArray indices('c', {1}, std::vector<double>{2}, nd4j::DataType::INT32);
|
||||
NDArray expected('c', {1,5}, {11, 12, 13, 14, 15.}, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::gather op;
|
||||
|
@ -206,7 +206,7 @@ TEST_F(DeclarableOpsTests2, gather_8) {
|
|||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests2, gather_9) {
|
||||
NDArray x('c', {2, 4, 3, 2}, nd4j::DataType::FLOAT32);
|
||||
NDArray indices('c', {2}, {1, 0}, nd4j::DataType::INT32);
|
||||
NDArray indices('c', {2}, std::vector<double>{1, 0}, nd4j::DataType::INT32);
|
||||
|
||||
nd4j::ops::gather op;
|
||||
auto result = op.evaluate({&x, &indices}, {}, {-2});
|
||||
|
@ -238,7 +238,7 @@ TEST_F(DeclarableOpsTests2, gather_10) {
|
|||
TEST_F(DeclarableOpsTests2, gather_11) {
|
||||
|
||||
NDArray x('c', {2, 2}, {1, 2, 3, 4});
|
||||
NDArray indices('c', {2}, {1, 0}, nd4j::DataType::INT64);
|
||||
NDArray indices('c', {2}, std::vector<double>{1, 0}, nd4j::DataType::INT64);
|
||||
NDArray e('c', {2, 2}, {3, 4, 1, 2});
|
||||
|
||||
nd4j::ops::gather op;
|
||||
|
|
|
@ -243,7 +243,7 @@ TEST_F(DeclarableOpsTests5, Test_SetSeed_1) {
|
|||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests5, scatterMul_test1) {
|
||||
auto matrix = NDArrayFactory::create<float>('c', {2, 2}, {1.f, 2.f, 3.f, 4.f});
|
||||
NDArray idc('c', {1}, {0LL}, nd4j::DataType::INT64);
|
||||
NDArray idc('c', {1}, std::vector<double>({0LL}), nd4j::DataType::INT64);
|
||||
auto updates = NDArrayFactory::create<float>('c', {1, 2}, {10.f, 1.f});
|
||||
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {10.f, 2.f, 3.f, 4.f});
|
||||
|
||||
|
@ -261,7 +261,7 @@ TEST_F(DeclarableOpsTests5, scatterMul_test1) {
|
|||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests5, scatterDiv_test1) {
|
||||
auto matrix = NDArrayFactory::create<float>('c', {2, 2}, {1.f, 2.f, 3.f, 4.f});
|
||||
NDArray idc('c', {1}, {0LL}, nd4j::DataType::INT64);
|
||||
NDArray idc('c', {1}, std::vector<double>({0LL}), nd4j::DataType::INT64);
|
||||
auto updates = NDArrayFactory::create<float>('c', {1, 2}, {10.f, 1.f});
|
||||
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {0.10f, 2.f, 3.f, 4.f});
|
||||
|
||||
|
@ -279,7 +279,7 @@ TEST_F(DeclarableOpsTests5, scatterDiv_test1) {
|
|||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests5, scatterSub_test1) {
|
||||
auto matrix = NDArrayFactory::create<float>('c', {2, 2}, {1.f, 2.f, 3.f, 4.f});
|
||||
NDArray idc('c', {1}, {0LL}, nd4j::DataType::INT64);
|
||||
NDArray idc('c', {1}, std::vector<double>({0LL}), nd4j::DataType::INT64);
|
||||
auto updates = NDArrayFactory::create<float>('c', {1, 2}, {10.f, 1.f});
|
||||
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {-9.f, 1.f, 3.f, 4.f});
|
||||
|
||||
|
|
|
@ -1411,7 +1411,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_2) {
|
|||
TEST_F(DeclarableOpsTests6, MatrixDeterminant_3) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {1, 3, 3}, {3.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 3.0});
|
||||
NDArray exp('c', {1}, {-54.0});
|
||||
NDArray exp('c', {1}, std::vector<double>{-54.0});
|
||||
|
||||
nd4j::ops::matrix_determinant op;
|
||||
auto result = op.evaluate({&x}, {}, {});
|
||||
|
@ -1453,7 +1453,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_4) {
|
|||
TEST_F(DeclarableOpsTests6, MatrixDeterminant_5) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {1, 4, 4});
|
||||
NDArray exp('c', {1}, {-16.0});
|
||||
NDArray exp('c', {1}, std::vector<double>{-16.0});
|
||||
x.linspace(1);
|
||||
x.p(5, 4.0);
|
||||
x.p(12, 12.0);
|
||||
|
|
|
@ -83,7 +83,7 @@ TEST_F(FlatUtilsTests, flat_bool_serde_1) {
|
|||
}
|
||||
|
||||
TEST_F(FlatUtilsTests, flat_string_serde_1) {
|
||||
auto array = NDArrayFactory::string('c', {3}, {"alpha", "beta", "gamma"});
|
||||
auto array = NDArrayFactory::string( {3}, {"alpha", "beta", "gamma"});
|
||||
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
auto flatArray = FlatUtils::toFlatArray(builder, array);
|
||||
|
|
|
@ -1277,14 +1277,14 @@ TEST_F(JavaInteropTests, test_size_dtype_1) {
|
|||
}
|
||||
|
||||
TEST_F(JavaInteropTests, test_expandable_array_op_1) {
|
||||
auto x = NDArrayFactory::string('c', {2}, {"first string", "second"});
|
||||
auto d = NDArrayFactory::string(" ");
|
||||
auto x = NDArrayFactory::string( {2}, {"first string", "second"});
|
||||
auto d = NDArrayFactory::string(" ", nd4j::DataType::UTF8);
|
||||
|
||||
auto z0 = NDArrayFactory::create<Nd4jLong>('c', {6});
|
||||
auto z1 = NDArrayFactory::string('c', {3}, {"", "", ""});
|
||||
auto z1 = NDArrayFactory::string( {3}, {"", "", ""});
|
||||
|
||||
auto exp0 = NDArrayFactory::create<Nd4jLong>({0,0, 0,1, 1,0});
|
||||
auto exp1 = NDArrayFactory::string('c', {3}, {"first", "string", "second"});
|
||||
auto exp1 = NDArrayFactory::string( {3}, {"first", "string", "second"});
|
||||
|
||||
InteropDataBuffer iz0(z0.dataBuffer());
|
||||
InteropDataBuffer iz1(z1.dataBuffer());
|
||||
|
|
|
@ -204,7 +204,7 @@ TEST_F(MultiDataTypeTests, ndarray_repeat_test1) {
|
|||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(MultiDataTypeTests, ndarray_bufferAsT_test1) {
|
||||
NDArray x('f', {2}, {1.5, 3.5}, nd4j::DataType::FLOAT32);
|
||||
NDArray y('c', {}, {1.5}, nd4j::DataType::FLOAT32);
|
||||
NDArray y('c', {}, std::vector<double>{1.5}, nd4j::DataType::FLOAT32);
|
||||
|
||||
const int* buffX = x.bufferAsT<int>();
|
||||
const int* buffY = y.bufferAsT<int>();
|
||||
|
@ -217,8 +217,8 @@ TEST_F(MultiDataTypeTests, ndarray_assign_test1) {
|
|||
NDArray x('c', {2,2}, {0, 1, 2, 3}, nd4j::DataType::UINT8);
|
||||
NDArray exp('c', {2,2}, {10, 10, 20, 20}, nd4j::DataType::UINT8);
|
||||
|
||||
NDArray scalar1('c', {}, {10.5}, nd4j::DataType::FLOAT32);
|
||||
NDArray scalar2('c', {}, {20.8}, nd4j::DataType::DOUBLE);
|
||||
NDArray scalar1('c', {}, std::vector<double>{10.5}, nd4j::DataType::FLOAT32);
|
||||
NDArray scalar2('c', {}, std::vector<double>{20.8}, nd4j::DataType::DOUBLE);
|
||||
|
||||
x(0,{0}).assign(scalar1);
|
||||
x(1,{0}).assign(scalar2);
|
||||
|
@ -233,9 +233,9 @@ TEST_F(MultiDataTypeTests, ndarray_assign_test1) {
|
|||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test1) {
|
||||
NDArray x('f', {2,2}, {0, 1.5, 2.5, 3.5}, nd4j::DataType::HALF);
|
||||
NDArray exp1('c', {}, {3}, nd4j::DataType::INT64);
|
||||
NDArray exp2('c', {1,1}, {1}, nd4j::DataType::INT64);
|
||||
NDArray exp3('c', {2}, {1,2}, nd4j::DataType::INT64);
|
||||
NDArray exp1('c', {}, std::vector<double>{3}, nd4j::DataType::INT64);
|
||||
NDArray exp2('c', {1,1}, std::vector<double>{1}, nd4j::DataType::INT64);
|
||||
NDArray exp3('c', {2}, std::vector<double>{1,2}, nd4j::DataType::INT64);
|
||||
|
||||
auto scalar1 = x.reduceAlongDimension(nd4j::reduce::CountNonZero, {}/*whole range*/);
|
||||
ASSERT_EQ(scalar1, exp1);
|
||||
|
@ -250,7 +250,7 @@ TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test1) {
|
|||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test2) {
|
||||
NDArray x('c', {2, 2}, {0, 1, 2, 3}, nd4j::DataType::INT32);
|
||||
NDArray exp1('c', {}, {1.5}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp1('c', {}, std::vector<double>{1.5}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp2('c', {2}, {0.5,2.5}, nd4j::DataType::FLOAT32);
|
||||
|
||||
auto scalar1 = x.reduceAlongDimension(nd4j::reduce::Mean, {}/*whole range*/);
|
||||
|
@ -265,7 +265,7 @@ TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test2) {
|
|||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test3) {
|
||||
NDArray x('c', {2, 2}, {0.5, 1.5, 2.5, 3.5}, nd4j::DataType::HALF);
|
||||
NDArray exp1('c', {}, {8.}, nd4j::DataType::HALF);
|
||||
NDArray exp1('c', {}, std::vector<double>{8.}, nd4j::DataType::HALF);
|
||||
NDArray exp2('c', {2}, {2.,6.}, nd4j::DataType::HALF);
|
||||
|
||||
auto scalar1 = x.reduceAlongDimension(nd4j::reduce::Sum, {}/*whole range*/);
|
||||
|
@ -278,8 +278,8 @@ TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test3) {
|
|||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test4) {
|
||||
NDArray x('c', {2, 2}, {10.5, 1.5, -2.5, -3.5}, nd4j::DataType::HALF);
|
||||
NDArray exp1('c', {}, {1}, nd4j::DataType::BOOL);
|
||||
NDArray exp2('c', {2}, {1,0}, nd4j::DataType::BOOL);
|
||||
NDArray exp1('c', {}, std::vector<double>{1}, nd4j::DataType::BOOL);
|
||||
NDArray exp2('c', {2}, std::vector<double>{1, 0}, nd4j::DataType::BOOL);
|
||||
|
||||
auto scalar1 = x.reduceAlongDimension(nd4j::reduce::IsPositive, {}/*whole range*/);
|
||||
ASSERT_EQ(scalar1, exp1);
|
||||
|
@ -291,8 +291,8 @@ TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test4) {
|
|||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(MultiDataTypeTests, ndarray_varianceNumber_test1) {
|
||||
NDArray x('f', {2, 2}, {0, 1, 2, 3}, nd4j::DataType::INT64);
|
||||
NDArray exp1('c', {}, {1.666666667}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp2('c', {}, {1.118033989}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp1('c', {}, std::vector<double>{1.666666667}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp2('c', {}, std::vector<double>{1.118033989}, nd4j::DataType::FLOAT32);
|
||||
|
||||
auto scalar1 = x.varianceNumber(variance::SummaryStatsVariance);
|
||||
ASSERT_EQ(scalar1, exp1);
|
||||
|
@ -475,8 +475,8 @@ TEST_F(MultiDataTypeTests, ndarray_operatorPlusEqual_test1) {
|
|||
if (!Environment::getInstance()->isExperimentalBuild())
|
||||
return;
|
||||
|
||||
NDArray scalar1('c', {0}, {4}, nd4j::DataType::INT32);
|
||||
NDArray scalar2('c', {0}, {1.5}, nd4j::DataType::HALF);
|
||||
NDArray scalar1('c', {0}, std::vector<double>{4}, nd4j::DataType::INT32);
|
||||
NDArray scalar2('c', {0}, std::vector<double>{1.5}, nd4j::DataType::HALF);
|
||||
|
||||
NDArray x1('c', {2,3}, {1.5, 2.5, 3.5, 4.5, 5.5, 6.5}, nd4j::DataType::FLOAT32);
|
||||
NDArray x2('c', {3,2}, {10, 20, 30, 40, 50, 60}, nd4j::DataType::INT64);
|
||||
|
@ -485,8 +485,8 @@ TEST_F(MultiDataTypeTests, ndarray_operatorPlusEqual_test1) {
|
|||
NDArray x5('c', {2,2}, {0, 1, 2, 3}, nd4j::DataType::HALF);
|
||||
NDArray x6('c', {2}, {0.4, 0.5}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray exp1('c', {0}, {5}, nd4j::DataType::INT32);
|
||||
NDArray exp2('c', {0}, {6.5}, nd4j::DataType::HALF);
|
||||
NDArray exp1('c', {0}, std::vector<double>{5}, nd4j::DataType::INT32);
|
||||
NDArray exp2('c', {0}, std::vector<double>{6.5}, nd4j::DataType::HALF);
|
||||
NDArray exp3('c', {3,2}, {11, 22, 33, 44, 55, 66}, nd4j::DataType::INT64);
|
||||
NDArray exp4('c', {2,3}, {12.5, 24.5, 36.5, 48.5, 60.5, 72.5}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp5('c', {2,2}, {0.4, 1.5, 2.4, 3.5}, nd4j::DataType::HALF);
|
||||
|
@ -553,8 +553,8 @@ TEST_F(MultiDataTypeTests, ndarray_operatorMinusEqual_test1) {
|
|||
if (!Environment::getInstance()->isExperimentalBuild())
|
||||
return;
|
||||
|
||||
NDArray scalar1('c', {0}, {4}, nd4j::DataType::INT32);
|
||||
NDArray scalar2('c', {0}, {1.5}, nd4j::DataType::HALF);
|
||||
NDArray scalar1('c', {0}, std::vector<double>{4}, nd4j::DataType::INT32);
|
||||
NDArray scalar2('c', {0}, std::vector<double>{1.5}, nd4j::DataType::HALF);
|
||||
|
||||
NDArray x1('c', {2,3}, {1.5, 2.5, 3.5, 4.5, 5.5, 6.5}, nd4j::DataType::FLOAT32);
|
||||
NDArray x2('c', {3,2}, {10, 20, 30, 40, 50, 60}, nd4j::DataType::INT64);
|
||||
|
@ -563,8 +563,8 @@ TEST_F(MultiDataTypeTests, ndarray_operatorMinusEqual_test1) {
|
|||
NDArray x5('c', {2,2}, {0, 1, 2, 3}, nd4j::DataType::HALF);
|
||||
NDArray x6('c', {2}, {0.4, 0.5}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray exp1('c', {0}, {2}, nd4j::DataType::INT32);
|
||||
NDArray exp2('c', {0}, {-0.5}, nd4j::DataType::HALF);
|
||||
NDArray exp1('c', {0}, std::vector<double>{2}, nd4j::DataType::INT32);
|
||||
NDArray exp2('c', {0}, std::vector<double>{-0.5}, nd4j::DataType::HALF);
|
||||
NDArray exp3('c', {3,2}, {8, 17, 26, 35, 44, 53}, nd4j::DataType::INT64);
|
||||
NDArray exp4('c', {2,3}, {-6.5, -14.5, -22.5, -30.5, -38.5, -46.5}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp5('c', {2,2}, {0.4, -0.5, -1.6, -2.5}, nd4j::DataType::HALF);
|
||||
|
@ -631,8 +631,8 @@ TEST_F(MultiDataTypeTests, ndarray_operatorMultiplyEqual_test1) {
|
|||
if (!Environment::getInstance()->isExperimentalBuild())
|
||||
return;
|
||||
|
||||
NDArray scalar1('c', {0}, {3}, nd4j::DataType::INT32);
|
||||
NDArray scalar2('c', {0}, {2.5}, nd4j::DataType::HALF);
|
||||
NDArray scalar1('c', {0}, std::vector<double>{3}, nd4j::DataType::INT32);
|
||||
NDArray scalar2('c', {0}, std::vector<double>{2.5}, nd4j::DataType::HALF);
|
||||
|
||||
NDArray x1('c', {2,3}, {1.5, 2.5, 3.5, 4.5, 5.5, 6.5}, nd4j::DataType::FLOAT32);
|
||||
NDArray x2('c', {3,2}, {1, 2, 3, 4, 5, 6}, nd4j::DataType::INT64);
|
||||
|
@ -641,8 +641,8 @@ TEST_F(MultiDataTypeTests, ndarray_operatorMultiplyEqual_test1) {
|
|||
NDArray x5('c', {2,2}, {0, 1, 2, 3}, nd4j::DataType::HALF);
|
||||
NDArray x6('c', {2}, {0.4, 0.5}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray exp1('c', {0}, {7}, nd4j::DataType::INT32);
|
||||
NDArray exp2('c', {0}, {17.5}, nd4j::DataType::HALF);
|
||||
NDArray exp1('c', {0}, std::vector<double>{7}, nd4j::DataType::INT32);
|
||||
NDArray exp2('c', {0}, std::vector<double>{17.5}, nd4j::DataType::HALF);
|
||||
NDArray exp3('c', {3,2}, {1, 5, 10, 18, 27, 39}, nd4j::DataType::INT64);
|
||||
NDArray exp4('c', {2,3}, {1.5, 12.5, 35, 81, 148.5, 253.5}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp5('c', {2,2}, {0., 0.5, 0.8, 1.5}, nd4j::DataType::HALF);
|
||||
|
@ -709,8 +709,8 @@ TEST_F(MultiDataTypeTests, ndarray_operatorDivideEqual_test1) {
|
|||
if (!Environment::getInstance()->isExperimentalBuild())
|
||||
return;
|
||||
|
||||
NDArray scalar1('c', {0}, {3}, nd4j::DataType::INT32);
|
||||
NDArray scalar2('c', {0}, {2.5}, nd4j::DataType::HALF);
|
||||
NDArray scalar1('c', {0}, std::vector<double>{3}, nd4j::DataType::INT32);
|
||||
NDArray scalar2('c', {0}, std::vector<double>{2.5}, nd4j::DataType::HALF);
|
||||
|
||||
NDArray x1('c', {2,3}, {1.5, 2.5, 3.5, 4.5, 5.5, 6.5}, nd4j::DataType::FLOAT32);
|
||||
NDArray x2('c', {3,2}, {10, 20, 30, 40, 50, 60}, nd4j::DataType::INT64);
|
||||
|
@ -719,8 +719,8 @@ TEST_F(MultiDataTypeTests, ndarray_operatorDivideEqual_test1) {
|
|||
NDArray x5('c', {2,2}, {1, 2, 3, 4}, nd4j::DataType::HALF);
|
||||
NDArray x6('c', {2}, {0.4, 0.5}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray exp1('c', {0}, {1}, nd4j::DataType::INT32);
|
||||
NDArray exp2('c', {0}, {2.5}, nd4j::DataType::HALF);
|
||||
NDArray exp1('c', {0}, std::vector<double>{1}, nd4j::DataType::INT32);
|
||||
NDArray exp2('c', {0}, std::vector<double>{2.5}, nd4j::DataType::HALF);
|
||||
NDArray exp3('c', {3,2}, {6, 8, 8, 8, 9, 9}, nd4j::DataType::INT64);
|
||||
NDArray exp4('c', {2,3}, {0.25, 0.3125, 0.4375, 0.5625, 0.611111111, 0.722222222}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp5('c', {2,2}, {0.4, 0.25, 0.1333333, 0.125}, nd4j::DataType::HALF);
|
||||
|
@ -792,10 +792,10 @@ TEST_F(MultiDataTypeTests, ndarray_reduceNumberFloat_test1) {
|
|||
NDArray x3('c', {2,2}, {0.5, 1.5, 2.5, 3.5}, nd4j::DataType::DOUBLE);
|
||||
NDArray x4('c', {2,2}, {0, 1, 0, 1}, nd4j::DataType::BOOL);
|
||||
|
||||
NDArray exp1('c', {0}, {1.5}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp2('c', {0}, {2}, nd4j::DataType::HALF);
|
||||
NDArray exp3('c', {0}, {2}, nd4j::DataType::DOUBLE);
|
||||
NDArray exp4('c', {0}, {0.25},nd4j::DataType::FLOAT32);
|
||||
NDArray exp1('c', {0}, std::vector<double>{1.5}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp2('c', {0}, std::vector<double>{2}, nd4j::DataType::HALF);
|
||||
NDArray exp3('c', {0}, std::vector<double>{2}, nd4j::DataType::DOUBLE);
|
||||
NDArray exp4('c', {0}, std::vector<double>{0.25},nd4j::DataType::FLOAT32);
|
||||
|
||||
|
||||
NDArray scalar = x1.reduceNumber(reduce::Mean);
|
||||
|
@ -829,10 +829,10 @@ TEST_F(MultiDataTypeTests, ndarray_reduceNumberSame_test1) {
|
|||
NDArray x3('c', {2,2}, {0.5, 1.5, 2.5, 3.5}, nd4j::DataType::DOUBLE);
|
||||
NDArray x4('c', {2,2}, {0, 1, 0, 1}, nd4j::DataType::BOOL);
|
||||
|
||||
NDArray exp1('c', {0}, {6}, nd4j::DataType::INT64);
|
||||
NDArray exp2('c', {0}, {8}, nd4j::DataType::HALF);
|
||||
NDArray exp3('c', {0}, {8}, nd4j::DataType::DOUBLE);
|
||||
NDArray exp4('c', {0}, {1}, nd4j::DataType::BOOL);
|
||||
NDArray exp1('c', {0}, std::vector<double>{6}, nd4j::DataType::INT64);
|
||||
NDArray exp2('c', {0}, std::vector<double>{8}, nd4j::DataType::HALF);
|
||||
NDArray exp3('c', {0}, std::vector<double>{8}, nd4j::DataType::DOUBLE);
|
||||
NDArray exp4('c', {0}, std::vector<double>{1}, nd4j::DataType::BOOL);
|
||||
|
||||
|
||||
NDArray scalar = x1.reduceNumber(reduce::Sum);
|
||||
|
@ -866,7 +866,7 @@ TEST_F(MultiDataTypeTests, ndarray_reduceNumberBool_test1) {
|
|||
NDArray x3('c', {2,2}, {0.5, 1.5, 2.5, 3.5}, nd4j::DataType::DOUBLE);
|
||||
NDArray x4('c', {2,2}, {-2, -1, 0, 1}, nd4j::DataType::BOOL);
|
||||
|
||||
NDArray exp1('c', {0}, {1}, nd4j::DataType::BOOL);
|
||||
NDArray exp1('c', {0}, std::vector<double>{1}, nd4j::DataType::BOOL);
|
||||
|
||||
NDArray scalar = x1.reduceNumber(reduce::IsFinite);
|
||||
ASSERT_EQ(scalar, exp1);
|
||||
|
@ -899,10 +899,10 @@ TEST_F(MultiDataTypeTests, ndarray_reduceNumberLong_test1) {
|
|||
NDArray x3('c', {2,2}, {0.5, -1.5, 0, 3.5}, nd4j::DataType::DOUBLE);
|
||||
NDArray x4('c', {2,2}, {0, 1, 0, 1}, nd4j::DataType::BOOL);
|
||||
|
||||
NDArray exp1('c', {0}, {3}, nd4j::DataType::INT64);
|
||||
NDArray exp2('c', {0}, {4}, nd4j::DataType::INT64);
|
||||
NDArray exp3('c', {0}, {3}, nd4j::DataType::INT64);
|
||||
NDArray exp4('c', {0}, {2}, nd4j::DataType::INT64);
|
||||
NDArray exp1('c', {0}, std::vector<double>{3}, nd4j::DataType::INT64);
|
||||
NDArray exp2('c', {0}, std::vector<double>{4}, nd4j::DataType::INT64);
|
||||
NDArray exp3('c', {0}, std::vector<double>{3}, nd4j::DataType::INT64);
|
||||
NDArray exp4('c', {0}, std::vector<double>{2}, nd4j::DataType::INT64);
|
||||
|
||||
NDArray scalar = x1.reduceNumber(reduce::CountNonZero);
|
||||
ASSERT_EQ(scalar, exp1);
|
||||
|
@ -934,9 +934,9 @@ TEST_F(MultiDataTypeTests, ndarray_indexReduceNumber_test1) {
|
|||
NDArray x2('c', {2,2}, {0.5, 1.5, -4.5, 3.5}, nd4j::DataType::HALF);
|
||||
NDArray x3('c', {2,2}, {0, -1, 0, 1}, nd4j::DataType::BOOL);
|
||||
|
||||
NDArray exp1('c', {0}, {3}, nd4j::DataType::INT64);
|
||||
NDArray exp2('c', {0}, {2}, nd4j::DataType::INT64);
|
||||
NDArray exp3('c', {0}, {1}, nd4j::DataType::INT64);
|
||||
NDArray exp1('c', {0}, std::vector<double>{3}, nd4j::DataType::INT64);
|
||||
NDArray exp2('c', {0}, std::vector<double>{2}, nd4j::DataType::INT64);
|
||||
NDArray exp3('c', {0}, std::vector<double>{1}, nd4j::DataType::INT64);
|
||||
|
||||
NDArray scalar = x1.indexReduceNumber(nd4j::indexreduce::IndexAbsoluteMax);
|
||||
ASSERT_EQ(scalar, exp1);
|
||||
|
@ -1238,15 +1238,15 @@ TEST_F(MultiDataTypeTests, ndarray_applyTrueBroadcast_test1) {
|
|||
NDArray x7('c', {2}, {1, 2}, nd4j::DataType::INT64);
|
||||
NDArray x8('c', {2,2}, nd4j::DataType::BOOL);
|
||||
|
||||
NDArray x13('c', {0}, {3}, nd4j::DataType::INT64);
|
||||
NDArray x14('c', {0}, {1.5}, nd4j::DataType::DOUBLE);
|
||||
NDArray x13('c', {0}, std::vector<double>{3}, nd4j::DataType::INT64);
|
||||
NDArray x14('c', {0}, std::vector<double>{1.5}, nd4j::DataType::DOUBLE);
|
||||
NDArray x15(nd4j::DataType::DOUBLE);
|
||||
NDArray x16('c', {2,2}, nd4j::DataType::DOUBLE);
|
||||
|
||||
NDArray exp1('c', {2,2}, {11, 22, 31, 42}, nd4j::DataType::HALF);
|
||||
NDArray exp2('c', {2,2}, {11, 22, 31, 42}, nd4j::DataType::INT32);
|
||||
NDArray exp3('c', {2,2}, {1, 1, 1, 1}, nd4j::DataType::BOOL);
|
||||
NDArray exp4('c', {0}, {4.5}, nd4j::DataType::DOUBLE);
|
||||
NDArray exp4('c', {0}, std::vector<double>{4.5}, nd4j::DataType::DOUBLE);
|
||||
NDArray exp5('c', {2,2}, {11.5, 21.5, 31.5, 41.5}, nd4j::DataType::DOUBLE);
|
||||
|
||||
x1.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), x2, x3);
|
||||
|
@ -1289,13 +1289,13 @@ TEST_F(MultiDataTypeTests, ndarray_applyTrueBroadcast_test2) {
|
|||
NDArray x1('c', {2,2}, {10, 20, 30, 40}, nd4j::DataType::HALF);
|
||||
NDArray x2('c', {2}, {10, 40}, nd4j::DataType::HALF);
|
||||
NDArray x3('c', {2,2}, nd4j::DataType::BOOL);
|
||||
NDArray x4('c', {0}, {10}, nd4j::DataType::HALF);
|
||||
NDArray x5('c', {0}, {20}, nd4j::DataType::HALF);
|
||||
NDArray x4('c', {0}, std::vector<double>{10}, nd4j::DataType::HALF);
|
||||
NDArray x5('c', {0}, std::vector<double>{20}, nd4j::DataType::HALF);
|
||||
NDArray x6(nd4j::DataType::BOOL);
|
||||
|
||||
NDArray exp1('c', {2,2}, {1, 0, 0, 1}, nd4j::DataType::BOOL);
|
||||
NDArray exp2('c', {2,2}, {1, 0, 0, 0}, nd4j::DataType::BOOL);
|
||||
NDArray exp3('c', {0}, {0}, nd4j::DataType::BOOL);
|
||||
NDArray exp3('c', {0}, std::vector<double>{0}, nd4j::DataType::BOOL);
|
||||
|
||||
x1.applyTrueBroadcast(BroadcastBoolOpsTuple(nd4j::scalar::EqualTo, nd4j::pairwise::EqualTo, nd4j::broadcast::EqualTo), x2, x3);
|
||||
ASSERT_EQ(x3, exp1);
|
||||
|
@ -1459,16 +1459,16 @@ TEST_F(MultiDataTypeTests, ndarray_applyIndexedLambda_test1) {
|
|||
//////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(MultiDataTypeTests, ndarray_applyPairwiseLambda_test1) {
|
||||
|
||||
NDArray x1('c', {2,2}, {0, 1, 2, 3}, nd4j::DataType::DOUBLE);
|
||||
NDArray x2('c', {2,2}, {0, 1, 2, 3}, nd4j::DataType::INT64);
|
||||
NDArray x3('c', {2,2}, {0, 1.5, 2.5, 3.5}, nd4j::DataType::FLOAT32);
|
||||
NDArray x1('c', {2,2}, {0., 1, 2, 3}, nd4j::DataType::DOUBLE);
|
||||
NDArray x2('c', {2,2}, {0., 1, 2, 3}, nd4j::DataType::INT64);
|
||||
NDArray x3('c', {2,2}, {0., 1.5, 2.5, 3.5}, nd4j::DataType::FLOAT32);
|
||||
NDArray x4('c', {2,2}, nd4j::DataType::DOUBLE);
|
||||
NDArray x5('c', {2,2}, {0, 1.5, 2.5, 3.5}, nd4j::DataType::FLOAT32);
|
||||
NDArray x6('c', {2,2}, {0.1, -1, -1, 0.1}, nd4j::DataType::BOOL);
|
||||
NDArray x7('c', {2,2}, nd4j::DataType::BOOL);
|
||||
NDArray other1('c', {2,2}, {0.1, 0.1, 0.1, 0.1}, nd4j::DataType::FLOAT32);
|
||||
NDArray other2('c', {2,2}, {0.1, 0.1, 0.1, 0.1}, nd4j::DataType::DOUBLE);
|
||||
NDArray other3('c', {2,2}, {0, -1, -2, -3}, nd4j::DataType::INT64);
|
||||
NDArray other3('c', {2,2}, {0., -1, -2, -3}, nd4j::DataType::INT64);
|
||||
NDArray other4('c', {2,2}, {1, 0, 0.1, 0}, nd4j::DataType::BOOL);
|
||||
|
||||
auto func1 = [](float elem1, float elem2) { return elem1 + elem2; };
|
||||
|
@ -1478,10 +1478,10 @@ TEST_F(MultiDataTypeTests, ndarray_applyPairwiseLambda_test1) {
|
|||
auto func5 = [](float elem1, int elem2) { return elem1 - elem2; };
|
||||
|
||||
NDArray exp1('c', {2,2}, {0.1, 1.1, 2.1, 3.1}, nd4j::DataType::DOUBLE);
|
||||
NDArray exp2('c', {2,2}, {0, 0, 0, 0}, nd4j::DataType::INT64);
|
||||
NDArray exp2('c', {2,2}, {0., 0, 0, 0}, nd4j::DataType::INT64);
|
||||
NDArray exp3('c', {2,2}, {0.1, 1.1, 2.1, 3.1}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp4('c', {2,2}, {0.1, 1.6, 2.6, 3.6}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp5('c', {2,2}, {0, 1, 0, 1}, nd4j::DataType::BOOL);
|
||||
NDArray exp5('c', {2,2}, {0., 1, 0, 1}, nd4j::DataType::BOOL);
|
||||
|
||||
x1.applyPairwiseLambda<double>(other2, func1, x4);
|
||||
ASSERT_EQ(x4, exp1);
|
||||
|
@ -1505,16 +1505,16 @@ TEST_F(MultiDataTypeTests, ndarray_applyPairwiseLambda_test1) {
|
|||
//////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(MultiDataTypeTests, ndarray_applyIndexedPairwiseLambda_test1) {
|
||||
|
||||
NDArray x1('c', {2,2}, {0, 1, 2, 3}, nd4j::DataType::DOUBLE);
|
||||
NDArray x2('c', {2,2}, {0, 1, 2, 3}, nd4j::DataType::INT64);
|
||||
NDArray x3('c', {2,2}, {0, 1.5, 2.5, 3.5}, nd4j::DataType::FLOAT32);
|
||||
NDArray x1('c', {2,2}, {0., 1, 2, 3}, nd4j::DataType::DOUBLE);
|
||||
NDArray x2('c', {2,2}, {0., 1, 2, 3}, nd4j::DataType::INT64);
|
||||
NDArray x3('c', {2,2}, {0., 1.5, 2.5, 3.5}, nd4j::DataType::FLOAT32);
|
||||
NDArray x4('c', {2,2}, nd4j::DataType::DOUBLE);
|
||||
NDArray x5('c', {2,2}, {0, 1.5, 2.5, 3.5}, nd4j::DataType::FLOAT32);
|
||||
NDArray x6('c', {2,2}, {0.1, -1, -1, 0.1}, nd4j::DataType::BOOL);
|
||||
NDArray x7('c', {2,2}, nd4j::DataType::BOOL);
|
||||
NDArray other1('c', {2,2}, {0.1, 0.1, 0.1, 0.1}, nd4j::DataType::FLOAT32);
|
||||
NDArray other2('c', {2,2}, {0.1, 0.1, 0.1, 0.1}, nd4j::DataType::DOUBLE);
|
||||
NDArray other3('c', {2,2}, {0, -1, -2, -3}, nd4j::DataType::INT64);
|
||||
NDArray other3('c', {2,2}, {0., -1, -2, -3}, nd4j::DataType::INT64);
|
||||
NDArray other4('c', {2,2}, {1, 0, 0.1, 0}, nd4j::DataType::BOOL);
|
||||
|
||||
auto func1 = [](Nd4jLong idx, float elem1, float elem2) { return elem1 + elem2 + idx; };
|
||||
|
@ -1524,10 +1524,10 @@ TEST_F(MultiDataTypeTests, ndarray_applyIndexedPairwiseLambda_test1) {
|
|||
auto func5 = [](Nd4jLong idx, float elem1, int elem2) { return elem1 - elem2 + idx; };
|
||||
|
||||
NDArray exp1('c', {2,2}, {0.1, 2.1, 4.1, 6.1}, nd4j::DataType::DOUBLE);
|
||||
NDArray exp2('c', {2,2}, {0, 1, 2, 3}, nd4j::DataType::INT64);
|
||||
NDArray exp2('c', {2,2}, {0., 1, 2, 3}, nd4j::DataType::INT64);
|
||||
NDArray exp3('c', {2,2}, {0.1, 2.1, 4.1, 6.1}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp4('c', {2,2}, {0.1, 2.6, 4.6, 6.6}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp5('c', {2,2}, {0, 1, 1, 1}, nd4j::DataType::BOOL);
|
||||
NDArray exp5('c', {2,2}, {0., 1, 1, 1}, nd4j::DataType::BOOL);
|
||||
|
||||
x1.applyIndexedPairwiseLambda<double>(other2, func1, x4);
|
||||
ASSERT_EQ(x4, exp1);
|
||||
|
@ -1551,25 +1551,25 @@ TEST_F(MultiDataTypeTests, ndarray_applyIndexedPairwiseLambda_test1) {
|
|||
//////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(MultiDataTypeTests, ndarray_applyTriplewiseLambda_test1) {
|
||||
|
||||
NDArray x1('c', {2,2}, {0, 1, 2, 3}, nd4j::DataType::DOUBLE);
|
||||
NDArray x2('c', {2,2}, {0, -1, -2, -3}, nd4j::DataType::DOUBLE);
|
||||
NDArray x1('c', {2,2}, {0., 1, 2, 3}, nd4j::DataType::DOUBLE);
|
||||
NDArray x2('c', {2,2}, {0., -1, -2, -3}, nd4j::DataType::DOUBLE);
|
||||
NDArray x3('c', {2,2}, {0, -1.5, -2.5, -3.5}, nd4j::DataType::DOUBLE);
|
||||
NDArray x4('c', {2,2}, nd4j::DataType::DOUBLE);
|
||||
|
||||
NDArray x5('c', {2,2}, {0, 1, 2, 3}, nd4j::DataType::INT32);
|
||||
NDArray x6('c', {2,2}, {0, -1, -2, -3}, nd4j::DataType::INT32);
|
||||
NDArray x7('c', {2,2}, {0, 10, 20, 30}, nd4j::DataType::INT32);
|
||||
NDArray x5('c', {2,2}, {0., 1, 2, 3}, nd4j::DataType::INT32);
|
||||
NDArray x6('c', {2,2}, {0., -1, -2, -3}, nd4j::DataType::INT32);
|
||||
NDArray x7('c', {2,2}, {0., 10, 20, 30}, nd4j::DataType::INT32);
|
||||
|
||||
NDArray x8('c', {2,2}, {0, 1, 0, 1}, nd4j::DataType::BOOL);
|
||||
NDArray x9('c', {2,2}, {1, 1, 0, 1}, nd4j::DataType::BOOL);
|
||||
NDArray x10('c', {2,2}, {0, 0, 0, 0}, nd4j::DataType::BOOL);
|
||||
NDArray x8('c', {2,2}, {0., 1, 0, 1}, nd4j::DataType::BOOL);
|
||||
NDArray x9('c', {2,2}, {1., 1, 0, 1}, nd4j::DataType::BOOL);
|
||||
NDArray x10('c', {2,2}, {0., 0, 0, 0}, nd4j::DataType::BOOL);
|
||||
|
||||
auto func1 = [](double elem1, float elem2, int elem3) { return elem1 + elem2 + elem3; };
|
||||
auto func2 = [](float elem1, float elem2, float elem3) { return elem1 + elem2 + elem3; };
|
||||
auto func3 = [](int elem1, int elem2, int elem3) { return elem1 + elem2 + elem3; };
|
||||
auto func4 = [](bool elem1, bool elem2, bool elem3) { return elem1 + elem2 + elem3; };
|
||||
|
||||
NDArray exp('c', {2,2}, {1, 1, 0, 1}, nd4j::DataType::BOOL);
|
||||
NDArray exp('c', {2,2}, {1., 1, 0, 1}, nd4j::DataType::BOOL);
|
||||
|
||||
x1.applyTriplewiseLambda<double>(x2, x3, func1, x4);
|
||||
ASSERT_EQ(x4, x2);
|
||||
|
@ -1590,7 +1590,7 @@ TEST_F(MultiDataTypeTests, ndarray_applyTriplewiseLambda_test1) {
|
|||
TEST_F(MultiDataTypeTests, ndarray_applyIndexReduce_test1) {
|
||||
|
||||
NDArray x1('c', {2,3}, {0, 1, 2, 3, 4, 5}, nd4j::DataType::DOUBLE);
|
||||
NDArray exp1('c', {}, {5}, nd4j::DataType::INT64);
|
||||
NDArray exp1('c', {}, std::vector<double>{5}, nd4j::DataType::INT64);
|
||||
NDArray exp2('c', {2}, {2,2}, nd4j::DataType::INT64);
|
||||
NDArray exp3('c', {3}, {1,1,1}, nd4j::DataType::INT64);
|
||||
|
||||
|
@ -1608,10 +1608,10 @@ TEST_F(MultiDataTypeTests, ndarray_applyIndexReduce_test1) {
|
|||
TEST_F(MultiDataTypeTests, ndarray_applyIndexReduce_test2) {
|
||||
|
||||
NDArray x1('c', {2,3}, {0, 1, 2, 3, 4, 5}, nd4j::DataType::DOUBLE);
|
||||
NDArray scalar('c', {}, {5}, nd4j::DataType::INT64);
|
||||
NDArray scalar('c', {}, std::vector<double>{5}, nd4j::DataType::INT64);
|
||||
NDArray vec1('c', {2}, {2,2}, nd4j::DataType::INT64);
|
||||
NDArray vec2('c', {3}, {1,1,1}, nd4j::DataType::INT64);
|
||||
NDArray exp1('c', {}, {5}, nd4j::DataType::INT64);
|
||||
NDArray exp1('c', {}, std::vector<double>{5}, nd4j::DataType::INT64);
|
||||
NDArray exp2('c', {2}, {2,2}, nd4j::DataType::INT64);
|
||||
NDArray exp3('c', {3}, {1,1,1}, nd4j::DataType::INT64);
|
||||
|
||||
|
@ -1632,8 +1632,8 @@ TEST_F(MultiDataTypeTests, applyReduce3_test1) {
|
|||
NDArray x2('c', {2,2}, {-1,-2,-3,-4}, nd4j::DataType::INT32);
|
||||
NDArray x3('c', {2,2}, {1.5,1.5,1.5,1.5}, nd4j::DataType::DOUBLE);
|
||||
NDArray x4('c', {2,2}, {1,2,3,4}, nd4j::DataType::DOUBLE);
|
||||
NDArray exp1('c', {}, {-30}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp2('c', {}, {15}, nd4j::DataType::DOUBLE);
|
||||
NDArray exp1('c', {}, std::vector<double>{-30}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp2('c', {}, std::vector<double>{15}, nd4j::DataType::DOUBLE);
|
||||
|
||||
auto result = x1.applyReduce3(reduce3::Dot, x2);
|
||||
ASSERT_EQ(result, exp1);
|
||||
|
@ -1654,8 +1654,8 @@ TEST_F(MultiDataTypeTests, applyReduce3_test2) {
|
|||
NDArray x7('c', {2,3}, {1.5,1.5,1.5,1.5,1.5,1.5}, nd4j::DataType::DOUBLE);
|
||||
NDArray x8('c', {2,3}, {1,2,3,4,5,6}, nd4j::DataType::DOUBLE);
|
||||
|
||||
NDArray exp1('c', {}, {-30}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp2('c', {}, {15}, nd4j::DataType::DOUBLE);
|
||||
NDArray exp1('c', {}, std::vector<double>{-30}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp2('c', {}, std::vector<double>{15}, nd4j::DataType::DOUBLE);
|
||||
NDArray exp3('c', {3}, {-18,-20,-18}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp4('c', {2}, {-28,-28}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp5('c', {3}, {7.5,10.5,13.5}, nd4j::DataType::DOUBLE);
|
||||
|
|
|
@ -184,7 +184,7 @@ TEST_F(NDArrayConstructorsTests, test_linspace_1) {
|
|||
TEST_F(NDArrayConstructorsTests, test_constructor_10) {
|
||||
|
||||
NDArray scalar1(nd4j::DataType::DOUBLE); // scalar1 = 0
|
||||
NDArray scalar2('c', {}, {0});
|
||||
NDArray scalar2('c', {}, std::vector<double>{0});
|
||||
|
||||
ASSERT_TRUE(scalar1.isActualOnDeviceSide());
|
||||
ASSERT_TRUE(!scalar1.isActualOnHostSide());
|
||||
|
|
|
@ -1226,8 +1226,8 @@ TEST_F(NDArrayCudaBasicsTests, applyReduce3_3) {
|
|||
NDArray x3('c', {3,2}, {1.5,1.5,1.5,1.5,1.5,1.5}, nd4j::DataType::DOUBLE);
|
||||
NDArray x4('c', {3,2}, {1,2,3,4,5,6}, nd4j::DataType::DOUBLE);
|
||||
|
||||
NDArray exp1('c', {}, {-204}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp2('c', {}, {31.5}, nd4j::DataType::DOUBLE);
|
||||
NDArray exp1('c', {}, std::vector<double>{-204}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp2('c', {}, std::vector<double>{31.5}, nd4j::DataType::DOUBLE);
|
||||
|
||||
|
||||
auto z = x1.applyReduce3(reduce3::Dot, x2);
|
||||
|
@ -1260,7 +1260,7 @@ TEST_F(NDArrayCudaBasicsTests, applyAllReduce3_1) {
|
|||
NDArray exp2('c', {6,4}, {-36.f, -44.f, -52.f, -60.f,-42.f, -52.f, -62.f, -72.f, 2.f, 0.f, -2.f,
|
||||
-4.f, 6.f, 4.f, 2.f, 0.f, 10.f, 8.f, 6.f, 4.f, 14.f, 12.f, 10.f, 8.f},
|
||||
nd4j::DataType::FLOAT32);
|
||||
NDArray exp3('c', {1,1}, {31.5}, nd4j::DataType::DOUBLE);
|
||||
NDArray exp3('c', {1,1}, std::vector<double>{31.5}, nd4j::DataType::DOUBLE);
|
||||
NDArray exp4('c', {3,3}, {4.5, 10.5, 16.5,4.5, 10.5, 16.5,4.5, 10.5, 16.5}, nd4j::DataType::DOUBLE);
|
||||
|
||||
auto z = x1.applyAllReduce3(reduce3::Dot, x2, {0,2});
|
||||
|
@ -1292,15 +1292,15 @@ TEST_F(NDArrayCudaBasicsTests, applyIndexReduce_test1) {
|
|||
|
||||
NDArray x('c', {2,3}, {0, 10, 1, 2, 2.5,-4}, nd4j::DataType::DOUBLE);
|
||||
|
||||
NDArray scalar('c', {}, {100}, nd4j::DataType::INT64);
|
||||
NDArray scalar('c', {}, std::vector<double>{100}, nd4j::DataType::INT64);
|
||||
NDArray vec1('c', {2}, {100,100}, nd4j::DataType::INT64);
|
||||
NDArray vec2('c', {3}, {100,100,100}, nd4j::DataType::INT64);
|
||||
|
||||
NDArray exp1('c', {}, {1}, nd4j::DataType::INT64);
|
||||
NDArray exp1('c', {}, std::vector<double>{1}, nd4j::DataType::INT64);
|
||||
NDArray exp2('c', {2}, {1,1}, nd4j::DataType::INT64);
|
||||
NDArray exp3('c', {3}, {1,0,0}, nd4j::DataType::INT64);
|
||||
|
||||
NDArray exp4('c', {}, {2}, nd4j::DataType::INT64);
|
||||
NDArray exp4('c', {}, std::vector<double>{2}, nd4j::DataType::INT64);
|
||||
NDArray exp5('c', {2}, {1,1}, nd4j::DataType::INT64);
|
||||
NDArray exp6('c', {3}, {1,0,0}, nd4j::DataType::INT64);
|
||||
|
||||
|
@ -1331,11 +1331,11 @@ TEST_F(NDArrayCudaBasicsTests, applyIndexReduce_test2) {
|
|||
|
||||
NDArray x('c', {2,3}, {0, 10, 1, 2, 2.5,-4}, nd4j::DataType::DOUBLE);
|
||||
|
||||
NDArray exp1('c', {}, {1}, nd4j::DataType::INT64);
|
||||
NDArray exp1('c', {}, std::vector<double>{1}, nd4j::DataType::INT64);
|
||||
NDArray exp2('c', {2}, {1,1}, nd4j::DataType::INT64);
|
||||
NDArray exp3('c', {3}, {1,0,0}, nd4j::DataType::INT64);
|
||||
|
||||
NDArray exp4('c', {}, {2}, nd4j::DataType::INT64);
|
||||
NDArray exp4('c', {}, std::vector<double>{2}, nd4j::DataType::INT64);
|
||||
NDArray exp5('c', {2}, {1,1}, nd4j::DataType::INT64);
|
||||
NDArray exp6('c', {3}, {1,0,0}, nd4j::DataType::INT64);
|
||||
|
||||
|
@ -1365,13 +1365,13 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_float_test1) {
|
|||
|
||||
NDArray x('c', {2,3,2}, {1,2,3,4,5,6,7,8,-1,-2,-3,-4,}, nd4j::DataType::INT32);
|
||||
|
||||
NDArray z1('c', {}, {100}, nd4j::DataType::DOUBLE);
|
||||
NDArray z1('c', {}, std::vector<double>{100}, nd4j::DataType::DOUBLE);
|
||||
NDArray z2('c', {2,2}, {100,100,100,100}, nd4j::DataType::FLOAT32);
|
||||
NDArray z3('c', {3}, {100,100,100}, nd4j::DataType::DOUBLE);
|
||||
NDArray z4('c', {3,2}, {100,100,100,100,100,100}, nd4j::DataType::FLOAT32);
|
||||
NDArray z5('c', {2}, {100,100}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray exp1('c', {}, {2.166667}, nd4j::DataType::DOUBLE);
|
||||
NDArray exp1('c', {}, std::vector<double>{2.166667}, nd4j::DataType::DOUBLE);
|
||||
NDArray exp2('c', {2,2}, {3.f,4.f,1.f,0.666667f}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp3('c', {3}, {4.5,1,1}, nd4j::DataType::DOUBLE);
|
||||
NDArray exp4('c', {3,2}, {4,5,1,1,1,1}, nd4j::DataType::FLOAT32);
|
||||
|
@ -1403,7 +1403,7 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_float_test2) {
|
|||
|
||||
NDArray x('c', {2,3,2}, {1,2,3,4,5,6,7,8,-1,-2,-3,-4,}, nd4j::DataType::DOUBLE);
|
||||
|
||||
NDArray exp1('c', {}, {2.166667}, nd4j::DataType::DOUBLE);
|
||||
NDArray exp1('c', {}, std::vector<double>{2.166667}, nd4j::DataType::DOUBLE);
|
||||
NDArray exp2('c', {2,2}, {3,4,1,0.666667}, nd4j::DataType::DOUBLE);
|
||||
NDArray exp3('c', {3}, {4.5,1,1}, nd4j::DataType::DOUBLE);
|
||||
NDArray exp4('c', {3,2}, {4,5,1,1,1,1}, nd4j::DataType::DOUBLE);
|
||||
|
@ -1477,13 +1477,13 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_same_test1) {
|
|||
|
||||
NDArray x('c', {2,3,2}, {1.5f,2.f,3.f,4.f,5.f,6.f,7.5f,8.f,-1.f,-2.f,-3.5f,-4.f}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray z1('c', {}, {100}, nd4j::DataType::FLOAT32);
|
||||
NDArray z1('c', {}, std::vector<double>{100}, nd4j::DataType::FLOAT32);
|
||||
NDArray z2('c', {2,2}, {100,100,100,100}, nd4j::DataType::FLOAT32);
|
||||
NDArray z3('c', {3}, {100,100,100}, nd4j::DataType::FLOAT32);
|
||||
NDArray z4('c', {3,2}, {100,100,100,100,100,100}, nd4j::DataType::FLOAT32);
|
||||
NDArray z5('c', {2}, {100,100}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray exp1('c', {}, {26.5f}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp1('c', {}, std::vector<double>{26.5f}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp2('c', {2,2}, {9.5f,12.f,3.f,2.f}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp3('c', {3}, {19.f,4.f,3.5f}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp4('c', {3,2}, {9.f,10.f,2.f,2.f,1.5f,2.f}, nd4j::DataType::FLOAT32);
|
||||
|
@ -1515,7 +1515,7 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_same_test2) {
|
|||
|
||||
NDArray x('c', {2,3,2}, {1.5,2,3,4,5,6,7.5,8,-1,-2,-3.5,-4,}, nd4j::DataType::INT64);
|
||||
|
||||
NDArray exp1('c', {}, {26}, nd4j::DataType::INT64);
|
||||
NDArray exp1('c', {}, std::vector<double>{26}, nd4j::DataType::INT64);
|
||||
NDArray exp2('c', {2,2}, {9,12,3,2}, nd4j::DataType::INT64);
|
||||
NDArray exp3('c', {3}, {18,4,4}, nd4j::DataType::INT64);
|
||||
NDArray exp4('c', {3,2}, {8,10,2,2,2,2}, nd4j::DataType::INT64);
|
||||
|
@ -1547,13 +1547,13 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_bool_test1) {
|
|||
|
||||
NDArray x('c', {2,3,2}, {0.5,2,3,-4,5,6,-7.5,8,-1,-0.5,-3.5,4}, nd4j::DataType::DOUBLE);
|
||||
|
||||
NDArray z1('c', {}, {true}, nd4j::DataType::BOOL);
|
||||
NDArray z1('c', {}, std::vector<double>{true}, nd4j::DataType::BOOL);
|
||||
NDArray z2('c', {2,2}, {true,true,true,true}, nd4j::DataType::BOOL);
|
||||
NDArray z3('c', {3}, {true,true,true}, nd4j::DataType::BOOL);
|
||||
NDArray z4('c', {3,2}, {true,true,true,true,true,true}, nd4j::DataType::BOOL);
|
||||
NDArray z5('c', {2}, {true,true}, nd4j::DataType::BOOL);
|
||||
|
||||
NDArray exp1('c', {}, {true}, nd4j::DataType::BOOL);
|
||||
NDArray exp1('c', {}, std::vector<double>{true}, nd4j::DataType::BOOL);
|
||||
NDArray exp2('c', {2,2}, {true,true,false,true}, nd4j::DataType::BOOL);
|
||||
NDArray exp3('c', {3}, {true,true,true}, nd4j::DataType::BOOL);
|
||||
NDArray exp4('c', {3,2}, {true,true,true,false,true,true}, nd4j::DataType::BOOL);
|
||||
|
@ -1585,7 +1585,7 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_bool_test2) {
|
|||
|
||||
NDArray x('c', {2,3,2}, {0.5,2,3,-4,5,6,-7.5,8,-1,-0.5,-3.5,4}, nd4j::DataType::INT32);
|
||||
|
||||
NDArray exp1('c', {}, {1}, nd4j::DataType::BOOL);
|
||||
NDArray exp1('c', {}, std::vector<double>{1}, nd4j::DataType::BOOL);
|
||||
NDArray exp2('c', {2,2}, {1,1,0,1}, nd4j::DataType::BOOL);
|
||||
NDArray exp3('c', {3}, {1,1,1}, nd4j::DataType::BOOL);
|
||||
NDArray exp4('c', {3,2}, {0,1,1,0,1,1}, nd4j::DataType::BOOL);
|
||||
|
@ -1617,13 +1617,13 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_long_test1) {
|
|||
|
||||
NDArray x('c', {2,3,2}, {0.5f,2.f,3.f,-0.f,5.f,6.f,-7.5f,0.f,-1.f,-0.5f,-3.5f,4.f}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray z1('c', {}, {100}, nd4j::DataType::INT64);
|
||||
NDArray z1('c', {}, std::vector<double>{100}, nd4j::DataType::INT64);
|
||||
NDArray z2('c', {2,2}, {100,100,100,100}, nd4j::DataType::INT64);
|
||||
NDArray z3('c', {3}, {100,100,100}, nd4j::DataType::INT64);
|
||||
NDArray z4('c', {3,2}, {100,100,100,100,100,100}, nd4j::DataType::INT64);
|
||||
NDArray z5('c', {2}, {100,100}, nd4j::DataType::INT64);
|
||||
|
||||
NDArray exp1('c', {}, {2}, nd4j::DataType::INT64);
|
||||
NDArray exp1('c', {}, std::vector<double>{2}, nd4j::DataType::INT64);
|
||||
NDArray exp2('c', {2,2}, {0,1,0,1}, nd4j::DataType::INT64);
|
||||
NDArray exp3('c', {3}, {1,1,0}, nd4j::DataType::INT64);
|
||||
NDArray exp4('c', {3,2}, {0,1,0,1,0,0}, nd4j::DataType::INT64);
|
||||
|
@ -1655,7 +1655,7 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_long_test2) {
|
|||
|
||||
NDArray x('c', {2,3,2}, {0.5,2,3,-0,5,6,-7.5,0,-1,-0.5,-3.5,4}, nd4j::DataType::INT32);
|
||||
|
||||
NDArray exp1('c', {}, {4}, nd4j::DataType::INT64);
|
||||
NDArray exp1('c', {}, std::vector<double>{4}, nd4j::DataType::INT64);
|
||||
NDArray exp2('c', {2,2}, {1,1,0,2}, nd4j::DataType::INT64);
|
||||
NDArray exp3('c', {3}, {2,2,0}, nd4j::DataType::INT64);
|
||||
NDArray exp4('c', {3,2}, {1,1,0,2,0,0}, nd4j::DataType::INT64);
|
||||
|
|
|
@ -692,7 +692,7 @@ TEST_F(ParityOpsTests, Test_Bias_Add_1) {
|
|||
|
||||
TEST_F(ParityOpsTests, Test_Scatter_Add_1) {
|
||||
auto matrix = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
|
||||
NDArray idc('c', {1}, {0}, nd4j::DataType::INT64);
|
||||
NDArray idc('c', {1}, std::vector<double>({0}), nd4j::DataType::INT64);
|
||||
auto updates = NDArrayFactory::create<float>('c', {1, 2}, {1, 1});
|
||||
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {2, 3, 3, 4});
|
||||
|
||||
|
@ -710,7 +710,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_1) {
|
|||
TEST_F(ParityOpsTests, Test_Scatter_Add_2) {
|
||||
|
||||
auto vec = NDArrayFactory::create<float>('c', {4}, {1, 2, 3, 4});
|
||||
NDArray idc('c', {1, 4}, {0, 1, 2, 3}, nd4j::DataType::INT64);
|
||||
NDArray idc('c', {1, 4}, {0., 1, 2, 3}, nd4j::DataType::INT64);
|
||||
auto updates = NDArrayFactory::create<float>('c', {1, 4}, {1, 1, 1, 1});
|
||||
auto exp = NDArrayFactory::create<float>('c', {1, 4}, {2, 3, 4, 5});
|
||||
|
||||
|
@ -727,7 +727,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_2) {
|
|||
|
||||
TEST_F(ParityOpsTests, Test_Scatter_Add_3) {
|
||||
auto matrix = NDArrayFactory::create<float>('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8});
|
||||
NDArray idc('c', {1}, {0}, nd4j::DataType::INT64);
|
||||
NDArray idc('c', {1}, std::vector<double>({0}), nd4j::DataType::INT64);
|
||||
auto updates = NDArrayFactory::create<float>('c', {1, 2, 2}, {1, 1, 1, 1});
|
||||
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {2, 3, 4, 5, 5, 6, 7, 8});
|
||||
|
||||
|
@ -744,7 +744,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_3) {
|
|||
|
||||
TEST_F(ParityOpsTests, Test_Scatter_Add_4) {
|
||||
auto matrix = NDArrayFactory::create<float>('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8});
|
||||
NDArray idc('c', {1, 2}, {0, 0}, nd4j::DataType::INT64);
|
||||
NDArray idc('c', {1, 2}, std::vector<double>{0, 0}, nd4j::DataType::INT64);
|
||||
auto updates = NDArrayFactory::create<float>('c', {1, 2, 2, 2}, {1, 1, 1, 1, 1, 1, 1, 1});
|
||||
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {3, 4, 5, 6, 5, 6, 7, 8});
|
||||
|
||||
|
@ -761,7 +761,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_4) {
|
|||
|
||||
TEST_F(ParityOpsTests, Test_Scatter_Add_5) {
|
||||
auto matrix = NDArrayFactory::create<float>('c', {2, 2, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
|
||||
NDArray idc('c', {2, 2}, {1, 1, 0, 0}, nd4j::DataType::INT64);
|
||||
NDArray idc('c', {2, 2}, {1., 1, 0, 0}, nd4j::DataType::INT64);
|
||||
auto updates = NDArrayFactory::create<float>('c', {2, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
auto exp = NDArrayFactory::create<float>('c', {2, 2, 3}, {9., 11., 13.,15., 17., 19., 9., 11., 13.,15., 17., 19.});
|
||||
|
||||
|
@ -796,7 +796,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_6) {
|
|||
|
||||
TEST_F(ParityOpsTests, Test_Scatter_Add_7) {
|
||||
auto matrix = NDArrayFactory::create<float>('c', {10, 3}, {1.f,2.f,3.f,4.f,5.f,6.f,7.f,8.f,9.f,10.f,11.f,12.f,13.f,14.f,15.f,16.f,17.f,18.f,19.f,20.f,21.f,22.f,23.f,24.f,25.f,26.f,27.f,28.f,29.f,30.f});
|
||||
NDArray idc('c', {}, {5}, nd4j::DataType::INT64);
|
||||
NDArray idc('c', {}, std::vector<double>{5}, nd4j::DataType::INT64);
|
||||
auto updates = NDArrayFactory::create<float>('c', {3}, {10.f, 20.f, 30.f});
|
||||
auto exp = NDArrayFactory::create<float>('c', {10, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f,11.f,12.f, 13.f,14.f,15.f, 26.f,37.f,48.f, 19.f,20.f,21.f, 22.f,23.f,24.f, 25.f,26.f,27.f, 28.f,29.f,30.f});
|
||||
|
||||
|
@ -845,7 +845,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_9) {
|
|||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(ParityOpsTests, scatterMax_test1) {
|
||||
auto matrix = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
|
||||
NDArray idc('c', {1}, {0.}, nd4j::DataType::INT64);
|
||||
NDArray idc('c', {1}, std::vector<double>{0.}, nd4j::DataType::INT64);
|
||||
auto updates = NDArrayFactory::create<float>('c', {1, 2}, {10, 1});
|
||||
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {10, 2, 3, 4});
|
||||
|
||||
|
@ -879,7 +879,7 @@ TEST_F(ParityOpsTests, scatterMax_test2) {
|
|||
|
||||
TEST_F(ParityOpsTests, scatterMax_test3) {
|
||||
auto matrix = NDArrayFactory::create<float>('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8});
|
||||
NDArray idc('c', {1}, {0}, nd4j::DataType::INT64);
|
||||
NDArray idc('c', {1}, std::vector<double>({0}), nd4j::DataType::INT64);
|
||||
auto updates = NDArrayFactory::create<float>('c', {1, 2, 2}, {10, 1, 30, 1});
|
||||
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {10, 2, 30, 4, 5, 6, 7, 8});
|
||||
|
||||
|
@ -896,7 +896,7 @@ TEST_F(ParityOpsTests, scatterMax_test3) {
|
|||
|
||||
TEST_F(ParityOpsTests, scatterMax_test4) {
|
||||
auto matrix = NDArrayFactory::create<float>('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8});
|
||||
NDArray idc('c', {1,2}, {0,0}, nd4j::DataType::INT32);
|
||||
NDArray idc('c', {1,2}, std::vector<double>{0.,0}, nd4j::DataType::INT32);
|
||||
auto updates = NDArrayFactory::create<float>('c', {1, 2, 2, 2}, {1,10,1,10, 1,1,10,1.});
|
||||
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {1, 10, 10, 10, 5, 6, 7, 8});
|
||||
|
||||
|
@ -948,7 +948,7 @@ TEST_F(ParityOpsTests, scatterMax_test6) {
|
|||
|
||||
TEST_F(ParityOpsTests, scatterMin_test1) {
|
||||
auto matrix = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
|
||||
NDArray idc('c', {1}, {0}, nd4j::DataType::INT32);
|
||||
NDArray idc('c', {1}, std::vector<double>({0}), nd4j::DataType::INT32);
|
||||
auto updates = NDArrayFactory::create<float>('c', {1, 2}, {-1, 1});
|
||||
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {-1, 1, 3, 4});
|
||||
|
||||
|
@ -982,7 +982,7 @@ TEST_F(ParityOpsTests, scatterMin_test2) {
|
|||
|
||||
TEST_F(ParityOpsTests, scatterMin_test3) {
|
||||
auto matrix = NDArrayFactory::create<float>('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8});
|
||||
NDArray idc('c', {1}, {0}, nd4j::DataType::INT32);
|
||||
NDArray idc('c', {1}, std::vector<double>({0}), nd4j::DataType::INT32);
|
||||
auto updates = NDArrayFactory::create<float>('c', {1, 2, 2}, {10, 1, 30, 2});
|
||||
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {1, 1, 3, 2, 5, 6, 7, 8});
|
||||
|
||||
|
@ -999,7 +999,7 @@ TEST_F(ParityOpsTests, scatterMin_test3) {
|
|||
|
||||
TEST_F(ParityOpsTests, scatterMin_test4) {
|
||||
auto matrix = NDArrayFactory::create<float>('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8});
|
||||
NDArray idc('c', {1,2}, {0,0}, nd4j::DataType::INT32);
|
||||
NDArray idc('c', {1,2}, std::vector<double>{0.,0}, nd4j::DataType::INT32);
|
||||
auto updates = NDArrayFactory::create<float>('c', {1, 2, 2, 2}, {1,10,1,10, 1,1,10,1.});
|
||||
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {1, 1, 1, 1, 5, 6, 7, 8});
|
||||
|
||||
|
|
|
@ -1005,24 +1005,24 @@ TEST_F(RNGTests, test_uniform_119) {
|
|||
}
|
||||
|
||||
TEST_F(RNGTests, test_multinomial_1) {
|
||||
|
||||
|
||||
NDArray probs('f', { 3, 3 }, { 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3 }, nd4j::DataType::FLOAT32);
|
||||
NDArray expected('f', { 3, 3 }, { 0, 1, 2, 2, 0, 0, 1, 2, 1 }, nd4j::DataType::INT64);
|
||||
NDArray expected('f', { 3, 3 }, { 0., 1, 2, 2, 0, 0, 1, 2, 1 }, nd4j::DataType::INT64);
|
||||
NDArray output('f', { 3, 3 }, nd4j::DataType::INT64);
|
||||
NDArray samples('f', { 1 }, { 3 }, nd4j::DataType::INT32);
|
||||
|
||||
NDArray samples('f', { 1 }, std::vector<double>({3}), nd4j::DataType::INT32);
|
||||
|
||||
nd4j::ops::random_multinomial op;
|
||||
RandomGenerator rng(1234, 1234);
|
||||
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64}, {}, {}, false) );
|
||||
ASSERT_TRUE(expected.isSameShape(output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
|
||||
NDArray probsZ('c', { 1, 3 }, { 0.3, 0.3, 0.3 }, nd4j::DataType::FLOAT32);
|
||||
NDArray expectedZ('c', { 3, 3 }, { 0, 0, 0, 0, 0, 0, 0, 0, 0 }, nd4j::DataType::INT64);
|
||||
NDArray expectedZ('c', { 3, 3 }, { 0., 0, 0, 0, 0, 0, 0, 0, 0 }, nd4j::DataType::INT64);
|
||||
|
||||
auto result = op.evaluate({ &probsZ, &samples }, { }, { 1, INT64 });
|
||||
auto outputZ = result->at(0);
|
||||
|
||||
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_TRUE(expectedZ.isSameShape(outputZ));
|
||||
ASSERT_TRUE(expectedZ.equalsTo(outputZ));
|
||||
|
@ -1031,7 +1031,7 @@ TEST_F(RNGTests, test_multinomial_1) {
|
|||
|
||||
TEST_F(RNGTests, test_multinomial_2) {
|
||||
|
||||
NDArray samples('c', { 1 }, { 20 }, nd4j::DataType::INT32);
|
||||
NDArray samples('c', { 1 }, std::vector<double>{ 20 }, nd4j::DataType::INT32);
|
||||
NDArray probs('c', { 3, 5 }, { 0.2, 0.3, 0.5, 0.3, 0.5, 0.2, 0.5, 0.2, 0.3, 0.35, 0.25, 0.3, 0.25, 0.25, 0.5 }, nd4j::DataType::FLOAT32);
|
||||
NDArray expected('c', { 3, 20 }, { 0, 2, 0, 2, 0, 4, 2, 0, 1, 2, 0, 2, 3, 0, 0, 2, 4, 4, 1, 0, 2, 3, 2, 3, 0, 1, 3, 1, 1, 1, 2, 4, 3, 3, 1, 4, 4, 2, 0, 0, 3, 3, 3, 0, 0, 2, 2, 3, 3, 0, 0, 2, 3, 4, 2, 2, 3, 2, 1, 2 }, nd4j::DataType::INT64);
|
||||
NDArray output('c', { 3, 20 }, nd4j::DataType::INT64);
|
||||
|
@ -1041,11 +1041,11 @@ TEST_F(RNGTests, test_multinomial_2) {
|
|||
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, {}, false));
|
||||
ASSERT_TRUE(expected.isSameShape(output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
|
||||
NDArray probs2('c', { 5, 3 }, { 0.2, 0.3, 0.5, 0.3, 0.5, 0.2, 0.5, 0.2, 0.3, 0.35, 0.25, 0.3, 0.25, 0.25, 0.5 }, nd4j::DataType::FLOAT32);
|
||||
NDArray expected2('c', { 20, 3 }, { 0, 2, 3, 2, 3, 3, 0, 2, 3, 2, 3, 0, 0, 0, 0, 4, 1, 2, 2, 3, 2, 3, 1, 3, 1, 1, 3, 2, 1, 0, 0, 2, 0, 2, 4, 2, 3, 3, 3, 0, 3, 4, 0, 1, 2, 2, 0, 2, 4, 4, 0, 4, 2, 2, 1, 0, 1, 0, 0, 2 }, nd4j::DataType::INT64);
|
||||
NDArray output2('c', { 20, 3 }, nd4j::DataType::INT64);
|
||||
|
||||
|
||||
rng.setStates(1234, 1234);
|
||||
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs2, &samples }, { &output2 }, {}, { 1, INT64 }, {}, {}, false));
|
||||
ASSERT_TRUE(expected2.isSameShape(output2));
|
||||
|
@ -1053,16 +1053,17 @@ TEST_F(RNGTests, test_multinomial_2) {
|
|||
}
|
||||
|
||||
TEST_F(RNGTests, test_multinomial_3) {
|
||||
|
||||
|
||||
NDArray probs('c', { 4, 3 }, { 0.3, 0.3, 0.4, 0.3, 0.4, 0.3, 0.3, 0.3, 0.4, 0.4, 0.3, 0.3 }, nd4j::DataType::FLOAT32);
|
||||
NDArray expected('c', { 4, 5 }, nd4j::DataType::INT64);
|
||||
NDArray output('c', { 4, 5 }, nd4j::DataType::INT64);
|
||||
NDArray samples('c', { 1 }, { 5 }, nd4j::DataType::INT32);
|
||||
NDArray samples('c', { 1 }, std::vector<double>{ 5 }, nd4j::DataType::INT32);
|
||||
RandomGenerator rng(1234, 1234);
|
||||
|
||||
nd4j::ops::random_multinomial op;
|
||||
|
||||
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &expected }, {}, { 0, INT64 }, {}, {}, false));
|
||||
|
||||
|
||||
rng.setStates(1234, 1234);
|
||||
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, {}, false));
|
||||
ASSERT_TRUE(expected.isSameShape(output));
|
||||
|
@ -1074,7 +1075,7 @@ TEST_F(RNGTests, test_multinomial_4) {
|
|||
NDArray probs('c', { 3, 4 }, { 0.3, 0.3, 0.4, 0.3, 0.4, 0.3, 0.3, 0.3, 0.4, 0.4, 0.3, 0.3 }, nd4j::DataType::FLOAT32);
|
||||
NDArray expected('c', { 5, 4 }, nd4j::DataType::INT64);
|
||||
NDArray output('c', { 5, 4 }, nd4j::DataType::INT64);
|
||||
NDArray samples('c', { 1 }, { 5 }, nd4j::DataType::INT32);
|
||||
NDArray samples('c', { 1 }, std::vector<double>{ 5 }, nd4j::DataType::INT32);
|
||||
|
||||
RandomGenerator rng(1234, 1234);
|
||||
nd4j::ops::random_multinomial op;
|
||||
|
@ -1092,15 +1093,15 @@ TEST_F(RNGTests, test_multinomial_5) {
|
|||
int ClassValue = 2;
|
||||
int Samples = 100000;
|
||||
|
||||
NDArray samples('c', { 1 }, { 1.*Samples }, nd4j::DataType::INT32);
|
||||
|
||||
NDArray samples('c', { 1 }, std::vector<double>{ 1.*Samples }, nd4j::DataType::INT32);
|
||||
|
||||
NDArray probs('c', { ClassValue, batchValue }, { 1.0, 1.0 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
|
||||
nd4j::ops::random_multinomial op;
|
||||
|
||||
NDArray output('c', { Samples, batchValue }, nd4j::DataType::INT64);
|
||||
RandomGenerator rng(1234, 1234);
|
||||
|
||||
|
||||
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 1 }, {}, {}, false));
|
||||
|
||||
auto deviation = output.varianceNumber(variance::SummaryStatsStandardDeviation, false);
|
||||
|
@ -1109,7 +1110,7 @@ TEST_F(RNGTests, test_multinomial_5) {
|
|||
// theoretical values for binomial
|
||||
ASSERT_NEAR(0.5, deviation.e<double>(0), 4e-3); // 1000000 3e-3);
|
||||
ASSERT_NEAR(0.5, mean.e<double>(0), 4e-3); // 1000000 3e-3);
|
||||
|
||||
|
||||
for (int i = 0; i < output.lengthOf(); i++) {
|
||||
auto value = output.e<Nd4jLong>(i);
|
||||
ASSERT_TRUE(value >= 0 && value < ClassValue);
|
||||
|
@ -1139,8 +1140,8 @@ TEST_F(RNGTests, test_multinomial_6) {
|
|||
int batchValue = 1;
|
||||
int ClassValue = 5;
|
||||
int Samples = 100000;
|
||||
|
||||
NDArray samples('c', { 1 }, { 1. * Samples }, nd4j::DataType::INT32);
|
||||
|
||||
NDArray samples('c', { 1 }, std::vector<double>{ 1. * Samples }, nd4j::DataType::INT32);
|
||||
|
||||
nd4j::ops::random_multinomial op;
|
||||
NDArray probExpect('c', { ClassValue }, { 0.058, 0.096, 0.1576, 0.2598, 0.4287 }, nd4j::DataType::DOUBLE);
|
||||
|
@ -1152,8 +1153,8 @@ TEST_F(RNGTests, test_multinomial_6) {
|
|||
auto outputR = resultR->at(0);
|
||||
ASSERT_EQ(Status::OK(), resultR->status());
|
||||
|
||||
NDArray countsR('c', { ClassValue }, { 0, 0, 0, 0, 0 }, nd4j::DataType::DOUBLE);
|
||||
|
||||
NDArray countsR('c', { ClassValue }, { 0., 0, 0, 0, 0 }, nd4j::DataType::DOUBLE);
|
||||
|
||||
for (int i = 0; i < outputR->lengthOf(); i++) {
|
||||
auto value = outputR->e<Nd4jLong>(i);
|
||||
ASSERT_TRUE(value >= 0 && value < ClassValue);
|
||||
|
@ -1179,11 +1180,11 @@ TEST_F(RNGTests, test_multinomial_6) {
|
|||
RandomGenerator rng(1234, 1234);
|
||||
NDArray probs('c', { batchValue, ClassValue }, { 1., 1.5, 2., 2.5, 3. }, nd4j::DataType::FLOAT32);
|
||||
NDArray output('c', { batchValue, Samples }, nd4j::DataType::INT64);
|
||||
|
||||
|
||||
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, {}, false));
|
||||
|
||||
NDArray counts('c', { ClassValue }, { 0, 0, 0, 0, 0 }, nd4j::DataType::DOUBLE);
|
||||
|
||||
NDArray counts('c', { ClassValue }, { 0., 0, 0, 0, 0 }, nd4j::DataType::DOUBLE);
|
||||
|
||||
for (int i = 0; i < output.lengthOf(); i++) {
|
||||
auto value = output.e<Nd4jLong>(i);
|
||||
ASSERT_TRUE(value >= 0 && value < ClassValue);
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2019-2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -16,6 +17,7 @@
|
|||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
// @author Oleg Semeniv <oleg.semeniv@gmail.com>
|
||||
//
|
||||
|
||||
|
||||
|
@ -30,7 +32,7 @@ class StringTests : public testing::Test {
|
|||
public:
|
||||
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_Test_1) {
|
||||
std::string f("alpha");
|
||||
auto array = NDArrayFactory::string(f);
|
||||
|
@ -43,7 +45,7 @@ TEST_F(StringTests, Basic_Test_1) {
|
|||
|
||||
ASSERT_EQ(f, z);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_Test_2) {
|
||||
std::string f("alpha");
|
||||
auto array = NDArrayFactory::string(f.c_str());
|
||||
|
@ -56,23 +58,213 @@ TEST_F(StringTests, Basic_Test_2) {
|
|||
|
||||
ASSERT_EQ(f, z);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_Test_3) {
|
||||
auto array = NDArrayFactory::string('c', {3, 2}, {"alpha", "beta", "gamma", "phi", "theta", "omega"});
|
||||
|
||||
auto array = NDArrayFactory::string({3, 2}, {"alpha", "beta", "gamma", "phi", "theta", "omega"});
|
||||
|
||||
ASSERT_EQ(6, array.lengthOf());
|
||||
ASSERT_EQ(2, array.rankOf());
|
||||
|
||||
array.printIndexedBuffer("String array");
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_Test_4) {
|
||||
|
||||
NDArray array( { 3, 2 }, std::vector<const char32_t*>{ U"alpha", U"beta", U"gamma€한", U"pÿqwe", U"ß水𝄋", U"omega" });
|
||||
ASSERT_EQ(6, array.lengthOf());
|
||||
ASSERT_EQ(2, array.rankOf());
|
||||
|
||||
array.printIndexedBuffer("String array");
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_Test_5) {
|
||||
|
||||
NDArray array( { 3, 2 }, std::vector<const char16_t*>{ u"alpha", u"beta", u"gamma€한", u"pÿqwe", u"ß水𝄋", u"omega" });
|
||||
ASSERT_EQ(6, array.lengthOf());
|
||||
ASSERT_EQ(2, array.rankOf());
|
||||
|
||||
array.printIndexedBuffer("String array");
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_Test_6) {
|
||||
|
||||
NDArray array( { 3, 2 }, std::vector<const char*>{ "alpha", "beta", "gamma€한", "pÿqwe", "ß水𝄋", "omega" });
|
||||
ASSERT_EQ(6, array.lengthOf());
|
||||
ASSERT_EQ(2, array.rankOf());
|
||||
|
||||
array.printIndexedBuffer("String array");
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_Test_7) {
|
||||
|
||||
NDArray array( { 3, 2 }, std::vector<std::u32string>{ U"alpha", U"beta", U"gamma€한", U"pÿqwe", U"ß水𝄋", U"omega" });
|
||||
ASSERT_EQ(6, array.lengthOf());
|
||||
ASSERT_EQ(2, array.rankOf());
|
||||
|
||||
array.printIndexedBuffer("String array");
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_Test_8) {
|
||||
|
||||
NDArray array( { 3, 2 }, std::vector<std::u16string>{ u"alpha", u"beta", u"gamma€한", u"pÿqwe", u"ß水𝄋", u"omega" });
|
||||
ASSERT_EQ(6, array.lengthOf());
|
||||
ASSERT_EQ(2, array.rankOf());
|
||||
|
||||
array.printIndexedBuffer("String array");
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_Test_9) {
|
||||
|
||||
NDArray array( { 3, 2 }, std::vector<std::string>{ "alpha", "beta", "gamma€한", "pÿqwe", "ß水𝄋", "omega" });
|
||||
ASSERT_EQ(6, array.lengthOf());
|
||||
ASSERT_EQ(2, array.rankOf());
|
||||
|
||||
array.printIndexedBuffer("String array");
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_Test_10) {
|
||||
|
||||
NDArray array(std::u32string(U"gamma€한"));
|
||||
ASSERT_EQ(1, array.lengthOf());
|
||||
ASSERT_EQ(0, array.rankOf());
|
||||
array.printIndexedBuffer("String array");
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_Test_11) {
|
||||
|
||||
NDArray array(U"gamma€한");
|
||||
ASSERT_EQ(1, array.lengthOf());
|
||||
ASSERT_EQ(0, array.rankOf());
|
||||
|
||||
array.printIndexedBuffer("String array");
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_Test_12) {
|
||||
|
||||
NDArray array(std::u16string(u"gamma€한"));
|
||||
ASSERT_EQ(1, array.lengthOf());
|
||||
ASSERT_EQ(0, array.rankOf());
|
||||
array.printIndexedBuffer("String array");
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_Test_13) {
|
||||
|
||||
NDArray array(u"gamma€한");
|
||||
ASSERT_EQ(1, array.lengthOf());
|
||||
ASSERT_EQ(0, array.rankOf());
|
||||
|
||||
array.printIndexedBuffer("String array");
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_Test_14) {
|
||||
|
||||
NDArray array(std::string("gamma€한"));
|
||||
ASSERT_EQ(1, array.lengthOf());
|
||||
ASSERT_EQ(0, array.rankOf());
|
||||
array.printIndexedBuffer("String array");
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_Test_15) {
|
||||
|
||||
NDArray array("gamma€한");
|
||||
ASSERT_EQ(1, array.lengthOf());
|
||||
ASSERT_EQ(0, array.rankOf());
|
||||
|
||||
array.printIndexedBuffer("String array");
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_Test_16) {
|
||||
|
||||
auto array = NDArrayFactory::string( { 3, 2 }, std::vector<std::string>{ "alpha", "beta", "gamma", "phi", "theta", "omega" });
|
||||
|
||||
ASSERT_EQ(6, array.lengthOf());
|
||||
ASSERT_EQ(2, array.rankOf());
|
||||
|
||||
array.printIndexedBuffer("String array");
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_Test_17) {
|
||||
|
||||
auto array = NDArrayFactory::string({ 3, 2 }, std::vector<const char*>{ "alpha", "beta", "gamma", "phi", "theta", "omega" });
|
||||
|
||||
ASSERT_EQ(6, array.lengthOf());
|
||||
ASSERT_EQ(2, array.rankOf());
|
||||
|
||||
array.printIndexedBuffer("String array");
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_Test_18) {
|
||||
|
||||
auto array = NDArrayFactory::string({ 3, 2 }, std::vector<std::u16string>{ u"alpha", u"beta", u"gamma", u"phi", u"theta", u"omega" });
|
||||
|
||||
ASSERT_EQ(6, array.lengthOf());
|
||||
ASSERT_EQ(2, array.rankOf());
|
||||
|
||||
array.printIndexedBuffer("String array");
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_Test_19) {
|
||||
|
||||
auto array = NDArrayFactory::string( { 3, 2 }, std::vector<const char16_t*>{ u"alpha", u"beta", u"gamma", u"phi", u"theta", u"omega" });
|
||||
|
||||
ASSERT_EQ(6, array.lengthOf());
|
||||
ASSERT_EQ(2, array.rankOf());
|
||||
|
||||
array.printIndexedBuffer("String array");
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_Test_20) {
|
||||
|
||||
auto array = NDArrayFactory::string( { 3, 2 }, std::vector<std::u32string>{ U"alpha", U"beta", U"gamma", U"phi", U"theta", U"omega" });
|
||||
|
||||
ASSERT_EQ(6, array.lengthOf());
|
||||
ASSERT_EQ(2, array.rankOf());
|
||||
|
||||
array.printIndexedBuffer("String array");
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_Test_21) {
|
||||
|
||||
auto array = NDArrayFactory::string( { 3, 2 }, std::vector<const char32_t*>{ U"alpha", U"òèçùà12345¤z", U"ß水𝄋ÿ€한𐍈®кею90ощъ]ї", U"phi", U"theta", U"omega" });
|
||||
|
||||
ASSERT_EQ(6, array.lengthOf());
|
||||
ASSERT_EQ(2, array.rankOf());
|
||||
|
||||
array.printIndexedBuffer("String array");
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_Test_22) {
|
||||
std::u16string f(u"ß水𝄋ÿ€한𐍈®кею90ощъ]ї");
|
||||
auto array = NDArrayFactory::string(f.c_str());
|
||||
ASSERT_EQ(nd4j::DataType::UTF16, array.dataType());
|
||||
|
||||
ASSERT_EQ(1, array.lengthOf());
|
||||
ASSERT_EQ(0, array.rankOf());
|
||||
|
||||
auto z = array.e<std::u16string>(0);
|
||||
|
||||
ASSERT_EQ(f, z);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_Test_23) {
|
||||
std::u32string f(U"ß水𝄋ÿ€한𐍈®кею90ощъ]ї");
|
||||
auto array = NDArrayFactory::string(f.c_str());
|
||||
ASSERT_EQ(nd4j::DataType::UTF32, array.dataType());
|
||||
|
||||
ASSERT_EQ(1, array.lengthOf());
|
||||
ASSERT_EQ(0, array.rankOf());
|
||||
|
||||
auto z = array.e<std::u32string>(0);
|
||||
|
||||
ASSERT_EQ(f, z);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Export_Test_1) {
|
||||
auto array = NDArrayFactory::string('c', {3}, {"alpha", "beta", "gamma"});
|
||||
|
||||
auto array = NDArrayFactory::string( {3}, {"alpha", "beta", "gamma"});
|
||||
auto vector = array.asByteVector();
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_dup_1) {
|
||||
std::string f("alpha");
|
||||
auto array = NDArrayFactory::string(f);
|
||||
|
@ -91,20 +283,20 @@ TEST_F(StringTests, Basic_dup_1) {
|
|||
|
||||
delete dup;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, byte_length_test_1) {
|
||||
std::string f("alpha");
|
||||
auto array = NDArrayFactory::string(f);
|
||||
|
||||
ASSERT_EQ(f.length(), StringUtils::byteLength(array));
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, byte_length_test_2) {
|
||||
auto array = NDArrayFactory::string('c', {2}, {"alpha", "beta"});
|
||||
auto array = NDArrayFactory::string( {2}, {"alpha", "beta"});
|
||||
|
||||
ASSERT_EQ(9, StringUtils::byteLength(array));
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, test_split_1) {
|
||||
auto split = StringUtils::split("alpha beta gamma", " ");
|
||||
|
||||
|
@ -112,4 +304,562 @@ TEST_F(StringTests, test_split_1) {
|
|||
ASSERT_EQ(std::string("alpha"), split[0]);
|
||||
ASSERT_EQ(std::string("beta"), split[1]);
|
||||
ASSERT_EQ(std::string("gamma"), split[2]);
|
||||
}
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, test_unicode_utf8_utf16) {
|
||||
|
||||
std::string utf8 = u8"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~";
|
||||
std::u16string utf16Exp = u"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~";
|
||||
|
||||
std::u16string utf16Res;
|
||||
ASSERT_TRUE(StringUtils::u8StringToU16String(utf8, utf16Res));
|
||||
|
||||
ASSERT_EQ(utf16Res.size(), utf16Exp.size());
|
||||
for (auto i = 0; i < utf16Exp.size(); i++) {
|
||||
ASSERT_EQ(utf16Exp[i], utf16Res[i]);
|
||||
}
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, test_unicode_utf8_utf32) {
|
||||
|
||||
std::string utf8 = u8"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~";
|
||||
std::u32string utf32Exp = U"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~";
|
||||
|
||||
std::u32string utf32Res;
|
||||
ASSERT_TRUE(StringUtils::u8StringToU32String(utf8, utf32Res));
|
||||
|
||||
ASSERT_EQ(utf32Res.size(), utf32Exp.size());
|
||||
for (auto i = 0; i < utf32Exp.size(); i++) {
|
||||
ASSERT_EQ(utf32Exp[i], utf32Res[i]);
|
||||
}
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, test_unicode_utf16_utf8) {
|
||||
|
||||
std::string utf8Exp = u8"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~";
|
||||
std::u16string utf16 = u"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~";
|
||||
|
||||
std::string utf8Res;
|
||||
ASSERT_TRUE(StringUtils::u16StringToU8String(utf16, utf8Res));
|
||||
|
||||
ASSERT_EQ(utf8Res.size(), utf8Exp.size());
|
||||
for (auto i = 0; i < utf8Exp.size(); i++) {
|
||||
ASSERT_EQ(utf8Exp[i], utf8Res[i]);
|
||||
}
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, test_unicode_utf32_utf8) {
|
||||
|
||||
std::string utf8Exp = u8"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею 90ощъ]їїщkk1q\n\t\rop~";
|
||||
std::u32string utf32 = U"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею 90ощъ]їїщkk1q\n\t\rop~";
|
||||
|
||||
std::string utf8Res;
|
||||
ASSERT_TRUE(StringUtils::u32StringToU8String(utf32, utf8Res));
|
||||
|
||||
ASSERT_EQ(utf8Res.size(), utf8Exp.size());
|
||||
for (auto i = 0; i < utf8Exp.size(); i++) {
|
||||
ASSERT_EQ(utf8Exp[i], utf8Res[i]);
|
||||
}
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, test_unicode_utf16_utf32) {
|
||||
|
||||
std::u32string utf32Exp = U"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~";
|
||||
std::u16string utf16 = u"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~";
|
||||
|
||||
std::u32string utf32Res;
|
||||
ASSERT_TRUE(StringUtils::u16StringToU32String(utf16, utf32Res));
|
||||
|
||||
ASSERT_EQ(utf32Res.size(), utf32Exp.size());
|
||||
for (auto i = 0; i < utf32Exp.size(); i++) {
|
||||
ASSERT_EQ(utf32Exp[i], utf32Res[i]);
|
||||
}
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, test_unicode_utf32_utf16) {
|
||||
|
||||
std::u16string utf16Exp = u"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~";
|
||||
std::u32string utf32 = U"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~";
|
||||
|
||||
std::u16string utf16Res;
|
||||
ASSERT_TRUE(StringUtils::u32StringToU16String(utf32, utf16Res));
|
||||
|
||||
ASSERT_EQ(utf16Res.size(), utf16Exp.size());
|
||||
for (auto i = 0; i < utf16Exp.size(); i++) {
|
||||
ASSERT_EQ(utf16Exp[i], utf16Res[i]);
|
||||
}
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, byte_length_test_Default) {
|
||||
|
||||
std::string f("alpha");
|
||||
auto array = NDArrayFactory::string(f);
|
||||
|
||||
ASSERT_EQ(f.length(), StringUtils::byteLength(array));
|
||||
|
||||
std::u16string f16(u"alpha");
|
||||
auto array16 = NDArrayFactory::string(f16);
|
||||
|
||||
ASSERT_EQ(sizeof(char16_t)*f16.length(), StringUtils::byteLength(array16));
|
||||
|
||||
std::u32string f32(U"alpha");
|
||||
auto array32 = NDArrayFactory::string(f32);
|
||||
|
||||
ASSERT_EQ(sizeof(char32_t) * f32.length(), StringUtils::byteLength(array32));
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, byte_length_test_UTF16) {
|
||||
std::string f(u8"alpha");
|
||||
auto array = NDArrayFactory::string(f, nd4j::DataType::UTF16);
|
||||
|
||||
ASSERT_EQ(sizeof(char16_t) * f.length(), StringUtils::byteLength(array));
|
||||
|
||||
std::u16string f16(u"alpha");
|
||||
auto array16 = NDArrayFactory::string(f16, nd4j::DataType::UTF16);
|
||||
|
||||
ASSERT_EQ(sizeof(char16_t) * f16.length(), StringUtils::byteLength(array16));
|
||||
|
||||
std::u32string f32(U"alpha");
|
||||
auto array32 = NDArrayFactory::string(f32, nd4j::DataType::UTF16);
|
||||
|
||||
ASSERT_EQ(sizeof(char16_t) * f32.length(), StringUtils::byteLength(array32));
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_Test_UTF16toU8) {
|
||||
|
||||
std::u16string f16(u"alpha水𝄋ÿ€한𐍈®кею");
|
||||
auto array = NDArrayFactory::string(f16, nd4j::DataType::UTF8);
|
||||
ASSERT_EQ(nd4j::DataType::UTF8, array.dataType());
|
||||
|
||||
ASSERT_EQ(1, array.lengthOf());
|
||||
ASSERT_EQ(0, array.rankOf());
|
||||
|
||||
auto z = array.e<std::string>(0);
|
||||
|
||||
std::string f(u8"alpha水𝄋ÿ€한𐍈®кею");
|
||||
ASSERT_EQ(f, z);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_Test_UTF32toU8) {
|
||||
std::u32string f32(U"alpha水𝄋ÿ€한𐍈®кею");
|
||||
auto array = NDArrayFactory::string(f32.c_str(), nd4j::DataType::UTF8);
|
||||
ASSERT_EQ(nd4j::DataType::UTF8, array.dataType());
|
||||
|
||||
ASSERT_EQ(1, array.lengthOf());
|
||||
ASSERT_EQ(0, array.rankOf());
|
||||
|
||||
auto z = array.e<std::string>(0);
|
||||
std::string f(u8"alpha水𝄋ÿ€한𐍈®кею");
|
||||
ASSERT_EQ(f, z);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_Test_UTF16toU16) {
|
||||
|
||||
std::u16string f16(u"€alpha水𝄋ÿ€한𐍈®кею");
|
||||
auto array = NDArrayFactory::string(f16, nd4j::DataType::UTF16);
|
||||
ASSERT_EQ(nd4j::DataType::UTF16, array.dataType());
|
||||
|
||||
ASSERT_EQ(1, array.lengthOf());
|
||||
ASSERT_EQ(0, array.rankOf());
|
||||
auto z = array.e<std::u16string>(0);
|
||||
|
||||
ASSERT_EQ(z, f16);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_Test_UTF32toU16) {
|
||||
|
||||
std::u32string f32(U"€alpha水𝄋ÿ€한𐍈®кею");
|
||||
auto array = NDArrayFactory::string(f32, nd4j::DataType::UTF16);
|
||||
ASSERT_EQ(nd4j::DataType::UTF16, array.dataType());
|
||||
|
||||
ASSERT_EQ(1, array.lengthOf());
|
||||
ASSERT_EQ(0, array.rankOf());
|
||||
auto z = array.e<std::u16string>(0);
|
||||
std::u16string f16(u"€alpha水𝄋ÿ€한𐍈®кею");
|
||||
ASSERT_EQ(z, f16);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_Test_UTF16toU32) {
|
||||
|
||||
std::u16string f16(u"€alpha水𝄋ÿ€한𐍈®кею");
|
||||
auto array = NDArrayFactory::string(f16, nd4j::DataType::UTF32);
|
||||
ASSERT_EQ(nd4j::DataType::UTF32, array.dataType());
|
||||
|
||||
ASSERT_EQ(1, array.lengthOf());
|
||||
ASSERT_EQ(0, array.rankOf());
|
||||
|
||||
auto z = array.e<std::u32string>(0);
|
||||
std::u32string fres(U"€alpha水𝄋ÿ€한𐍈®кею");
|
||||
ASSERT_EQ(z, fres);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_Test_UTF32toU32) {
|
||||
|
||||
std::u32string f32(U"€alpha水𝄋ÿ€한𐍈®кею");
|
||||
auto array = NDArrayFactory::string(f32);
|
||||
ASSERT_EQ(nd4j::DataType::UTF32, array.dataType());
|
||||
|
||||
ASSERT_EQ(1, array.lengthOf());
|
||||
ASSERT_EQ(0, array.rankOf());
|
||||
auto z = array.e<std::u32string>(0);
|
||||
ASSERT_EQ(f32, z);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_Test_UTF8toU32) {
|
||||
|
||||
std::string f(u8"€alpha水𝄋ÿ€한𐍈®кею");
|
||||
auto array = NDArrayFactory::string(f, nd4j::DataType::UTF32);
|
||||
ASSERT_EQ(nd4j::DataType::UTF32, array.dataType());
|
||||
|
||||
ASSERT_EQ(1, array.lengthOf());
|
||||
ASSERT_EQ(0, array.rankOf());
|
||||
std::u32string f32(U"€alpha水𝄋ÿ€한𐍈®кею");
|
||||
auto z = array.e<std::u32string>(0);
|
||||
ASSERT_EQ(f32, z);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_Test_StringVecU8toUTF16) {
|
||||
auto array = NDArrayFactory::string({ 3, 2 }, { "alpha€", "beta", "gamma水", "phi", "theta", "omega水" }, nd4j::DataType::UTF16);
|
||||
|
||||
ASSERT_EQ(6, array.lengthOf());
|
||||
ASSERT_EQ(2, array.rankOf());
|
||||
|
||||
array.printIndexedBuffer("String array");
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_Test_StringVecU8toUTF32) {
|
||||
auto array = NDArrayFactory::string( { 3, 2 }, { "alpha€", "beta水", "gamma", "phi", "theta", "omega" }, nd4j::DataType::UTF32);
|
||||
|
||||
ASSERT_EQ(6, array.lengthOf());
|
||||
ASSERT_EQ(2, array.rankOf());
|
||||
|
||||
array.printIndexedBuffer("String array");
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Export_Test_U8toUTF16) {
|
||||
auto array = NDArrayFactory::string({ 3 }, { "alpha", "beta", "gamma" }, nd4j::DataType::UTF16);
|
||||
|
||||
auto vector = array.asByteVector();
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Export_Test_U8toUTF32) {
|
||||
auto array = NDArrayFactory::string({ 3 }, { "alpha", "beta", "gamma" }, nd4j::DataType::UTF32);
|
||||
|
||||
auto vector = array.asByteVector();
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_Test_StringVecU16toUTF16) {
|
||||
auto array = NDArrayFactory::string({ 3, 2 }, { u"alpha水", u"beta", u"gamma", u"phi", u"theta水", u"omega" }, nd4j::DataType::UTF16);
|
||||
|
||||
ASSERT_EQ(6, array.lengthOf());
|
||||
ASSERT_EQ(2, array.rankOf());
|
||||
|
||||
array.printIndexedBuffer("String array");
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_Test_StringVecU16toUTF32) {
|
||||
auto array = NDArrayFactory::string( { 3, 2 }, { u"alpha水", u"beta", u"gamma水", u"phi", u"theta", u"omega" }, nd4j::DataType::UTF32);
|
||||
|
||||
ASSERT_EQ(6, array.lengthOf());
|
||||
ASSERT_EQ(2, array.rankOf());
|
||||
|
||||
array.printIndexedBuffer("String array");
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_Test_StringVecU16toUTF8) {
|
||||
auto array = NDArrayFactory::string( { 3, 2 }, { u"alpha€", u"beta水", u"gamma", u"phi水", u"theta", u"omega" }, nd4j::DataType::UTF8);
|
||||
|
||||
ASSERT_EQ(6, array.lengthOf());
|
||||
ASSERT_EQ(2, array.rankOf());
|
||||
|
||||
array.printIndexedBuffer("String array");
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Export_Test_U16toUTF8) {
|
||||
auto array = NDArrayFactory::string( { 3 }, { u"alpha", u"beta", u"gamma" }, nd4j::DataType::UTF8);
|
||||
|
||||
auto vector = array.asByteVector();
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Export_Test_U16toUTF16) {
|
||||
auto array = NDArrayFactory::string( { 3 }, { u"alpha", u"beta", u"gamma" }, nd4j::DataType::UTF16);
|
||||
|
||||
auto vector = array.asByteVector();
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Export_Test_U16toUTF32) {
|
||||
auto array = NDArrayFactory::string( { 3 }, { u"alpha水", u"beta", u"gamma水" }, nd4j::DataType::UTF32);
|
||||
|
||||
auto vector = array.asByteVector();
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_Test_StringVecU32toUTF32) {
|
||||
auto array = NDArrayFactory::string( { 3, 2 }, { U"alpha€", U"beta水", U"gamma", U"phi", U"theta", U"omega水" }, nd4j::DataType::UTF32);
|
||||
|
||||
ASSERT_EQ(6, array.lengthOf());
|
||||
ASSERT_EQ(2, array.rankOf());
|
||||
|
||||
array.printIndexedBuffer("String array");
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_Test_StringVecU32toUTF16) {
|
||||
auto array = NDArrayFactory::string({ 3, 2 }, { U"alpha水", U"水beta", U"gamma", U"phi水", U"theta", U"omega" }, nd4j::DataType::UTF16);
|
||||
|
||||
ASSERT_EQ(6, array.lengthOf());
|
||||
ASSERT_EQ(2, array.rankOf());
|
||||
|
||||
array.printIndexedBuffer("String array");
|
||||
|
||||
printf("Array elements size: \n");
|
||||
for (int e = 0; e < array.lengthOf(); e++) {
|
||||
printf("Element %d size: %d\n", e, static_cast<int>(array.e<std::u16string>(e).size()));
|
||||
}
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_Test_StringVecU32toUTF8) {
|
||||
auto array = NDArrayFactory::string( { 3, 2 }, { U"alpha水", U"beta", U"gamma水", U"phi", U"theta", U"omega" }, nd4j::DataType::UTF8);
|
||||
|
||||
ASSERT_EQ(6, array.lengthOf());
|
||||
ASSERT_EQ(2, array.rankOf());
|
||||
|
||||
array.printIndexedBuffer("String array");
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Export_Test_U32toUTF32) {
|
||||
auto array = NDArrayFactory::string( { 3 }, { U"alpha", U"beta", U"gamma" }, nd4j::DataType::UTF32);
|
||||
|
||||
auto vector = array.asByteVector();
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Export_Test_U32toUTF16) {
|
||||
auto array = NDArrayFactory::string( { 3 }, { U"alpha", U"beta水", U"gamma水" }, nd4j::DataType::UTF16);
|
||||
|
||||
auto vector = array.asByteVector();
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Export_Test_U32toUTF8) {
|
||||
auto array = NDArrayFactory::string( { 3 }, { U"alpha", U"beta", U"gamma水" }, nd4j::DataType::UTF8);
|
||||
|
||||
auto vector = array.asByteVector();
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_dup_UTF16) {
|
||||
std::u16string f(u"€alpha水𝄋ÿ€한𐍈®кею");
|
||||
auto array = NDArrayFactory::string(f);
|
||||
ASSERT_EQ(nd4j::DataType::UTF16, array.dataType());
|
||||
|
||||
ASSERT_EQ(1, array.lengthOf());
|
||||
ASSERT_EQ(0, array.rankOf());
|
||||
|
||||
auto dup = new NDArray(array.dup());
|
||||
|
||||
auto z0 = array.e<std::u16string>(0);
|
||||
auto z1 = dup->e<std::u16string>(0);
|
||||
|
||||
ASSERT_EQ(f, z0);
|
||||
ASSERT_EQ(f, z1);
|
||||
|
||||
delete dup;
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_dup_UTF32) {
|
||||
std::u32string f(U"€alpha水𝄋ÿ€한𐍈®кею");
|
||||
auto array = NDArrayFactory::string(f);
|
||||
ASSERT_EQ(nd4j::DataType::UTF32, array.dataType());
|
||||
|
||||
ASSERT_EQ(1, array.lengthOf());
|
||||
ASSERT_EQ(0, array.rankOf());
|
||||
|
||||
auto dup = new NDArray(array.dup());
|
||||
|
||||
auto z0 = array.e<std::u32string>(0);
|
||||
auto z1 = dup->e<std::u32string>(0);
|
||||
|
||||
ASSERT_EQ(f, z0);
|
||||
ASSERT_EQ(f, z1);
|
||||
|
||||
delete dup;
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_cast_UTF32toUTF8) {
|
||||
|
||||
std::u32string u32(U"€alpha水𝄋ÿ€한𐍈®кею");
|
||||
|
||||
std::string u8(u8"€alpha水𝄋ÿ€한𐍈®кею");
|
||||
|
||||
auto array = NDArrayFactory::string(u32);
|
||||
ASSERT_EQ(nd4j::DataType::UTF32, array.dataType());
|
||||
|
||||
ASSERT_EQ(1, array.lengthOf());
|
||||
ASSERT_EQ(0, array.rankOf());
|
||||
|
||||
auto aCast = array.cast(nd4j::DataType::UTF8);
|
||||
|
||||
auto z0 = array.e<std::u32string>(0);
|
||||
auto z1 = aCast.e<std::string>(0);
|
||||
|
||||
ASSERT_EQ(u32, z0);
|
||||
ASSERT_EQ(u8, z1);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_cast_UTF32toUTF16) {
|
||||
|
||||
std::u32string u32(U"€alpha水𝄋ÿ€한𐍈®кею");
|
||||
|
||||
std::u16string u16(u"€alpha水𝄋ÿ€한𐍈®кею");
|
||||
|
||||
auto array = NDArrayFactory::string(u32);
|
||||
ASSERT_EQ(nd4j::DataType::UTF32, array.dataType());
|
||||
|
||||
ASSERT_EQ(1, array.lengthOf());
|
||||
ASSERT_EQ(0, array.rankOf());
|
||||
|
||||
auto aCast = array.cast(nd4j::DataType::UTF16);
|
||||
|
||||
auto z0 = array.e<std::u32string>(0);
|
||||
auto z1 = aCast.e<std::u16string>(0);
|
||||
|
||||
ASSERT_EQ(u32, z0);
|
||||
ASSERT_EQ(u16, z1);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_cast_UTF32toUTF32) {
|
||||
|
||||
std::u32string u32(U"€alpha水𝄋ÿ€한𐍈®кею");
|
||||
|
||||
auto array = NDArrayFactory::string(u32);
|
||||
ASSERT_EQ(nd4j::DataType::UTF32, array.dataType());
|
||||
|
||||
ASSERT_EQ(1, array.lengthOf());
|
||||
ASSERT_EQ(0, array.rankOf());
|
||||
|
||||
auto aCast = array.cast(nd4j::DataType::UTF32);
|
||||
|
||||
auto z0 = array.e<std::u32string>(0);
|
||||
auto z1 = aCast.e<std::u32string>(0);
|
||||
|
||||
ASSERT_EQ(u32, z0);
|
||||
ASSERT_EQ(u32, z1);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_cast_UTF16toUTF16) {
|
||||
|
||||
std::u16string u16(u"€alpha水𝄋ÿ€한𐍈®кею");
|
||||
|
||||
auto array = NDArrayFactory::string(u16);
|
||||
ASSERT_EQ(nd4j::DataType::UTF16, array.dataType());
|
||||
|
||||
ASSERT_EQ(1, array.lengthOf());
|
||||
ASSERT_EQ(0, array.rankOf());
|
||||
|
||||
auto aCast = array.cast(nd4j::DataType::UTF16);
|
||||
|
||||
auto z0 = array.e<std::u16string>(0);
|
||||
auto z1 = aCast.e<std::u16string>(0);
|
||||
|
||||
ASSERT_EQ(u16, z0);
|
||||
ASSERT_EQ(u16, z1);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_cast_UTF16toUTF32) {
|
||||
|
||||
std::u32string u32(U"€alpha水𝄋ÿ€한𐍈®кею");
|
||||
|
||||
std::u16string u16(u"€alpha水𝄋ÿ€한𐍈®кею");
|
||||
|
||||
auto array = NDArrayFactory::string(u16);
|
||||
ASSERT_EQ(nd4j::DataType::UTF16, array.dataType());
|
||||
|
||||
ASSERT_EQ(1, array.lengthOf());
|
||||
ASSERT_EQ(0, array.rankOf());
|
||||
|
||||
auto aCast = array.cast(nd4j::DataType::UTF32);
|
||||
|
||||
auto z0 = array.e<std::u16string>(0);
|
||||
auto z1 = aCast.e<std::u32string>(0);
|
||||
|
||||
ASSERT_EQ(u32, z1);
|
||||
ASSERT_EQ(u16, z0);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_cast_UTF16toUTF8) {
|
||||
|
||||
std::string u8(u8"€alpha水𝄋ÿ€한𐍈®кею");
|
||||
|
||||
std::u16string u16(u"€alpha水𝄋ÿ€한𐍈®кею");
|
||||
|
||||
auto array = NDArrayFactory::string(u16);
|
||||
ASSERT_EQ(nd4j::DataType::UTF16, array.dataType());
|
||||
|
||||
ASSERT_EQ(1, array.lengthOf());
|
||||
ASSERT_EQ(0, array.rankOf());
|
||||
|
||||
auto aCast = array.cast(nd4j::DataType::UTF8);
|
||||
|
||||
auto z0 = array.e<std::u16string>(0);
|
||||
auto z1 = aCast.e<std::string>(0);
|
||||
|
||||
ASSERT_EQ(u8, z1);
|
||||
ASSERT_EQ(u16, z0);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_cast_UTF8toUTF8) {
|
||||
|
||||
std::string u8("€alpha水𝄋ÿ€한𐍈®кею");
|
||||
|
||||
auto array = NDArrayFactory::string(u8);
|
||||
ASSERT_EQ(nd4j::DataType::UTF8, array.dataType());
|
||||
|
||||
ASSERT_EQ(1, array.lengthOf());
|
||||
ASSERT_EQ(0, array.rankOf());
|
||||
|
||||
auto aCast = array.cast(nd4j::DataType::UTF8);
|
||||
|
||||
auto z0 = array.e<std::string>(0);
|
||||
auto z1 = aCast.e<std::string>(0);
|
||||
|
||||
ASSERT_EQ(u8, z1);
|
||||
ASSERT_EQ(u8, z0);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_cast_UTF8toUTF16) {
|
||||
|
||||
std::string u8(u8"€alpha水𝄋ÿ€한𐍈®кею");
|
||||
|
||||
std::u16string u16(u"€alpha水𝄋ÿ€한𐍈®кею");
|
||||
|
||||
auto array = NDArrayFactory::string(u8);
|
||||
ASSERT_EQ(nd4j::DataType::UTF8, array.dataType());
|
||||
|
||||
ASSERT_EQ(1, array.lengthOf());
|
||||
ASSERT_EQ(0, array.rankOf());
|
||||
|
||||
auto aCast = array.cast(nd4j::DataType::UTF16);
|
||||
|
||||
auto z0 = array.e<std::string>(0);
|
||||
auto z1 = aCast.e<std::u16string>(0);
|
||||
|
||||
ASSERT_EQ(u8, z0);
|
||||
ASSERT_EQ(u16, z1);
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(StringTests, Basic_cast_UTF8toUTF32) {
|
||||
|
||||
std::string u8(u8"€alpha水𝄋ÿ€한𐍈®кею");
|
||||
|
||||
std::u32string u32(U"€alpha水𝄋ÿ€한𐍈®кею");
|
||||
|
||||
auto array = NDArrayFactory::string(u8);
|
||||
ASSERT_EQ(nd4j::DataType::UTF8, array.dataType());
|
||||
|
||||
ASSERT_EQ(1, array.lengthOf());
|
||||
ASSERT_EQ(0, array.rankOf());
|
||||
|
||||
auto aCast = array.cast(nd4j::DataType::UTF32);
|
||||
|
||||
auto z0 = array.e<std::string>(0);
|
||||
auto z1 = aCast.e<std::u32string>(0);
|
||||
|
||||
ASSERT_EQ(u8, z0);
|
||||
ASSERT_EQ(u32, z1);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue