Oleh tenzor mmul (#231)
* Libnd4j: TensorMMul backprop op #8174, raw implementation Signed-off-by: Oleg <oleg.semeniv@gmail.com> * Libnd4j: TensorMMul backprop op #8174 merge master and some corrections Signed-off-by: Oleg <oleg.semeniv@gmail.com> * Libnd4j: TensorMMul backprop op #8174 algorithm update, need testing, sync with master * Libnd4j: TensorMMul backprop op #8174 fixed incorrect B axes calculation Signed-off-by: Oleg <oleg.semeniv@gmail.com> * Libnd4j: TensorMMul backprop op #8174 optimize axes identification and fix bug of indeces overlapping, added first test. need testing with different shapes Signed-off-by: Oleg <oleg.semeniv@gmail.com> * Libnd4j: TensorMMul backprop op #8174 some fixes and improvements need more testing Signed-off-by: Oleg <oleg.semeniv@gmail.com> * Libnd4j: TensorMMul backprop op #8174 fixed order of matrix multiply Signed-off-by: Oleg <oleg.semeniv@gmail.com> * Libnd4j: TensorMMul backprop op #8174 fixed issue of incorrect axes definition, add tests based on TF, need additional testing for case dLdC not equal 1 Signed-off-by: Oleg <oleg.semeniv@gmail.com> * Libnd4j: TensorMMul backprop op #8174 fixed scalar case add test Signed-off-by: Oleg <oleg.semeniv@gmail.com> * Libnd4j: TensorMMul backprop op #8174 fixed bp algorithm, axes definition, need some mode testing with different orders combination f,c; c,f f,f and add some checks for inputs Signed-off-by: Oleg <oleg.semeniv@gmail.com> * Libnd4j: TensorMMul backprop op #8174 some checks and corrections added tests, exists the problem with different input orders support A-f B-c and A-f B-f Signed-off-by: Oleg <oleg.semeniv@gmail.com> * Libnd4j: TensorMMul backprop op #8174 sync master Signed-off-by: Oleg <oleg.semeniv@gmail.com> * - correct bug in MmulHelper::tensorDot(a, b, c, axes_a, axes_b,permutForC) Signed-off-by: Yurii <iuriish@yahoo.com> * Libnd4j: TensorMMul backprop op #8174 code clean up and refactoring Signed-off-by: Oleg <oleg.semeniv@gmail.com> * - add check for linspase ordered permutations in ShapeUtils::evalShapeForTensorDot Signed-off-by: Yurii <iuriish@yahoo.com> * - provide additional code in shape::reshape stuff in order to reduce amount of allocation/copy operations during reshaping procedure Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on problem of wrong shape evaluation during permute/reshape procedures Signed-off-by: Yurii <iuriish@yahoo.com> * - still looking for bug reason in reshape/permute stuff Signed-off-by: Yurii <iuriish@yahoo.com> * - correct bug in transform cuda native ops Signed-off-by: Yurii <iuriish@yahoo.com> * - correct bug in NDArray::assign Signed-off-by: Yurii <iuriish@yahoo.com> * - remove old shape::reshape stuff Signed-off-by: Yurii <iuriish@yahoo.com> * - add possibility to disable copy of old buffer to new buffer during reshape operation in NDArray class Signed-off-by: Yurii <iuriish@yahoo.com> * - correct bug in tensorDot which had to do with wrong pointers assigments Signed-off-by: Yurii <iuriish@yahoo.com> Co-authored-by: Oleh <oleg.semeniv@gmail.com>
This commit is contained in:
		
							parent
							
								
									8c0e378ec3
								
							
						
					
					
						commit
						fe47f52896
					
				| @ -999,14 +999,14 @@ namespace nd4j { | ||||
|         *  set new order and shape in case of suitable array length (in-place operation) | ||||
|         *  order - order to set | ||||
|         *  shape - shape to set | ||||
|         * | ||||
|         *  copyToNewBuff - if true then old buffer will be copied to new buffer if last one will be allocated after reshaping | ||||
|         *  if there was permute applied before or there are weird strides, then new buffer is allocated for array | ||||
|         */ | ||||
| 		bool reshapei(const char order, const std::initializer_list<Nd4jLong>& shape); | ||||
| 		bool reshapei(const char order, const std::vector<Nd4jLong>& shape); | ||||
| 		bool reshapei(const char order, const std::initializer_list<Nd4jLong>& shape, const bool copyToNewBuff = true); | ||||
| 		bool reshapei(const char order, const std::vector<Nd4jLong>& shape, const bool copyToNewBuff = true); | ||||
| 
 | ||||
|         bool reshapei(const std::initializer_list<Nd4jLong>& shape); | ||||
| 		bool reshapei(const std::vector<Nd4jLong>& shape); | ||||
|         bool reshapei(const std::initializer_list<Nd4jLong>& shape, const bool copyToNewBuff = true); | ||||
| 		bool reshapei(const std::vector<Nd4jLong>& shape, const bool copyToNewBuff = true); | ||||
| 
 | ||||
|         /**
 | ||||
|         *  creates new array with corresponding order and shape, new array will point on _buffer of this array | ||||
| @ -1015,8 +1015,8 @@ namespace nd4j { | ||||
|         * | ||||
|         * if permute have been applied before or there are weird strides, then new buffer is allocated for new array | ||||
|         */ | ||||
| 		NDArray reshape(const char order, const std::vector<Nd4jLong>& shape) const &; | ||||
|         NDArray reshape(const char order, const std::vector<Nd4jLong>& shape) &&; | ||||
| 		NDArray reshape(const char order, const std::vector<Nd4jLong>& shape, const bool copyToNewBuff = true) const &; | ||||
|         NDArray reshape(const char order, const std::vector<Nd4jLong>& shape, const bool copyToNewBuff = true) &&; | ||||
| 
 | ||||
|         /**
 | ||||
|         *  calculate strides and set given order | ||||
| @ -1493,7 +1493,7 @@ namespace nd4j { | ||||
|          * @return | ||||
|          */ | ||||
|         bool isS() const; | ||||
|          | ||||
| 
 | ||||
|         template <typename T> | ||||
|         std::vector<T> asVectorT(); | ||||
| 
 | ||||
|  | ||||
| @ -42,7 +42,7 @@ ND4J_EXPORT std::u32string NDArray::e(const Nd4jLong i) const; | ||||
| ////////////////////////////////////////////////////////////////////////
 | ||||
| // copy constructor
 | ||||
| NDArray::NDArray(const NDArray& other) { | ||||
|      | ||||
| 
 | ||||
|     _context = other._context; | ||||
|     _offset  = 0; | ||||
| 
 | ||||
| @ -308,7 +308,7 @@ NDArray::NDArray(const std::u16string& u16string, nd4j::DataType dtype, nd4j::La | ||||
|     if (!unicode::isStringValidU16(u16string.data(), u16string.data() + u16string.size())) { | ||||
|         throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); | ||||
|     } | ||||
|      | ||||
| 
 | ||||
|     // one word that is why used 1
 | ||||
|     Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(1); | ||||
| 
 | ||||
| @ -435,11 +435,11 @@ NDArray::NDArray(const std::string& str, nd4j::DataType dtype, nd4j::LaunchConte | ||||
|     _offset = 0; | ||||
| 
 | ||||
|     setShapeInfo(ShapeDescriptor::scalarDescriptor(dtype)); | ||||
|      | ||||
| 
 | ||||
|     memcpy(bufferAsT<int8_t>(), &offsets[0], 2 * sizeof(Nd4jLong)); | ||||
|      | ||||
| 
 | ||||
|     auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength); | ||||
|      | ||||
| 
 | ||||
|     if (dtype == DataType::UTF8) { | ||||
|         memcpy(data, str.data(), str.size()); | ||||
|     } | ||||
| @ -456,13 +456,13 @@ NDArray::NDArray(const std::string& str, nd4j::DataType dtype, nd4j::LaunchConte | ||||
| /////////////////////////////////////////////////////////////////////////
 | ||||
| // constructors for vector of  strings
 | ||||
| NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const char*>& string, const nd4j::DataType dataType, nd4j::LaunchContext* context) { | ||||
|      | ||||
| 
 | ||||
|     if (!DataTypeUtils::isS(dataType)) | ||||
|         throw std::invalid_argument("NDArray::NDArray: invalid DataType, only string dataTypes have to be used"); | ||||
| 
 | ||||
|     if (shape::prodLong(shape.data(), shape.size()) != string.size()) | ||||
|         throw std::invalid_argument("NDArray::NDArray: Number of strings should match length of array"); | ||||
|          | ||||
| 
 | ||||
|     for (const auto& str : string) { | ||||
|         if (!unicode::isStringValidU8(str, str + std::char_traits<char>::length(str)) ) { | ||||
|             throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); | ||||
| @ -497,7 +497,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const cha | ||||
|     setAttached(context->getWorkspace() != nullptr); | ||||
| 
 | ||||
|     memcpy(bufferAsT<int8_t>(), offsets.data(), offsets.size() * sizeof(Nd4jLong)); | ||||
|      | ||||
| 
 | ||||
|     auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength); | ||||
| 
 | ||||
|     auto func = PRAGMA_THREADS_FOR{ | ||||
| @ -631,9 +631,9 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::u16s | ||||
|     setAttached(context->getWorkspace() != nullptr); | ||||
| 
 | ||||
|     memcpy(bufferAsT<int8_t>(), offsets.data(), offsets.size() * sizeof(Nd4jLong)); | ||||
|      | ||||
| 
 | ||||
|     auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength); | ||||
|      | ||||
| 
 | ||||
|     auto func = PRAGMA_THREADS_FOR{ | ||||
|         for (auto e = start; e < stop; e += increment) { | ||||
|              auto cdata = data + offsets[e]; | ||||
| @ -699,7 +699,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const cha | ||||
| 
 | ||||
|     auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength); | ||||
| 
 | ||||
|      | ||||
| 
 | ||||
|     auto func = PRAGMA_THREADS_FOR{ | ||||
|         for (auto e = start; e < stop; e += increment) { | ||||
|              auto cdata = data + offsets[e]; | ||||
| @ -715,7 +715,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const cha | ||||
|         } | ||||
|     }; | ||||
|     samediff::Threads::parallel_for(func, 0, lengthOf(), 1); | ||||
|      | ||||
| 
 | ||||
|     tickWriteHost(); | ||||
|     syncToDevice(); | ||||
| } | ||||
| @ -764,8 +764,8 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::u32s | ||||
| 
 | ||||
|     memcpy(bufferAsT<int8_t>(), offsets.data(), offsets.size() * sizeof(Nd4jLong)); | ||||
| 
 | ||||
|     auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);  | ||||
|      | ||||
|     auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength); | ||||
| 
 | ||||
|     auto func = PRAGMA_THREADS_FOR{ | ||||
|         for (auto e = start; e < stop; e += increment) { | ||||
|             auto cdata = data + offsets[e]; | ||||
| @ -781,7 +781,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::u32s | ||||
|         } | ||||
|     }; | ||||
|     samediff::Threads::parallel_for(func, 0, lengthOf(), 1); | ||||
|      | ||||
| 
 | ||||
|     tickWriteHost(); | ||||
|     syncToDevice(); | ||||
| } | ||||
| @ -831,7 +831,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const cha | ||||
|     memcpy(bufferAsT<int8_t>(), offsets.data(), offsets.size() * sizeof(Nd4jLong)); | ||||
| 
 | ||||
|     auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength); | ||||
|      | ||||
| 
 | ||||
|     auto func = PRAGMA_THREADS_FOR{ | ||||
|         for (auto e = start; e < stop; e += increment) { | ||||
|             auto cdata = data + offsets[e]; | ||||
| @ -847,7 +847,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const cha | ||||
|         } | ||||
|     }; | ||||
|     samediff::Threads::parallel_for(func, 0, lengthOf(), 1); | ||||
|      | ||||
| 
 | ||||
|     tickWriteHost(); | ||||
|     syncToDevice(); | ||||
| } | ||||
| @ -887,8 +887,8 @@ bool NDArray::isC() const { | ||||
| 
 | ||||
| //////////////////////////////////////////////////////////////////////////
 | ||||
| bool NDArray::isS() const { | ||||
|     return (dataType() == DataType::UTF8 ||  | ||||
|             dataType() == DataType::UTF16 ||  | ||||
|     return (dataType() == DataType::UTF8 || | ||||
|             dataType() == DataType::UTF16 || | ||||
|             dataType() == DataType::UTF32); | ||||
| } | ||||
| 
 | ||||
| @ -1197,8 +1197,8 @@ void NDArray::assign(const NDArray& other, bool allowParallelism) { | ||||
|             throw std::runtime_error("NDArray::assign: lengths of arrays are mismatched"); | ||||
|         } | ||||
| 
 | ||||
|         // memcpy is allowed only for same order && same ews (being equal to 1)
 | ||||
|         if (ordering() == other.ordering() && dataType() == other.dataType() && ews() == 1 && other.ews() == 1) | ||||
|         // memcpy is allowed only for same order c && same ews (being equal to 1)
 | ||||
|         if (ordering() == other.ordering() && ordering() == 'c' && dataType() == other.dataType() && ews() == 1 && other.ews() == 1) | ||||
|             copyBuffersContinuouslyFrom(other, other.lengthOf() * other.sizeOfT()); | ||||
|         else { | ||||
|             NDArray::prepareSpecialUse({this}, {&other}); | ||||
| @ -1569,20 +1569,25 @@ Nd4jLong NDArray::tensorsAlongDimension(const std::vector<int>& dimensions) cons | ||||
| 
 | ||||
| //////////////////////////////////////////////////////////////////////////
 | ||||
| void NDArray::printShapeInfo(const char * msg) const { | ||||
|     //shape::printShapeInfo(_shapeInfo);
 | ||||
|     if (msg == nullptr) | ||||
|         shape::printShapeInfoLinear(_shapeInfo); | ||||
|     else { | ||||
|         int rank = shape::rank(_shapeInfo); | ||||
|         int lim = shape::shapeInfoLength(rank); | ||||
|         printf("%s: [", msg); | ||||
|         for (int i = 0; i < shape::shapeInfoLength(rank); i++) { | ||||
|             printf("%lld", (long long) _shapeInfo[i]); | ||||
|             if (i < lim - 1) | ||||
|                 printf(", "); | ||||
|         } | ||||
|         printf("]\n"); | ||||
| 
 | ||||
|     int rank = shape::rank(_shapeInfo); | ||||
|     int lim = shape::shapeInfoLength(rank); | ||||
| 
 | ||||
|     if(msg != nullptr) | ||||
|         printf("shapeInfo %s: [", msg); | ||||
|     else | ||||
|         printf("shapeInfo: ["); | ||||
| 
 | ||||
|     printf("%i,  ", rank); | ||||
|     for (int i = 1; i < shape::shapeInfoLength(rank) - 3; i++){ | ||||
|         if(i == rank + 1) | ||||
|             printf("  "); | ||||
|         printf("%lld,", _shapeInfo[i]); | ||||
|     } | ||||
|     printf("  %lld,", shape::type(_shapeInfo)); | ||||
|     printf("%lld,", shape::elementWiseStride(_shapeInfo)); | ||||
|     printf("%lld]\n", (Nd4jLong)shape::order(_shapeInfo)); | ||||
| 
 | ||||
|     fflush(stdout); | ||||
| } | ||||
| 
 | ||||
| @ -1624,7 +1629,7 @@ void NDArray::printBuffer(const char* msg, Nd4jLong limit, const bool sync) cons | ||||
|             if (e < limit - 1) | ||||
|                 printf(", "); | ||||
|         } | ||||
|     }  | ||||
|     } | ||||
|     else if (this->isS()) { | ||||
|         // todo do we need this print offsets
 | ||||
|         /*
 | ||||
| @ -1773,7 +1778,7 @@ void NDArray::printIndexedBuffer(const char* msg, Nd4jLong limit) const { | ||||
|             printf("%s\n", this->e<bool>(0)?"true":"false"); | ||||
|         } | ||||
|         else if (this->isS()) { | ||||
|             // todo do we need this 
 | ||||
|             // todo do we need this
 | ||||
|             // printf("\"%lld\"\n", this->getOffset(e));
 | ||||
|             printf("\"%s\"\n", this->e<std::string>(0).c_str()); | ||||
|         } | ||||
| @ -1855,19 +1860,19 @@ void NDArray::updateStrides(const char order) { | ||||
| 
 | ||||
| //////////////////////////////////////////////////////////////////////////
 | ||||
| // set new order and shape in case of suitable array length
 | ||||
| bool NDArray::reshapei(const char order, const std::initializer_list<Nd4jLong>& shape) { | ||||
| bool NDArray::reshapei(const char order, const std::initializer_list<Nd4jLong>& shape, const bool copyToNewBuff) { | ||||
|     std::vector<Nd4jLong> vShape(shape); | ||||
|     return reshapei(order, vShape); | ||||
|     return reshapei(order, vShape, copyToNewBuff); | ||||
| } | ||||
| 
 | ||||
| //////////////////////////////////////////////////////////////////////////
 | ||||
| bool NDArray::reshapei(const std::initializer_list<Nd4jLong>& shape) { | ||||
|     return reshapei('c', shape); | ||||
| bool NDArray::reshapei(const std::initializer_list<Nd4jLong>& shape, const bool copyToNewBuff) { | ||||
|     return reshapei(ordering(), shape, copyToNewBuff); | ||||
| } | ||||
| 
 | ||||
| //////////////////////////////////////////////////////////////////////////
 | ||||
| bool NDArray::reshapei(const std::vector<Nd4jLong>& shape) { | ||||
|     return reshapei('c', shape); | ||||
| bool NDArray::reshapei(const std::vector<Nd4jLong>& shape, const bool copyToNewBuff) { | ||||
|     return reshapei(ordering(), shape, copyToNewBuff); | ||||
| } | ||||
| 
 | ||||
| //////////////////////////////////////////////////////////////////////////
 | ||||
| @ -1918,18 +1923,18 @@ Nd4jLong NDArray::argMax(std::initializer_list<int> dimensions) { | ||||
| 
 | ||||
| //////////////////////////////////////////////////////////////////////////
 | ||||
| // create new array with corresponding order and shape, new array will point to the same _buffer as this array
 | ||||
| NDArray NDArray::reshape(const char order, const std::vector<Nd4jLong>& shape) const & { | ||||
| NDArray NDArray::reshape(const char order, const std::vector<Nd4jLong>& shape, const bool copyToNewBuff) const & { | ||||
| 
 | ||||
|     NDArray newArr(getDataBuffer(), ShapeDescriptor(getShapeInfo()), getContext(), getBufferOffset()); | ||||
|     newArr.reshapei(order, shape); | ||||
|     newArr.reshapei(order, shape, copyToNewBuff); | ||||
| 
 | ||||
|     return newArr; | ||||
| } | ||||
| 
 | ||||
| //////////////////////////////////////////////////////////////////////////
 | ||||
| NDArray NDArray::reshape(const char order, const std::vector<Nd4jLong>& shape) && { | ||||
| NDArray NDArray::reshape(const char order, const std::vector<Nd4jLong>& shape, const bool copyToNewBuff) && { | ||||
| 
 | ||||
|     this->reshapei(order, shape); | ||||
|     this->reshapei(order, shape, copyToNewBuff); | ||||
|     return std::move(*this); | ||||
| } | ||||
| 
 | ||||
| @ -2280,7 +2285,7 @@ template <typename T> | ||||
| NDArray NDArray::asT() const{ | ||||
| 
 | ||||
|     auto result = isScalar() ? NDArray('c', {}, std::vector<double>{0.}, DataTypeUtils::fromT<T>(), this->getContext()) : NDArray(ordering(), getShapeAsVector(), DataTypeUtils::fromT<T>(), this->getContext()); | ||||
|      | ||||
| 
 | ||||
|     NDArray::prepareSpecialUse({&result}, {this}); | ||||
|     NativeOpExecutioner::execTransformAny(getContext(), transform::AnyOps::Assign, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.getBuffer(), result.getShapeInfo(), result.getSpecialBuffer(), result.getSpecialShapeInfo(), nullptr, nullptr, nullptr); | ||||
|     NDArray::registerSpecialUse({&result}, {this}); | ||||
| @ -2298,15 +2303,15 @@ NDArray NDArray::asS() const { | ||||
| 
 | ||||
|     auto dtype = DataTypeUtils::fromT<T>(); | ||||
| 
 | ||||
|     if (!(DataTypeUtils::isS(dtype)))  | ||||
|     if (!(DataTypeUtils::isS(dtype))) | ||||
|         throw std::invalid_argument("NDArray::asS: invalid DataType used"); | ||||
|      | ||||
| 
 | ||||
|     if (dtype == dataType()) { | ||||
|          | ||||
| 
 | ||||
|         Nd4jLong offsetsLength = ShapeUtils::stringBufferHeaderRequirements(lengthOf()); | ||||
|         const auto nInputoffsets = bufferAsT<Nd4jLong>(); | ||||
|         std::shared_ptr<DataBuffer> pBuffer = std::make_shared<DataBuffer>(offsetsLength + nInputoffsets[lengthOf()], dtype, getContext()->getWorkspace(), true); | ||||
|          | ||||
| 
 | ||||
|         NDArray res(pBuffer, ShapeDescriptor(dtype, ordering(), getShapeAsVector()), getContext()); | ||||
|         res.setAttached(getContext()->getWorkspace() != nullptr); | ||||
| 
 | ||||
| @ -2319,7 +2324,7 @@ NDArray NDArray::asS() const { | ||||
|         registerPrimaryUse({ &res }, { this }); | ||||
|         return res; | ||||
|     } | ||||
|   | ||||
| 
 | ||||
|     Nd4jLong offsetsLength = ShapeUtils::stringBufferHeaderRequirements(lengthOf()); | ||||
| 
 | ||||
|     std::vector<Nd4jLong> offsets(lengthOf() + 1); | ||||
| @ -2353,7 +2358,7 @@ NDArray NDArray::asS() const { | ||||
| 
 | ||||
|     NDArray res(pBuffer, ShapeDescriptor(dtype, ordering(), getShapeAsVector()), getContext()); | ||||
|     res.setAttached(getContext()->getWorkspace() != nullptr); | ||||
|      | ||||
| 
 | ||||
|     preparePrimaryUse({ &res }, { this }); | ||||
| 
 | ||||
|     memcpy(res.bufferAsT<int8_t>(), offsets.data(), offsets.size() * sizeof(Nd4jLong)); | ||||
| @ -2403,7 +2408,7 @@ BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT NDArray NDArray::asS, () const, LIBND | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////////
 | ||||
| NDArray NDArray::asT(DataType dtype) const { | ||||
|      | ||||
| 
 | ||||
|     if (isS() && !DataTypeUtils::isS(dtype)) | ||||
|         throw std::runtime_error("NDArray::asT: you can't use this method on String array with not string DataType!"); | ||||
| 
 | ||||
| @ -3221,7 +3226,7 @@ BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT std::vector, NDArray::asVectorT(), LI | ||||
| 
 | ||||
| //////////////////////////////////////////////////////////////////////////
 | ||||
| // set new order and shape in case of suitable array length
 | ||||
| bool NDArray::reshapei(const char order, const std::vector<Nd4jLong>& cshape) { | ||||
| bool NDArray::reshapei(const char order, const std::vector<Nd4jLong>& cshape, const bool copyToNewBuff) { | ||||
| 
 | ||||
|     // check firstly whether cshape is identical to shape of array, if yes then reshape is unnecessary
 | ||||
|     if(order == ordering() && shape::shapeEquals(rankOf(), shapeOf(), cshape.size(), cshape.data())) | ||||
| @ -3293,19 +3298,15 @@ bool NDArray::reshapei(const char order, const std::vector<Nd4jLong>& cshape) { | ||||
|     Nd4jLong *shapeInfoNew; | ||||
|     ALLOCATE(shapeInfoNew, getContext()->getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); | ||||
| 
 | ||||
|     bool canReshape = shape::reshapeC(rankOf(), shapeInfo(), shape.size(), shape.data(), shapeInfoNew); | ||||
|     bool canReshape = shape::reshapeC(shapeInfo(), order, shape.size(), shape.data(), shapeInfoNew); | ||||
| 
 | ||||
|     // we can do this only if there was no permute applied, or there are no weird strides
 | ||||
|     if (canReshape) { | ||||
|         if(ordering() == 'c' && order == 'f') | ||||
|             throw std::invalid_argument("NDArray::reshapei(order, shape): in case of reshapeC it doesn't make sense to reshape from c order to f order !"); | ||||
| 
 | ||||
|         shape::setEws(shapeInfoNew, arrLength); | ||||
|         setShapeInfo(shapeInfoNew); | ||||
|     } | ||||
|     else { | ||||
|         NDArray temp(order, shape, dataType(), getContext()); | ||||
|         this->applyTransform(transform::Assign, temp, nullptr); | ||||
|         if(copyToNewBuff) | ||||
|             this->applyTransform(transform::Assign, temp, nullptr); | ||||
|         *this = std::move(temp); | ||||
|     } | ||||
| 
 | ||||
| @ -3463,7 +3464,7 @@ NDArray NDArray::dup(const char newOrder) const { | ||||
|     if (isS()) { | ||||
|         if (dataType() == DataType::UTF8) { | ||||
|             std::vector<std::string> strings(lengthOf()); | ||||
|              | ||||
| 
 | ||||
|             auto func = PRAGMA_THREADS_FOR{ | ||||
|                     for (auto i = start; i < stop; i += increment) { | ||||
|                            strings[i] = std::move(this->e<std::string>(i)); | ||||
| @ -3521,7 +3522,7 @@ bool NDArray::equalsTo(const NDArray *other, double eps) const { | ||||
| 
 | ||||
|     if (isS()) { | ||||
|         // string is special case, we'll compare them one by one, considering both arrays are guaranteed to have the same length
 | ||||
|          | ||||
| 
 | ||||
|         if (dataType() == DataType::UTF8) { | ||||
|             for (int e = 0; e < this->lengthOf(); e++) { | ||||
|                 auto s1 = this->e<std::string>(e); | ||||
| @ -3585,7 +3586,7 @@ std::string NDArray::e(const Nd4jLong i) const { | ||||
|     if (i == lengthOf()) | ||||
|         throw std::runtime_error("Can't get std::string for index out of range"); | ||||
| 
 | ||||
|      | ||||
| 
 | ||||
|     if (this->dataType() == DataType::UTF16) { | ||||
|         auto u16 = this->e<std::u16string>(i); | ||||
|         std::string s; | ||||
| @ -4846,7 +4847,7 @@ NDArray NDArray::operator()(const std::vector<Nd4jLong>& idx, const bool keepUni | ||||
|     auto shapeOf = shape::shapeOf(newShapeInfo); | ||||
|     auto stridesOf = shape::stride(newShapeInfo); | ||||
| 
 | ||||
|     Nd4jLong offset(0), subArrLen(1); | ||||
|     Nd4jLong offset = 0; | ||||
|     int n(isStrided ? 3 : 2), first, last, stride; | ||||
| 
 | ||||
|     for (int d = rank - 1; d >= 0; --d) { | ||||
| @ -4863,29 +4864,31 @@ NDArray NDArray::operator()(const std::vector<Nd4jLong>& idx, const bool keepUni | ||||
|             if(shapeOf[d] != 1) | ||||
|                 stridesOf[d] *= stride; | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|         subArrLen *= shapeOf[d]; | ||||
|     Nd4jLong *shapeInfoNoUnities = newShapeInfo; | ||||
| 
 | ||||
|     if(!keepUnitiesInShape) { | ||||
| 
 | ||||
|         std::vector<int> dimsWithUnities; | ||||
| 
 | ||||
|         for (uint d = 0; d < rank; ++d) | ||||
|             if(idx[n*d] != idx[n*d+1] && shapeOf[d] == 1) | ||||
|                 dimsWithUnities.push_back(d); | ||||
| 
 | ||||
|         if(!dimsWithUnities.empty()) | ||||
|             shapeInfoNoUnities = ShapeBuilders::copyShapeInfoWithoutUnites(newShapeInfo, dimsWithUnities.size(), dimsWithUnities.data(), getContext()->getWorkspace()); | ||||
|     } | ||||
| 
 | ||||
|     // check if there is possibility to set ews = 1
 | ||||
|     shape::setEws(newShapeInfo, subArrLen); | ||||
|     shape::checkStridesSetEwsAndOrder(shapeInfoNoUnities); | ||||
| 
 | ||||
|     NDArray result(_buffer, ShapeDescriptor(newShapeInfo), getContext(), offset + getBufferOffset()); | ||||
|     NDArray result(_buffer, ShapeDescriptor(shapeInfoNoUnities), getContext(), offset + getBufferOffset()); | ||||
|     result._isView = true; | ||||
| 
 | ||||
|     if(!keepUnitiesInShape) { | ||||
|         const int coeff = isStrided ? 3 : 2; | ||||
|         std::vector<Nd4jLong> nonUnitDims; | ||||
| 
 | ||||
|         for (int d = 0; d < rank; ++d) | ||||
|             if(!(idx[coeff*d] != idx[coeff*d+1] && newShapeInfo[d+1] == 1)) | ||||
|                 nonUnitDims.push_back(newShapeInfo[d+1]); | ||||
| 
 | ||||
|         if(nonUnitDims.size() != rank) | ||||
|             result.reshapei(nonUnitDims); | ||||
|     } | ||||
| 
 | ||||
|     RELEASE(newShapeInfo, getContext()->getWorkspace()); | ||||
|     if(newShapeInfo != shapeInfoNoUnities) | ||||
|         RELEASE(shapeInfoNoUnities, getContext()->getWorkspace()); | ||||
| 
 | ||||
|     return result; | ||||
| } | ||||
|  | ||||
| @ -30,15 +30,15 @@ | ||||
| 
 | ||||
| namespace nd4j { | ||||
|     class ND4J_EXPORT ShapeBuilders { | ||||
|     public:         | ||||
|     public: | ||||
|         static Nd4jLong* createScalarShapeInfo(nd4j::DataType dataType, nd4j::memory::Workspace* workspace = nullptr); | ||||
|          | ||||
| 
 | ||||
|         static Nd4jLong* createVectorShapeInfo(const nd4j::DataType dataType, const Nd4jLong length, nd4j::memory::Workspace* workspace = nullptr); | ||||
| 
 | ||||
|         /**
 | ||||
|         *   create shapeInfo for given order basing on shape stored in shapeOnly vector | ||||
|         *   memory allocation for shapeInfo is on given workspace | ||||
|         */         | ||||
|         */ | ||||
|         static Nd4jLong* createShapeInfo(const nd4j::DataType dataType, const char order, int rank, const Nd4jLong* shapeOnly, memory::Workspace* workspace = nullptr); | ||||
|         static Nd4jLong* createShapeInfo(const nd4j::DataType dataType, const char order, const std::vector<Nd4jLong>& shapeOnly, memory::Workspace* workspace = nullptr); | ||||
|         static Nd4jLong* createShapeInfo(const nd4j::DataType dataType, const char order, const std::initializer_list<Nd4jLong>& shapeOnly, memory::Workspace* workspace = nullptr); | ||||
| @ -51,6 +51,13 @@ namespace nd4j { | ||||
|         static Nd4jLong* copyShapeInfoAndType(const Nd4jLong* inShapeInfo, const DataType dtype, const bool copyStrides, memory::Workspace* workspace = nullptr); | ||||
|         static Nd4jLong* copyShapeInfoAndType(const Nd4jLong* inShapeInfo, const Nd4jLong* shapeInfoToGetTypeFrom, const bool copyStrides, memory::Workspace* workspace = nullptr); | ||||
| 
 | ||||
|         /**
 | ||||
|         * allocates memory for new shapeInfo and copy all information from inShapeInfo to new shapeInfo except dimensions in dimsToExclude (unit dimensions) and corresponding strides | ||||
|         * for example inShapeInfo is {3, 2,1,3,1,4,  12,12,4,4,1, 16384,1,99}, dimsToExclude = {2,3}, dimsSize = 2 | ||||
|         * then outShapeInfo will contain {3, 2,3,4, 12,4,1, 16384,1,99} | ||||
|         */ | ||||
|         static Nd4jLong* copyShapeInfoWithoutUnites(const Nd4jLong* inShapeInfo, const int dimsSize, const int* dimsToExclude, memory::Workspace* workspace = nullptr); | ||||
| 
 | ||||
|         static Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType, memory::Workspace* workspace = nullptr); | ||||
| 
 | ||||
|         static Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType, const char order, const std::vector<Nd4jLong> &shape, memory::Workspace* workspace = nullptr); | ||||
|  | ||||
| @ -68,7 +68,7 @@ namespace nd4j { | ||||
|             const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(shapeInfo, dimsToExclude); | ||||
|             const int subArrRank = (rank == dimsToExclude.size() || descriptor.areUnitiesinShape()) ? rank : rank - dimsToExclude.size(); | ||||
| 
 | ||||
|             auto sPtr = new Nd4jLong[shape::shapeInfoLength(subArrRank)]; | ||||
|             auto sPtr = new Nd4jLong[shape::shapeInfoLength(subArrRank)];   // shape of sub-arrays (same for all for them)
 | ||||
|             auto oPtr = new Nd4jLong[numOfSubArrs]; | ||||
| 
 | ||||
|             if (numOfSubArrs > 0) | ||||
|  | ||||
| @ -43,23 +43,30 @@ nd4j::NDArray* nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::N | ||||
| 
 | ||||
|     auto outShape = ShapeUtils::evalShapeForTensorDot(a, b, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt); | ||||
| 
 | ||||
|     NDArray aPR = a->permute(permutAt); | ||||
|     NDArray bPR = b->permute(permutBt); | ||||
|     // check whether permutation is necessary
 | ||||
|     const NDArray* aP = permutAt.empty() ? a : new NDArray(a->permute(permutAt)); | ||||
|     const NDArray* bP = permutBt.empty() ? b : new NDArray(b->permute(permutBt)); | ||||
| 
 | ||||
|     // check whether reshape is necessary
 | ||||
|     if(!aPR.isSameShape(shapeAt)) | ||||
|         aPR.reshapei( shapeAt); | ||||
|     if(!bPR.isSameShape(shapeBt)) | ||||
|         bPR.reshapei( shapeBt); | ||||
|     const NDArray* aPR = aP->isSameShape(shapeAt) ? aP : new NDArray(aP->reshape(aP->ordering(), shapeAt)); | ||||
|     const NDArray* bPR = bP->isSameShape(shapeAt) ? bP : new NDArray(bP->reshape(bP->ordering(), shapeBt)); | ||||
| 
 | ||||
|     NDArray* c = mmul(&aPR, &bPR, nullptr, 1.0, 0.0); | ||||
|     NDArray* c = mmul(aPR, bPR, nullptr, 1.0, 0.0); | ||||
| 
 | ||||
|     c->reshapei(outShape); | ||||
| 
 | ||||
|     if(aP != aPR) | ||||
|         delete aPR; | ||||
|     if(bP != bPR) | ||||
|         delete bPR; | ||||
|     if(a != aP) | ||||
|         delete aP; | ||||
|     if(b != bP) | ||||
|         delete bP; | ||||
| 
 | ||||
|     return c; | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| //////////////////////////////////////////////////////////////////////////
 | ||||
| void nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::NDArray* b, nd4j::NDArray* c, const std::vector<int>& axes_a, const std::vector<int>& axes_b, const std::vector<int>& permutForC) { | ||||
| 
 | ||||
| @ -67,32 +74,38 @@ void nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::NDArray* b, | ||||
|     std::vector<Nd4jLong> shapeAt, shapeBt; | ||||
|     ShapeUtils::evalShapeForTensorDot(a, b, axes_a, axes_b, permutAt, permutBt, shapeAt, shapeBt); | ||||
| 
 | ||||
|     NDArray *cP(c), *cPR(c); | ||||
| 
 | ||||
|     // check whether permutation is required
 | ||||
|     if(!permutForC.empty()) | ||||
|         cP = new NDArray(c->permute(permutForC)); | ||||
|     NDArray* cP = permutForC.empty() ? c : new NDArray(c->permute(permutForC)); | ||||
| 
 | ||||
|     auto aPR = a->permute(permutAt); | ||||
|     auto bPR = b->permute(permutBt); | ||||
|     // check whether permutation is necessary
 | ||||
|     const NDArray* aP = permutAt.empty() ? a : new NDArray(a->permute(permutAt)); | ||||
|     const NDArray* bP = permutBt.empty() ? b : new NDArray(b->permute(permutBt)); | ||||
| 
 | ||||
|     // check whether reshape is necessary
 | ||||
|     if(!aPR.isSameShape(shapeAt)) | ||||
|             aPR.reshapei(shapeAt); | ||||
|     if(!bPR.isSameShape(shapeBt)) | ||||
|             bPR.reshapei(shapeBt); | ||||
|     const NDArray* aPR = aP->isSameShape(shapeAt) ? aP : new NDArray(aP->reshape(aP->ordering(), shapeAt)); | ||||
|     const NDArray* bPR = bP->isSameShape(shapeAt) ? bP : new NDArray(bP->reshape(bP->ordering(), shapeBt)); | ||||
| 
 | ||||
|     if(!cP->isSameShape({aPR.sizeAt(0), bPR.sizeAt(1)})) | ||||
|         cPR = new NDArray(cP->reshape(cP->ordering(), {aPR.sizeAt(0), bPR.sizeAt(1)})); | ||||
|     std::vector<Nd4jLong> requiredCshape = {aPR->sizeAt(0), bPR->sizeAt(1)}; | ||||
| 
 | ||||
|     mmul(&aPR, &bPR, cPR, 1.0, 0.0); | ||||
|     NDArray* cPR = cP->isSameShape(requiredCshape) ? cP : new NDArray(cP->reshape(cP->ordering(), requiredCshape, false)); | ||||
| 
 | ||||
|     mmul(aPR, bPR, cPR, 1.0, 0.0); | ||||
| 
 | ||||
|     if(cPR->getBuffer() != cP->getBuffer() || cPR->getSpecialBuffer() != cP->getSpecialBuffer() )   // this means both permute and reshape have been performed on c, cP always points on c->getBuffer()
 | ||||
|         cP->assign(cPR); | ||||
| 
 | ||||
|     if(cPR != c) | ||||
|    if(aP != aPR) | ||||
|         delete aPR; | ||||
|     if(bP != bPR) | ||||
|         delete bPR; | ||||
|     if(a != aP) | ||||
|         delete aP; | ||||
|     if(b != bP) | ||||
|         delete bP; | ||||
| 
 | ||||
|     if(cP != cPR) | ||||
|         delete cPR; | ||||
|     if(cP != c) | ||||
|     if(c != cP) | ||||
|         delete cP; | ||||
| } | ||||
| 
 | ||||
| @ -129,7 +142,7 @@ void nd4j::MmulHelper::tensorDot(const NDArray* a, const NDArray* b, NDArray* c, | ||||
|     if(!whatToDoWithC.empty()) { | ||||
|         cArrs = std::vector<NDArray*>(whatToDoWithC.size()+1, c); | ||||
|         for(int i = 0; i < cArrs.size()-1; ++i) | ||||
|             cArrs[i+1] = (whatToDoWithC[i] == 'p') ? new NDArray(cArrs[i]->permute(modifC[i])) : new NDArray(cArrs[i]->reshape(c->ordering(), modifC[i]));  // since we ignore first element in cArrs (that is cArrs[0]) then it is always equal to c
 | ||||
|             cArrs[i+1] = (whatToDoWithC[i] == 'p') ? new NDArray(cArrs[i]->permute(modifC[i])) : new NDArray(cArrs[i]->reshape(c->ordering(), modifC[i], false));  // since we ignore first element in cArrs (that is cArrs[0]) then it is always equal to c
 | ||||
|     } | ||||
| 
 | ||||
|     mmul(aPR, bPR, cArrs[cArrs.size()-1], 1.0, 0.0); | ||||
| @ -208,7 +221,7 @@ nd4j::NDArray* MmulHelper::mmul(const nd4j::NDArray* A, const nd4j::NDArray* B, | ||||
|     // vector x matrix, A{M} x B{M,N} = C{N} -> reduce to matrix x matrix A2{1,M} x B{M,N} = C2{1,N}, since there is no corresponding blas operation sgevm
 | ||||
|     if(isAVector && bRank == 2) { | ||||
|         NDArray* A2 = new NDArray(A->reshape(A->ordering(), {1, A->lengthOf()}));               // A{M} -> A2{1,M}
 | ||||
|         NDArray* C2 = C ? new NDArray(C->reshape(C->ordering(), {1, C->lengthOf()})) : nullptr; // C{N} -> C2{1,N}
 | ||||
|         NDArray* C2 = C ? new NDArray(C->reshape(C->ordering(), {1, C->lengthOf()}, false)) : nullptr; // C{N} -> C2{1,N}
 | ||||
|         auto result = mmulMxM(A2, B, C2, alpha, beta, outOrder);                                // result{1,N}
 | ||||
|         delete A2; | ||||
|         delete C2; | ||||
|  | ||||
| @ -139,5 +139,15 @@ namespace nd4j { | ||||
|         return ShapeBuilders::copyShapeInfoAndType(inShapeInfo, ArrayOptions::dataType(shapeInfoToGetTypeFrom), copyStrides, workspace); | ||||
|     } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////////////////
 | ||||
| Nd4jLong* ShapeBuilders::copyShapeInfoWithoutUnites(const Nd4jLong* inShapeInfo, const int dimsSize, const int* dimsToExclude, memory::Workspace* workspace) { | ||||
| 
 | ||||
|     Nd4jLong *outShapeInfo = nullptr; | ||||
|     ALLOCATE(outShapeInfo, workspace, shape::shapeInfoLength(inShapeInfo[0] - dimsSize), Nd4jLong); | ||||
| 
 | ||||
|     shape::excludeUnitiesFromShapeInfo(inShapeInfo, dimsSize, dimsToExclude, outShapeInfo); | ||||
| 
 | ||||
|     return outShapeInfo; | ||||
| } | ||||
| 
 | ||||
| } | ||||
| @ -75,10 +75,23 @@ std::vector<Nd4jLong> ShapeUtils::evalShapeForTensorDot(const Nd4jLong* aShapeIn | ||||
|     permutBt = axesB; | ||||
|     permutBt.insert(permutBt.end(), list_B.begin(), list_B.end()); | ||||
| 
 | ||||
|     // if permut contains something like {0,1,2,..rank-1}, then there is no need to make permutation and we return empty vector in this case
 | ||||
|     uint i1, i2; | ||||
|     for(i1 = 0; i1 < aRank; ++i1) | ||||
|         if(permutAt[i1] != i1) | ||||
|             break; | ||||
|     if(i1 == aRank) | ||||
|         permutAt = {}; | ||||
|     for(i2 = 0; i2 < bRank; ++i2) | ||||
|         if(permutBt[i2] != i2) | ||||
|             break; | ||||
|     if(i2 == bRank) | ||||
|         permutBt = {}; | ||||
| 
 | ||||
|     Nd4jLong n2 = 1; | ||||
|     for (int i = 0; i < axeAsize; i++) | ||||
|         n2 *= aShapeInfo[axesA[i] + 1]; | ||||
|     shapeAt = {-1, n2}; | ||||
|     shapeAt = {shape::length(aShapeInfo) / n2, n2}; | ||||
| 
 | ||||
|     std::vector<Nd4jLong> oldShapeA; | ||||
|     oldShapeA.resize(list_A.size()); | ||||
| @ -89,7 +102,7 @@ std::vector<Nd4jLong> ShapeUtils::evalShapeForTensorDot(const Nd4jLong* aShapeIn | ||||
|     Nd4jLong n3 = 1; | ||||
|     for (int i = 0; i < axeBsize; i++) | ||||
|         n3 *= bShapeInfo[axesB[i] + 1]; | ||||
|     shapeBt = {n3, -1}; | ||||
|     shapeBt = {n3, shape::length(bShapeInfo) / n3}; | ||||
| 
 | ||||
|     std::vector<Nd4jLong> oldShapeB; | ||||
|     oldShapeB.resize(list_B.size()); | ||||
| @ -306,10 +319,10 @@ std::vector<Nd4jLong> ShapeUtils::evalRepeatShape(int axis, const std::vector<in | ||||
|     Nd4jLong* ShapeUtils::evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace) { | ||||
| 
 | ||||
|         if (!arr.nonNull()) | ||||
|             throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments in pn/termute method: either array is nullptr!"); | ||||
|             throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments: array is nullptr!"); | ||||
| 
 | ||||
|         if (rank != arr.rankOf()) | ||||
|             throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments in pn/termute method: rank is not suitable!"); | ||||
|             throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments: rank is not suitable!"); | ||||
| 
 | ||||
|         auto shapeInfoLength = shape::shapeInfoLength(rank); | ||||
|         // allocate memory for new array - shapeInfo
 | ||||
|  | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -84,7 +84,7 @@ namespace functions { | ||||
| 	    	auto tid = blockIdx.x * blockDim.x + threadIdx.x; | ||||
| 			int totalThreads = gridDim.x * blockDim.x; | ||||
| 
 | ||||
| 		    if(xEws > 0 && zEws > 0 && xOrder == zOrder) { | ||||
| 		    if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') { | ||||
| 
 | ||||
| 				for (int i = tid; i < length; i += totalThreads) | ||||
| 					z[i * zEws] = OpType::op(x[i * xEws], params); | ||||
|  | ||||
| @ -89,7 +89,7 @@ namespace functions { | ||||
| 	    	    auto tid = blockIdx.x * blockDim.x + threadIdx.x; | ||||
| 				int totalThreads = gridDim.x * blockDim.x; | ||||
| 
 | ||||
| 		        if(xEws > 0 && zEws > 0 && xOrder == zOrder) { | ||||
| 		        if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') { | ||||
| 
 | ||||
| 					for (int i = tid; i < length; i += totalThreads) | ||||
| 						z[i * zEws] = OpType::op(x[i * xEws], params); | ||||
|  | ||||
| @ -97,7 +97,7 @@ namespace functions { | ||||
| 	    	    auto tid = blockIdx.x * blockDim.x + threadIdx.x; | ||||
| 				int totalThreads = gridDim.x * blockDim.x; | ||||
| 
 | ||||
| 		        if(xEws > 0 && zEws > 0 && xOrder == zOrder) { | ||||
| 		        if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') { | ||||
| 
 | ||||
| 					for (Nd4jLong i = tid; i < length; i += totalThreads) | ||||
|                         z[i * zEws] = OpType::op(x[i * xEws], params); | ||||
|  | ||||
| @ -87,7 +87,7 @@ namespace functions { | ||||
| 	    	    auto tid = blockIdx.x * blockDim.x + threadIdx.x; | ||||
| 				int totalThreads = gridDim.x * blockDim.x; | ||||
| 
 | ||||
| 		        if(xEws > 0 && zEws > 0 && xOrder == zOrder) { | ||||
| 		        if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') { | ||||
| 
 | ||||
| 					for (int i = tid; i < length; i += totalThreads) | ||||
| 						z[i * zEws] = OpType::op(x[i * xEws], params); | ||||
|  | ||||
| @ -89,7 +89,7 @@ namespace functions { | ||||
| 	    	    auto tid = blockIdx.x * blockDim.x + threadIdx.x; | ||||
| 				int totalThreads = gridDim.x * blockDim.x; | ||||
| 
 | ||||
| 		        if(xEws > 0 && zEws > 0 && xOrder == zOrder) { | ||||
| 		        if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') { | ||||
| 
 | ||||
| 					for (int i = tid; i < length; i += totalThreads) | ||||
| 						z[i * zEws] = OpType::op(x[i * xEws], params); | ||||
|  | ||||
| @ -21,70 +21,174 @@ | ||||
| #include <op_boilerplate.h> | ||||
| #if NOT_EXCLUDED(OP_tensormmul) | ||||
| 
 | ||||
| #include <numeric> | ||||
| #include <helpers/ShapeUtils.h> | ||||
| #include <ops/declarable/CustomOperations.h> | ||||
| #include <MmulHelper.h> | ||||
| 
 | ||||
| 
 | ||||
| namespace nd4j { | ||||
|     namespace ops { | ||||
|         CUSTOM_OP_IMPL(tensormmul, 2, 1, false, 0, -1) { | ||||
|             auto a = INPUT_VARIABLE(0); | ||||
|             auto b = INPUT_VARIABLE(1); | ||||
| namespace ops  { | ||||
| 
 | ||||
|             auto c = OUTPUT_VARIABLE(0); //
 | ||||
| ////////////////////////////////////////////////////////////////////////
 | ||||
| CUSTOM_OP_IMPL(tensormmul, 2, 1, false, 0, -1) { | ||||
| 
 | ||||
|             REQUIRE_TRUE(a->dataType() == b->dataType(), 0, "tensormmul: A, B and C data types must be the same"); | ||||
|     auto a = INPUT_VARIABLE(0); | ||||
|     auto b = INPUT_VARIABLE(1); | ||||
| 
 | ||||
|             // building axes
 | ||||
|             int axe0_size = INT_ARG(0); | ||||
|             int axe1_size = INT_ARG(axe0_size+1); | ||||
|             std::vector<int> axes_0(axe0_size), axes_1(axe1_size); | ||||
|             for (int e = 0; e < axe0_size; e++) | ||||
|                 axes_0[e] = (int) INT_ARG(e+1); | ||||
|     auto c = OUTPUT_VARIABLE(0); | ||||
| 
 | ||||
|             for (int e = 0; e < axe1_size; e++) | ||||
|                 axes_1[e] = (int) INT_ARG(e + axe0_size + 2); | ||||
|     REQUIRE_TRUE(a->dataType() == b->dataType(), 0, "tensormmul: A, B and C data types must be the same"); | ||||
| 
 | ||||
|             nd4j_verbose("axe0: %i; axe1: %i;\n", axes_0.size(), axes_1.size()); | ||||
|     // building axes
 | ||||
|     int axe0_size = INT_ARG(0); | ||||
|     int axe1_size = INT_ARG(axe0_size+1); | ||||
|     std::vector<int> axes_0(axe0_size), axes_1(axe1_size); | ||||
|     for (int e = 0; e < axe0_size; e++) | ||||
|         axes_0[e] = (int)INT_ARG(e + 1); | ||||
| 
 | ||||
|             MmulHelper::tensorDot(a, b, c, axes_0, axes_1); | ||||
|             return Status::OK(); | ||||
|         } | ||||
|         DECLARE_SYN(tensordot, tensormmul); | ||||
|     for (int e = 0; e < axe1_size; e++) | ||||
|         axes_1[e] = (int)INT_ARG(e + axe0_size + 2); | ||||
| 
 | ||||
|     nd4j_verbose("axe0: %i; axe1: %i;\n", axes_0.size(), axes_1.size()); | ||||
| 
 | ||||
|         DECLARE_SHAPE_FN(tensormmul) {                | ||||
|          | ||||
|             auto aShapeInfo = inputShape->at(0); | ||||
|             auto bShapeInfo = inputShape->at(1); | ||||
|     MmulHelper::tensorDot(a, b, c, axes_0, axes_1); | ||||
|     return Status::OK(); | ||||
| } | ||||
| DECLARE_SYN(tensordot, tensormmul); | ||||
| 
 | ||||
|             REQUIRE_TRUE(ArrayOptions::dataType(aShapeInfo) == ArrayOptions::dataType(bShapeInfo), 0, "tensormmul: A and B data types must be the same"); | ||||
| ////////////////////////////////////////////////////////////////////////
 | ||||
| DECLARE_SHAPE_FN(tensormmul) { | ||||
| 
 | ||||
|             // building axes
 | ||||
|             int axe0_size = INT_ARG(0); | ||||
|             int axe1_size = INT_ARG(axe0_size+1); | ||||
|             std::vector<int> axes_0(axe0_size), axes_1(axe1_size); | ||||
|             for (int e = 0; e < axe0_size; e++) | ||||
|                 axes_0[e] = (int) INT_ARG(e+1); | ||||
|     auto aShapeInfo = inputShape->at(0); | ||||
|     auto bShapeInfo = inputShape->at(1); | ||||
| 
 | ||||
|             for (int e = 0; e < axe1_size; e++) | ||||
|                 axes_1[e] = (int) INT_ARG(e + axe0_size + 2); | ||||
|     REQUIRE_TRUE(ArrayOptions::dataType(aShapeInfo) == ArrayOptions::dataType(bShapeInfo), 0, "tensormmul: A and B data types must be the same"); | ||||
| 
 | ||||
|             // evaluate shapes 
 | ||||
|             std::vector<int> permutAt, permutBt; | ||||
|             std::vector<Nd4jLong> shapeAt, shapeBt; | ||||
|             auto outShape = nd4j::ShapeUtils::evalShapeForTensorDot(aShapeInfo, bShapeInfo, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt); | ||||
|     // building axes
 | ||||
|     int axe0_size = INT_ARG(0); | ||||
|     int axe1_size = INT_ARG(axe0_size+1); | ||||
|     std::vector<int> axes_0(axe0_size), axes_1(axe1_size); | ||||
|     for (int e = 0; e < axe0_size; e++) | ||||
|         axes_0[e] = (int) INT_ARG(e+1); | ||||
| 
 | ||||
|             return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(aShapeInfo), 'c', outShape))); | ||||
|         } | ||||
|     for (int e = 0; e < axe1_size; e++) | ||||
|         axes_1[e] = (int) INT_ARG(e + axe0_size + 2); | ||||
| 
 | ||||
|         DECLARE_TYPES(tensormmul) { | ||||
|             getOpDescriptor() | ||||
|                     ->setAllowedInputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) | ||||
|                     ->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) | ||||
|                     ->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); | ||||
|         } | ||||
|     // evaluate shapes
 | ||||
|     std::vector<int> permutAt, permutBt; | ||||
|     std::vector<Nd4jLong> shapeAt, shapeBt; | ||||
|     auto outShape = nd4j::ShapeUtils::evalShapeForTensorDot(aShapeInfo, bShapeInfo, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt); | ||||
| 
 | ||||
|     return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(aShapeInfo), 'c', outShape))); | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////////
 | ||||
| DECLARE_TYPES(tensormmul) { | ||||
|     getOpDescriptor() | ||||
|             ->setAllowedInputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) | ||||
|             ->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}) | ||||
|             ->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF}); | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////////
 | ||||
| CUSTOM_OP_IMPL(tensormmul_bp, 3, 2, false, 0, -1) { | ||||
| 
 | ||||
|     auto A = INPUT_VARIABLE(0); | ||||
|     auto B = INPUT_VARIABLE(1); | ||||
| 
 | ||||
|     auto dLdC = INPUT_VARIABLE(2); | ||||
| 
 | ||||
|     auto dLdA = OUTPUT_VARIABLE(0); | ||||
|     auto dLdB = OUTPUT_VARIABLE(1); | ||||
| 
 | ||||
|     REQUIRE_TRUE( (A->dataType() == B->dataType() && (dLdC->dataType() == A->dataType())), 0, "tensormmul_bp: A, B and dLdC data types must be the same"); | ||||
| 
 | ||||
|     int axe0Size = INT_ARG(0); | ||||
|     int axe1Size = INT_ARG(axe0Size + 1); | ||||
| 
 | ||||
|     auto Arank = A->rankOf(); | ||||
|     auto Brank = B->rankOf(); | ||||
|     auto dLdCrank = dLdC->rankOf(); | ||||
| 
 | ||||
|     REQUIRE_TRUE((Arank >= axe0Size), 0, "tensormmul_bp: A rank must be the higher or same as input axes 0"); | ||||
| 
 | ||||
|     REQUIRE_TRUE((Brank >= axe1Size), 0, "tensormmul_bp: B rank must be the higher or same as input axes 1"); | ||||
| 
 | ||||
|     // building axes
 | ||||
|     std::vector<int> axes0(axe0Size), axes1(axe1Size); | ||||
|     for (uint e = 0; e < axe0Size; e++) | ||||
|         axes0[e] = (int)INT_ARG(e + 1); | ||||
|     for (uint e = 0; e < axe1Size; e++) | ||||
|         axes1[e] = (int)INT_ARG(e + axe0Size + 2); | ||||
| 
 | ||||
|     std::vector<int> permutAt, permutBt; | ||||
|     std::vector<Nd4jLong> shapeAt, shapeBt; | ||||
| 
 | ||||
|     ShapeUtils::evalShapeForTensorDot(A, B, axes0, axes1, permutAt, permutBt, shapeAt, shapeBt); | ||||
| 
 | ||||
|     // special case for scalar value
 | ||||
|     if (dLdC->isScalar()) { | ||||
| 
 | ||||
|         dLdA->assign((*dLdC) * *B); | ||||
|         dLdB->assign((*dLdC) * *A); | ||||
| 
 | ||||
|         return Status::OK(); | ||||
|     } | ||||
| 
 | ||||
|     std::vector<int> axesA = ShapeUtils::evalDimsToExclude(Arank, axes0); | ||||
|     std::vector<int> axesB = ShapeUtils::evalDimsToExclude(Brank, axes1); | ||||
| 
 | ||||
|     // rank always have to be divided by 2
 | ||||
|     std::vector<int> axesAdLdC, axesBdLdC; | ||||
|     if (dLdCrank > 1) { | ||||
|         axesAdLdC.resize(dLdCrank / 2); | ||||
|         std::iota(axesAdLdC.begin(), axesAdLdC.end(), 0); | ||||
|         axesBdLdC = ShapeUtils::evalDimsToExclude(dLdCrank, axesAdLdC); | ||||
|     } | ||||
|     else { | ||||
|         axesAdLdC.push_back(0); | ||||
|         axesBdLdC.push_back(0); | ||||
|     } | ||||
| 
 | ||||
|     // calculate dLdA
 | ||||
|     MmulHelper::tensorDot(dLdC, B, dLdA, axesBdLdC, axesB, permutAt); | ||||
| 
 | ||||
|     // calculate dLdB
 | ||||
|     MmulHelper::tensorDot(A, dLdC, dLdB, axesA, axesAdLdC, permutBt); | ||||
| 
 | ||||
|     return Status::OK(); | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////////
 | ||||
| DECLARE_SHAPE_FN(tensormmul_bp) { | ||||
| 
 | ||||
|     auto aShapeInfo = inputShape->at(0); | ||||
|     auto bShapeInfo = inputShape->at(1); | ||||
|     auto dLShapeInfo = inputShape->at(2); | ||||
| 
 | ||||
|     REQUIRE_TRUE((ArrayOptions::dataType(aShapeInfo) == ArrayOptions::dataType(bShapeInfo) && | ||||
|                  (ArrayOptions::dataType(dLShapeInfo) == ArrayOptions::dataType(aShapeInfo))), 0, "tensormmul_bp: A, B and dLdC data types must be the same"); | ||||
| 
 | ||||
|     Nd4jLong* dLdAShapeInfo = nullptr; | ||||
|     Nd4jLong* dLdBShapeInfo = nullptr; | ||||
| 
 | ||||
|     COPY_SHAPE(aShapeInfo, dLdAShapeInfo); | ||||
|     COPY_SHAPE(bShapeInfo, dLdBShapeInfo); | ||||
| 
 | ||||
|     return SHAPELIST(CONSTANT(dLdAShapeInfo), CONSTANT(dLdBShapeInfo)); | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////////
 | ||||
| DECLARE_TYPES(tensormmul_bp) { | ||||
|     getOpDescriptor() | ||||
|         ->setAllowedInputTypes(0, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF }) // maybe better ALL_FLOATS
 | ||||
|         ->setAllowedInputTypes(1, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF }) | ||||
|         ->setAllowedInputTypes(2, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF }) | ||||
|         ->setAllowedOutputTypes(0, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF }) | ||||
|         ->setAllowedOutputTypes(1, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF }); | ||||
| } | ||||
| } | ||||
| } | ||||
| 
 | ||||
| #endif | ||||
| @ -79,7 +79,7 @@ CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 5) { | ||||
|     } | ||||
| 
 | ||||
|     auto inputReshaped   = input  ->reshape(input->ordering(),   reshapeForInput); | ||||
|     auto outputReshaped  = output ->reshape(output->ordering(),  reshapeForOutput); | ||||
|     auto outputReshaped  = output ->reshape(output->ordering(),  reshapeForOutput, false); | ||||
|     auto weightsReshaped = weights->reshape(weights->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)});   // [kW, iC, oC] -> [1, kW, iC, oC]
 | ||||
| 
 | ||||
|     nd4j::ops::conv2d conv2d; | ||||
| @ -216,10 +216,10 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) { | ||||
|     } | ||||
| 
 | ||||
|     auto inputReshaped   = input  ->reshape(input->ordering(),  reshapeForInput); | ||||
|     auto gradIReshaped   = gradI  ->reshape(gradI->ordering(),  reshapeForInput); | ||||
|     auto gradIReshaped   = gradI  ->reshape(gradI->ordering(),  reshapeForInput, false); | ||||
|     auto gradOReshaped   = gradO  ->reshape(gradO->ordering(),  reshapeForGradO); | ||||
|     auto weightsReshaped = weights->reshape(weights->ordering(),{1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)});    // [kW, iC, oC] -> [1, kW, iC, oC]
 | ||||
|     auto gradWReshaped   = gradW  ->reshape(gradW->ordering(),  {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)});    // [kW, iC, oC] -> [1, kW, iC, oC]
 | ||||
|     auto weightsReshaped = weights->reshape(weights->ordering(),{1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)});       // [kW, iC, oC] -> [1, kW, iC, oC]
 | ||||
|     auto gradWReshaped   = gradW  ->reshape(gradW->ordering(),  {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}, false);// [kW, iC, oC] -> [1, kW, iC, oC]
 | ||||
| 
 | ||||
|     nd4j::ops::conv2d_bp conv2dBP; | ||||
|     auto status = conv2dBP.execute({&inputReshaped, &weightsReshaped, bias, &gradOReshaped}, {&gradIReshaped, &gradWReshaped, gradB}, {}, {1,kW,  1,sW,  0,pW,  1,dW,  paddingMode,  !isNCW}, {}); | ||||
|  | ||||
| @ -239,7 +239,7 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) { | ||||
|     //----- calculation of gradO -----//
 | ||||
|     if(gradB) { | ||||
|         if(gradB->rankOf() == 2) | ||||
|             gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()})); | ||||
|             gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}, false)); | ||||
|         gradO->reduceAlongDimension(reduce::Sum, *gradB, gradOaxesForDot);                          // sum over bS oD oH oW
 | ||||
|         if(gradB != OUTPUT_VARIABLE(2)) | ||||
|             delete gradB; | ||||
|  | ||||
| @ -233,7 +233,7 @@ CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) { | ||||
|     // ----- calculation of gradB ----- //
 | ||||
|     if(gradB) { | ||||
|         if(gradB->rankOf() == 2) | ||||
|             gradB = new NDArray(gradB->reshape(gradB->ordering(), {gradB->lengthOf()})); | ||||
|             gradB = new NDArray(gradB->reshape(gradB->ordering(), {gradB->lengthOf()}, false)); | ||||
|         gradO->reduceAlongDimension(reduce::Sum, *gradB, {0, 2, 3});                                // sum over bS, oH, oW
 | ||||
|         if(gradB != OUTPUT_VARIABLE(2)) | ||||
|             delete gradB; | ||||
|  | ||||
| @ -243,7 +243,7 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) { | ||||
|     // ----- calculation of gradB ----- //
 | ||||
|     if(gradB) { | ||||
|         if(gradB->rankOf() == 2) | ||||
|             gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()})); | ||||
|             gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}, false)); | ||||
|         gradO->reduceAlongDimension(reduce::Sum, *gradB, {0, 2, 3, 4});                                // sum over bS, oD, oH, oW
 | ||||
|         if(gradB != OUTPUT_VARIABLE(2)) | ||||
|             delete gradB; | ||||
|  | ||||
| @ -61,13 +61,13 @@ namespace nd4j { | ||||
|             } | ||||
| 
 | ||||
|             auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)}); | ||||
|             auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}):output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}); | ||||
|             auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}, false) : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false); | ||||
| 
 | ||||
|             return helpers::resizeAreaFunctor(block.launchContext(), &source, width, height, alignCorners, &target); | ||||
|         } | ||||
| 
 | ||||
|         DECLARE_SHAPE_FN(resize_area) { | ||||
|             auto shapeList = SHAPELIST();  | ||||
|             auto shapeList = SHAPELIST(); | ||||
|             auto in = inputShape->at(0); | ||||
| 
 | ||||
|             Nd4jLong* outputShape; | ||||
| @ -90,7 +90,7 @@ namespace nd4j { | ||||
|             } | ||||
| 
 | ||||
|             REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_area: Source tensor should have rank 4, but %i given.", inRank); | ||||
|              | ||||
| 
 | ||||
|             ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(inRank), Nd4jLong); | ||||
|             outputShape[0] = inRank; | ||||
|             if (inRank == 4) { | ||||
|  | ||||
| @ -62,13 +62,13 @@ namespace nd4j { | ||||
|             REQUIRE_TRUE(!halfPixelAlign || (halfPixelAlign && !alignCorners), 0, "resize_bicubic: `half_pixel_centers' should be false or true only when `align_corners' is false"); | ||||
| 
 | ||||
|             auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)}); | ||||
|             auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}):output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}); | ||||
|             auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}, false) : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false); | ||||
| 
 | ||||
|             return helpers::resizeBicubicFunctorA(block.launchContext(), &source, width, height, alignCorners, halfPixelAlign, &target); | ||||
|         } | ||||
| 
 | ||||
|         DECLARE_SHAPE_FN(resize_bicubic) { | ||||
|             auto shapeList = SHAPELIST();  | ||||
|             auto shapeList = SHAPELIST(); | ||||
|             auto in = inputShape->at(0); | ||||
| 
 | ||||
|             Nd4jLong* outputShape; | ||||
| @ -82,7 +82,7 @@ namespace nd4j { | ||||
|             height = newImageSize->e<int>(1); | ||||
| 
 | ||||
|             REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_bicubic: Source tensor should have rank 4, but %i given.", inRank); | ||||
|              | ||||
| 
 | ||||
|             ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(inRank), Nd4jLong); | ||||
|             outputShape[0] = inRank; | ||||
|             if (inRank == 4) { | ||||
|  | ||||
| @ -43,7 +43,7 @@ namespace nd4j { | ||||
|             REQUIRE_TRUE(inRank == output->rankOf(), 0, "resize_bilinear: Input and output ranks should be equals, but %i and %i occured.", inRank, output->rankOf()); | ||||
| 
 | ||||
|             auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)}); | ||||
|             auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}):output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}); | ||||
|             auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}, false) : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false); | ||||
| 
 | ||||
|             if (block.width() > 1) { | ||||
|                 auto newImageSize = INPUT_VARIABLE(1); | ||||
| @ -71,7 +71,7 @@ namespace nd4j { | ||||
|         } | ||||
| 
 | ||||
|         DECLARE_SHAPE_FN(resize_bilinear) { | ||||
|             auto shapeList = SHAPELIST();  | ||||
|             auto shapeList = SHAPELIST(); | ||||
|             auto in = inputShape->at(0); | ||||
| 
 | ||||
|             Nd4jLong* outputShape; | ||||
| @ -94,7 +94,7 @@ namespace nd4j { | ||||
|                 width = INT_ARG(0); | ||||
|                 height = INT_ARG(1); | ||||
|             } | ||||
|              | ||||
| 
 | ||||
|             ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(inRank), Nd4jLong); | ||||
|             outputShape[0] = inRank; | ||||
|             if (inRank == 4) { | ||||
|  | ||||
| @ -63,13 +63,13 @@ namespace nd4j { | ||||
|             REQUIRE_TRUE(((alignCorners && height > 2) || (height > 0)) && ((alignCorners && width > 1) || (width > 0)), 0,  "resize_nearest_neighbor: Wrong input or output size to resize (width = %d, height = %d)", width, height); | ||||
| 
 | ||||
|             auto source = inRank == 4?*image:image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)}); | ||||
|             auto target = inRank == 4?*output:output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}); | ||||
|             auto target = inRank == 4 ? *output : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false); | ||||
| 
 | ||||
|             return helpers::resizeNeighborFunctor(block.launchContext(), inRank==4?image:&source, width, height, alignCorners, halfPixelCenter, inRank == 4 ? output : &target); | ||||
|         } | ||||
| 
 | ||||
|         DECLARE_SHAPE_FN(resize_nearest_neighbor) { | ||||
|             auto shapeList = SHAPELIST();  | ||||
|             auto shapeList = SHAPELIST(); | ||||
|             auto in = inputShape->at(0); | ||||
|             auto inRank = shape::rank(in); | ||||
|             Nd4jLong* outputShape; | ||||
|  | ||||
| @ -36,14 +36,14 @@ namespace nd4j { | ||||
|                     int _a = INT_ARG(e); | ||||
|                     if (_a < 0) | ||||
|                         _a += input->rankOf(); | ||||
|                          | ||||
| 
 | ||||
|                     axis.emplace_back(_a); | ||||
|                 } | ||||
|             else if (block.width() > 1) { | ||||
|                 auto a = INPUT_VARIABLE(1); | ||||
|                 for (Nd4jLong e = 0; e < a->lengthOf(); e++) { | ||||
|                     int _a = a->e<int>(e); | ||||
|                      | ||||
| 
 | ||||
|                     if (_a < 0) | ||||
|                         _a += input->rankOf(); | ||||
| 
 | ||||
| @ -71,7 +71,7 @@ namespace nd4j { | ||||
|             } | ||||
| 
 | ||||
|             if (block.isInplace()) { | ||||
|                 output->reshapei(input->ordering(), shape); | ||||
|                 output->reshapei(input->ordering(), shape, false); | ||||
|             } else { | ||||
|                 auto tmp = input->reshape(input->ordering(), shape); | ||||
|                 output->assign(tmp); | ||||
| @ -106,20 +106,20 @@ namespace nd4j { | ||||
|                     int _a = INT_ARG(e); | ||||
|                     if (_a < 0) | ||||
|                         _a += rank; | ||||
|                          | ||||
| 
 | ||||
|                     axis.emplace_back(_a); | ||||
|                 } | ||||
|             else if (block.width() > 1) { | ||||
|                 auto a = INPUT_VARIABLE(1); | ||||
|                 for (int e = 0; e < a->lengthOf(); e++) { | ||||
|                     int _a = a->e<int>(e); | ||||
|                      | ||||
| 
 | ||||
|                     if (_a < 0) | ||||
|                         _a += rank; | ||||
| 
 | ||||
|                     axis.emplace_back(_a); | ||||
|                 } | ||||
|                  | ||||
| 
 | ||||
|             } | ||||
| 
 | ||||
|             auto order = shape::order(in); | ||||
|  | ||||
| @ -57,7 +57,8 @@ namespace nd4j { | ||||
|          * IArgs[1]... axes values for second array | ||||
|          */ | ||||
|         #if NOT_EXCLUDED(OP_tensormmul) | ||||
|         DECLARE_CUSTOM_OP(tensormmul, 2, 1, false, 0, -1);    | ||||
|         DECLARE_CUSTOM_OP(tensormmul, 2, 1, false, 0, -1); | ||||
|         DECLARE_CUSTOM_OP(tensormmul_bp, 3, 2, false, 0, -1); | ||||
|         #endif | ||||
| 
 | ||||
|         /**
 | ||||
|  | ||||
| @ -432,7 +432,7 @@ namespace nd4j { | ||||
|                 ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); | ||||
| 
 | ||||
|             NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext()); | ||||
|             NDArray outputReshaped = output->reshape(output->ordering(), outReShape); | ||||
|             NDArray outputReshaped = output->reshape(output->ordering(), outReShape, false); | ||||
| 
 | ||||
|             helpers::im2col(*output->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext()));  // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
 | ||||
|             MmulHelper::tensorDot(&columns, weights, &outputReshaped, modifColumns, {{2,0,1,3},{iC,kH*kW,mC}}, modifOutput);              // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC]
 | ||||
| @ -505,7 +505,7 @@ namespace nd4j { | ||||
|             if(gradB) { | ||||
|                 NDArray* gradBR = gradB; | ||||
|                 if(gradB->rankOf() == 2) | ||||
|                     gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()})); | ||||
|                     gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}, false)); | ||||
|                 gradO->reduceAlongDimension(reduce::Sum, *gradBR, {0,indOoH,indOoH+1});                      // sum over bS, oH, oW
 | ||||
| 
 | ||||
|                 if(gradBR != gradB) | ||||
|  | ||||
| @ -30,7 +30,7 @@ namespace helpers { | ||||
| void crossBatched(nd4j::LaunchContext * context, NDArray *a, NDArray *b, NDArray *o) { | ||||
|     auto _a = a->reshape(a->ordering(), {-1, 3}); | ||||
|     auto _b = b->reshape(b->ordering(), {-1, 3}); | ||||
|     auto _o = o->reshape(o->ordering(), {-1, 3}); | ||||
|     auto _o = o->reshape(o->ordering(), {-1, 3}, false); | ||||
| 
 | ||||
|     auto tadsA = _a.allTensorsAlongDimension({1}); | ||||
|     auto tadsB = _b.allTensorsAlongDimension({1}); | ||||
|  | ||||
| @ -244,14 +244,14 @@ void spaceToBatch(nd4j::LaunchContext* context, const NDArray& input, NDArray& o | ||||
| 
 | ||||
|     // [bS, iH, iW, iC] is rearranged/permuted to [bS*blockSize*blockSize, (iH + padBottom + padTop)/blockSize, (iW + padLeft + padRight)/blockSize, iC]
 | ||||
| 
 | ||||
|     NDArray outputRearranged0 = output.reshape(output.ordering(), {blockSize, blockSize, input.sizeAt(0), output.sizeAt(1), output.sizeAt(2), output.sizeAt(3)}); | ||||
|     NDArray outputRearranged0 = output.reshape(output.ordering(), {blockSize, blockSize, input.sizeAt(0), output.sizeAt(1), output.sizeAt(2), output.sizeAt(3)}, false); | ||||
|     outputRearranged0.permutei({2, 3,0, 4,1, 5}); | ||||
| 
 | ||||
|     if(input.lengthOf() == output.lengthOf()) { | ||||
|         outputRearranged0.assign(input); | ||||
|     } | ||||
|     else { | ||||
|         NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), {input.sizeAt(0), output.sizeAt(1) * blockSize, output.sizeAt(2) * blockSize, output.sizeAt(3)}); | ||||
|         NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), {input.sizeAt(0), output.sizeAt(1) * blockSize, output.sizeAt(2) * blockSize, output.sizeAt(3)}, false); | ||||
|         BUILD_SINGLE_SELECTOR(input.dataType(), spaceToBatch_, (input, outputRearranged1, padBottom, padTop, padLeft, padRight), LIBND4J_TYPES); | ||||
| 
 | ||||
|         if(output.getBuffer() != outputRearranged1.getBuffer()) | ||||
| @ -352,7 +352,7 @@ void spaceToBatchND(nd4j::LaunchContext* context, const NDArray& input, const ND | ||||
|     for(int j = 1; j < rank; ++i, ++j) | ||||
|         temp[i] = output.sizeAt(j); | ||||
| 
 | ||||
|     NDArray outputRearranged0 = output.reshape(output.ordering(), temp); | ||||
|     NDArray outputRearranged0 = output.reshape(output.ordering(), temp, false); | ||||
| 
 | ||||
|     //*** construct permuting std::vector for permutation of output array ***//
 | ||||
| 
 | ||||
| @ -382,7 +382,7 @@ void spaceToBatchND(nd4j::LaunchContext* context, const NDArray& input, const ND | ||||
|         for(i = 1; i < rank; ++i) | ||||
|             temp[i] = (i <= numOfSpatialDims) ? output.sizeAt(i) * blockShape.e<Nd4jLong>(i - 1) : output.sizeAt(i); | ||||
| 
 | ||||
|         NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), temp); | ||||
|         NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), temp, false); | ||||
| 
 | ||||
|         BUILD_SINGLE_SELECTOR(input.dataType(), spaceToBatchND_, (input, padding, outputRearranged1, numOfSpatialDims), LIBND4J_TYPES); | ||||
| 
 | ||||
|  | ||||
| @ -59,7 +59,7 @@ void FORCEINLINE cross(nd4j::LaunchContext * context, NDArray *a, NDArray *b, ND | ||||
|     void FORCEINLINE _crossBatched(nd4j::LaunchContext * context, NDArray *a, NDArray *b, NDArray *o) { | ||||
|         auto a_ = a->reshape(a->ordering(), {-1, 3}); | ||||
|         auto b_ = b->reshape(b->ordering(), {-1, 3}); | ||||
|         auto o_ = o->reshape(o->ordering(), {-1, 3}); | ||||
|         auto o_ = o->reshape(o->ordering(), {-1, 3}, false); | ||||
| 
 | ||||
|         auto tadsA = a_.allTensorsAlongDimension({1}); | ||||
|         auto tadsB = b_.allTensorsAlongDimension({1}); | ||||
|  | ||||
| @ -322,7 +322,7 @@ static void depthwiseConv2d_(nd4j::graph::Context& block, const NDArray* input, | ||||
|         ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); | ||||
| 
 | ||||
|     NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext()); | ||||
|     NDArray outputReshaped = output->reshape(output->ordering(), outReShape); | ||||
|     NDArray outputReshaped = output->reshape(output->ordering(), outReShape, false); | ||||
| 
 | ||||
|     helpers::im2col(*output->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext()));  // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] | ||||
|     MmulHelper::tensorDot(&columns, weights, &outputReshaped, modifColumns, {{2,0,1,3},{iC,kH*kW,mC}}, modifOutput);              // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC] | ||||
| @ -1228,7 +1228,7 @@ static void conv2dBP_(nd4j::graph::Context& block, const NDArray* input, const N | ||||
|         NDArray* gradBR = gradB; | ||||
|         if(gradB->rankOf() == 2) | ||||
|             gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()})); | ||||
|         gradO->reduceAlongDimension(reduce::Sum, *gradBR, gradOaxesForDot);                          // sum over bS, oH, oW | ||||
|         gradO->reduceAlongDimension(reduce::Sum, *gradBR, gradOaxesForDot, false);                          // sum over bS, oH, oW | ||||
|         if(gradBR != gradB) | ||||
|             delete gradBR; | ||||
|     } | ||||
| @ -1310,7 +1310,7 @@ static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, con | ||||
|         NDArray* gradBR = gradB; | ||||
|         if(gradB->rankOf() == 2) | ||||
|             gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()})); | ||||
|         gradO->reduceAlongDimension(reduce::Sum, *gradBR, {0,indOoH,indOoH+1});                      // sum over bS, oH, oW | ||||
|         gradO->reduceAlongDimension(reduce::Sum, *gradBR, {0,indOoH,indOoH+1}, false);                      // sum over bS, oH, oW | ||||
|         if(gradBR != gradB) | ||||
|             delete gradBR; | ||||
|     } | ||||
|  | ||||
| @ -313,7 +313,7 @@ void spaceToBatch(nd4j::LaunchContext* context, const NDArray& input, NDArray& o | ||||
| 
 | ||||
|     // [bS, iH, iW, iC] is rearranged/permuted to [bS*blockSize*blockSize, (iH + padBottom + padTop)/blockSize, (iW + padLeft + padRight)/blockSize, iC] | ||||
| 
 | ||||
|     NDArray outputRearranged0 = output.reshape(output.ordering(), {blockSize, blockSize, input.sizeAt(0), output.sizeAt(1), output.sizeAt(2), input.sizeAt(3)}); | ||||
|     NDArray outputRearranged0 = output.reshape(output.ordering(), {blockSize, blockSize, input.sizeAt(0), output.sizeAt(1), output.sizeAt(2), input.sizeAt(3)}, false); | ||||
|     outputRearranged0.permutei({2, 3,0, 4,1, 5}); | ||||
| 
 | ||||
|     if(input.lengthOf() == output.lengthOf()) { | ||||
| @ -322,7 +322,7 @@ void spaceToBatch(nd4j::LaunchContext* context, const NDArray& input, NDArray& o | ||||
|     } | ||||
|     else { | ||||
| 
 | ||||
|         NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), {input.sizeAt(0), output.sizeAt(1) * blockSize, output.sizeAt(2) * blockSize, input.sizeAt(3)}); | ||||
|         NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), {input.sizeAt(0), output.sizeAt(1) * blockSize, output.sizeAt(2) * blockSize, input.sizeAt(3)}, false); | ||||
| 
 | ||||
|         const int threadsPerBlock = MAX_NUM_THREADS / 2; | ||||
|         const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; | ||||
| @ -439,7 +439,7 @@ void spaceToBatchND(nd4j::LaunchContext* context, const NDArray& input, const ND | ||||
|     for(int j = 1; j < rank; ++i, ++j) | ||||
|         temp[i] = output.sizeAt(j); | ||||
| 
 | ||||
|     NDArray outputRearranged0 = output.reshape(output.ordering(), temp); | ||||
|     NDArray outputRearranged0 = output.reshape(output.ordering(), temp, false); | ||||
| 
 | ||||
|     //*** construct permuting std::vector for permutation of output array ***// | ||||
| 
 | ||||
| @ -469,7 +469,7 @@ void spaceToBatchND(nd4j::LaunchContext* context, const NDArray& input, const ND | ||||
|         for(i = 1; i < rank; ++i) | ||||
|             temp[i] = (i <= numOfSpatialDims) ? output.sizeAt(i) * blockShape.e<Nd4jLong>(i - 1) : output.sizeAt(i); | ||||
| 
 | ||||
|         NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), temp); | ||||
|         NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), temp, false); | ||||
| 
 | ||||
|         const int threadsPerBlock = MAX_NUM_THREADS / 4; | ||||
|         const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; | ||||
|  | ||||
| @ -471,9 +471,9 @@ PLATFORM_IMPL(lstmLayer, ENGINE_CPU) { | ||||
|     if(cI) | ||||
|         cIR = new NDArray(cI->reshape(cI->ordering(), {1,dirDim,bS,nOut})); | ||||
|     if(hL) | ||||
|         hLR = new NDArray(hL->reshape(hL->ordering(), {1,dirDim,bS,nOut})); | ||||
|         hLR = new NDArray(hL->reshape(hL->ordering(), {1,dirDim,bS,nOut}, false)); | ||||
|     if(cL) | ||||
|         cLR = new NDArray(cL->reshape(cL->ordering(), {1,dirDim,bS,nOut})); | ||||
|         cLR = new NDArray(cL->reshape(cL->ordering(), {1,dirDim,bS,nOut}, false)); | ||||
| 
 | ||||
|     lstmLayerMKLDNN(xP, WxR, WrR, bR, hIR, cIR, params, hP, hLR, cLR); | ||||
| 
 | ||||
|  | ||||
| @ -321,6 +321,280 @@ TEST_F(DeclarableOpsTests1, TestTensorDot4) { | ||||
|     delete results; | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests1, TestTensorDot5) { | ||||
| 
 | ||||
|     auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); | ||||
|     auto y = NDArrayFactory::create<double>('c', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); | ||||
|     auto expected = NDArrayFactory::create<double>('c', {2,4,2,4}, {44,110,160, 66,132, 38, 88,154, 68,170,224,102,204, 82,136,238, 92,230,288,138,276,126,184,322, 116,290,352,174,348,170,232,406, 76,190,160,114,228,182,152,266, 100,250,224,150,300,226,200,350, 124,310,288,186,372,270,248,434, 148,370,352,222,444,314,296,518}); | ||||
| 
 | ||||
|     nd4j::ops::tensormmul op; | ||||
|     auto results = op.evaluate({&x, &y}, {}, {1,1,1,2}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||
| 
 | ||||
|     auto *result = results->at(0); | ||||
| 
 | ||||
|     ASSERT_TRUE(expected.isSameShape(result)); | ||||
|     ASSERT_TRUE(expected.equalsTo(result)); | ||||
| 
 | ||||
|     delete results; | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests1, TestTensorDot6) { | ||||
| 
 | ||||
|     auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); | ||||
|     auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); | ||||
|     auto expected = NDArrayFactory::create<double>('c', {2,4,2,4}, {22, 66,110,154, 44, 88,132,176, 34,102,170,238, 68,136,204,272, 46,138,230,322, 92,184,276,368, 58,174,290,406,116,232,348,464, 38,114,190,266, 76,152,228,304, 50,150,250,350,100,200,300,400, 62,186,310,434,124,248,372,496, 74,222,370,518,148,296,444,592}); | ||||
| 
 | ||||
|     nd4j::ops::tensormmul op; | ||||
|     auto results = op.evaluate({&x, &y}, {}, {1,1,1,2}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||
| 
 | ||||
|     auto *result = results->at(0); | ||||
| 
 | ||||
|     ASSERT_TRUE(expected.isSameShape(result)); | ||||
|     ASSERT_TRUE(expected.equalsTo(result)); | ||||
| 
 | ||||
|     delete results; | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests1, TestTensorDot7) { | ||||
| 
 | ||||
|     auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); | ||||
|     auto y = NDArrayFactory::create<double>('c', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); | ||||
|     auto expected = NDArrayFactory::create<double>('c', {2,4,2,4}, {76,166,112,106,196, 62,136,226, 60,174,208, 98,212,230,136,250, 76,214,336,122,260,174,168,306, 124,286,240,178,340,150,232,394, 100,226,176,142,268,106,184,310, 84,234,272,134,284,274,184,334, 100,274,400,158,332,218,216,390, 148,346,304,214,412,194,280,478}); | ||||
| 
 | ||||
|     nd4j::ops::tensormmul op; | ||||
|     auto results = op.evaluate({&x, &y}, {}, {1,1,1,2}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||
| 
 | ||||
|     auto *result = results->at(0); | ||||
| 
 | ||||
|     ASSERT_TRUE(expected.isSameShape(result)); | ||||
|     ASSERT_TRUE(expected.equalsTo(result)); | ||||
| 
 | ||||
|     delete results; | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests1, TestTensorDot8) { | ||||
| 
 | ||||
|     auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); | ||||
|     auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); | ||||
|     auto expected = NDArrayFactory::create<double>('c', {2,4,2,4}, {30, 90,150,210, 60,120,180,240, 38,114,190,266, 76,152,228,304, 46,138,230,322, 92,184,276,368, 54,162,270,378,108,216,324,432, 42,126,210,294, 84,168,252,336, 50,150,250,350,100,200,300,400, 58,174,290,406,116,232,348,464, 66,198,330,462,132,264,396,528}); | ||||
| 
 | ||||
|     nd4j::ops::tensormmul op; | ||||
|     auto results = op.evaluate({&x, &y}, {}, {1,1,1,2}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||
| 
 | ||||
|     auto *result = results->at(0); | ||||
| 
 | ||||
|     ASSERT_TRUE(expected.isSameShape(result)); | ||||
|     ASSERT_TRUE(expected.equalsTo(result)); | ||||
| 
 | ||||
|     delete results; | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests1, TestTensorDot9) { | ||||
| 
 | ||||
|     // NDArray z('f',{2,2,3}, nd4j::DataType::DOUBLE);
 | ||||
|     // z.linspace(1);
 | ||||
|     // z.printShapeInfo();
 | ||||
|     // z.printIndexedBuffer();
 | ||||
|     // z.reshapei('c', {4,3});
 | ||||
|     // z.printShapeInfo();
 | ||||
|     // z.printIndexedBuffer();
 | ||||
| 
 | ||||
|     auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); | ||||
|     auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); | ||||
|     auto expected = NDArrayFactory::create<double>('c', {3,4,4,3}, {14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422, 62, 62, 62,142,142,142,222,222,222,302,302,302, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422, 62, 62, 62,142,142,142,222,222,222,302,302,302, 62, 62, 62,142,142,142,222,222,222,302,302,302, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422}); | ||||
| 
 | ||||
|     nd4j::ops::tensormmul op; | ||||
|     auto results = op.evaluate({&x, &y}, {}, {1,0,1,0}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||
| 
 | ||||
|     auto *result = results->at(0); | ||||
| 
 | ||||
|     ASSERT_TRUE(expected.isSameShape(result)); | ||||
|     ASSERT_TRUE(expected.equalsTo(result)); | ||||
| 
 | ||||
|     delete results; | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests1, TestTensorDot10) { | ||||
| 
 | ||||
|     auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); | ||||
|     auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); | ||||
|     auto expected = NDArrayFactory::create<double>('c', {4,4}, {114,258,402,546, 138,314,490,666, 162,370,578,786, 186,426,666,906}); | ||||
| 
 | ||||
|     nd4j::ops::tensormmul op; | ||||
|     auto results = op.evaluate({&x, &y}, {}, {2,0,1, 2,0,2}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||
| 
 | ||||
|     auto *result = results->at(0); | ||||
| 
 | ||||
|     ASSERT_TRUE(expected.isSameShape(result)); | ||||
|     ASSERT_TRUE(expected.equalsTo(result)); | ||||
| 
 | ||||
|     delete results; | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests1, TestTensorDot11) { | ||||
| 
 | ||||
|     auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); | ||||
|     auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); | ||||
|     auto expected = NDArrayFactory::create<double>('c', {4,4}, {98,218,338,458, 134,302,470,638, 170,386,602,818, 206,470,734,998}); | ||||
| 
 | ||||
|     nd4j::ops::tensormmul op; | ||||
|     auto results = op.evaluate({&x, &y}, {}, {2,0,1, 2,0,2}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||
| 
 | ||||
|     auto *result = results->at(0); | ||||
| 
 | ||||
|     ASSERT_TRUE(expected.isSameShape(result)); | ||||
|     ASSERT_TRUE(expected.equalsTo(result)); | ||||
| 
 | ||||
|     delete results; | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests1, TestTensorDot12) { | ||||
| 
 | ||||
|     auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); | ||||
|     auto y = NDArrayFactory::create<double>('c', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); | ||||
|     auto expected = NDArrayFactory::create<double>('c', {4,4}, {272,292,312,332, 368,396,424,452, 464,500,536,572, 560,604,648,692}); | ||||
| 
 | ||||
|     nd4j::ops::tensormmul op; | ||||
|     auto results = op.evaluate({&x, &y}, {}, {2,0,1, 2,0,2}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||
| 
 | ||||
|     auto *result = results->at(0); | ||||
| 
 | ||||
|     ASSERT_TRUE(expected.isSameShape(result)); | ||||
|     ASSERT_TRUE(expected.equalsTo(result)); | ||||
| 
 | ||||
|     delete results; | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests1, TestTensorDot13) { | ||||
| 
 | ||||
|     auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); | ||||
|     auto y = NDArrayFactory::create<double>('c', {4,2,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); | ||||
|     auto expected = NDArrayFactory::create<double>('c', {3,3}, {640,560,640, 576,624,576, 640,560,640}); | ||||
| 
 | ||||
|     nd4j::ops::tensormmul op; | ||||
|     auto results = op.evaluate({&x, &y}, {}, {2,0,2, 2,1,0}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||
| 
 | ||||
|     auto *result = results->at(0); | ||||
| 
 | ||||
|     ASSERT_TRUE(expected.isSameShape(result)); | ||||
|     ASSERT_TRUE(expected.equalsTo(result)); | ||||
| 
 | ||||
|     delete results; | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests1, TestTensorDot14) { | ||||
| 
 | ||||
|     auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); | ||||
|     auto y = NDArrayFactory::create<double>('c', {4,2,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); | ||||
|     auto expected = NDArrayFactory::create<double>('c', {3,3}, {648,600,520, 648,536,648, 520,600,648}); | ||||
| 
 | ||||
|     nd4j::ops::tensormmul op; | ||||
|     auto results = op.evaluate({&x, &y}, {}, {2,0,2, 2,1,0}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||
| 
 | ||||
|     auto *result = results->at(0); | ||||
| 
 | ||||
|     ASSERT_TRUE(expected.isSameShape(result)); | ||||
|     ASSERT_TRUE(expected.equalsTo(result)); | ||||
| 
 | ||||
|     delete results; | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests1, TestTensorDot15) { | ||||
| 
 | ||||
|     auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); | ||||
|     auto y = NDArrayFactory::create<double>('f', {4,2,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); | ||||
|     auto expected = NDArrayFactory::create<double>('c', {3,3}, {624,624,624, 656,656,656, 624,624,624}); | ||||
| 
 | ||||
|     nd4j::ops::tensormmul op; | ||||
|     auto results = op.evaluate({&x, &y}, {}, {2,0,2, 2,1,0}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||
| 
 | ||||
|     auto *result = results->at(0); | ||||
| 
 | ||||
|     ASSERT_TRUE(expected.isSameShape(result)); | ||||
|     ASSERT_TRUE(expected.equalsTo(result)); | ||||
| 
 | ||||
|     delete results; | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests1, TestTensorDot16) { | ||||
| 
 | ||||
|     NDArray x('c', {1}, std::vector<double>{2}, nd4j::DataType::FLOAT32); | ||||
|     NDArray y('c', {2,1,2}, {1,2,3,4}, nd4j::DataType::FLOAT32); | ||||
|     NDArray exp('c', {2,2}, {2,4,6,8}, nd4j::DataType::FLOAT32); | ||||
| 
 | ||||
|     nd4j::ops::tensormmul op; | ||||
|     auto results = op.evaluate({&x, &y}, {}, {1,0, 1,1}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||
| 
 | ||||
|     auto *result = results->at(0); | ||||
| 
 | ||||
|     ASSERT_TRUE(exp.isSameShape(result)); | ||||
|     ASSERT_TRUE(exp.equalsTo(result)); | ||||
| 
 | ||||
|     delete results; | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests1, TestTensorDot17) { | ||||
| 
 | ||||
|     NDArray x('f', {16,16}, nd4j::DataType::FLOAT32); | ||||
|     NDArray y('f', {1000,16}, nd4j::DataType::FLOAT32); | ||||
|     NDArray z('c', {16,1000}, nd4j::DataType::FLOAT32); | ||||
| 
 | ||||
|     nd4j::ops::tensormmul op; | ||||
|     auto status = op.execute({&x, &y}, {&z}, {}, {1,1, 1,1}, {}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, status); | ||||
| } | ||||
| 
 | ||||
| //////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests1, DivergentCheck1) { | ||||
|     auto op = nd4j::ops::OpRegistrator::getInstance()->getOperation("switch"); | ||||
|  | ||||
| @ -708,30 +708,6 @@ TEST_F(DeclarableOpsTests12, multiUnique_2) { | ||||
|     ASSERT_TRUE(nd4j::ops::helpers::multiUnique(arrayList)); | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests12, tensormmul_6) { | ||||
| 
 | ||||
|     NDArray x('c', {1}, std::vector<double>{2}, nd4j::DataType::FLOAT32); | ||||
|     NDArray y('c', {2,1,2}, {1,2,3,4}, nd4j::DataType::FLOAT32); | ||||
|     NDArray exp('c', {2,2}, {2,4,6,8}, nd4j::DataType::FLOAT32); | ||||
| 
 | ||||
|     nd4j::ops::tensormmul op; | ||||
|     auto results = op.evaluate({&x, &y}, {}, {1,0, 1,1}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||
| 
 | ||||
|     auto *result = results->at(0); | ||||
|     // exp.printShapeInfo();
 | ||||
|     // result->printShapeInfo();
 | ||||
|     // result->printIndexedBuffer();
 | ||||
| 
 | ||||
|     ASSERT_TRUE(exp.isSameShape(result)); | ||||
|     ASSERT_TRUE(exp.equalsTo(result)); | ||||
| 
 | ||||
|     delete results; | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests12, reduceMeanBp_4) { | ||||
| 
 | ||||
|  | ||||
| @ -1560,3 +1560,447 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test11) { | ||||
| 
 | ||||
|     delete resultsB; | ||||
| } | ||||
| //////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests15, TestTensorMmul_BP1) { | ||||
| 
 | ||||
|     NDArray A('c', { 1, 2, 3 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6 }, nd4j::DataType::FLOAT32); | ||||
|     NDArray B('c', { 1, 2, 4 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8 }, nd4j::DataType::FLOAT32); | ||||
|     NDArray dLdC('c', { 3, 4 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.1 }, nd4j::DataType::FLOAT32); | ||||
| 
 | ||||
|     NDArray dLdA('c', { 1, 2, 3 }, { 3.3,  8.5,  13.36, 3.7, 9.54, 15. }, nd4j::DataType::FLOAT32); | ||||
|     NDArray dLdB('c', { 1, 2, 4 }, { 3.38, 4.04, 4.7, 5.13, 3.83, 4.58, 5.33, 5.82 }, nd4j::DataType::FLOAT32); | ||||
| 
 | ||||
|     nd4j::ops::tensormmul_bp op_bp; | ||||
| 
 | ||||
|     auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,0,1, 2,0,1 }, {}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status()); | ||||
|      | ||||
|     auto* dLdAbp = resultsBP->at(0); | ||||
|     auto* dLdBbp = resultsBP->at(1); | ||||
|      | ||||
|     ASSERT_TRUE(dLdA.isSameShape(*dLdAbp)); | ||||
|     ASSERT_TRUE(dLdA.equalsTo(*dLdAbp)); | ||||
| 
 | ||||
|     ASSERT_TRUE(dLdB.isSameShape(*dLdBbp)); | ||||
|     ASSERT_TRUE(dLdB.equalsTo(*dLdBbp)); | ||||
| 
 | ||||
|     delete resultsBP; | ||||
| } | ||||
| //////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests15, TestTensorMmul_BP2) { | ||||
| 
 | ||||
|     NDArray A('c', { 1, 2, 3 }, { 2,2,2, 2,2,2 }, nd4j::DataType::FLOAT32); | ||||
|     NDArray B('c', { 1, 2, 3 }, { 3,3,3,3, 3,3 }, nd4j::DataType::FLOAT32); | ||||
|     NDArray dLdC('c', { 1 }, { 1 }, nd4j::DataType::FLOAT32); | ||||
| 
 | ||||
|     nd4j::ops::tensormmul_bp op_bp; | ||||
|     auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status()); | ||||
|      | ||||
|     auto* dLdAbp = resultsBP->at(0); | ||||
|     auto* dLdBbp = resultsBP->at(1); | ||||
|      | ||||
|     ASSERT_TRUE(B.isSameShape(*dLdAbp)); | ||||
|     ASSERT_TRUE(B.equalsTo(*dLdAbp)); | ||||
| 
 | ||||
|     ASSERT_TRUE(A.isSameShape(*dLdBbp)); | ||||
|     ASSERT_TRUE(A.equalsTo(*dLdBbp)); | ||||
| 
 | ||||
|     delete resultsBP; | ||||
| } | ||||
| //////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests15, TestTensorMmul_BP3) { | ||||
| 
 | ||||
|     NDArray A('c', { 3, 2, 2 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, nd4j::DataType::FLOAT32); | ||||
|     NDArray B('c', { 4, 2, 2 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6 }, nd4j::DataType::FLOAT32); | ||||
|     NDArray dLdC('c', { 3, 4 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2 }, nd4j::DataType::FLOAT32); | ||||
| 
 | ||||
|     NDArray dA('c', { 3, 2, 2 }, { 3.9, 4., 4.1, 4.2, 9.82, 10.08, 10.34, 10.6, 15.74, 16.16, 16.58, 17. }, nd4j::DataType::FLOAT32); | ||||
|     NDArray dB('c', { 4, 2, 2 }, { 4.07, 4.22, 4.37, 4.52, 4.82, 5., 5.18, 5.36, 5.57, 5.78, 5.99, 6.2, 6.32, 6.56, 6.8,  7.04 }, nd4j::DataType::FLOAT32); | ||||
| 
 | ||||
|     nd4j::ops::tensormmul_bp op_bp; | ||||
| 
 | ||||
|     auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status()); | ||||
| 
 | ||||
|     auto* dLdAbp = resultsBP->at(0); | ||||
|     auto* dLdBbp = resultsBP->at(1); | ||||
| 
 | ||||
|     ASSERT_TRUE(dA.isSameShape(*dLdAbp)); | ||||
|     ASSERT_TRUE(dA.equalsTo(*dLdAbp)); | ||||
| 
 | ||||
|     ASSERT_TRUE(dB.isSameShape(*dLdBbp)); | ||||
|     ASSERT_TRUE(dB.equalsTo(*dLdBbp)); | ||||
| 
 | ||||
|      delete resultsBP; | ||||
| } | ||||
| //////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests15, TestTensorMmul_BP4) { | ||||
| 
 | ||||
|     NDArray A('c', { 3, 4, 1 }, { 0.4, 3, 5,  9, 23, 0.12,  8, 9, 0.1,  0, 124, 3 }, nd4j::DataType::FLOAT32); | ||||
|     NDArray B('c', { 2, 4, 1 }, { 4, 13, .5,  19, 2.3, 1.2,  18, .9 }, nd4j::DataType::FLOAT32); | ||||
|     NDArray dLdC('c', { 3, 2 }, { 1.1, 1.2, 1.3, 1.4, 1.5, 1.6 }, nd4j::DataType::FLOAT32); | ||||
| 
 | ||||
|     NDArray dLdA('c', { 3, 4, 1 }, { 7.16, 15.74, 22.15, 21.98, 8.42, 18.58, 25.85, 25.96, 9.68, 21.42, 29.55, 29.94 }, nd4j::DataType::FLOAT32); | ||||
|     NDArray dLdB('c', { 2, 4, 1 }, { 30.49, 3.456, 201.9, 26.1, 32.84 , 3.768, 215.6, 28.2 }, nd4j::DataType::FLOAT32); | ||||
| 
 | ||||
|     nd4j::ops::tensormmul_bp op_bp; | ||||
| 
 | ||||
|     auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status()); | ||||
|      | ||||
|     auto* dLdAbp = resultsBP->at(0); | ||||
|     auto* dLdBbp = resultsBP->at(1); | ||||
| 
 | ||||
|     ASSERT_TRUE(dLdA.isSameShape(*dLdAbp)); | ||||
|     ASSERT_TRUE(dLdA.equalsTo(*dLdAbp)); | ||||
| 
 | ||||
|     ASSERT_TRUE(dLdB.isSameShape(*dLdBbp)); | ||||
|     ASSERT_TRUE(dLdB.equalsTo(*dLdBbp)); | ||||
| 
 | ||||
|     delete resultsBP; | ||||
| } | ||||
| //////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests15, TestTensorMmul_BP5) { | ||||
| 
 | ||||
|     NDArray A('c', { 3, 4, 1, 1 }, { 0.4, 3, 5,  9, 23, 0.12,  8, 9, 0.1,  0, 124, 3 }, nd4j::DataType::FLOAT32); | ||||
|     NDArray B('c', { 2, 4, 1, 1 }, { 4, 13, .5,  19, 2.3, 1.2,  18, .9 }, nd4j::DataType::FLOAT32); | ||||
|     NDArray dLdC('c', { 3, 1, 2, 1 }, { 1.1,1.2,1.3,1.4,1.5,1.6 }, nd4j::DataType::FLOAT32); | ||||
| 
 | ||||
|     NDArray dLdA('c', { 3, 4, 1, 1 }, { 7.16, 15.74, 22.15, 21.98, 8.42, 18.58, 25.85, 25.96, 9.68, 21.42, 29.55, 29.94 }, nd4j::DataType::FLOAT32); | ||||
|     NDArray dLdB('c', { 2, 4, 1, 1 }, { 30.49, 3.456, 201.9,  26.1, 32.84,  3.768, 215.6, 28.2 }, nd4j::DataType::FLOAT32); | ||||
| 
 | ||||
|     nd4j::ops::tensormmul_bp op_bp; | ||||
| 
 | ||||
|     auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status()); | ||||
|      | ||||
|     auto* dLdAbp = resultsBP->at(0); | ||||
|     auto* dLdBbp = resultsBP->at(1); | ||||
| 
 | ||||
|     ASSERT_TRUE(dLdA.isSameShape(*dLdAbp)); | ||||
|     ASSERT_TRUE(dLdA.equalsTo(*dLdAbp)); | ||||
| 
 | ||||
|     ASSERT_TRUE(dLdB.isSameShape(*dLdBbp)); | ||||
|     ASSERT_TRUE(dLdB.equalsTo(*dLdBbp)); | ||||
| 
 | ||||
|     delete resultsBP; | ||||
| } | ||||
| //////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests15, TestTensorMmul_BP6) { | ||||
| 
 | ||||
|     NDArray A('c', { 2, 2, 2 }, { 2,2, 2,2, 2,2, 2,2 }, nd4j::DataType::FLOAT32); | ||||
|     NDArray B('c', { 2, 2, 2 }, { 3,3, 3,3, 3,3, 3,3  }, nd4j::DataType::FLOAT32); | ||||
| 
 | ||||
|     auto dLdC = NDArrayFactory::create<float>(1.f); | ||||
| 
 | ||||
|     nd4j::ops::tensormmul_bp op_bp; | ||||
|     auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 3,0,1,2, 3,0,1,2 }, {}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status()); | ||||
|      | ||||
|     auto* dLdAbp = resultsBP->at(0); | ||||
|     auto* dLdBbp = resultsBP->at(1); | ||||
| 
 | ||||
|     ASSERT_TRUE(B.isSameShape(*dLdAbp)); | ||||
|     ASSERT_TRUE(B.equalsTo(*dLdAbp)); | ||||
| 
 | ||||
|     ASSERT_TRUE(A.isSameShape(*dLdBbp)); | ||||
|     ASSERT_TRUE(A.equalsTo(*dLdBbp)); | ||||
| 
 | ||||
|     delete resultsBP; | ||||
| } | ||||
| //////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests15, TestTensorMmul_BP7) { | ||||
| 
 | ||||
|     NDArray A('c', { 3, 4, 1 }, { 0.4, 3, 5,  9, 23, 0.12,  8, 9, 0.1,  0, 124, 3 }, nd4j::DataType::FLOAT32); | ||||
|     NDArray B('c', { 2, 4, 1 }, { 4, 13, .5,  19, 2.3, 1.2,  18, .9 }, nd4j::DataType::FLOAT32); | ||||
|     NDArray dLdC('c', { 3, 1, 2, 1 }, { 1.1, 1.2, 1.3, 1.4, 1.5, 1.6 }, nd4j::DataType::FLOAT32); | ||||
| 
 | ||||
|     NDArray dLdA('c', { 3, 4, 1 }, { 7.16, 15.74, 22.15, 21.98, 8.42, 18.58, 25.85, 25.96, 9.68, 21.42, 29.55, 29.94 }, nd4j::DataType::FLOAT32); | ||||
|     NDArray dLdB('c', { 2, 4, 1 }, { 30.49, 3.456, 201.9,  26.1, 32.84,  3.768, 215.6, 28.2 }, nd4j::DataType::FLOAT32); | ||||
| 
 | ||||
|     nd4j::ops::tensormmul_bp op_bp; | ||||
| 
 | ||||
|     auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 1,1, 1,1 }, {}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status()); | ||||
|     auto* dLdAbp = resultsBP->at(0); | ||||
|     auto* dLdBbp = resultsBP->at(1); | ||||
| 
 | ||||
|     ASSERT_TRUE(dLdA.isSameShape(*dLdAbp)); | ||||
|     ASSERT_TRUE(dLdA.equalsTo(*dLdAbp)); | ||||
| 
 | ||||
|     ASSERT_TRUE(dLdB.isSameShape(*dLdBbp)); | ||||
|     ASSERT_TRUE(dLdB.equalsTo(*dLdBbp)); | ||||
| 
 | ||||
|     delete resultsBP; | ||||
| } | ||||
| //////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests15, TestTensorMmul_BP8) { | ||||
| 
 | ||||
|     NDArray A('c', { 1, 1, 4, 3 }, { 0.4, 3, 5,  9, 23, 0.12,  8, 9, 0.1,  0, 124, 3 }, nd4j::DataType::FLOAT32); | ||||
|     NDArray B('c', { 1, 1, 4, 2 }, { 4, 13, .5,  19, 2.3, 1.2,  18, .9 }, nd4j::DataType::FLOAT32); | ||||
|     NDArray dLdC('c', { 3, 2 }, { 1.1,1.2,1.3,1.4,1.5,1.6 }, nd4j::DataType::FLOAT32); | ||||
| 
 | ||||
|     NDArray dLdA('c', { 1, 1, 4, 3 }, { 20., 23.4,  26.8, 23.35, 27.25, 31.15, 3.97,  4.67,  5.37, 20.88, 24.66, 28.44 }, nd4j::DataType::FLOAT32); | ||||
|     NDArray dLdB('c', { 1, 1, 4, 2 }, { 11.84,   12.68,  39.98,  43.192, 20.65, 22.36, 165.7,   178.4 }, nd4j::DataType::FLOAT32); | ||||
| 
 | ||||
|     nd4j::ops::tensormmul_bp op_bp; | ||||
| 
 | ||||
|     auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 3,0,1,2, 3,0,1,2 }, {}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status()); | ||||
|      | ||||
|     auto* dLdAbp = resultsBP->at(0); | ||||
|     auto* dLdBbp = resultsBP->at(1); | ||||
| 
 | ||||
|     ASSERT_TRUE(dLdA.isSameShape(*dLdAbp)); | ||||
|     ASSERT_TRUE(dLdA.equalsTo(*dLdAbp)); | ||||
| 
 | ||||
|     ASSERT_TRUE(dLdB.isSameShape(*dLdBbp)); | ||||
|     ASSERT_TRUE(dLdB.equalsTo(*dLdBbp)); | ||||
| 
 | ||||
|     delete resultsBP; | ||||
| } | ||||
| //////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests15, TestTensorMmul_BP9) { | ||||
| 
 | ||||
|     NDArray A('c', { 3, 2, 2, 1 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, nd4j::DataType::FLOAT32); | ||||
|     NDArray B('c', { 4, 2, 2 ,1 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6 }, nd4j::DataType::FLOAT32); | ||||
|     NDArray dLdC('c', { 3, 1, 4, 1 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2 }, nd4j::DataType::FLOAT32); | ||||
| 
 | ||||
|     NDArray dA('c', { 3, 2, 2, 1 }, { 3.9, 4., 4.1, 4.2, 9.82, 10.08, 10.34, 10.6, 15.74, 16.16, 16.58, 17. }, nd4j::DataType::FLOAT32); | ||||
|     NDArray dB('c', { 4, 2, 2, 1 }, { 4.07, 4.22, 4.37, 4.52, 4.82, 5., 5.18, 5.36, 5.57, 5.78, 5.99, 6.2, 6.32, 6.56, 6.8,  7.04 }, nd4j::DataType::FLOAT32); | ||||
| 
 | ||||
|     nd4j::ops::tensormmul_bp op_bp; | ||||
| 
 | ||||
|     auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status()); | ||||
| 
 | ||||
|     auto* dLdAbp = resultsBP->at(0); | ||||
|     auto* dLdBbp = resultsBP->at(1); | ||||
| 
 | ||||
|     ASSERT_TRUE(dA.isSameShape(*dLdAbp)); | ||||
|     ASSERT_TRUE(dA.equalsTo(*dLdAbp)); | ||||
| 
 | ||||
|     ASSERT_TRUE(dB.isSameShape(*dLdBbp)); | ||||
|     ASSERT_TRUE(dB.equalsTo(*dLdBbp)); | ||||
| 
 | ||||
|     delete resultsBP; | ||||
| } | ||||
| //////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests15, TestTensorMmul_BP10) { | ||||
| 
 | ||||
|     NDArray A('c', { 1, 2, 2, 3 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, nd4j::DataType::FLOAT32); | ||||
|     NDArray B('c', { 1, 2, 2 ,4 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6 }, nd4j::DataType::FLOAT32); | ||||
|     NDArray dLdC('c', { 1, 3, 1, 4 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2 }, nd4j::DataType::FLOAT32); | ||||
| 
 | ||||
| 
 | ||||
|     NDArray dA('c', { 1, 2, 2, 3 }, { 3.3, 8.5, 13.7, 3.7, 9.54, 15.38, 4.1, 10.58, 17.06, 4.5,  11.62, 18.74 }, nd4j::DataType::FLOAT32); | ||||
|     NDArray dB('c', { 1, 2, 2, 4 }, { 3.38, 4.04, 4.7, 5.36, 3.83, 4.58, 5.33, 6.08, 4.28, 5.12, 5.96, 6.8, 4.73, 5.66, 6.59, 7.52 }, nd4j::DataType::FLOAT32); | ||||
| 
 | ||||
|     nd4j::ops::tensormmul_bp op_bp; | ||||
| 
 | ||||
|     auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status()); | ||||
| 
 | ||||
|     auto* dLdAbp = resultsBP->at(0); | ||||
|     auto* dLdBbp = resultsBP->at(1); | ||||
| 
 | ||||
|     ASSERT_TRUE(dA.isSameShape(*dLdAbp)); | ||||
|     ASSERT_TRUE(dA.equalsTo(*dLdAbp)); | ||||
| 
 | ||||
|     ASSERT_TRUE(dB.isSameShape(*dLdBbp)); | ||||
|     ASSERT_TRUE(dB.equalsTo(*dLdBbp)); | ||||
| 
 | ||||
|     delete resultsBP; | ||||
| } | ||||
| //////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests15, TestTensorMmul_BP11) { | ||||
| 
 | ||||
|     NDArray A('c', { 2, 2, 3 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, nd4j::DataType::FLOAT32); | ||||
|     NDArray B('c', { 2, 2 ,4 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6 }, nd4j::DataType::FLOAT32); | ||||
|     NDArray dLdC('c', { 3, 4 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2 }, nd4j::DataType::FLOAT32); | ||||
| 
 | ||||
| 
 | ||||
|     NDArray dA('c', { 2, 2, 3 }, { 3.3, 8.5, 13.7, 3.7, 9.54, 15.38, 4.1, 10.58, 17.06, 4.5,  11.62, 18.74 }, nd4j::DataType::FLOAT32); | ||||
|     NDArray dB('c', { 2, 2, 4 }, { 3.38, 4.04, 4.7, 5.36, 3.83, 4.58, 5.33, 6.08, 4.28, 5.12, 5.96, 6.8, 4.73, 5.66, 6.59, 7.52 }, nd4j::DataType::FLOAT32); | ||||
| 
 | ||||
|     nd4j::ops::tensormmul_bp op_bp; | ||||
| 
 | ||||
|     auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,0,1, 2,0,1 }, {}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status()); | ||||
| 
 | ||||
|     auto* dLdAbp = resultsBP->at(0); | ||||
|     auto* dLdBbp = resultsBP->at(1); | ||||
| 
 | ||||
|     ASSERT_TRUE(dA.isSameShape(*dLdAbp)); | ||||
|     ASSERT_TRUE(dA.equalsTo(*dLdAbp)); | ||||
| 
 | ||||
|     ASSERT_TRUE(dB.isSameShape(*dLdBbp)); | ||||
|     ASSERT_TRUE(dB.equalsTo(*dLdBbp)); | ||||
| 
 | ||||
|     delete resultsBP; | ||||
| } | ||||
| //////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests15, TestTensorMmul_BP12) { | ||||
| 
 | ||||
|     NDArray A('c', { 2, 2, 3 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, nd4j::DataType::FLOAT32); | ||||
|     NDArray B('c', { 2, 2 ,3 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2 }, nd4j::DataType::FLOAT32); | ||||
|     NDArray dLdC('c', { 2, 3, 2, 3 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2, | ||||
|                       1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2., 2.1, 2.2, 2.3, 2.4, | ||||
|                       2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2, 3.3, 3.4, 3.5, 3.6 }, nd4j::DataType::FLOAT32); | ||||
| 
 | ||||
|     NDArray dA('c', { 2, 2, 3 }, { 7.66, 20.26, 32.86, 8.29, 21.97, 35.65, 45.46, 58.06, 70.66, 49.33, 63.01, 76.69 }, nd4j::DataType::FLOAT32); | ||||
|     NDArray dB('c', { 2, 2, 3 }, { 25.86, 27.36, 28.86, 28.74, 30.42, 32.1, 30.36, 31.86, 33.36, 33.78, 35.46, 37.14 }, nd4j::DataType::FLOAT32); | ||||
| 
 | ||||
|     nd4j::ops::tensormmul_bp op_bp; | ||||
| 
 | ||||
|     auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 1,1, 1,1 }, {}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status()); | ||||
| 
 | ||||
|     auto* dLdAbp = resultsBP->at(0); | ||||
|     auto* dLdBbp = resultsBP->at(1); | ||||
| 
 | ||||
|     ASSERT_TRUE(dA.isSameShape(*dLdAbp)); | ||||
|     ASSERT_TRUE(dA.equalsTo(*dLdAbp)); | ||||
| 
 | ||||
|     ASSERT_TRUE(dB.isSameShape(*dLdBbp)); | ||||
|     ASSERT_TRUE(dB.equalsTo(*dLdBbp)); | ||||
| 
 | ||||
|     delete resultsBP; | ||||
| } | ||||
| //////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests15, TestTensorMmul_BP13) { | ||||
| 
 | ||||
|     NDArray A('c', { 3, 2, 2 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, nd4j::DataType::DOUBLE); | ||||
|     NDArray B('c', { 3, 2, 2 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2 }, nd4j::DataType::DOUBLE); | ||||
|     NDArray dLdC('c', { 3, 2, 3, 2 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2, | ||||
|                       1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2., 2.1, 2.2, 2.3, 2.4, | ||||
|                       2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2, 3.3, 3.4, 3.5, 3.6 }, nd4j::DataType::DOUBLE); | ||||
| 
 | ||||
|     NDArray dA('c', { 3, 2, 2 }, { 7.79, 20.57, 8.21, 21.71, 33.35, 46.13, 35.21, 48.71, 58.91, 71.69, 62.21, 75.71 }, nd4j::DataType::DOUBLE); | ||||
|     NDArray dB('c', { 3, 2, 2 }, { 26.49, 28.02, 28.41, 30.06, 29.55, 31.08, 31.71, 33.36, 32.61, 34.14, 35.01, 36.66 }, nd4j::DataType::DOUBLE); | ||||
| 
 | ||||
|     nd4j::ops::tensormmul_bp op_bp; | ||||
| 
 | ||||
|     auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 1,1, 1,1 }, {}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status()); | ||||
| 
 | ||||
|     auto* dLdAbp = resultsBP->at(0); | ||||
|     auto* dLdBbp = resultsBP->at(1); | ||||
| 
 | ||||
|     ASSERT_TRUE(dA.isSameShape(*dLdAbp)); | ||||
|     ASSERT_TRUE(dA.equalsTo(*dLdAbp)); | ||||
| 
 | ||||
|     ASSERT_TRUE(dB.isSameShape(*dLdBbp)); | ||||
|     ASSERT_TRUE(dB.equalsTo(*dLdBbp)); | ||||
| 
 | ||||
|     delete resultsBP; | ||||
| } | ||||
| //////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests15, TestTensorMmul_BP14) { | ||||
| 
 | ||||
|     NDArray A('c', { 2, 2, 2, 2 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2, 3.3, 3.4, 3.5, 3.6 }, nd4j::DataType::DOUBLE); | ||||
| 
 | ||||
|     NDArray B('c', { 2, 2, 2, 2 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6 }, nd4j::DataType::DOUBLE); | ||||
| 
 | ||||
|     NDArray dLdC('c', { 2, 2, 2, 2, 2, 2 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2, | ||||
|                       1.3, 1.4, 1.5, 1.6, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2, | ||||
|                       1.3, 1.4, 1.5, 1.6, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2, | ||||
|                       1.3, 1.4, 1.5, 1.6, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2, | ||||
|                       1.3, 1.4, 1.5, 1.6 }, nd4j::DataType::DOUBLE); | ||||
| 
 | ||||
|     NDArray dA('c', { 2, 2, 2, 2 }, { 13.88, 37.24, 13.88, 37.24, 15.32, 41.24, 15.32, 41.24, 13.88, 37.24, 13.88, 37.24, 15.32, 41.24, 15.32, 41.24 }, nd4j::DataType::DOUBLE); | ||||
|     NDArray dB('c', { 2, 2, 2, 2 }, { 10.76, 12.88, 15., 17.12, 12.36, 14.8, 17.24, 19.68, 19.24, 21.36, 23.48, 25.6, 22.12, 24.56, 27., 29.44 }, nd4j::DataType::DOUBLE); | ||||
| 
 | ||||
|     nd4j::ops::tensormmul_bp op_bp; | ||||
| 
 | ||||
|     auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 1,1, 1,1 }, {}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status()); | ||||
| 
 | ||||
|     auto* dLdAbp = resultsBP->at(0); | ||||
|     auto* dLdBbp = resultsBP->at(1); | ||||
| 
 | ||||
|     ASSERT_TRUE(dA.isSameShape(*dLdAbp)); | ||||
|     ASSERT_TRUE(dA.equalsTo(*dLdAbp)); | ||||
| 
 | ||||
|     ASSERT_TRUE(dB.isSameShape(*dLdBbp)); | ||||
|     ASSERT_TRUE(dB.equalsTo(*dLdBbp)); | ||||
| 
 | ||||
|     delete resultsBP; | ||||
| } | ||||
| //////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests15, TestTensorMmul_BP15) { | ||||
| 
 | ||||
|     NDArray A('c', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, nd4j::DataType::FLOAT32); | ||||
|     NDArray B('f', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, nd4j::DataType::FLOAT32); | ||||
| 
 | ||||
|     NDArray dLdC('f', { 2, 2 }, { 23.0, 24.44, 2.0, 26. }, nd4j::DataType::FLOAT32); | ||||
| 
 | ||||
|     NDArray dA('c', { 2, 2, 3 }, { 27., 127., 227., 77., 177., 277., 76.44, 278.20001, 479.96002, 177.32, 379.08001, 580.839966 }, nd4j::DataType::FLOAT32); | ||||
|     NDArray dB('f', { 2, 2, 3 }, { 194.08, 184., 336.4, 268., 241.52, 212., 383.839996, 296., 288.96002, 240., 431.27999, 324. }, nd4j::DataType::FLOAT32); | ||||
| 
 | ||||
|     nd4j::ops::tensormmul_bp op; | ||||
|     auto results = op.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2,2,1,2 }); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||
| 
 | ||||
|     auto* dLdA = results->at(0); | ||||
|     auto* dLdB = results->at(1); | ||||
| 
 | ||||
|     ASSERT_TRUE(dA.isSameShape(*dLdA)); | ||||
|     ASSERT_TRUE(dA.equalsTo(*dLdA)); | ||||
| 
 | ||||
|     ASSERT_TRUE(dB.isSameShape(*dLdB)); | ||||
|     ASSERT_TRUE(dB.equalsTo(*dLdB)); | ||||
| 
 | ||||
|     delete results; | ||||
| } | ||||
| //////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests15, TestTensorMmul_BP16) { | ||||
| 
 | ||||
|     NDArray A('f', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, nd4j::DataType::DOUBLE); | ||||
|     NDArray B('c', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, nd4j::DataType::DOUBLE); | ||||
| 
 | ||||
|     NDArray dLdC('c', { 2, 2 }, nd4j::DataType::DOUBLE); | ||||
| 
 | ||||
|     const OpArgsHolder argsHolderFF({ &A, &B }, {}, { 2,1,2, 2,1,2 }); | ||||
|     const OpArgsHolder argsHolderBP({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }); | ||||
| 
 | ||||
|     nd4j::ops::tensormmul op; | ||||
|     nd4j::ops::tensormmul_bp op_bp; | ||||
| 
 | ||||
|     const bool isGradCorrect = GradCheck::checkGrad(op, op_bp, argsHolderFF, argsHolderBP, {1,0}); | ||||
|     ASSERT_TRUE(isGradCorrect); | ||||
| } | ||||
| //////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests15, TestTensorMmul_BP17) { | ||||
| 
 | ||||
|     NDArray A('f', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, nd4j::DataType::DOUBLE); | ||||
|     NDArray B('f', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, nd4j::DataType::DOUBLE); | ||||
| 
 | ||||
|     NDArray dLdC('c', { 2, 2 }, nd4j::DataType::DOUBLE); | ||||
| 
 | ||||
|     const OpArgsHolder argsHolderFF({ &A, &B }, {}, { 2,1,2, 2,1,2 }); | ||||
|     const OpArgsHolder argsHolderBP({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }); | ||||
| 
 | ||||
|     nd4j::ops::tensormmul op; | ||||
|     nd4j::ops::tensormmul_bp op_bp; | ||||
| 
 | ||||
|     const bool isGradCorrect = GradCheck::checkGrad(op, op_bp, argsHolderFF, argsHolderBP, { 1,0 }); | ||||
|     ASSERT_TRUE(isGradCorrect); | ||||
| } | ||||
| 
 | ||||
|  | ||||
| @ -578,246 +578,6 @@ TEST_F(DeclarableOpsTests2, Test_Concat_BP_1) { | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests2, TestTensorDot5) { | ||||
| 
 | ||||
|     auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); | ||||
|     auto y = NDArrayFactory::create<double>('c', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); | ||||
|     auto expected = NDArrayFactory::create<double>('c', {2,4,2,4}, {44,110,160, 66,132, 38, 88,154, 68,170,224,102,204, 82,136,238, 92,230,288,138,276,126,184,322, 116,290,352,174,348,170,232,406, 76,190,160,114,228,182,152,266, 100,250,224,150,300,226,200,350, 124,310,288,186,372,270,248,434, 148,370,352,222,444,314,296,518}); | ||||
| 
 | ||||
|     nd4j::ops::tensormmul op; | ||||
|     auto results = op.evaluate({&x, &y}, {}, {1,1,1,2}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||
| 
 | ||||
|     auto *result = results->at(0); | ||||
| 
 | ||||
|     ASSERT_TRUE(expected.isSameShape(result)); | ||||
|     ASSERT_TRUE(expected.equalsTo(result)); | ||||
| 
 | ||||
|     delete results; | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests2, TestTensorDot6) { | ||||
| 
 | ||||
|     auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); | ||||
|     auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); | ||||
|     auto expected = NDArrayFactory::create<double>('c', {2,4,2,4}, {22, 66,110,154, 44, 88,132,176, 34,102,170,238, 68,136,204,272, 46,138,230,322, 92,184,276,368, 58,174,290,406,116,232,348,464, 38,114,190,266, 76,152,228,304, 50,150,250,350,100,200,300,400, 62,186,310,434,124,248,372,496, 74,222,370,518,148,296,444,592}); | ||||
| 
 | ||||
|     nd4j::ops::tensormmul op; | ||||
|     auto results = op.evaluate({&x, &y}, {}, {1,1,1,2}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||
| 
 | ||||
|     auto *result = results->at(0); | ||||
| 
 | ||||
|     ASSERT_TRUE(expected.isSameShape(result)); | ||||
|     ASSERT_TRUE(expected.equalsTo(result)); | ||||
| 
 | ||||
|     delete results; | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests2, TestTensorDot7) { | ||||
| 
 | ||||
|     auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); | ||||
|     auto y = NDArrayFactory::create<double>('c', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); | ||||
|     auto expected = NDArrayFactory::create<double>('c', {2,4,2,4}, {76,166,112,106,196, 62,136,226, 60,174,208, 98,212,230,136,250, 76,214,336,122,260,174,168,306, 124,286,240,178,340,150,232,394, 100,226,176,142,268,106,184,310, 84,234,272,134,284,274,184,334, 100,274,400,158,332,218,216,390, 148,346,304,214,412,194,280,478}); | ||||
| 
 | ||||
|     nd4j::ops::tensormmul op; | ||||
|     auto results = op.evaluate({&x, &y}, {}, {1,1,1,2}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||
| 
 | ||||
|     auto *result = results->at(0); | ||||
| 
 | ||||
|     ASSERT_TRUE(expected.isSameShape(result)); | ||||
|     ASSERT_TRUE(expected.equalsTo(result)); | ||||
| 
 | ||||
|     delete results; | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests2, TestTensorDot8) { | ||||
| 
 | ||||
|     auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); | ||||
|     auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); | ||||
|     auto expected = NDArrayFactory::create<double>('c', {2,4,2,4}, {30, 90,150,210, 60,120,180,240, 38,114,190,266, 76,152,228,304, 46,138,230,322, 92,184,276,368, 54,162,270,378,108,216,324,432, 42,126,210,294, 84,168,252,336, 50,150,250,350,100,200,300,400, 58,174,290,406,116,232,348,464, 66,198,330,462,132,264,396,528}); | ||||
| 
 | ||||
|     nd4j::ops::tensormmul op; | ||||
|     auto results = op.evaluate({&x, &y}, {}, {1,1,1,2}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||
| 
 | ||||
|     auto *result = results->at(0); | ||||
| 
 | ||||
|     ASSERT_TRUE(expected.isSameShape(result)); | ||||
|     ASSERT_TRUE(expected.equalsTo(result)); | ||||
| 
 | ||||
|     delete results; | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests2, TestTensorDot9) { | ||||
| 
 | ||||
|     // NDArray z('f',{2,2,3}, nd4j::DataType::DOUBLE);
 | ||||
|     // z.linspace(1);
 | ||||
|     // z.printShapeInfo();
 | ||||
|     // z.printIndexedBuffer();
 | ||||
|     // z.reshapei('c', {4,3});
 | ||||
|     // z.printShapeInfo();
 | ||||
|     // z.printIndexedBuffer();
 | ||||
| 
 | ||||
|     auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); | ||||
|     auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); | ||||
|     auto expected = NDArrayFactory::create<double>('c', {3,4,4,3}, {14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422, 62, 62, 62,142,142,142,222,222,222,302,302,302, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422, 62, 62, 62,142,142,142,222,222,222,302,302,302, 62, 62, 62,142,142,142,222,222,222,302,302,302, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422}); | ||||
| 
 | ||||
|     nd4j::ops::tensormmul op; | ||||
|     auto results = op.evaluate({&x, &y}, {}, {1,0,1,0}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||
| 
 | ||||
|     auto *result = results->at(0); | ||||
| 
 | ||||
|     ASSERT_TRUE(expected.isSameShape(result)); | ||||
|     ASSERT_TRUE(expected.equalsTo(result)); | ||||
| 
 | ||||
|     delete results; | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests2, TestTensorDot10) { | ||||
| 
 | ||||
|     auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); | ||||
|     auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); | ||||
|     auto expected = NDArrayFactory::create<double>('c', {4,4}, {114,258,402,546, 138,314,490,666, 162,370,578,786, 186,426,666,906}); | ||||
| 
 | ||||
|     nd4j::ops::tensormmul op; | ||||
|     auto results = op.evaluate({&x, &y}, {}, {2,0,1, 2,0,2}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||
| 
 | ||||
|     auto *result = results->at(0); | ||||
| 
 | ||||
|     ASSERT_TRUE(expected.isSameShape(result)); | ||||
|     ASSERT_TRUE(expected.equalsTo(result)); | ||||
| 
 | ||||
|     delete results; | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests2, TestTensorDot11) { | ||||
| 
 | ||||
|     auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); | ||||
|     auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); | ||||
|     auto expected = NDArrayFactory::create<double>('c', {4,4}, {98,218,338,458, 134,302,470,638, 170,386,602,818, 206,470,734,998}); | ||||
| 
 | ||||
|     nd4j::ops::tensormmul op; | ||||
|     auto results = op.evaluate({&x, &y}, {}, {2,0,1, 2,0,2}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||
| 
 | ||||
|     auto *result = results->at(0); | ||||
| 
 | ||||
|     ASSERT_TRUE(expected.isSameShape(result)); | ||||
|     ASSERT_TRUE(expected.equalsTo(result)); | ||||
| 
 | ||||
|     delete results; | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests2, TestTensorDot12) { | ||||
| 
 | ||||
|     auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); | ||||
|     auto y = NDArrayFactory::create<double>('c', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); | ||||
|     auto expected = NDArrayFactory::create<double>('c', {4,4}, {272,292,312,332, 368,396,424,452, 464,500,536,572, 560,604,648,692}); | ||||
| 
 | ||||
|     nd4j::ops::tensormmul op; | ||||
|     auto results = op.evaluate({&x, &y}, {}, {2,0,1, 2,0,2}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||
| 
 | ||||
|     auto *result = results->at(0); | ||||
| 
 | ||||
|     ASSERT_TRUE(expected.isSameShape(result)); | ||||
|     ASSERT_TRUE(expected.equalsTo(result)); | ||||
| 
 | ||||
|     delete results; | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests2, TestTensorDot13) { | ||||
| 
 | ||||
|     auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); | ||||
|     auto y = NDArrayFactory::create<double>('c', {4,2,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); | ||||
|     auto expected = NDArrayFactory::create<double>('c', {3,3}, {640,560,640, 576,624,576, 640,560,640}); | ||||
| 
 | ||||
|     nd4j::ops::tensormmul op; | ||||
|     auto results = op.evaluate({&x, &y}, {}, {2,0,2, 2,1,0}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||
| 
 | ||||
|     auto *result = results->at(0); | ||||
| 
 | ||||
|     ASSERT_TRUE(expected.isSameShape(result)); | ||||
|     ASSERT_TRUE(expected.equalsTo(result)); | ||||
| 
 | ||||
|     delete results; | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests2, TestTensorDot14) { | ||||
| 
 | ||||
|     auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); | ||||
|     auto y = NDArrayFactory::create<double>('c', {4,2,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); | ||||
|     auto expected = NDArrayFactory::create<double>('c', {3,3}, {648,600,520, 648,536,648, 520,600,648}); | ||||
| 
 | ||||
|     nd4j::ops::tensormmul op; | ||||
|     auto results = op.evaluate({&x, &y}, {}, {2,0,2, 2,1,0}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||
| 
 | ||||
|     auto *result = results->at(0); | ||||
| 
 | ||||
|     ASSERT_TRUE(expected.isSameShape(result)); | ||||
|     ASSERT_TRUE(expected.equalsTo(result)); | ||||
| 
 | ||||
|     delete results; | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests2, TestTensorDot15) { | ||||
| 
 | ||||
|     auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); | ||||
|     auto y = NDArrayFactory::create<double>('f', {4,2,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); | ||||
|     auto expected = NDArrayFactory::create<double>('c', {3,3}, {624,624,624, 656,656,656, 624,624,624}); | ||||
| 
 | ||||
|     nd4j::ops::tensormmul op; | ||||
|     auto results = op.evaluate({&x, &y}, {}, {2,0,2, 2,1,0}); | ||||
| 
 | ||||
|     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||
| 
 | ||||
|     auto *result = results->at(0); | ||||
| 
 | ||||
|     ASSERT_TRUE(expected.isSameShape(result)); | ||||
|     ASSERT_TRUE(expected.equalsTo(result)); | ||||
| 
 | ||||
|     delete results; | ||||
| 
 | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_1) { | ||||
|  | ||||
| @ -2043,34 +2043,6 @@ TEST_F(DeclarableOpsTests9, cumprod_test1) { | ||||
|     const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 1}, {1, 1},GradCheck::MEAN); | ||||
| 
 | ||||
|     ASSERT_TRUE(isGradCorrect); | ||||
| 
 | ||||
|     //************************************//
 | ||||
| /*    exclusive = 1; reverse = 0;
 | ||||
| 
 | ||||
|     result = op.execute({&inputC, &axis}, {}, {exclusive, reverse}); | ||||
|     ASSERT_EQ(Status::OK(), result->status()); | ||||
|     z = result->at(0); | ||||
|     ASSERT_TRUE(expTF.equalsTo(z)); | ||||
|     delete result; | ||||
| */ | ||||
|     //************************************//
 | ||||
| /*    exclusive = 0; reverse = 1;
 | ||||
| 
 | ||||
|     result = op.execute({&inputC, &axis}, {}, {exclusive, reverse}); | ||||
|     ASSERT_EQ(Status::OK(), result->status()); | ||||
|     z = result->at(0); | ||||
|     ASSERT_TRUE(expFT.equalsTo(z)); | ||||
|     delete result; | ||||
| */ | ||||
|     //************************************//
 | ||||
| /*    exclusive = 1; reverse = 1;
 | ||||
| 
 | ||||
|     result = op.execute({&inputC, &axis}, {}, {exclusive, reverse}); | ||||
|     ASSERT_EQ(Status::OK(), result->status()); | ||||
|     z = result->at(0); | ||||
|     ASSERT_TRUE(expTT.equalsTo(z)); | ||||
|     delete result; | ||||
| */ | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////////////////
 | ||||
| @ -2079,11 +2051,6 @@ TEST_F(DeclarableOpsTests9, cumprod_test2) { | ||||
|     auto inputC = NDArrayFactory::create<double>('c', {2, 2}); | ||||
|     auto axis = NDArrayFactory::create<double>(1.); | ||||
| 
 | ||||
| //    auto expFF = NDArrayFactory::create<double>('c', {3, 5}, {1.,   2.,   6.,    24.,   120., 6.,  42., 336.,  3024., 30240.,11., 132.,1716., 24024.,360360.});
 | ||||
| //    auto expTF = NDArrayFactory::create<double>('c', {3, 5}, {1, 1, 2, 6, 24,1, 6, 42, 336, 3024,1, 11, 132, 1716, 24024});
 | ||||
| 
 | ||||
| //    auto expFT = NDArrayFactory::create<double>('c', {3, 5}, {120, 120, 60, 20, 5,30240, 5040, 720, 90, 10,360360, 32760, 2730, 210, 15});    //+++
 | ||||
| //    auto expTT = NDArrayFactory::create<double>('c', {3, 5}, {120, 60, 20, 5, 1,5040, 720, 90, 10, 1,32760, 2730, 210, 15, 1});
 | ||||
|     auto gradO = NDArrayFactory::create<double>('c', {2, 2}); | ||||
| 
 | ||||
|     int exclusive, reverse; | ||||
|  | ||||
| @ -61,7 +61,7 @@ public class TestPCA extends BaseNd4jTest { | ||||
|             assertEquals("Reconstructed matrix is very different from the original.", 0.0, Diff.getDouble(i), 1.0); | ||||
|         } | ||||
|     } | ||||
|      | ||||
| 
 | ||||
|     @Test | ||||
|     public void testFactorSVDTransposed() { | ||||
|         int m = 4; | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user