| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | /*******************************************************************************
 | 
					
						
							|  |  |  |  * Copyright (c) 2015-2018 Skymind, Inc. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * This program and the accompanying materials are made available under the | 
					
						
							|  |  |  |  * terms of the Apache License, Version 2.0 which is available at | 
					
						
							|  |  |  |  * https://www.apache.org/licenses/LICENSE-2.0.
 | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * Unless required by applicable law or agreed to in writing, software | 
					
						
							|  |  |  |  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | 
					
						
							|  |  |  |  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | 
					
						
							|  |  |  |  * License for the specific language governing permissions and limitations | 
					
						
							|  |  |  |  * under the License. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * SPDX-License-Identifier: Apache-2.0 | 
					
						
							|  |  |  |  ******************************************************************************/ | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | //  @author raver119@gmail.com
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include "../ConstantTadHelper.h"
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  | #include <helpers/TAD.h>
 | 
					
						
							|  |  |  | #include <helpers/ShapeUtils.h>
 | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  | #ifndef __CUDABLAS__
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-02-24 06:51:01 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  | namespace sd { | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     ConstantTadHelper::ConstantTadHelper() { | 
					
						
							| 
									
										
										
										
											2020-02-24 06:51:01 +02:00
										 |  |  |         MAP_IMPL<TadDescriptor, TadPack> pack; | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |         _cache.emplace_back(pack); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     ConstantTadHelper* ConstantTadHelper::getInstance() { | 
					
						
							|  |  |  |         if (!_INSTANCE) | 
					
						
							|  |  |  |             _INSTANCE = new ConstantTadHelper(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return _INSTANCE; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-09-03 22:02:02 +03:00
										 |  |  |     TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) { | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |         return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-09-03 22:02:02 +03:00
										 |  |  |     TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape) { | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |         return tadForDimensions(originalShape, const_cast<int *>(dimensions.data()), dimensions.size(), keepUnitiesInShape); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-09-03 22:02:02 +03:00
										 |  |  |     TadPack ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) { | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |         TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, keepUnitiesInShape); | 
					
						
							|  |  |  |         return tadForDimensions(tadDescriptor); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-09-03 22:02:02 +03:00
										 |  |  |     TadPack ConstantTadHelper::tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape) { | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |         TadDescriptor tadDescriptor(descriptor, dimensions, keepUnitiesInShape); | 
					
						
							|  |  |  |         return tadForDimensions(tadDescriptor); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-09-03 22:02:02 +03:00
										 |  |  |     TadPack ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) { | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |         const int deviceId = 0; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         _mutex.lock(); | 
					
						
							|  |  |  |         if (_cache[deviceId].count(descriptor) == 0) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             const auto shapeInfo = descriptor.originalShape().toShapeInfo(); | 
					
						
							|  |  |  |             const int rank = shape::rank(shapeInfo); | 
					
						
							|  |  |  |             const std::vector<int> dimsToExclude = ShapeUtils::evalDimsToExclude(rank, descriptor.axis()); | 
					
						
							|  |  |  |             const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(shapeInfo, dimsToExclude); | 
					
						
							|  |  |  |             const int subArrRank = (rank == dimsToExclude.size() || descriptor.areUnitiesinShape()) ? rank : rank - dimsToExclude.size(); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
											  
											
												Oleh tenzor mmul (#231)
* Libnd4j: TensorMMul backprop op #8174, raw implementation
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* Libnd4j: TensorMMul backprop op #8174 merge master and some corrections
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* Libnd4j: TensorMMul backprop op #8174 algorithm update, need testing, sync with  master
* Libnd4j: TensorMMul backprop op #8174 fixed incorrect B axes calculation
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* Libnd4j: TensorMMul backprop op #8174 optimize axes identification and fix bug of indeces overlapping, added first test. need testing with different shapes
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* Libnd4j: TensorMMul backprop op #8174 some fixes and improvements need more testing
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* Libnd4j: TensorMMul backprop op #8174 fixed order of matrix multiply
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* Libnd4j: TensorMMul backprop op #8174 fixed issue of incorrect axes definition, add tests based on TF, need additional testing for case dLdC not equal 1
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* Libnd4j: TensorMMul backprop op #8174 fixed scalar case add test
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* Libnd4j: TensorMMul backprop op #8174 fixed bp algorithm, axes definition, need some mode testing with different orders combination f,c; c,f f,f and add some checks for inputs
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* Libnd4j: TensorMMul backprop op #8174 some checks and corrections added tests, exists the problem with different input orders support A-f B-c and A-f B-f
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* Libnd4j: TensorMMul backprop op #8174 sync master
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* - correct bug in MmulHelper::tensorDot(a, b, c, axes_a, axes_b,permutForC)
Signed-off-by: Yurii <iuriish@yahoo.com>
* Libnd4j: TensorMMul backprop op #8174 code clean up and refactoring
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* - add check for linspase ordered permutations in ShapeUtils::evalShapeForTensorDot
Signed-off-by: Yurii <iuriish@yahoo.com>
* - provide additional code in shape::reshape stuff in order to reduce amount of allocation/copy operations during reshaping procedure
Signed-off-by: Yurii <iuriish@yahoo.com>
* - further work on problem of wrong shape evaluation during permute/reshape procedures
Signed-off-by: Yurii <iuriish@yahoo.com>
* - still looking for bug reason in reshape/permute stuff
Signed-off-by: Yurii <iuriish@yahoo.com>
* - correct bug in transform cuda native ops
Signed-off-by: Yurii <iuriish@yahoo.com>
* - correct bug in NDArray::assign
Signed-off-by: Yurii <iuriish@yahoo.com>
* - remove old shape::reshape stuff
Signed-off-by: Yurii <iuriish@yahoo.com>
* - add possibility to disable copy of old buffer to new buffer during reshape operation in NDArray class
Signed-off-by: Yurii <iuriish@yahoo.com>
* - correct bug in tensorDot which had to do with wrong pointers assigments
Signed-off-by: Yurii <iuriish@yahoo.com>
Co-authored-by: Oleh <oleg.semeniv@gmail.com>
											
										 
											2020-02-13 19:33:54 +02:00
										 |  |  |             auto sPtr = new Nd4jLong[shape::shapeInfoLength(subArrRank)];   // shape of sub-arrays (same for all for them)
 | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |             auto oPtr = new Nd4jLong[numOfSubArrs]; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-06-15 21:34:34 +10:00
										 |  |  |             if (numOfSubArrs > 0) | 
					
						
							| 
									
										
										
										
											2020-03-03 06:32:37 +02:00
										 |  |  |                 shape::calcSubArrsShapeInfoAndOffsets(shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), sPtr, oPtr, descriptor.areUnitiesinShape()); | 
					
						
							| 
									
										
										
										
											2019-06-15 21:34:34 +10:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |             ConstantDataBuffer shapesBuffer(sPtr, nullptr, shape::shapeInfoLength(subArrRank)*sizeof(Nd4jLong), DataType::INT64); | 
					
						
							|  |  |  |             ConstantDataBuffer offsetsBuffer(oPtr, nullptr, numOfSubArrs*sizeof(Nd4jLong), DataType::INT64); | 
					
						
							|  |  |  |             TadPack t(shapesBuffer, offsetsBuffer, numOfSubArrs); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             // auto shapeInfo = descriptor.originalShape().toShapeInfo();
 | 
					
						
							|  |  |  |             // shape::TAD tad;
 | 
					
						
							|  |  |  |             // tad.init(shapeInfo, descriptor.axis().data(), descriptor.axis().size());
 | 
					
						
							|  |  |  |             // tad.createTadOnlyShapeInfo();
 | 
					
						
							|  |  |  |             // tad.createOffsets();
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             // auto sPtr = new Nd4jLong[shape::shapeInfoLength(tad.tadOnlyShapeInfo)];
 | 
					
						
							|  |  |  |             // auto oPtr = new Nd4jLong[tad.numTads];
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             // memcpy(sPtr, tad.tadOnlyShapeInfo, shape::shapeInfoByteLength(tad.tadOnlyShapeInfo));
 | 
					
						
							|  |  |  |             // memcpy(oPtr, tad.tadOffsets, tad.numTads * sizeof(Nd4jLong));
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             // TadPack t(shapesBuffer, offsetsBuffer, tad.numTads);
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             _cache[deviceId][descriptor] = t; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             TadPack &r = _cache[deviceId][descriptor]; | 
					
						
							|  |  |  |             _mutex.unlock(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             delete[] shapeInfo; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             return r; | 
					
						
							|  |  |  |         } else { | 
					
						
							| 
									
										
										
										
											2019-09-03 22:02:02 +03:00
										 |  |  |             TadPack r = _cache[deviceId][descriptor]; | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |             _mutex.unlock(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             return r; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |     sd::ConstantTadHelper* sd::ConstantTadHelper::_INSTANCE = 0; | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #endif
 |