[WIP] More of CUDA operations (#69)
* initial commit Signed-off-by: raver119 <raver119@gmail.com> * - gruCell_bp further Signed-off-by: Yurii <yurii@skymind.io> * - further work on gruCell_bp Signed-off-by: Yurii <yurii@skymind.io> * Inverse matrix cublas implementation. Partial working revision. * Separation of segment ops helpers. Max separation. * Separated segment_min ops. * Separation of segment_mean/sum/prod/sqrtN ops heleprs. * Fixed diagonal processing with LUP decomposition. * Modified inversion approach using current state of LU decomposition. * Implementation of matrix_inverse op with cuda kernels. Working revision. * Implemented sequence_mask cuda helper. Eliminated waste printf with matrix_inverse implementation. Added proper tests. * - further work on gruCell_bp (ff/cuda) Signed-off-by: Yurii <yurii@skymind.io> * comment one test for gruCell_bp Signed-off-by: Yurii <yurii@skymind.io> * - provide cuda static_rnn Signed-off-by: Yurii <yurii@skymind.io> * Refactored random_shuffle op to use new random generator. * Refactored random_shuffle op helper. * Fixed debug tests with random ops tests. * Implement random_shuffle op cuda kernel helper and tests. * - provide cuda scatter_update Signed-off-by: Yurii <yurii@skymind.io> * Implementation of random_shuffle for linear case with cuda kernels and tests. * Implemented random_shuffle with cuda kernels. Final revision. * - finally gruCell_bp is completed Signed-off-by: Yurii <yurii@skymind.io> * Dropout op cuda helper implementation. * Implemented dropout_bp cuda helper. * Implemented alpha_dropout_bp with cuda kernel helpers. * Refactored helper. * Implementation of suppresion helper with cuda kernels. * - provide cpu code fot hsvToRgb, rgbToHsv, adjustHue Signed-off-by: Yurii <yurii@skymind.io> * Using sort by value method. * Implementation of image.non_max_suppression op cuda-based helper. * - correcting and testing adjust_hue, adjust_saturation cpu/cuda code Signed-off-by: Yurii <yurii@skymind.io> * Added cuda device prefixes to declarations. * Implementation of hashcode op with cuda helper. Initital revision. * rnn cu impl removed Signed-off-by: raver119 <raver119@gmail.com>
This commit is contained in:
		
							parent
							
								
									06e4f5f96e
								
							
						
					
					
						commit
						763a225c6a
					
				| @ -208,9 +208,9 @@ namespace nd4j { | |||||||
|         NDArray(const char order, const std::vector<Nd4jLong> &shape, const std::vector<double>& data, nd4j::DataType dtype = DOUBLE, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext()); |         NDArray(const char order, const std::vector<Nd4jLong> &shape, const std::vector<double>& data, nd4j::DataType dtype = DOUBLE, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext()); | ||||||
| 
 | 
 | ||||||
|         /**
 |         /**
 | ||||||
|         *  this constructor creates new array using given buffer (without memory allocating) and shape information stored in shape |         *  this constructor creates new array using given buffer (without memory allocation) and shape information stored in shape | ||||||
|         */ |         */ | ||||||
|         NDArray(void *buffer, const char order, const std::vector<Nd4jLong> &shape,  nd4j::DataType dtype, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); |         NDArray(void *buffer, const char order, const std::vector<Nd4jLong> &shape,  nd4j::DataType dtype, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext(), const bool isBuffAlloc = false); | ||||||
| 
 | 
 | ||||||
|         /**
 |         /**
 | ||||||
|         *  this constructor creates new NDArray with shape matching "other" array, |         *  this constructor creates new NDArray with shape matching "other" array, | ||||||
|  | |||||||
| @ -132,9 +132,8 @@ NDArray::NDArray(const NDArray *other, const bool copyStrides, nd4j::LaunchConte | |||||||
|         _buffer = std::make_shared<DataBuffer>(lengthOf() * sizeOfT(), dataType(), getContext()->getWorkspace()); |         _buffer = std::make_shared<DataBuffer>(lengthOf() * sizeOfT(), dataType(), getContext()->getWorkspace()); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| ////////////////////////////////////////////////////////////////////////
 | ////////////////////////////////////////////////////////////////////////
 | ||||||
| NDArray::NDArray(void* buffer, const char order, const std::vector<Nd4jLong> &shape,  nd4j::DataType dtype, nd4j::LaunchContext * context) { | NDArray::NDArray(void* buffer, const char order, const std::vector<Nd4jLong> &shape,  nd4j::DataType dtype, nd4j::LaunchContext * context, const bool isBuffAlloc) { | ||||||
| 
 | 
 | ||||||
|     if (shape.empty()) |     if (shape.empty()) | ||||||
|         throw std::runtime_error("NDArray constructor: input shape is empty !"); |         throw std::runtime_error("NDArray constructor: input shape is empty !"); | ||||||
| @ -148,7 +147,7 @@ NDArray::NDArray(void* buffer, const char order, const std::vector<Nd4jLong> &sh | |||||||
| 
 | 
 | ||||||
|     setShapeInfo(ShapeDescriptor(dtype, order, shape)); |     setShapeInfo(ShapeDescriptor(dtype, order, shape)); | ||||||
| 
 | 
 | ||||||
|     _buffer = std::make_shared<DataBuffer>(buffer, lengthOf() * sizeOfT(), dataType(), true, getContext()->getWorkspace()); |     _buffer = std::make_shared<DataBuffer>(buffer, lengthOf() * sizeOfT(), dataType(), isBuffAlloc, getContext()->getWorkspace()); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| ////////////////////////////////////////////////////////////////////////
 | ////////////////////////////////////////////////////////////////////////
 | ||||||
|  | |||||||
| @ -1498,16 +1498,6 @@ void NativeOps::specialConcat( | |||||||
|  * This method saves |  * This method saves | ||||||
|  */ |  */ | ||||||
| nd4j::TadPack* NativeOps::tadOnlyShapeInfo(Nd4jLong *dXShapeInfo, int *dimension, int dimensionLength) { | nd4j::TadPack* NativeOps::tadOnlyShapeInfo(Nd4jLong *dXShapeInfo, int *dimension, int dimensionLength) { | ||||||
| 	/*shape::TAD tad; |  | ||||||
| 	tad.init(dXShapeInfo, dimension, dimensionLength); |  | ||||||
| 	//tad->setOutputBuffer(target); |  | ||||||
| 	tad.createTadOnlyShapeInfo(); |  | ||||||
| 	tad.createOffsets(); |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 	std::memcpy(reinterpret_cast<void *>(target), tad.tadOnlyShapeInfo, shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); |  | ||||||
| 	std::memcpy(reinterpret_cast<void *>(offsets), tad.tadOffsets, tad.numTads * sizeof(Nd4jLong)); |  | ||||||
| 	*/ |  | ||||||
| 	auto pack = new TadPack(); | 	auto pack = new TadPack(); | ||||||
| 	*pack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(dXShapeInfo, dimension, dimensionLength); | 	*pack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(dXShapeInfo, dimension, dimensionLength); | ||||||
|     return pack; |     return pack; | ||||||
|  | |||||||
| @ -45,9 +45,9 @@ namespace nd4j { | |||||||
| 
 | 
 | ||||||
|         static ConstantTadHelper* getInstance(); |         static ConstantTadHelper* getInstance(); | ||||||
| 
 | 
 | ||||||
|         TadPack& tadForDimensions(Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape = false); |         TadPack& tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape = false); | ||||||
|         TadPack& tadForDimensions(Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape = false); |         TadPack& tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape = false); | ||||||
|         TadPack& tadForDimensions(Nd4jLong *originalShape, int dimensions, const bool keepUnitiesInShape = false); |         TadPack& tadForDimensions(const Nd4jLong *originalShape, int dimensions, const bool keepUnitiesInShape = false); | ||||||
|         TadPack& tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape = false); |         TadPack& tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape = false); | ||||||
|         TadPack& tadForDimensions(TadDescriptor &descriptor); |         TadPack& tadForDimensions(TadDescriptor &descriptor); | ||||||
|     }; |     }; | ||||||
|  | |||||||
| @ -38,15 +38,15 @@ namespace nd4j { | |||||||
|         return _INSTANCE; |         return _INSTANCE; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     TadPack& ConstantTadHelper::tadForDimensions(Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) { |     TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) { | ||||||
|         return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape); |         return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     TadPack& ConstantTadHelper::tadForDimensions(Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape) { |     TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape) { | ||||||
|         return tadForDimensions(originalShape, const_cast<int *>(dimensions.data()), dimensions.size(), keepUnitiesInShape); |         return tadForDimensions(originalShape, const_cast<int *>(dimensions.data()), dimensions.size(), keepUnitiesInShape); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     TadPack& ConstantTadHelper::tadForDimensions(Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) { |     TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) { | ||||||
|         TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, keepUnitiesInShape); |         TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, keepUnitiesInShape); | ||||||
|         return tadForDimensions(tadDescriptor); |         return tadForDimensions(tadDescriptor); | ||||||
|     } |     } | ||||||
|  | |||||||
| @ -42,15 +42,15 @@ namespace nd4j { | |||||||
|         return _INSTANCE; |         return _INSTANCE; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     TadPack& ConstantTadHelper::tadForDimensions(Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) { |     TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) { | ||||||
|         return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape); |         return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     TadPack& ConstantTadHelper::tadForDimensions(Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape) { |     TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape) { | ||||||
|         return tadForDimensions(originalShape, const_cast<int *>(dimensions.data()), dimensions.size(), keepUnitiesInShape); |         return tadForDimensions(originalShape, const_cast<int *>(dimensions.data()), dimensions.size(), keepUnitiesInShape); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     TadPack& ConstantTadHelper::tadForDimensions(Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) { |     TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) { | ||||||
|         TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, keepUnitiesInShape); |         TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, keepUnitiesInShape); | ||||||
|         return tadForDimensions(tadDescriptor); |         return tadForDimensions(tadDescriptor); | ||||||
|     } |     } | ||||||
|  | |||||||
| @ -58,7 +58,7 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons | |||||||
| 	const std::vector<NDArray*>& inArrsFF = argsHolderFF.getInArrs(); | 	const std::vector<NDArray*>& inArrsFF = argsHolderFF.getInArrs(); | ||||||
| 	const std::vector<NDArray*>& inArrsBP = argsHolderBP.getInArrs(); | 	const std::vector<NDArray*>& inArrsBP = argsHolderBP.getInArrs(); | ||||||
| 
 | 
 | ||||||
| 	// fill input gradient arrays in accordance to type of loss function
 | 	// fill input gradient arrays in accordance to kind of loss function
 | ||||||
| 	fillGradArrays(loss, std::vector<NDArray*>(&inArrsBP[numInArrsFF], &inArrsBP[numInArrsFF + numInGradArrsBP])); | 	fillGradArrays(loss, std::vector<NDArray*>(&inArrsBP[numInArrsFF], &inArrsBP[numInArrsFF + numInGradArrsBP])); | ||||||
| 
 | 
 | ||||||
| 	// beck prop pass
 | 	// beck prop pass
 | ||||||
|  | |||||||
| @ -987,9 +987,10 @@ namespace shape { | |||||||
|     // dimsToExclude - should be sorted in increasing order
 |     // dimsToExclude - should be sorted in increasing order
 | ||||||
|     ND4J_EXPORT _CUDA_HD int outerArrayIndexes(Nd4jLong* maxIdxs, const Nd4jLong minIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, const int* dimsToExclude = nullptr); |     ND4J_EXPORT _CUDA_HD int outerArrayIndexes(Nd4jLong* maxIdxs, const Nd4jLong minIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, const int* dimsToExclude = nullptr); | ||||||
| 
 | 
 | ||||||
|     // calculate offsets of max-array, these output offsets correspond to one minIdx index of min-array which is sub-array of max-array
 |     // calculate offsets of max-array, these offsets correspond to one minIdx index of min-array which is sub-array of max-array
 | ||||||
|  |     // maxOffsets - will contain calculated offsets of max-array, buffer for maxOffsets should be allocated beforehand
 | ||||||
|     // dimsToExclude - should be sorted in increasing order
 |     // dimsToExclude - should be sorted in increasing order
 | ||||||
|     // memBuff - auxiliary memory buffer (size = 2 * max_rank) for coordinates and increments storing, should be passed from outside
 |     // memBuff - auxiliary memory buffer (size = 2 * max_rank) for coordinates and increments storing, should be allocated beforehand
 | ||||||
|     ND4J_EXPORT _CUDA_HD int outerArrayOffsets(Nd4jLong* maxOffsets, const Nd4jLong minIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, Nd4jLong* memBuff, const int* dimsToExclude = nullptr); |     ND4J_EXPORT _CUDA_HD int outerArrayOffsets(Nd4jLong* maxOffsets, const Nd4jLong minIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, Nd4jLong* memBuff, const int* dimsToExclude = nullptr); | ||||||
| 
 | 
 | ||||||
|     // calculates offsets for entities (elements or sub-arrays), shape in context of sub-array means dimensions excluded from outer array
 |     // calculates offsets for entities (elements or sub-arrays), shape in context of sub-array means dimensions excluded from outer array
 | ||||||
|  | |||||||
| @ -16,6 +16,7 @@ | |||||||
| 
 | 
 | ||||||
| //
 | //
 | ||||||
| // @author raver119@gmail.com
 | // @author raver119@gmail.com
 | ||||||
|  | // @author Yurii Shyrma (iuriish@yahoo.com)
 | ||||||
| //
 | //
 | ||||||
| 
 | 
 | ||||||
| #include <op_boilerplate.h> | #include <op_boilerplate.h> | ||||||
| @ -28,46 +29,35 @@ | |||||||
| namespace nd4j { | namespace nd4j { | ||||||
| namespace ops { | namespace ops { | ||||||
| 
 | 
 | ||||||
|     DECLARE_TYPES(adjust_hue) { |  | ||||||
|         getOpDescriptor() |  | ||||||
|                 ->setAllowedInputTypes(nd4j::DataType::ANY) |  | ||||||
|                 ->setSameMode(true); |  | ||||||
|     } |  | ||||||
| 
 | 
 | ||||||
|     CONFIGURABLE_OP_IMPL(adjust_hue, 1, 1, true, -2, -2) { | CONFIGURABLE_OP_IMPL(adjust_hue, 1, 1, true, 1, -2) { | ||||||
|  | 
 | ||||||
|     auto input  = INPUT_VARIABLE(0); |     auto input  = INPUT_VARIABLE(0); | ||||||
|     auto output = OUTPUT_VARIABLE(0); |     auto output = OUTPUT_VARIABLE(0); | ||||||
| 
 | 
 | ||||||
|         REQUIRE_TRUE(input->rankOf() == 3 || input->rankOf() == 4, 0, "AdjustHue: op expects either 3D or 4D input, but got %i instead", input->rankOf()); |     const int rank     = input->rankOf(); | ||||||
|  |     const int dimC     = block.getIArguments()->size() > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; | ||||||
|  |     const double delta = T_ARG(0); | ||||||
| 
 | 
 | ||||||
|  |     REQUIRE_TRUE(rank >= 3, 0, "ADJUST_HUE: op expects rank of input array to be >= 3, but got %i instead", rank); | ||||||
|  |     REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "ADJUST_HUE: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC)); | ||||||
|  |     REQUIRE_TRUE(-1. <= delta && delta <= 1., 0, "ADJUST_HUE: parameter delta must be within [-1, 1] interval, but got %f instead", delta); | ||||||
| 
 | 
 | ||||||
|         double delta = 0; |     NDArray deltaScalarArr = NDArrayFactory::create<double>(delta, block.launchContext()); | ||||||
|         if (block.numT() > 0) |  | ||||||
|             delta = T_ARG(0); |  | ||||||
|         else if (block.width() > 1) { |  | ||||||
|             auto _d = INPUT_VARIABLE(1); |  | ||||||
|             if (!_d->isScalar()) { |  | ||||||
|                 auto str = ShapeUtils::shapeAsString(_d); |  | ||||||
|                 REQUIRE_TRUE(_d->isScalar(), 0, "AdjustHue: delta should be scalar NDArray, but got %s instead", str.c_str()); |  | ||||||
|             } |  | ||||||
|             delta = _d->e<double>(0); |  | ||||||
|         } |  | ||||||
| 
 | 
 | ||||||
| 
 |     helpers::adjustHue(block.launchContext(), input, &deltaScalarArr, output, dimC); | ||||||
|         bool isNHWC = false; |  | ||||||
|         if (block.numI() > 0) |  | ||||||
|             isNHWC = INT_ARG(0) == 1; |  | ||||||
| 
 |  | ||||||
|         int numChannels = isNHWC ? input->sizeAt(-1) : input->sizeAt(-3); |  | ||||||
| 
 |  | ||||||
|         REQUIRE_TRUE(numChannels == 3, 0, "AdjustHue: this operation expects image with 3 channels (R, G, B), but got % instead", numChannels); |  | ||||||
| 
 |  | ||||||
|         auto ts = NDArrayFactory::create(delta, block.launchContext()); |  | ||||||
|         // FIXME: delta should be NDArray scalar
 |  | ||||||
|         helpers::_adjust_hue(block.launchContext(), input, output, &ts, isNHWC); |  | ||||||
| 
 | 
 | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | DECLARE_TYPES(adjust_hue) { | ||||||
|  |     getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY) | ||||||
|  |                      ->setSameMode(true); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| } | } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -27,45 +27,33 @@ | |||||||
| 
 | 
 | ||||||
| namespace nd4j { | namespace nd4j { | ||||||
| namespace ops { | namespace ops { | ||||||
|     DECLARE_TYPES(adjust_saturation) { |  | ||||||
|         getOpDescriptor() |  | ||||||
|                 ->setAllowedInputTypes(nd4j::DataType::ANY) |  | ||||||
|                 ->setSameMode(true); |  | ||||||
|     } |  | ||||||
| 
 | 
 | ||||||
|     CONFIGURABLE_OP_IMPL(adjust_saturation, 1, 1, true, -2, -2) { | CONFIGURABLE_OP_IMPL(adjust_saturation, 1, 1, true, 1, -2) { | ||||||
|  | 
 | ||||||
|     auto input  = INPUT_VARIABLE(0); |     auto input  = INPUT_VARIABLE(0); | ||||||
|     auto output = OUTPUT_VARIABLE(0); |     auto output = OUTPUT_VARIABLE(0); | ||||||
| 
 | 
 | ||||||
|         REQUIRE_TRUE(input->rankOf() == 3 || input->rankOf() == 4, 0, "AdjustSaturation: op expects either 3D or 4D input, but got %i instead", input->rankOf()); |     const int rank     = input->rankOf(); | ||||||
|  |     const int dimC     = block.getIArguments()->size() > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; | ||||||
|  |     const double factor = T_ARG(0); | ||||||
| 
 | 
 | ||||||
|         double delta = 0; |     REQUIRE_TRUE(rank >= 3, 0, "ADJUST_SATURATION: op expects rank of input array to be >= 3, but got %i instead", rank); | ||||||
|         if (block.numT() > 0) |     REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "ADJUST_SATURATION: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC)); | ||||||
|             delta = T_ARG(0); |  | ||||||
|         else if (block.width() > 1) { |  | ||||||
|             auto _d = INPUT_VARIABLE(1); |  | ||||||
|             if (!_d->isScalar()) { |  | ||||||
|                 auto str = ShapeUtils::shapeAsString(_d); |  | ||||||
|                 REQUIRE_TRUE(_d->isScalar(), 0, "AdjustSaturation: delta should be scalar NDArray, but got %s instead", str.c_str()); |  | ||||||
|             } |  | ||||||
| 
 | 
 | ||||||
|             delta = _d->e<double>(0); |     NDArray factorScalarArr = NDArrayFactory::create<double>(factor, block.launchContext()); | ||||||
|         } |  | ||||||
| 
 | 
 | ||||||
|         bool isNHWC = false; |     helpers::adjustSaturation(block.launchContext(), input, &factorScalarArr, output, dimC); | ||||||
|         if (block.numI() > 0) |  | ||||||
|             isNHWC = INT_ARG(0) == 1; |  | ||||||
| 
 |  | ||||||
|         int numChannels = isNHWC ? input->sizeAt(-1) : input->sizeAt(-3); |  | ||||||
| 
 |  | ||||||
|         REQUIRE_TRUE(numChannels == 3, 0, "AdjustSaturation: this operation expects image with 3 channels (R, G, B), but got % instead", numChannels); |  | ||||||
| 
 |  | ||||||
|         auto ts = NDArrayFactory::create(delta, block.launchContext()); |  | ||||||
|         // FIXME: delta should be NDArray scalar
 |  | ||||||
|         helpers::adjust_saturation(block.launchContext(), input, output, &ts, isNHWC); |  | ||||||
| 
 | 
 | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | DECLARE_TYPES(adjust_saturation) { | ||||||
|  |     getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY) | ||||||
|  |                      ->setSameMode(true); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| } | } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -27,6 +27,7 @@ | |||||||
| 
 | 
 | ||||||
| namespace nd4j { | namespace nd4j { | ||||||
| namespace ops { | namespace ops { | ||||||
|  | 
 | ||||||
| OP_IMPL(scatter_add, 3, 1, true) { | OP_IMPL(scatter_add, 3, 1, true) { | ||||||
|     auto input = INPUT_VARIABLE(0); |     auto input = INPUT_VARIABLE(0); | ||||||
|     auto indices = INPUT_VARIABLE(1); |     auto indices = INPUT_VARIABLE(1); | ||||||
| @ -74,8 +75,8 @@ namespace nd4j { | |||||||
| 
 | 
 | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
| } | } | ||||||
|  | 
 | ||||||
| DECLARE_SYN(ScatterAdd, scatter_add); | DECLARE_SYN(ScatterAdd, scatter_add); | ||||||
|     } |  | ||||||
| 
 | 
 | ||||||
| DECLARE_TYPES(scatter_add) { | DECLARE_TYPES(scatter_add) { | ||||||
|     getOpDescriptor() |     getOpDescriptor() | ||||||
| @ -84,6 +85,8 @@ namespace nd4j { | |||||||
|         ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) |         ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) | ||||||
|         ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); |         ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| #endif | #endif | ||||||
| @ -57,16 +57,26 @@ namespace nd4j { | |||||||
|             auto in = inputShape->at(0); |             auto in = inputShape->at(0); | ||||||
|             int outRank = shape::rank(in) + 1; |             int outRank = shape::rank(in) + 1; | ||||||
|             auto input = INPUT_VARIABLE(0); |             auto input = INPUT_VARIABLE(0); | ||||||
|  |             auto dtype = DataType::BOOL; | ||||||
|             Nd4jLong maxInd = input->argMax(); |             Nd4jLong maxInd = input->argMax(); | ||||||
|             float max = input->e<float>(maxInd); |             Nd4jLong max = input->e<Nd4jLong>(maxInd); | ||||||
|  | 
 | ||||||
|             if (block.getIArguments()->size() > 0) { |             if (block.getIArguments()->size() > 0) { | ||||||
|  |                 if (block.width() < 2) { | ||||||
|                 maxInd = INT_ARG(0); |                 maxInd = INT_ARG(0); | ||||||
|                 if (maxInd < max) |                 if (maxInd < max) | ||||||
|                     maxInd = static_cast<Nd4jLong>(max); |                     maxInd = static_cast<Nd4jLong>(max); | ||||||
|  |                 if (block.getIArguments()->size() > 1) | ||||||
|  |                     dtype = (DataType)INT_ARG(1); | ||||||
|                 } |                 } | ||||||
|             else if (block.width() > 1) { |                 else { | ||||||
|  |                     dtype = (DataType)INT_ARG(0); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             if (block.width() > 1) { | ||||||
|                 auto maxlen = INPUT_VARIABLE(1); |                 auto maxlen = INPUT_VARIABLE(1); | ||||||
|                 float tmaxlen = maxlen->e<float>(0); |                 Nd4jLong tmaxlen = maxlen->e<Nd4jLong>(0); | ||||||
|                 if (tmaxlen > max) |                 if (tmaxlen > max) | ||||||
|                     maxInd = static_cast<Nd4jLong>(tmaxlen); |                     maxInd = static_cast<Nd4jLong>(tmaxlen); | ||||||
|             } |             } | ||||||
| @ -80,14 +90,14 @@ namespace nd4j { | |||||||
|                 outShapeInfo[i + 1] = shape::sizeAt(in, i); |                 outShapeInfo[i + 1] = shape::sizeAt(in, i); | ||||||
|             outShapeInfo[outRank] = lastDimension; |             outShapeInfo[outRank] = lastDimension; | ||||||
| 
 | 
 | ||||||
|             ShapeUtils::updateStridesAndType(outShapeInfo, in, shape::order(in)); |             ShapeUtils::updateStridesAndType(outShapeInfo, dtype, shape::order(in)); | ||||||
| 
 | 
 | ||||||
|             return SHAPELIST(CONSTANT(outShapeInfo)); |             return SHAPELIST(CONSTANT(outShapeInfo)); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|         DECLARE_TYPES(sequence_mask) { |         DECLARE_TYPES(sequence_mask) { | ||||||
|             getOpDescriptor() |             getOpDescriptor() | ||||||
|                     ->setAllowedInputTypes(nd4j::DataType::ANY) |                     ->setAllowedInputTypes({ALL_INTS}) | ||||||
|                     ->setAllowedOutputTypes(nd4j::DataType::ANY); |                     ->setAllowedOutputTypes(nd4j::DataType::ANY); | ||||||
|         } |         } | ||||||
| } | } | ||||||
|  | |||||||
| @ -33,11 +33,11 @@ OP_IMPL(random_shuffle, 1, 1, true) { | |||||||
|     const bool isInplace = block.isInplace(); |     const bool isInplace = block.isInplace(); | ||||||
|     auto output = isInplace ? nullptr : OUTPUT_VARIABLE(0); |     auto output = isInplace ? nullptr : OUTPUT_VARIABLE(0); | ||||||
| 
 | 
 | ||||||
|     nd4j::random::RandomBuffer* rng = block.getRNG();    | //    nd4j::random::RandomBuffer* rng = block.getRNG();
 | ||||||
|  |     nd4j::graph::RandomGenerator rng = block.randomGenerator(); | ||||||
|  | //    REQUIRE_TRUE(rng != nullptr, 0, "RANDOM_SHUFFLE op: RNG should be defined in Graph !");
 | ||||||
| 
 | 
 | ||||||
|     REQUIRE_TRUE(rng != nullptr, 0, "RANDOM_SHUFFLE op: RNG should be defined in Graph !"); |     helpers::randomShuffle(block.launchContext(), *input, *output, rng, isInplace); | ||||||
| 
 |  | ||||||
|     helpers::randomShuffle(block.launchContext(), *input, *output, *rng, isInplace); |  | ||||||
|      |      | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
| } | } | ||||||
|  | |||||||
| @ -31,6 +31,7 @@ namespace ops  { | |||||||
| 
 | 
 | ||||||
| //////////////////////////////////////////////////////////////////////////
 | //////////////////////////////////////////////////////////////////////////
 | ||||||
| CUSTOM_OP_IMPL(gruCell, 6, 4, false, 0, 0) { | CUSTOM_OP_IMPL(gruCell, 6, 4, false, 0, 0) { | ||||||
|  | 
 | ||||||
|     auto x     = INPUT_VARIABLE(0);                   // input [bS, nIn], nIn - input size
 |     auto x     = INPUT_VARIABLE(0);                   // input [bS, nIn], nIn - input size
 | ||||||
|     auto hLast = INPUT_VARIABLE(1);                   // previous cell output [bS, nU],  that is at previous time step t-1, nU - number of units
 |     auto hLast = INPUT_VARIABLE(1);                   // previous cell output [bS, nU],  that is at previous time step t-1, nU - number of units
 | ||||||
|     auto Wru   = INPUT_VARIABLE(2);                   // RU weights - [nIn+nU, 2*nU] - reset and update gates (input/recurrent weights)
 |     auto Wru   = INPUT_VARIABLE(2);                   // RU weights - [nIn+nU, 2*nU] - reset and update gates (input/recurrent weights)
 | ||||||
| @ -118,65 +119,58 @@ DECLARE_SHAPE_FN(gruCell) { | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| //////////////////////////////////////////////////////////////////////////
 | //////////////////////////////////////////////////////////////////////////
 | ||||||
| CUSTOM_OP_IMPL(gruCell_bp, 6, 5, false, 0, 0) { | CUSTOM_OP_IMPL(gruCell_bp, 10, 6, false, 0, 0) { | ||||||
| 
 | 
 | ||||||
|     auto x      = INPUT_VARIABLE(0);                                // input [bS x iS]
 |     auto x      = INPUT_VARIABLE(0);                                // input [bS x iS]
 | ||||||
|     auto hi     = INPUT_VARIABLE(1);                                // previous cell output [bS x nU]
 |     auto hi     = INPUT_VARIABLE(1);                                // previous cell output [bS x nU]
 | ||||||
|     auto Wx     = INPUT_VARIABLE(2);                                 // input-to-hidden  weights, [iS x 3*nU]
 |     auto W      = INPUT_VARIABLE(2);                                // weights, [iS+nU x 2*nU]
 | ||||||
|     auto Wh     = INPUT_VARIABLE(3);                                 // hidden-to-hidden weights, [nU x 3*nU]
 |     auto Wc     = INPUT_VARIABLE(3);                                // c weights, [iS+nU x nU]
 | ||||||
|     auto b      = INPUT_VARIABLE(4);                                 // biases, [3*nU]
 |     auto b      = INPUT_VARIABLE(4);                                // biases, [2*nU]
 | ||||||
|     auto dLdh   = INPUT_VARIABLE(5);                                 // gradient wrt output, [bS,nU], that is epsilon_next
 |     auto bc     = INPUT_VARIABLE(5);                                // biases, [nU]
 | ||||||
|     auto dLdWxi = block.width() > 6 ? INPUT_VARIABLE(6) : nullptr;   // gradient wrt Wx at previous time step, [iS, 3*nU]
 |     auto dLdr   = INPUT_VARIABLE(6);                                // gradient wrt reset gate, [bS, nU]
 | ||||||
|     auto dLdWhi = block.width() > 7 ? INPUT_VARIABLE(7) : nullptr;   // gradient wrt Wh at previous time step, [nU, 3*nU]
 |     auto dLdu   = INPUT_VARIABLE(7);                                // gradient wrt update gate, [bS, nU]
 | ||||||
|     auto dLdbi  = block.width() > 8 ? INPUT_VARIABLE(8) : nullptr;   // gradient wrt b at previous time step,  [3*nU]
 |     auto dLdc   = INPUT_VARIABLE(8);                                // gradient wrt cell state, [bS, nU]
 | ||||||
|  |     auto dLdh   = INPUT_VARIABLE(9);                                // gradient wrt current cell output, [bS, nU]
 | ||||||
| 
 | 
 | ||||||
|     auto dLdx   = OUTPUT_VARIABLE(0);                                // gradient wrt x,  [bS, iS], that is epsilon
 |     auto dLdx   = OUTPUT_VARIABLE(0);                               // gradient wrt x,  [bS, iS]
 | ||||||
|     auto dLdhi  = OUTPUT_VARIABLE(1);                               // gradient wrt hi, [bS, nU]
 |     auto dLdhi  = OUTPUT_VARIABLE(1);                               // gradient wrt hi, [bS, nU]
 | ||||||
|     auto dLdWx  = OUTPUT_VARIABLE(2);                                // gradient wrt Wx, [iS, 3*nU]
 |     auto dLdW   = OUTPUT_VARIABLE(2);                               // gradient wrt W,  [iS+nU x 2*nU]
 | ||||||
|     auto dLdWh  = OUTPUT_VARIABLE(3);                                // gradient wrt Wh, [nU, 3*nU]
 |     auto dLdWc  = OUTPUT_VARIABLE(3);                               // gradient wrt Wc, [iS+nU x nU]
 | ||||||
|     auto dLdb   = OUTPUT_VARIABLE(4);                                // gradient wrt biases,  [3*nU]
 |     auto dLdb   = OUTPUT_VARIABLE(4);                               // gradient wrt biases, [2*nU]
 | ||||||
|  |     auto dLdbc  = OUTPUT_VARIABLE(5);                               // gradient wrt c biases, [nU]
 | ||||||
| 
 | 
 | ||||||
|     const int rank     = x->rankOf();                               // = 2
 |  | ||||||
|     const Nd4jLong bS = x->sizeAt(0); |     const Nd4jLong bS = x->sizeAt(0); | ||||||
|     const Nd4jLong iS = x->sizeAt(1); |     const Nd4jLong iS = x->sizeAt(1); | ||||||
|     const Nd4jLong nU = hi->sizeAt(1); |     const Nd4jLong nU = hi->sizeAt(1); | ||||||
| 
 | 
 | ||||||
|  |     REQUIRE_TRUE(x->rankOf() == 2, 0, "GRU_CELL_BP: rank of input array x must be 2, but got %i instead", x->rankOf()); | ||||||
|  | 
 | ||||||
|     const std::string hiShape        = ShapeUtils::shapeAsString(hi); |     const std::string hiShape        = ShapeUtils::shapeAsString(hi); | ||||||
|     const std::string hiCorrectShape = ShapeUtils::shapeAsString({bS, nU}); |     const std::string hiCorrectShape = ShapeUtils::shapeAsString({bS, nU}); | ||||||
|     const std::string wxShape          = ShapeUtils::shapeAsString(Wx); |     const std::string wShape         = ShapeUtils::shapeAsString(W); | ||||||
|     const std::string wxCorrectShape   = ShapeUtils::shapeAsString({iS, 3*nU}); |     const std::string wCorrectShape  = ShapeUtils::shapeAsString({iS+nU, 2*nU}); | ||||||
|     const std::string whShape          = ShapeUtils::shapeAsString(Wh); |     const std::string wcShape        = ShapeUtils::shapeAsString(Wc); | ||||||
|     const std::string whCorrectShape   = ShapeUtils::shapeAsString({nU, 3*nU}); |     const std::string wcCorrectShape = ShapeUtils::shapeAsString({iS+nU, nU}); | ||||||
|     const std::string bShape         = ShapeUtils::shapeAsString(b); |     const std::string bShape         = ShapeUtils::shapeAsString(b); | ||||||
|     const std::string bCorrectShape    = ShapeUtils::shapeAsString({3*nU}); |     const std::string bCorrectShape  = ShapeUtils::shapeAsString({2*nU}); | ||||||
|  |     const std::string bcShape        = ShapeUtils::shapeAsString(bc); | ||||||
|  |     const std::string bcCorrectShape = ShapeUtils::shapeAsString({nU}); | ||||||
|  |     const std::string dLdrShape      = ShapeUtils::shapeAsString(dLdr); | ||||||
|  |     const std::string dLduShape      = ShapeUtils::shapeAsString(dLdu); | ||||||
|  |     const std::string dLdcShape      = ShapeUtils::shapeAsString(dLdc); | ||||||
|     const std::string dLdhShape      = ShapeUtils::shapeAsString(dLdh); |     const std::string dLdhShape      = ShapeUtils::shapeAsString(dLdh); | ||||||
|     const std::string dLdhCorrectShape = ShapeUtils::shapeAsString({bS, nU}); |  | ||||||
| 
 | 
 | ||||||
|     REQUIRE_TRUE(hiShape   == hiCorrectShape,  0, "GRU_CELL_BP op: wrong shape of previous cell output array, expected is %s, but got %s instead !", hiCorrectShape.c_str(), hiShape.c_str()); |     REQUIRE_TRUE(hiShape   == hiCorrectShape,  0, "GRU_CELL_BP op: wrong shape of previous cell output array, expected is %s, but got %s instead !", hiCorrectShape.c_str(), hiShape.c_str()); | ||||||
|     REQUIRE_TRUE(wxShape   == wxCorrectShape,    0, "GRU_CELL_BP op: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", wxCorrectShape.c_str(), wxShape.c_str()); |     REQUIRE_TRUE(wShape    == wCorrectShape,   0, "GRU_CELL_BP op: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str()); | ||||||
|     REQUIRE_TRUE(whShape   == whCorrectShape,    0, "GRU_CELL_BP op: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", whCorrectShape.c_str(), whShape.c_str()); |     REQUIRE_TRUE(wcShape   == wcCorrectShape,  0, "GRU_CELL_BP op: wrong shape of c weights array, expected is %s, but got %s instead !", wcCorrectShape.c_str(), wcShape.c_str()); | ||||||
|     REQUIRE_TRUE(bShape    == bCorrectShape,   0, "GRU_CELL_BP op: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str()); |     REQUIRE_TRUE(bShape    == bCorrectShape,   0, "GRU_CELL_BP op: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str()); | ||||||
|     REQUIRE_TRUE(dLdhShape == dLdhCorrectShape,  0, "GRU_CELL_BP op: wrong shape of dLdh array (epsilon_next), expected is %s, but got %s instead !", dLdhCorrectShape.c_str(), dLdhShape.c_str()); |     REQUIRE_TRUE(bcShape   == bcCorrectShape,  0, "GRU_CELL_BP op: wrong shape of c biases array, expected is %s, but got %s instead !", bcCorrectShape.c_str(), bcShape.c_str()); | ||||||
|  |     REQUIRE_TRUE(dLdrShape == hiCorrectShape,  0, "GRU_CELL_BP op: wrong shape of dLdr array (gradient wrt reset gate), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLdrShape.c_str()); | ||||||
|  |     REQUIRE_TRUE(dLduShape == hiCorrectShape,  0, "GRU_CELL_BP op: wrong shape of dLdu array (gradient wrt update gate), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLduShape.c_str()); | ||||||
|  |     REQUIRE_TRUE(dLdcShape == hiCorrectShape,  0, "GRU_CELL_BP op: wrong shape of dLdc array (gradient wrt cell state), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLdcShape.c_str()); | ||||||
|  |     REQUIRE_TRUE(dLdhShape == hiCorrectShape,  0, "GRU_CELL_BP op: wrong shape of dLdh array (gradient wrt current cell output), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLdhShape.c_str()); | ||||||
| 
 | 
 | ||||||
|     if(dLdWxi != nullptr) { |     helpers::gruCellBP(block.launchContext(), x, hi, W, Wc, b, bc, dLdr, dLdu, dLdc, dLdh, dLdx, dLdhi, dLdW, dLdWc, dLdb, dLdbc); | ||||||
|         const std::string dLdWxiShape        = ShapeUtils::shapeAsString(dLdWxi); |  | ||||||
|         const std::string dLdWxiCorrectShape = ShapeUtils::shapeAsString({iS, 3*nU}); |  | ||||||
|         REQUIRE_TRUE(dLdWxiShape == dLdWxiCorrectShape,  0, "GRU_CELL_BP op: wrong shape of dLdWxi array (gradient wrt Wx at previous time step), expected is %s, but got %s instead !", dLdWxiCorrectShape.c_str(), dLdWxiShape.c_str()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     if(dLdWhi != nullptr) { |  | ||||||
|         const std::string dLdWhiShape        = ShapeUtils::shapeAsString(dLdWhi); |  | ||||||
|         const std::string dLdWhiCorrectShape = ShapeUtils::shapeAsString({nU, 3*nU}); |  | ||||||
|         REQUIRE_TRUE(dLdWhiShape == dLdWhiCorrectShape,  0, "GRU_CELL_BP op: wrong shape of dLdWhi array (gradient wrt Wh at previous time step), expected is %s, but got %s instead !", dLdWhiCorrectShape.c_str(), dLdWhiShape.c_str()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     if(dLdbi != nullptr) { |  | ||||||
|         const std::string dLdbiShape        = ShapeUtils::shapeAsString(dLdbi); |  | ||||||
|         const std::string dLdbiCorrectShape = ShapeUtils::shapeAsString({3*nU}); |  | ||||||
|         REQUIRE_TRUE(dLdbiShape == dLdbiCorrectShape,  0, "GRU_CELL_BP op: wrong shape of dLdbi array (gradient wrt biases at previous time step), expected is %s, but got %s instead !", dLdbiCorrectShape.c_str(), dLdbiShape.c_str()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     helpers::gruCellBP(block.launchContext(), x,  hi, Wx, Wh, b, dLdh, dLdWxi, dLdWhi, dLdbi, dLdx, dLdhi, dLdWx, dLdWh, dLdb); |  | ||||||
| 
 | 
 | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
| } | } | ||||||
| @ -192,6 +186,7 @@ DECLARE_TYPES(gruCell_bp) { | |||||||
|         ->setAllowedInputTypes(6, {ALL_FLOATS}) |         ->setAllowedInputTypes(6, {ALL_FLOATS}) | ||||||
|         ->setAllowedInputTypes(7, {ALL_FLOATS}) |         ->setAllowedInputTypes(7, {ALL_FLOATS}) | ||||||
|         ->setAllowedInputTypes(8, {ALL_FLOATS}) |         ->setAllowedInputTypes(8, {ALL_FLOATS}) | ||||||
|  |         ->setAllowedInputTypes(9, {ALL_FLOATS}) | ||||||
|         ->setAllowedOutputTypes({ALL_FLOATS}); |         ->setAllowedOutputTypes({ALL_FLOATS}); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| @ -199,53 +194,46 @@ DECLARE_SHAPE_FN(gruCell_bp) { | |||||||
| 
 | 
 | ||||||
|     auto xShapeInfo    = inputShape->at(0);                          // [bS x iS]
 |     auto xShapeInfo    = inputShape->at(0);                          // [bS x iS]
 | ||||||
|     auto hiShapeInfo   = inputShape->at(1);                          // [bS x nU]
 |     auto hiShapeInfo   = inputShape->at(1);                          // [bS x nU]
 | ||||||
|     auto wxShapeInfo     = inputShape->at(2);                                              // [iS x 3*nU]
 |     auto wShapeInfo    = inputShape->at(2);                          // [iS+nU x 2*nU]
 | ||||||
|     auto whShapeInfo     = inputShape->at(3);                                              // [nU x 3*nU]
 |     auto wcShapeInfo   = inputShape->at(3);                          // [iS+nU x nU]
 | ||||||
|     auto bShapeInfo      = inputShape->at(4);                                              // [3*nU]
 |     auto bShapeInfo    = inputShape->at(4);                          // [2*nU]
 | ||||||
|     auto dLdhShapeInfo   = inputShape->at(5);                                              // [bS x nU]
 |     auto bcShapeInfo   = inputShape->at(5);                          // [nU]
 | ||||||
|  |     auto dLdrShapeInfo = inputShape->at(6);                          // [bS, nU]
 | ||||||
|  |     auto dLduShapeInfo = inputShape->at(7);                          // [bS, nU]
 | ||||||
|  |     auto dLdcShapeInfo = inputShape->at(8);                          // [bS, nU]
 | ||||||
|  |     auto dLdhShapeInfo = inputShape->at(9);                          // [bS, nU]
 | ||||||
| 
 | 
 | ||||||
|     const int rank    = xShapeInfo[0];                               // = 2
 |     const int rank    = xShapeInfo[0];                               // = 2
 | ||||||
|     const Nd4jLong bS = xShapeInfo[1]; |     const Nd4jLong bS = xShapeInfo[1]; | ||||||
|     const Nd4jLong iS = xShapeInfo[2]; |     const Nd4jLong iS = xShapeInfo[2]; | ||||||
|     const Nd4jLong nU = hiShapeInfo[2]; |     const Nd4jLong nU = hiShapeInfo[2]; | ||||||
| 
 | 
 | ||||||
|  |     REQUIRE_TRUE(xShapeInfo[0] == 2, 0, "GRU_CELL_BP: rank of input array x must be 2, but got %i instead", xShapeInfo[0]); | ||||||
|  | 
 | ||||||
|     const std::string hiShape        = ShapeUtils::shapeAsString(hiShapeInfo); |     const std::string hiShape        = ShapeUtils::shapeAsString(hiShapeInfo); | ||||||
|     const std::string hiCorrectShape = ShapeUtils::shapeAsString({bS, nU}); |     const std::string hiCorrectShape = ShapeUtils::shapeAsString({bS, nU}); | ||||||
|     const std::string wxShape          = ShapeUtils::shapeAsString(wxShapeInfo); |     const std::string wShape         = ShapeUtils::shapeAsString(wShapeInfo); | ||||||
|     const std::string wxCorrectShape   = ShapeUtils::shapeAsString({iS, 3*nU}); |     const std::string wCorrectShape  = ShapeUtils::shapeAsString({iS+nU, 2*nU}); | ||||||
|     const std::string whShape          = ShapeUtils::shapeAsString(whShapeInfo); |     const std::string wcShape        = ShapeUtils::shapeAsString(wcShapeInfo); | ||||||
|     const std::string whCorrectShape   = ShapeUtils::shapeAsString({nU, 3*nU}); |     const std::string wcCorrectShape = ShapeUtils::shapeAsString({iS+nU, nU}); | ||||||
|     const std::string bShape         = ShapeUtils::shapeAsString(bShapeInfo); |     const std::string bShape         = ShapeUtils::shapeAsString(bShapeInfo); | ||||||
|     const std::string bCorrectShape    = ShapeUtils::shapeAsString({3*nU}); |     const std::string bCorrectShape  = ShapeUtils::shapeAsString({2*nU}); | ||||||
|  |     const std::string bcShape        = ShapeUtils::shapeAsString(bcShapeInfo); | ||||||
|  |     const std::string bcCorrectShape = ShapeUtils::shapeAsString({nU}); | ||||||
|  |     const std::string dLdrShape      = ShapeUtils::shapeAsString(dLdrShapeInfo); | ||||||
|  |     const std::string dLduShape      = ShapeUtils::shapeAsString(dLduShapeInfo); | ||||||
|  |     const std::string dLdcShape      = ShapeUtils::shapeAsString(dLdcShapeInfo); | ||||||
|     const std::string dLdhShape      = ShapeUtils::shapeAsString(dLdhShapeInfo); |     const std::string dLdhShape      = ShapeUtils::shapeAsString(dLdhShapeInfo); | ||||||
|     const std::string dLdhCorrectShape = ShapeUtils::shapeAsString({bS, nU}); |  | ||||||
| 
 | 
 | ||||||
|     REQUIRE_TRUE(hiShape   == hiCorrectShape,  0, "GRU_CELL_BP op: wrong shape of previous cell output array, expected is %s, but got %s instead !", hiCorrectShape.c_str(), hiShape.c_str()); |     REQUIRE_TRUE(hiShape   == hiCorrectShape,  0, "GRU_CELL_BP op: wrong shape of previous cell output array, expected is %s, but got %s instead !", hiCorrectShape.c_str(), hiShape.c_str()); | ||||||
|     REQUIRE_TRUE(wxShape   == wxCorrectShape,    0, "GRU_CELL_BP op: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", wxCorrectShape.c_str(), wxShape.c_str()); |     REQUIRE_TRUE(wShape    == wCorrectShape,   0, "GRU_CELL_BP op: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str()); | ||||||
|     REQUIRE_TRUE(whShape   == whCorrectShape,    0, "GRU_CELL_BP op: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", whCorrectShape.c_str(), whShape.c_str()); |     REQUIRE_TRUE(wcShape   == wcCorrectShape,  0, "GRU_CELL_BP op: wrong shape of c weights array, expected is %s, but got %s instead !", wcCorrectShape.c_str(), wcShape.c_str()); | ||||||
|     REQUIRE_TRUE(bShape    == bCorrectShape,   0, "GRU_CELL_BP op: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str()); |     REQUIRE_TRUE(bShape    == bCorrectShape,   0, "GRU_CELL_BP op: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str()); | ||||||
|     REQUIRE_TRUE(dLdhShape == dLdhCorrectShape,  0, "GRU_CELL_BP op: wrong shape of dLdh array (epsilon_next), expected is %s, but got %s instead !", dLdhCorrectShape.c_str(), dLdhShape.c_str()); |     REQUIRE_TRUE(bcShape   == bcCorrectShape,  0, "GRU_CELL_BP op: wrong shape of c biases array, expected is %s, but got %s instead !", bcCorrectShape.c_str(), bcShape.c_str()); | ||||||
| 
 |     REQUIRE_TRUE(dLdrShape == hiCorrectShape,  0, "GRU_CELL_BP op: wrong shape of dLdr array (gradient wrt reset gate), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLdrShape.c_str()); | ||||||
|     if(block.width() > 6) { |     REQUIRE_TRUE(dLduShape == hiCorrectShape,  0, "GRU_CELL_BP op: wrong shape of dLdu array (gradient wrt update gate), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLduShape.c_str()); | ||||||
|         Nd4jLong* dLdWxiShapeInfo = inputShape->at(6);                                              // [iS x 3*nU]
 |     REQUIRE_TRUE(dLdcShape == hiCorrectShape,  0, "GRU_CELL_BP op: wrong shape of dLdc array (gradient wrt cell state), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLdcShape.c_str()); | ||||||
|         const std::string dLdWxiShape        = ShapeUtils::shapeAsString(dLdWxiShapeInfo); |     REQUIRE_TRUE(dLdhShape == hiCorrectShape,  0, "GRU_CELL_BP op: wrong shape of dLdh array (gradient wrt current cell output), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLdhShape.c_str()); | ||||||
|         const std::string dLdWxiCorrectShape = ShapeUtils::shapeAsString({iS, 3*nU}); |  | ||||||
|         REQUIRE_TRUE(dLdWxiShape == dLdWxiCorrectShape,  0, "GRU_CELL_BP op: wrong shape of dLdWxi array (gradient wrt Wx at previous time step), expected is %s, but got %s instead !", dLdWxiCorrectShape.c_str(), dLdWxiShape.c_str()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     if(block.width() > 7) { |  | ||||||
|         Nd4jLong* dLdWhiShapeInfo = inputShape->at(7);                                              // [nU x 3*nU]
 |  | ||||||
|         const std::string dLdWhiShape        = ShapeUtils::shapeAsString(dLdWhiShapeInfo); |  | ||||||
|         const std::string dLdWhiCorrectShape = ShapeUtils::shapeAsString({nU, 3*nU}); |  | ||||||
|         REQUIRE_TRUE(dLdWhiShape == dLdWhiCorrectShape,  0, "GRU_CELL_BP op: wrong shape of dLdWhi array (gradient wrt Wh at previous time step), expected is %s, but got %s instead !", dLdWhiCorrectShape.c_str(), dLdWhiShape.c_str()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     if(block.width() > 8) { |  | ||||||
|         Nd4jLong* dLdbiShapeInfo  = inputShape->at(8);                                              // [3*nU]
 |  | ||||||
|         const std::string dLdbiShape        = ShapeUtils::shapeAsString(dLdbiShapeInfo); |  | ||||||
|         const std::string dLdbiCorrectShape = ShapeUtils::shapeAsString({3*nU}); |  | ||||||
|         REQUIRE_TRUE(dLdbiShape == dLdbiCorrectShape,  0, "GRU_CELL_BP op: wrong shape of dLdbi array (gradient wrt biases at previous time step), expected is %s, but got %s instead !", dLdbiCorrectShape.c_str(), dLdbiShape.c_str()); |  | ||||||
|     } |  | ||||||
| 
 | 
 | ||||||
|     Nd4jLong *dLdxShapeInfo = nullptr; |     Nd4jLong *dLdxShapeInfo = nullptr; | ||||||
|     COPY_SHAPE(xShapeInfo, dLdxShapeInfo); |     COPY_SHAPE(xShapeInfo, dLdxShapeInfo); | ||||||
| @ -253,17 +241,19 @@ DECLARE_SHAPE_FN(gruCell_bp) { | |||||||
|     Nd4jLong *dLdhiShapeInfo = nullptr; |     Nd4jLong *dLdhiShapeInfo = nullptr; | ||||||
|     COPY_SHAPE(hiShapeInfo, dLdhiShapeInfo); |     COPY_SHAPE(hiShapeInfo, dLdhiShapeInfo); | ||||||
| 
 | 
 | ||||||
|     Nd4jLong *dLdWxShapeInfo = nullptr; |     Nd4jLong *dLdWShapeInfo = nullptr; | ||||||
|     COPY_SHAPE(wxShapeInfo, dLdWxShapeInfo); |     COPY_SHAPE(wShapeInfo, dLdWShapeInfo); | ||||||
| 
 | 
 | ||||||
|     Nd4jLong *dLdWhShapeInfo = nullptr; |     Nd4jLong *dLdWcShapeInfo = nullptr; | ||||||
|     COPY_SHAPE(whShapeInfo, dLdWhShapeInfo); |     COPY_SHAPE(wcShapeInfo, dLdWcShapeInfo); | ||||||
| 
 | 
 | ||||||
|     Nd4jLong *dLdbShapeInfo = nullptr; |     Nd4jLong *dLdbShapeInfo = nullptr; | ||||||
|     COPY_SHAPE(bShapeInfo, dLdbShapeInfo); |     COPY_SHAPE(bShapeInfo, dLdbShapeInfo); | ||||||
| 
 | 
 | ||||||
|     return SHAPELIST(dLdxShapeInfo, dLdhiShapeInfo, dLdWxShapeInfo, dLdWhShapeInfo, dLdbShapeInfo); |     Nd4jLong *dLdbcShapeInfo = nullptr; | ||||||
|  |     COPY_SHAPE(bcShapeInfo, dLdbcShapeInfo); | ||||||
| 
 | 
 | ||||||
|  |     return SHAPELIST(dLdxShapeInfo, dLdhiShapeInfo, dLdWShapeInfo, dLdWcShapeInfo, dLdbShapeInfo, dLdbcShapeInfo); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -553,33 +553,31 @@ namespace nd4j { | |||||||
|         /**
 |         /**
 | ||||||
|          * This operation adjusts image hue by delta |          * This operation adjusts image hue by delta | ||||||
|          * Input arrays: |          * Input arrays: | ||||||
|          * 0 - 1D or 3D input array, must have 3 channels. |          * 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels. | ||||||
|          * 1 - optional scalar, delta value |  | ||||||
|          * |          * | ||||||
|          * T arguments: |          * T arguments: | ||||||
|          * 0 - optional delta value |          * 0 - delta value | ||||||
|          * |          * | ||||||
|          * Int arguments: |          * Int arguments: | ||||||
|          * 0 - optional argument, isNHWC. false by default. |          * 0 - optional argument, corresponds to dimension with 3 channels | ||||||
|          */ |          */ | ||||||
|         #if NOT_EXCLUDED(OP_adjust_hue) |         #if NOT_EXCLUDED(OP_adjust_hue) | ||||||
|         DECLARE_CONFIGURABLE_OP(adjust_hue, 1, 1, true, -2, -2); |         DECLARE_CONFIGURABLE_OP(adjust_hue, 1, 1, true, 1, -2); | ||||||
|         #endif |         #endif | ||||||
| 
 | 
 | ||||||
|         /**
 |         /**
 | ||||||
|          * This operation adjusts image saturation by delta |          * This operation adjusts image saturation by delta | ||||||
|          * Input arrays: |          * Input arrays: | ||||||
|          * 0 - 1D or 3D input array, must have 3 channels. |          * 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels. | ||||||
|          * 1 - optional scalar, delta value |  | ||||||
|          * |          * | ||||||
|          * T arguments: |          * T arguments: | ||||||
|          * 0 - optional delta value |          * 0 - saturation factor | ||||||
|          * |          * | ||||||
|          * Int arguments: |          * Int arguments: | ||||||
|          * 0 - optional argument, isNHWC. false by default. |          * 0 - optional argument, corresponds to dimension with 3 channels | ||||||
|          */ |          */ | ||||||
|         #if NOT_EXCLUDED(OP_adjust_saturation) |         #if NOT_EXCLUDED(OP_adjust_saturation) | ||||||
|         DECLARE_CONFIGURABLE_OP(adjust_saturation, 1, 1, true, -2, -2); |         DECLARE_CONFIGURABLE_OP(adjust_saturation, 1, 1, true, 1, -2); | ||||||
|         #endif |         #endif | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -259,8 +259,8 @@ namespace ops  { | |||||||
|        * Input arrays: |        * Input arrays: | ||||||
|        *    0: input with shape [batchSize x inSize], batchSize - batch size, inSize - number of features |        *    0: input with shape [batchSize x inSize], batchSize - batch size, inSize - number of features | ||||||
|        *    1: previous cell output [batchSize x numUnits],  that is at previous time step t-1 |        *    1: previous cell output [batchSize x numUnits],  that is at previous time step t-1 | ||||||
|        *    2: RU weights - [(nIn+nOut), 2*numUnits] - reset and update gates (input/recurrent weights) |        *    2: RU weights - [(inSize+numUnits), 2*numUnits] - reset and update gates (input/recurrent weights) | ||||||
|        *    3: C weights - [(nIn+nOut), numUnits] - cell gate (input/recurrent weights) |        *    3: C weights - [(inSize+numUnits), numUnits] - cell gate (input/recurrent weights) | ||||||
|        *    4: reset and update biases, [2*numUnits] - reset and update gates |        *    4: reset and update biases, [2*numUnits] - reset and update gates | ||||||
|        *    5: cell biases, [numUnits] |        *    5: cell biases, [numUnits] | ||||||
|        * |        * | ||||||
| @ -275,7 +275,7 @@ namespace ops  { | |||||||
|         #endif |         #endif | ||||||
| 
 | 
 | ||||||
|         #if NOT_EXCLUDED(OP_gruCell) |         #if NOT_EXCLUDED(OP_gruCell) | ||||||
|         DECLARE_CUSTOM_OP(gruCell_bp, 6, 5, false, 0, 0); |         DECLARE_CUSTOM_OP(gruCell_bp, 10, 6, false, 0, 0); | ||||||
|         #endif |         #endif | ||||||
| 
 | 
 | ||||||
|     //////////////////////////////////////////////////////////////////////////
 |     //////////////////////////////////////////////////////////////////////////
 | ||||||
|  | |||||||
| @ -16,6 +16,7 @@ | |||||||
| 
 | 
 | ||||||
| //
 | //
 | ||||||
| // @author raver119@gmail.com
 | // @author raver119@gmail.com
 | ||||||
|  | // @author Yurii Shyrma (iuriish@yahoo.com)
 | ||||||
| //
 | //
 | ||||||
| 
 | 
 | ||||||
| #include <op_boilerplate.h> | #include <op_boilerplate.h> | ||||||
| @ -24,6 +25,88 @@ | |||||||
| namespace nd4j    { | namespace nd4j    { | ||||||
| namespace ops     { | namespace ops     { | ||||||
| namespace helpers { | namespace helpers { | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |     void adjustHue(nd4j::LaunchContext* context, const NDArray *input, const NDArray* deltaScalarArr, NDArray *output, const int dimC); | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | ////////////////////////////////////////////////////////////////////////////////
 | ||||||
|  | template <typename T> | ||||||
|  | FORCEINLINE _CUDA_HD void rgbToHsv(const T& r, const T& g, const T& b, T& h, T& s, T& v) { | ||||||
|  | 
 | ||||||
|  |     // h values are in range [0, 360)
 | ||||||
|  |     // s and v values are in range [0, 1]
 | ||||||
|  | 
 | ||||||
|  |     const T max = nd4j::math::nd4j_max<T>(r, nd4j::math::nd4j_max<T>(g, b)); | ||||||
|  |     const T min = nd4j::math::nd4j_min<T>(r, nd4j::math::nd4j_min<T>(g, b)); | ||||||
|  |     const T c  = max - min; | ||||||
|  | 
 | ||||||
|  |     // calculate h
 | ||||||
|  |     if(c == 0) { | ||||||
|  |         h = 0; | ||||||
|  |     } | ||||||
|  |     else if(max == r) { | ||||||
|  |         h = 60.f * ((g - b) / c) + (g >= b ? 0 : 360); | ||||||
|  |     } | ||||||
|  |     else if(max == g) { | ||||||
|  |         h = 60.f * ((b - r) / c) + 120; | ||||||
|  |     } | ||||||
|  |     else { // max == b
 | ||||||
|  |         h = 60.f * ((r - g) / c) + 240; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // calculate s
 | ||||||
|  |     s = max == (T)0 ? (T)0 : c / max; | ||||||
|  | 
 | ||||||
|  |     // calculate v
 | ||||||
|  |     v = max / 255.f; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | ////////////////////////////////////////////////////////////////////////////////
 | ||||||
|  | template <typename T> | ||||||
|  | FORCEINLINE _CUDA_HD void hsvToRgb(const T& h, const T& s, const T& v, T& r, T& g, T& b) { | ||||||
|  | 
 | ||||||
|  |     const float sector = h / 60.f; | ||||||
|  |     const T c = v * s; | ||||||
|  | 
 | ||||||
|  |     if(0.f <= sector && sector < 1.f) { | ||||||
|  |         r = v; | ||||||
|  |         g = v - c * (1 - sector); | ||||||
|  |         b = v - c; | ||||||
|  |     } | ||||||
|  |     else if(1.f <= sector && sector < 2.f) { | ||||||
|  |         r = v - c * (sector - 1); | ||||||
|  |         g = v; | ||||||
|  |         b = v - c; | ||||||
|  |     } | ||||||
|  |     else if(2.f <= sector && sector < 3.f) { | ||||||
|  |         r = v - c; | ||||||
|  |         g = v; | ||||||
|  |         b = v - c * (3 - sector); | ||||||
|  |     } | ||||||
|  |     else if(3.f <= sector && sector < 4.f) { | ||||||
|  |         r = v - c; | ||||||
|  |         g = v - c * (sector - 3); | ||||||
|  |         b = v; | ||||||
|  |     } | ||||||
|  |     else if(4.f <= sector && sector < 5.f) { | ||||||
|  |         r = v - c * (5 - sector); | ||||||
|  |         g = v - c; | ||||||
|  |         b = v; | ||||||
|  |     } | ||||||
|  |     else {      // 5.f <= sector < 6.f
 | ||||||
|  |         r = v; | ||||||
|  |         g = v - c; | ||||||
|  |         b = v - c * (sector - 5); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     r *= 255; | ||||||
|  |     g *= 255; | ||||||
|  |     b *= 255; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | /*////////////////////////////////////////////////////////////////////////////////
 | ||||||
| template <typename T> | template <typename T> | ||||||
| static FORCEINLINE _CUDA_HD void rgb_to_hv(T r, T g, T b, T* h, T* v_min, T* v_max) { | static FORCEINLINE _CUDA_HD void rgb_to_hv(T r, T g, T b, T* h, T* v_min, T* v_max) { | ||||||
|     T v_mid; |     T v_mid; | ||||||
| @ -83,6 +166,7 @@ namespace helpers { | |||||||
|     *h = h_category + (increase ? ratio : (1 - ratio)); |     *h = h_category + (increase ? ratio : (1 - ratio)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | ////////////////////////////////////////////////////////////////////////////////
 | ||||||
| template <typename T> | template <typename T> | ||||||
| static FORCEINLINE _CUDA_HD void hv_to_rgb(T h, T v_min, T v_max, T* r, T* g, T* b) { | static FORCEINLINE _CUDA_HD void hv_to_rgb(T h, T v_min, T v_max, T* r, T* g, T* b) { | ||||||
|     int h_category = static_cast<int>(h); |     int h_category = static_cast<int>(h); | ||||||
| @ -128,7 +212,7 @@ namespace helpers { | |||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|     void _adjust_hue(nd4j::LaunchContext * context, NDArray *input, NDArray *output, NDArray *delta, bool isNHWC); | */ | ||||||
| } | } | ||||||
| } | } | ||||||
| } | } | ||||||
| @ -16,6 +16,7 @@ | |||||||
| 
 | 
 | ||||||
| //
 | //
 | ||||||
| // @author raver119@gmail.com
 | // @author raver119@gmail.com
 | ||||||
|  | // @author Yurii Shyrma (iuriish@yahoo.com)
 | ||||||
| //
 | //
 | ||||||
| 
 | 
 | ||||||
| #include <op_boilerplate.h> | #include <op_boilerplate.h> | ||||||
| @ -25,6 +26,10 @@ | |||||||
| namespace nd4j    { | namespace nd4j    { | ||||||
| namespace ops     { | namespace ops     { | ||||||
| namespace helpers { | namespace helpers { | ||||||
|  | 
 | ||||||
|  |     void adjustSaturation(nd4j::LaunchContext* context, const NDArray *input, const NDArray* factorScalarArr, NDArray *output, const int dimC); | ||||||
|  | 
 | ||||||
|  | /*
 | ||||||
|     template <typename T> |     template <typename T> | ||||||
|     static FORCEINLINE _CUDA_HD void rgb_to_hsv(T r, T g, T b, T* h, T* s, T* v) { |     static FORCEINLINE _CUDA_HD void rgb_to_hsv(T r, T g, T b, T* h, T* s, T* v) { | ||||||
|         T vv = nd4j::math::nd4j_max<T>(r, nd4j::math::nd4j_max<T>(g, b)); |         T vv = nd4j::math::nd4j_max<T>(r, nd4j::math::nd4j_max<T>(g, b)); | ||||||
| @ -109,8 +114,8 @@ namespace helpers { | |||||||
|         *g = gg + m; |         *g = gg + m; | ||||||
|         *b = bb + m; |         *b = bb + m; | ||||||
|     } |     } | ||||||
|  | */ | ||||||
| 
 | 
 | ||||||
|     void adjust_saturation(nd4j::LaunchContext * context, NDArray *input, NDArray *output, NDArray *delta, bool isNHWC); |  | ||||||
| } | } | ||||||
| } | } | ||||||
| } | } | ||||||
| @ -16,16 +16,84 @@ | |||||||
| 
 | 
 | ||||||
| //
 | //
 | ||||||
| // @author raver119@gmail.com
 | // @author raver119@gmail.com
 | ||||||
|  | // @author Yurii Shyrma (iuriish@yahoo.com)
 | ||||||
| //
 | //
 | ||||||
| 
 | 
 | ||||||
| #include <ops/declarable/helpers/adjust_hue.h> | #include <ops/declarable/helpers/adjust_hue.h> | ||||||
|  | #include <helpers/ConstantTadHelper.h> | ||||||
| 
 | 
 | ||||||
| namespace nd4j { | namespace nd4j { | ||||||
| namespace ops { | namespace ops { | ||||||
| namespace helpers { | namespace helpers { | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| template <typename T> | template <typename T> | ||||||
|     static void _adjust_hue_single(nd4j::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC) { | static void adjustHue_(const NDArray *input, const NDArray* deltaScalarArr, NDArray *output, const int dimC) { | ||||||
|  | 
 | ||||||
|  |     const T delta  = deltaScalarArr->e<T>(0); | ||||||
|  |     const int rank = input->rankOf(); | ||||||
|  | 
 | ||||||
|  |     const T* x = input->bufferAsT<T>(); | ||||||
|  |           T* z = output->bufferAsT<T>(); | ||||||
|  | 
 | ||||||
|  |     if(dimC == rank - 1 && input->ews() == 1 && output->ews() == 1 && input->ordering() == 'c' && output->ordering() == 'c') { | ||||||
|  | 
 | ||||||
|  |         PRAGMA_OMP_PARALLEL_FOR_SIMD | ||||||
|  |         for (Nd4jLong i = 0; i < input->lengthOf(); i += 3) { | ||||||
|  | 
 | ||||||
|  |             T h, s, v; | ||||||
|  | 
 | ||||||
|  |             rgbToHsv<T>(x[i], x[i+1], x[i+2], h, s, v); | ||||||
|  | 
 | ||||||
|  |             h += delta * 360; | ||||||
|  |             if(h > 360) | ||||||
|  |                 h -= 360; | ||||||
|  |             else if(h < 0) | ||||||
|  |                 h += 360; | ||||||
|  | 
 | ||||||
|  |             hsvToRgb<T>(h, s, v, z[i], z[i+1], z[i+2]); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     else { | ||||||
|  | 
 | ||||||
|  |         auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(),  {dimC}); | ||||||
|  |         auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {dimC}); | ||||||
|  | 
 | ||||||
|  |         const Nd4jLong numOfTads   = packX.numberOfTads(); | ||||||
|  |         const Nd4jLong xDimCstride = input->stridesOf()[dimC]; | ||||||
|  |         const Nd4jLong zDimCstride = output->stridesOf()[dimC]; | ||||||
|  | 
 | ||||||
|  |         PRAGMA_OMP_PARALLEL_FOR_SIMD | ||||||
|  |         for(Nd4jLong i = 0; i < numOfTads; ++i) { | ||||||
|  | 
 | ||||||
|  |             const T* xTad = x + packX.platformOffsets()[i]; | ||||||
|  |                   T* zTad = z + packZ.platformOffsets()[i]; | ||||||
|  | 
 | ||||||
|  |             T h, s, v; | ||||||
|  | 
 | ||||||
|  |             rgbToHsv<T>(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], h, s, v); | ||||||
|  | 
 | ||||||
|  |             h += delta * 360; | ||||||
|  |             if(h > 360) | ||||||
|  |                 h -= 360; | ||||||
|  |             else if(h < 0) | ||||||
|  |                 h += 360; | ||||||
|  | 
 | ||||||
|  |             hsvToRgb<T>(h, s, v, zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); | ||||||
|  | 
 | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | void adjustHue(nd4j::LaunchContext* context, const NDArray *input, const NDArray* deltaScalarArr, NDArray *output, const int dimC) { | ||||||
|  | 
 | ||||||
|  |     BUILD_SINGLE_SELECTOR(input->dataType(), adjustHue_, (input, deltaScalarArr, output, dimC), LIBND4J_TYPES); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | /*
 | ||||||
|  | template <typename T> | ||||||
|  | static void adjust_hue_single_(nd4j::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC) { | ||||||
|     // we're 100% sure it's 3
 |     // we're 100% sure it's 3
 | ||||||
|     const int numChannels = 3; |     const int numChannels = 3; | ||||||
|     int tuples = array->lengthOf() /  numChannels; |     int tuples = array->lengthOf() /  numChannels; | ||||||
| @ -93,7 +161,7 @@ namespace helpers { | |||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|     void _adjust_hue(nd4j::LaunchContext * context, NDArray *array, NDArray *output, NDArray* delta, bool isNHWC) { | void adjust_hue_(nd4j::LaunchContext * context, NDArray *array, NDArray *output, NDArray* delta, bool isNHWC) { | ||||||
|     auto xType = array->dataType(); |     auto xType = array->dataType(); | ||||||
| 
 | 
 | ||||||
|     float d = delta->e<float>(0); |     float d = delta->e<float>(0); | ||||||
| @ -104,18 +172,20 @@ namespace helpers { | |||||||
|         // FIXME: template selector should be moved out of loop
 |         // FIXME: template selector should be moved out of loop
 | ||||||
|         PRAGMA_OMP_PARALLEL_FOR |         PRAGMA_OMP_PARALLEL_FOR | ||||||
|         for (int e = 0; e < tSize; e++) { |         for (int e = 0; e < tSize; e++) { | ||||||
|                 BUILD_SINGLE_SELECTOR(xType, _adjust_hue_single, (context, tadsIn->at(e), tadsOut->at(e), d, isNHWC);, FLOAT_TYPES); |             BUILD_SINGLE_SELECTOR(xType, adjust_hue_single_, (context, tadsIn->at(e), tadsOut->at(e), d, isNHWC);, FLOAT_TYPES); | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|         delete tadsIn; |         delete tadsIn; | ||||||
|         delete tadsOut; |         delete tadsOut; | ||||||
|     } else { |     } else { | ||||||
|             BUILD_SINGLE_SELECTOR(xType, _adjust_hue_single, (context, array, output, d, isNHWC);, FLOAT_TYPES); |         BUILD_SINGLE_SELECTOR(xType, adjust_hue_single_, (context, array, output, d, isNHWC);, FLOAT_TYPES); | ||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|     BUILD_SINGLE_TEMPLATE(template void _adjust_hue_single, (nd4j::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC);, FLOAT_TYPES); | BUILD_SINGLE_TEMPLATE(template void adjust_hue_single_, (nd4j::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC);, FLOAT_TYPES); | ||||||
|  | */ | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| } | } | ||||||
| } | } | ||||||
|  | |||||||
| @ -16,15 +16,83 @@ | |||||||
| 
 | 
 | ||||||
| //
 | //
 | ||||||
| // @author raver119@gmail.com
 | // @author raver119@gmail.com
 | ||||||
|  | // @author Yurii Shyrma (iuriish@yahoo.com)
 | ||||||
| //
 | //
 | ||||||
| 
 | 
 | ||||||
| #include <ops/declarable/helpers/adjust_saturation.h> | #include <ops/declarable/helpers/adjust_saturation.h> | ||||||
|  | #include <ops/declarable/helpers/adjust_hue.h> | ||||||
|  | #include <helpers/ConstantTadHelper.h> | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| namespace nd4j    { | namespace nd4j    { | ||||||
| namespace ops     { | namespace ops     { | ||||||
| namespace helpers { | namespace helpers { | ||||||
| 
 | 
 | ||||||
|  | template <typename T> | ||||||
|  | static void adjustSaturation_(const NDArray *input, const NDArray* factorScalarArr, NDArray *output, const int dimC) { | ||||||
|  | 
 | ||||||
|  |     const T factor = factorScalarArr->e<T>(0); | ||||||
|  |     const int rank = input->rankOf(); | ||||||
|  | 
 | ||||||
|  |     const T* x = input->bufferAsT<T>(); | ||||||
|  |           T* z = output->bufferAsT<T>(); | ||||||
|  | 
 | ||||||
|  |     if(dimC == rank - 1 && input->ews() == 1 && output->ews() == 1 && input->ordering() == 'c' && output->ordering() == 'c') { | ||||||
|  | 
 | ||||||
|  |         PRAGMA_OMP_PARALLEL_FOR_SIMD | ||||||
|  |         for (Nd4jLong i = 0; i < input->lengthOf(); i += 3) { | ||||||
|  | 
 | ||||||
|  |             T h, s, v; | ||||||
|  | 
 | ||||||
|  |             rgbToHsv<T>(x[i], x[i+1], x[i+2], h, s, v); | ||||||
|  | 
 | ||||||
|  |             s *= factor; | ||||||
|  |             if(s > 1.f) | ||||||
|  |                 s = 1.f; | ||||||
|  |             else if(s < 0.f) | ||||||
|  |                 s = 0.f; | ||||||
|  | 
 | ||||||
|  |             hsvToRgb<T>(h, s, v, z[i], z[i+1], z[i+2]); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     else { | ||||||
|  | 
 | ||||||
|  |         auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(),  {dimC}); | ||||||
|  |         auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {dimC}); | ||||||
|  | 
 | ||||||
|  |         const Nd4jLong numOfTads   = packX.numberOfTads(); | ||||||
|  |         const Nd4jLong xDimCstride = input->stridesOf()[dimC]; | ||||||
|  |         const Nd4jLong zDimCstride = output->stridesOf()[dimC]; | ||||||
|  | 
 | ||||||
|  |         PRAGMA_OMP_PARALLEL_FOR_SIMD | ||||||
|  |         for(Nd4jLong i = 0; i < numOfTads; ++i) { | ||||||
|  | 
 | ||||||
|  |             const T* xTad = x + packX.platformOffsets()[i]; | ||||||
|  |                   T* zTad = z + packZ.platformOffsets()[i]; | ||||||
|  | 
 | ||||||
|  |             T h, s, v; | ||||||
|  | 
 | ||||||
|  |             rgbToHsv<T>(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], h, s, v); | ||||||
|  | 
 | ||||||
|  |             s *= factor; | ||||||
|  |             if(s > 1.f) | ||||||
|  |                 s = 1.f; | ||||||
|  |             else if(s < 0.f) | ||||||
|  |                 s = 0.f; | ||||||
|  | 
 | ||||||
|  |             hsvToRgb<T>(h, s, v, zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); | ||||||
|  | 
 | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | void adjustSaturation(nd4j::LaunchContext* context, const NDArray *input, const NDArray* factorScalarArr, NDArray *output, const int dimC) { | ||||||
|  | 
 | ||||||
|  |     BUILD_SINGLE_SELECTOR(input->dataType(), adjustSaturation_, (input, factorScalarArr, output, dimC), LIBND4J_TYPES); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | /*
 | ||||||
| template <typename T> | template <typename T> | ||||||
| static void adjust_saturation_single_(nd4j::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC) { | static void adjust_saturation_single_(nd4j::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC) { | ||||||
|     // we're 100% sure it's 3
 |     // we're 100% sure it's 3
 | ||||||
| @ -108,6 +176,7 @@ namespace helpers { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| BUILD_SINGLE_TEMPLATE(template void adjust_saturation_single_, (nd4j::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC), FLOAT_TYPES); | BUILD_SINGLE_TEMPLATE(template void adjust_saturation_single_, (nd4j::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC), FLOAT_TYPES); | ||||||
|  | */ | ||||||
| 
 | 
 | ||||||
| } | } | ||||||
| } | } | ||||||
|  | |||||||
| @ -59,14 +59,17 @@ namespace helpers { | |||||||
|             std::vector<Nd4jLong> dims(reduceShape->lengthOf()); |             std::vector<Nd4jLong> dims(reduceShape->lengthOf()); | ||||||
| 
 | 
 | ||||||
|             bool fit = true; |             bool fit = true; | ||||||
| 
 |             PRAGMA_OMP_PARALLEL_FOR_ARGS(firstprivate(fit)) | ||||||
|             for( int i = 0; fit && (i < dims.size()); i++ ) { |             for( int i = 0; i < dims.size(); i++ ) { | ||||||
|  |                 if (fit) { | ||||||
|                     dims[i] = reduceShape->e<Nd4jLong>(i); |                     dims[i] = reduceShape->e<Nd4jLong>(i); | ||||||
|                 for (int e = 0; fit && (e < input->rankOf()); ++e) |                     for (int e = 0; e < input->rankOf(); ++e) | ||||||
|  |                         if (fit) | ||||||
|                         if (input->sizeAt(e) % dims[i]) { |                         if (input->sizeAt(e) % dims[i]) { | ||||||
|                             fit = false; |                             fit = false; | ||||||
|                         } |                         } | ||||||
|                 } |                 } | ||||||
|  |             } | ||||||
| 
 | 
 | ||||||
|             // check dims to fit input
 |             // check dims to fit input
 | ||||||
|             REQUIRE_TRUE(fit, 0, "dropout: Noise shape should fit to input rank."); |             REQUIRE_TRUE(fit, 0, "dropout: Noise shape should fit to input rank."); | ||||||
|  | |||||||
| @ -35,82 +35,88 @@ namespace helpers { | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| //////////////////////////////////////////////////////////////////////////
 | //////////////////////////////////////////////////////////////////////////
 | ||||||
| void gruCell(nd4j::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* Wru, const NDArray* Wc, | void gruCell(nd4j::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* W, const NDArray* Wc, | ||||||
|              const NDArray* bru, const NDArray* bc, |              const NDArray* b, const NDArray* bc, | ||||||
|              NDArray* r, NDArray* u, NDArray* c, NDArray* h) { |              NDArray* r, NDArray* u, NDArray* c, NDArray* h) { | ||||||
| 
 | 
 | ||||||
|     //Inputs:
 |     //Inputs:
 | ||||||
|     // x        input [bS, nIn], nIn - input size
 |     // x        input [bS, iS], iS - input size
 | ||||||
|     // hLast    previous cell output [bS, nUn],  that is at previous time step t-1, nUn - number of units
 |     // hLast    previous cell output [bS, nU],  that is at previous time step t-1, nU - number of units
 | ||||||
|     // Wru      RU weights - [nIn+nUn, 2*nUn] - reset and update gates
 |     // W        RU weights - [iS+nU, 2*nU] - reset and update gates
 | ||||||
|     // Wc       C weights - [nIn+nUn, nUn] - cell gate
 |     // Wc       C weights - [iS+nU, nU] - cell gate
 | ||||||
|     // bru      r and u biases, [2*nUn] - reset and update gates
 |     // b        r and u biases, [2*nU] - reset and update gates
 | ||||||
|     // bc       c biases, [nUn] - cell gate
 |     // bc       c biases, [nU] - cell gate
 | ||||||
| 
 | 
 | ||||||
|     //Outputs:
 |     //Outputs:
 | ||||||
|     // r        Reset gate output [bS, nUn]
 |     // r        Reset gate output [bS, nU]
 | ||||||
|     // u        Update gate output [bS, nUn]
 |     // u        Update gate output [bS, nU]
 | ||||||
|     // c        Cell gate output [bS, nUn]
 |     // c        Cell gate output [bS, nU]
 | ||||||
|     // h        current cell output [bS, nUn]
 |     // h        current cell output [bS, nU]
 | ||||||
| 
 | 
 | ||||||
|     /***************************************************************************************/ |     /***************************************************************************************/ | ||||||
|     /************************ THIS IS NOT OPTIMAZED CODE ***********************************/ |     /************************ THIS IS NOT OPTIMAZED CODE ***********************************/ | ||||||
|     /** however it is more math-friendly and convenient for backprop formulas derivation) **/ |     /** however it is more math-friendly and convenient for backprop formulas derivation) **/ | ||||||
| 
 | 
 | ||||||
|     const int bS  = x->sizeAt(0); |     const int bS  = x->sizeAt(0); | ||||||
|     const int nIn = x->sizeAt(1); |     const int iS = x->sizeAt(1); | ||||||
|     const int nUn = hLast->sizeAt(1); |     const int nU = hLast->sizeAt(1); | ||||||
| 
 | 
 | ||||||
|     NDArray Wr = (*Wru)({0,nIn,       0,0});       // reset gates weights   [nIn, 2*nUn]
 |     NDArray Wrx = (*W)({0,iS,     0,nU});       // [iS, nU]
 | ||||||
|     NDArray Wu = (*Wru)({nIn,nIn+nUn, 0,0});       // updates gates weights [nUn, 2*nUn]
 |     NDArray Wux = (*W)({0,iS,     nU,2*nU});    // [iS, nU]
 | ||||||
|  |     NDArray Wrh = (*W)({iS,iS+nU, 0,nU});       // [nU, nU]
 | ||||||
|  |     NDArray Wuh = (*W)({iS,iS+nU, nU,2*nU});    // [nU, nU]
 | ||||||
| 
 | 
 | ||||||
|     NDArray Wcr = (*Wc)({0,nIn,       0,0});       // reset cell weights    [nIn, nUn]
 |     NDArray Wcx = (*Wc)({0,iS,     0,0});       // reset cell weights    [iS, nU]
 | ||||||
|     NDArray Wcu = (*Wc)({nIn,nIn+nUn, 0,0});       // updates cell weights  [nUn, nUn]
 |     NDArray Wch = (*Wc)({iS,iS+nU, 0,0});       // updates cell weights  [nU, nU]
 | ||||||
| 
 | 
 | ||||||
|     // gates = sigmoid(x*Wr + hLast*Wu + br + bu)
 |     NDArray br = (*b)({0,  nU});                // [nU]
 | ||||||
|     NDArray gates = mmul(*x, Wr) + mmul(*hLast, Wu) + *bru;    // [bS, nIn] * [nIn, 2*nUn] + [bS, nUn] * [nUn, 2*nUn] + [2*nUn] = [bS, 2*nUn]
 |     NDArray bu = (*b)({nU, 2*nU});              // [nU]
 | ||||||
|     gates.applyTransform(transform::Sigmoid); | 
 | ||||||
|  |     // × means matrix multipication
 | ||||||
|  |     // * means element-wise product or so called Hadamard product
 | ||||||
| 
 | 
 | ||||||
|     // reset gate
 |     // reset gate
 | ||||||
|     r->assign(gates({0,0, 0,nUn}));               // [bS, nUn]
 |     r->assign(mmul(*x, Wrx) + mmul(*hLast, Wrh) + br);         // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
 | ||||||
|  |     r->applyTransform(transform::Sigmoid); | ||||||
| 
 | 
 | ||||||
|     // update gate
 |     // update gate
 | ||||||
|     u->assign(gates({0,0, nUn,2*nUn}));            // [bS, nUn]
 |     u->assign(mmul(*x, Wux) + mmul(*hLast, Wuh) + bu);         // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
 | ||||||
|  |     u->applyTransform(transform::Sigmoid); | ||||||
| 
 | 
 | ||||||
|     // cell gate c = activation(x*Wcr + (r◦hlast)*Wcu + bc)
 |     // cell gate c = activation(x × Wcx + (r * hlast) × Wch + bc)
 | ||||||
|     c->assign(mmul(*x, Wcr) + mmul(*r * *hLast, Wcu) + *bc);    // [bS, nIn] * [nIn, nUn] + [bS, nUn] * [nUn, nUn] + [nUn] = [bS, nUn]
 |     c->assign(mmul(*x, Wcx) + mmul(*r * *hLast, Wch) + *bc);    // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
 | ||||||
|     c->applyTransform(transform::Tanh); |     c->applyTransform(transform::Tanh); | ||||||
| 
 | 
 | ||||||
|  |     NDArray temp = 1.f - *c * *c; | ||||||
|  | 
 | ||||||
|     // cell output
 |     // cell output
 | ||||||
|     h->assign(*u * *hLast + (1.f - *u) * *c); |     h->assign(*u * *hLast + (1.f - *u) * *c); | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /***************************************************************************************/ |     /***************************************************************************************/ | ||||||
|     /********************** THIS MORE OPTIMAZED CODE (except concat ) **********************/ |     /*************** THIS IS MORE OPTIMAZED CODE (should think about concat) ***************/ | ||||||
|     /***************************************************************************************/ |     /***************************************************************************************/ | ||||||
| /*
 | /*
 | ||||||
|     //Concat inputs: x + hLast : [bs, nIn + nUn]
 |     //Concat inputs: x + hLast : [bs, iS + nU]
 | ||||||
|     NDArray xhConcat(x->ordering(), {bS, nIn + nUn}, x->dataType(), context);  // concat([bs, nIn], [bs, nUn]) -> [bs, nIn + nUn]
 |     NDArray xhConcat(x->ordering(), {bS, iS + nU}, x->dataType(), context);  // concat([bs, iS], [bs, nU]) -> [bs, iS + nU]
 | ||||||
|     helpers::concat(context, {const_cast<NDArray*>(x), const_cast<NDArray*>(hLast)},  xhConcat, {1}); |     helpers::concat(context, {const_cast<NDArray*>(x), const_cast<NDArray*>(hLast)},  xhConcat, {1}); | ||||||
| 
 | 
 | ||||||
|     //mmul for reset and update gates: (x * weight_ux + hLast * weight_xr + b_u)
 |     //mmul for reset and update gates: (x × weight_ux + hLast × weight_xr + b_u)
 | ||||||
|     auto m = mmul(xhConcat, *Wru) + *bru ;    // [bs, nIn+nUn] * [nIn+nUn, 2*nUn] = [bs, 2*nUn]
 |     auto m = mmul(xhConcat, *W) + *b ;    // [bs, iS+nU] * [iS+nU, 2*nU] = [bs, 2*nU]
 | ||||||
|     // m += *bru;
 |     // m += *bru;
 | ||||||
| 
 | 
 | ||||||
|     sigmoidInplace(m);  //sigmoid(rz) and sigmoid(uz)
 |     m.applyTransform(transform::Sigmoid);  //sigmoid(rz) and sigmoid(uz)
 | ||||||
| 
 | 
 | ||||||
|     r->assign(m({0,0, 0, nUn})); |     r->assign(m({0,0, 0, nU})); | ||||||
|     u->assign(m({0,0, nUn, 2*nUn})); |     u->assign(m({0,0, nU, 2*nU})); | ||||||
| 
 | 
 | ||||||
|     // hLast = hLast * r
 |     // hLast = hLast * r
 | ||||||
|     xhConcat({0,0, nIn, nIn+nUn}) *= *r; |     xhConcat({0,0, iS, iS+nU}) *= *r; | ||||||
| 
 | 
 | ||||||
|     //c = tanh(x * weight_cx + (hLast .* r) * weight_cr + b_c)
 |     //c = tanh(x × weight_cx + (hLast * r) × weight_cr + b_c)
 | ||||||
|     MmulHelper::mmul(&xhConcat, Wc, c, 1.0, 0.0);       //c = 1.0 * xhConcat * Wc + 0.0 * c
 |     MmulHelper::mmul(&xhConcat, Wc, c, 1.0, 0.0);       //c = 1.0 * xhConcat * Wc + 0.0 * c
 | ||||||
|     *c += *bc; |     *c += *bc; | ||||||
|     tanhInplace(*c); |     c->applyTransform(transform::Tanh); | ||||||
| 
 | 
 | ||||||
|     //Output: h = (1-u).*c + u .* hPrev
 |     //Output: h = (1-u).*c + u .* hPrev
 | ||||||
|     //auto hResult = (*u) * (*hLast) + (1.0f - *u) * (*c); const_cast<NDArray*>(h)->assign(&hResult);
 |     //auto hResult = (*u) * (*hLast) + (1.0f - *u) * (*c); const_cast<NDArray*>(h)->assign(&hResult);
 | ||||||
| @ -122,19 +128,19 @@ void gruCell(nd4j::LaunchContext * context, const NDArray* x, const NDArray* hLa | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| //////////////////////////////////////////////////////////////////////////
 | //////////////////////////////////////////////////////////////////////////
 | ||||||
| void gruTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h) { | void gruTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h) { | ||||||
| 
 | 
 | ||||||
|     // x   input [time, bS, iS]
 |     // x   input [time, bS, iS]
 | ||||||
|     // h0  initial cell output (at time step = 0) [bS, nUn]
 |     // hLast  initial cell output (at time step = 0) [bS, nU]
 | ||||||
|     // Wx  input-to-hidden  weights, [iS, 3*nUn]
 |     // Wx  input-to-hidden  weights, [iS, 3*nU]
 | ||||||
|     // Wh  hidden-to-hidden weights, [nUn, 3*nUn]
 |     // Wh  hidden-to-hidden weights, [nU, 3*nU]
 | ||||||
|     // b   biases, [3*nUn]
 |     // b   biases, [3*nU]
 | ||||||
| 
 | 
 | ||||||
|     // h is cell outputs at each time step [time, bS, nUn]
 |     // h is cell outputs at each time step [time, bS, nU]
 | ||||||
| 
 | 
 | ||||||
|     const int time = x->sizeAt(0); |     const int time = x->sizeAt(0); | ||||||
| 
 | 
 | ||||||
|     NDArray ht_1(*h0); |     NDArray ht_1(*hLast); | ||||||
| 
 | 
 | ||||||
|     // loop through time steps
 |     // loop through time steps
 | ||||||
|     for (int t = 0; t < time; ++t) { |     for (int t = 0; t < time; ++t) { | ||||||
| @ -148,105 +154,208 @@ void gruTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray* | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| //////////////////////////////////////////////////////////////////////////
 | //////////////////////////////////////////////////////////////////////////
 | ||||||
| void gruCellBP(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh, const NDArray* dLdWx0, | void gruCellBP(nd4j::LaunchContext* context, | ||||||
|            const NDArray* dLdWh0, const NDArray* dLdb0, NDArray* dLdx, NDArray* dLdh0, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb) { |               const NDArray* x,    const NDArray* hLast, | ||||||
|  |               const NDArray* W,    const NDArray* Wc,        const NDArray* b,    const NDArray* bc, | ||||||
|  |               const NDArray* dLdr, const NDArray* dLdu,      const NDArray* dLdc, const NDArray* dLdh, | ||||||
|  |                     NDArray* dLdx,       NDArray* dLdhLast, | ||||||
|  |                     NDArray* dLdW,       NDArray* dLdWc, | ||||||
|  |                     NDArray* dLdb,       NDArray* dLdbc) { | ||||||
| 
 | 
 | ||||||
|  |     //Inputs:
 | ||||||
|     // x              input [bS, iS]
 |     // x              input [bS, iS]
 | ||||||
|     // h0                       previous cell output [bS, nUn],  that is at previous time step t-1
 |     // hLast          previous cell output [bS, nU],  that is at previous time step t-1
 | ||||||
|     // Wx                       input-to-hidden  weights, [iS, 3*nUn]
 |     // W              weights - [iS+nU, 2*nU] - reset and update gates
 | ||||||
|     // Wh                       hidden-to-hidden weights, [nUn, 3*nUn]
 |     // Wc             C weights - [iS+nU, nU] - cell gate
 | ||||||
|     // b                        biases, [3*nUn]
 |     // b              r and u biases, [2*nU] - reset and update gates
 | ||||||
|     // dLdh                     gradient wrt output, [bS,nUn], that is epsilon_next
 |     // bc             c biases, [nU] - cell gate
 | ||||||
|     // dLdWx0                   gradient wrt Wx at previous time step, [iS, 3*nUn]
 |     // dLdr           gradient wrt reset gate, [bS, nU]
 | ||||||
|     // dLdWh0                   gradient wrt Wh at previous time step, [nUn, 3*nUn]
 |     // dLdu           gradient wrt update gate, [bS, nU]
 | ||||||
|     // dLdb0                    gradient wrt b at previous time step,  [3*nUn]
 |     // dLdc           gradient wrt cell state, [bS, nU]
 | ||||||
|  |     // dLdh           gradient wrt current cell output, [bS, nU]
 | ||||||
| 
 | 
 | ||||||
|     // dLdx                   gradient wrt x,  [bS, iS], that is epsilon
 |     //Outputs:
 | ||||||
|     // dLdh0                  gradient wrt h0, [bS, nUn]
 |     // dLdx           gradient wrt x,  [bS, iS],
 | ||||||
|     // dLdWx                  gradient wrt Wx, [iS, 3*nUn]
 |     // dLdhLast       gradient wrt hLast, [bS, nU]
 | ||||||
|     // dLdWh                  gradient wrt Wh, [nUn, 3*nUn]
 |     // dLdW           gradient wrt W,  [iS+nU, 2*nU]
 | ||||||
|     // dLdb                   gradient wrt b at previous time step,  [3*nUn]
 |     // dLdWc          gradient wrt Wc, [iS+nU, nU]
 | ||||||
|  |     // dLdb           gradient wrt bru [2*nU]
 | ||||||
|  |     // dLdbc          gradient wrt bc  [nU]
 | ||||||
| 
 | 
 | ||||||
|     // h is current cell output [bS, nUn], that is at current time step t
 |     // * means element-wise product or so called Hadamard product
 | ||||||
|  |     // × means matrix multiplication
 | ||||||
|  | 
 | ||||||
|  |     /************************************************************************************************/ | ||||||
|  |     /******************************* THIS IS NOT OPTIMAZED CODE *************************************/ | ||||||
|  |     /*** aim is to have math-readable code in order to keep track of backprop formulas derivation ***/ | ||||||
|  | 
 | ||||||
|  |     const int bS  = x->sizeAt(0); | ||||||
|  |     const int iS = x->sizeAt(1); | ||||||
|  |     const int nU = hLast->sizeAt(1); | ||||||
|  | 
 | ||||||
|  |     NDArray xT     = x->transpose();            // [iS, bS]
 | ||||||
|  |     NDArray hLastT = hLast->transpose();        // [nU, bS]
 | ||||||
|  | 
 | ||||||
|  |     NDArray Wrx = (*W)({0,iS,     0,nU});       // [iS, nU]
 | ||||||
|  |     NDArray Wux = (*W)({0,iS,     nU,2*nU});    // [iS, nU]
 | ||||||
|  |     NDArray Wrh = (*W)({iS,iS+nU, 0,nU});       // [nU, nU]
 | ||||||
|  |     NDArray Wuh = (*W)({iS,iS+nU, nU,2*nU});    // [nU, nU]
 | ||||||
|  | 
 | ||||||
|  |     NDArray Wcx = (*Wc)({0,iS,     0,0});       // reset cell weights    [iS, nU]
 | ||||||
|  |     NDArray Wch = (*Wc)({iS,iS+nU, 0,0});       // updates cell weights  [nU, nU]
 | ||||||
|  | 
 | ||||||
|  |     NDArray br = (*b)({0,  nU});                // [nU]
 | ||||||
|  |     NDArray bu = (*b)({nU, 2*nU});              // [nU]
 | ||||||
|  | 
 | ||||||
|  |     NDArray WrxT = Wrx.transpose();             // [nU, iS]
 | ||||||
|  |     NDArray WuxT = Wux.transpose();             // [nU, iS]
 | ||||||
|  |     NDArray WrhT = Wrh.transpose();             // [nU, nU]
 | ||||||
|  |     NDArray WuhT = Wuh.transpose();             // [nU, nU]
 | ||||||
|  | 
 | ||||||
|  |     NDArray WcxT = Wcx.transpose();             // [nU, iS]
 | ||||||
|  |     NDArray WchT = Wch.transpose();             // [nU, nU]
 | ||||||
|  | 
 | ||||||
|  |     NDArray dLdWrx = (*dLdW)({0,iS,     0,nU});     // [iS, nU]
 | ||||||
|  |     NDArray dLdWux = (*dLdW)({0,iS,     nU,2*nU});  // [iS, nU]
 | ||||||
|  |     NDArray dLdWrh = (*dLdW)({iS,iS+nU, 0,nU});     // [nU, nU]
 | ||||||
|  |     NDArray dLdWuh = (*dLdW)({iS,iS+nU, nU,2*nU});  // [nU, nU]
 | ||||||
|  | 
 | ||||||
|  |     NDArray dLdWcx = (*dLdWc)({0,iS,     0,0});     // [iS, nU]
 | ||||||
|  |     NDArray dLdWch = (*dLdWc)({iS,iS+nU, 0,0});     // [nU, nU]
 | ||||||
|  | 
 | ||||||
|  |     NDArray dLdbr = (*dLdb)({0,  nU});              // [nU]
 | ||||||
|  |     NDArray dLdbu = (*dLdb)({nU, 2*nU});            // [nU]
 | ||||||
| 
 | 
 | ||||||
|     const int nUn = h0->sizeAt(1); |  | ||||||
| 
 | 
 | ||||||
|     // ***** feed forward step ***** //
 |     // ***** feed forward step ***** //
 | ||||||
|     // gates = sigmoid(x*Wx + h0*Wh + b)
 | 
 | ||||||
|     auto gates = sigmoid(mmul(*x, (*Wx)({0,0, 0,2*nUn})) + mmul(*h0, (*Wh)({0,0, 0,2*nUn})) + (*b)({0,2*nUn}));       // [bS, 2*nUn] + [bS, 2*nUn] + [1, 2*nUn] = [bS, 2*nUn]
 |  | ||||||
|     // reset gate
 |     // reset gate
 | ||||||
|     auto r = gates({0,0, 0, nUn});               // [bS, nUn]
 |     NDArray r = mmul(*x, Wrx) + mmul(*hLast, Wrh) + br;         // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
 | ||||||
|  |     r.applyTransform(transform::Sigmoid); | ||||||
|  | 
 | ||||||
|     // update gate
 |     // update gate
 | ||||||
|     auto u = gates({0,0, nUn, 2*nUn});            // [bS, nUn]
 |     NDArray u = mmul(*x, Wux) + mmul(*hLast, Wuh) + bu;         // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
 | ||||||
|     // ◦ means element-wise product or so called Hadamard product
 |     u.applyTransform(transform::Sigmoid); | ||||||
|     // n = tanh(x*Wx + (r◦h0)*Wh + b)
 | 
 | ||||||
|     auto n = tanh(mmul(*x, (*Wx)({0,0, 2*nUn,3*nUn})) + mmul((*h0)*r, (*Wh)({0,0, 2*nUn,3*nUn})) + (*b)({2*nUn,3*nUn}));     // [bS, nUn]
 |     // cell gate c = activation(x×Wcx + (r*hlast)×Wcu + bc)
 | ||||||
|  |     NDArray c = mmul(*x, Wcx) + mmul(r * *hLast, Wch) + *bc;    // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
 | ||||||
|  |     c.applyTransform(transform::Tanh); | ||||||
|  | 
 | ||||||
|  |     // h = (1 - u) * c + u * hPrev
 | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
|     // ***** back prop step ***** //
 |     // ***** back prop step ***** //
 | ||||||
|     auto Wxr  = (*Wx)({0,0, 0,   nUn}); |  | ||||||
|     auto Wxu  = (*Wx)({0,0, nUn,  2*nUn}); |  | ||||||
|     auto Wxn  = (*Wx)({0,0, 2*nUn,3*nUn}); |  | ||||||
|     auto Whr  = (*Wh)({0,0, 0,   nUn}); |  | ||||||
|     auto Whu  = (*Wh)({0,0, nUn,  2*nUn}); |  | ||||||
|     auto Whn  = (*Wh)({0,0, 2*nUn,3*nUn}); |  | ||||||
|     auto WxrT = Wxr.transpose(); |  | ||||||
|     auto WxuT = Wxu.transpose(); |  | ||||||
|     auto WxnT = Wxn.transpose(); |  | ||||||
|     auto WhrT = Whr.transpose(); |  | ||||||
|     auto WhuT = Whu.transpose(); |  | ||||||
|     auto WhnT = Whn.transpose(); |  | ||||||
|     auto xT   = x->transpose(); |  | ||||||
|     auto h0T  = h0->transpose(); |  | ||||||
| 
 | 
 | ||||||
|     auto dLdWxr = (*dLdWx)({0,0, 0,     nUn}); |     // notations:
 | ||||||
|     auto dLdWxu = (*dLdWx)({0,0, nUn,  2*nUn}); |     // Zr = x × Wrx + hLast × Wrh + br
 | ||||||
|     auto dLdWxn = (*dLdWx)({0,0, 2*nUn,3*nUn}); |     // Zu = x × Wux + hLast × Wuh + bu
 | ||||||
|  |     // Sr = sigmoid(Zr)
 | ||||||
|  |     // Su = sigmoid(Zu)
 | ||||||
|  |     // Zc = x × Wcx + (r * hlast) × Wch + bc
 | ||||||
| 
 | 
 | ||||||
|     auto dLdWhr = (*dLdWh)({0,0, 0,     nUn}); |  | ||||||
|     auto dLdWhu = (*dLdWh)({0,0, nUn,  2*nUn}); |  | ||||||
|     auto dLdWhn = (*dLdWh)({0,0, 2*nUn,3*nUn}); |  | ||||||
| 
 | 
 | ||||||
|     auto dLdbr = (*dLdb)({0,     nUn}); |     // dLdx = dLdh * dhdx = dLdh * (dhdu * dudx + dhdc * dcdx) = (dLdh * dhdu) * dudx + (dLdh * dhdc) * dcdx = dLdu * dudx + dLdc * dcdx
 | ||||||
|     auto dLdbu = (*dLdb)({nUn,  2*nUn}); |     //      = dLdx_u + dLdx_c
 | ||||||
|     auto dLdbn = (*dLdb)({2*nUn,3*nUn}); |     // dLdx_u = dLdu * dudx = dLdu * dudZu * dZudx = |dZudx = ... × WuxT| = (dLdu * dudZu) × WuxT
 | ||||||
|  |     // dLdx_c = dLdc * dcdx = dLdc * dcdZc * (dZcdx + dZcdr * drdx) = dLdc * dcdZc * dZcdx + dLdc * dcdZc * dZcdr * drdx = dLdx_c0 + dLdx_c1
 | ||||||
|  |     // dLdx_c0 = dLdc * dcdZc * dZcdx = |dZcdx = ... × WcxT| = (dLdc * dcdZc) × WcxT
 | ||||||
|  |     // dZcdr = (... * hLast) × WchT
 | ||||||
|  |     // dLdc * dcdZc * dZcdr = dLdr = (dLdc * dcdZc * hLast) × WchT
 | ||||||
|  |     // drdx = drdZr * dZrdx
 | ||||||
|  |     // dZrdx = ... × WrxT
 | ||||||
|  |     // dLdx_c1 = dLdc * dcdZc * dZcdr * drdx = dLdr * drdx = (dLdr * drdZr) × WrxT
 | ||||||
|  |     // finally dLdx = dLdx_u + dLdx_c0 + dLdx_c1 = (dLdu * dudZu) × WuxT + (dLdc * dcdZc) × WcxT + (dLdr * drdZr) × WrxT
 | ||||||
| 
 | 
 | ||||||
|     auto dhdu   = *h0  - n;              // [bS, nUn]
 |  | ||||||
|     auto dhdn   = 1.f - u;               // [bS, nUn]
 |  | ||||||
|     auto dSigdu = u * (1.f - u);         // [bS, nUn]
 |  | ||||||
|     auto dSigdr = r * (1.f - r);         // [bS, nUn]
 |  | ||||||
|     auto dActdn = 1.f - n * n;           // [bS, nUn]
 |  | ||||||
|     auto dndr   = mmul(dActdn * (*h0), WhnT); |  | ||||||
|     auto drdh0  = mmul(dSigdr, WhrT); |  | ||||||
| 
 | 
 | ||||||
|     auto dLdn = (*dLdh) * dhdn; |     // dLdhLast    = dLdh * (dhdhLast + dhdu * dudhLast + dhdc * dcdhLast) = dLdh * dhdhLast + dLdu * dudhLast + dLdc * dcdhLast
 | ||||||
|     auto dLdu = (*dLdh) * dhdu; |     //             = dLdhLast_h + dLdhLast_u + dLdhLast_c
 | ||||||
|     auto dLdr = dLdn * dndr; |     // dLdhLast_h  = dLdh * dhdhLas = dLdh * u
 | ||||||
|  |     // dLdhLast_u  = dLdu * dudhLast = |dudhLast = dudZu * dZudhLast , dZudhLast = ... × WuhT| = (dLdu * dudZu) × WuhT
 | ||||||
|  |     // dLdhLast_c  = dLdc * dcdhLast  = dLdc * (dcdZc * dZcdhLast + dcdZc * dZcdr * drdhLast) =
 | ||||||
|  |     //             = dLdc * dcdZc * dZcdhLast + dLdc * dcdZc * dZcdr * drdhLast =
 | ||||||
|  |     //             = dLdc * dcdZc * dZcdhLast + dLdr * drdhLast = dLdhLast_c0 + dLdhLast_c1
 | ||||||
|  |     // dLdhLast_c0 = dLdc * dcdZc * dZcdhLast = |dZcdhLast = (... * r) × WchT| = (dLdc * dcdZc * r) × WchT
 | ||||||
|  |     // dLdhLast_c1 = dLdr * drdhLast = |drdhLast  = drdZr * dZrdhLast, dZrdhLast = ... × WrhT| = (dLdr * drdZr) × WrhT
 | ||||||
|  |     // finally dLdhLast = dLdhLast_h + dLdhLast_u + dLdhLast_c0 + dLdhLast_c1 =
 | ||||||
|  |     //                  = dLdh * u + (dLdu * dudZu) × WuhT + (dLdc * dcdZc * r) × WchT + (dLdr * drdZr) × WrhT
 | ||||||
| 
 | 
 | ||||||
|     dLdx->assign( mmul(dLdu * dSigdu, WxuT) + mmul(dLdr * dSigdr, WxrT) + mmul(dLdn * dActdn, WxnT) );      // [bS,iS]
 |  | ||||||
|     dLdh0->assign( mmul(dLdu * dSigdu, WhuT) + mmul(dLdn * dActdn * (r + drdh0), WhnT) + (*dLdh)*u );       // [bS,nUn]
 |  | ||||||
| 
 | 
 | ||||||
|     dLdWxr.assign( mmul(xT, dSigdr * dLdr) );                                                               //  [iS,nUn]
 |     // dLdWrx = dLdh * dhdWrx = (dLdh * dhdc) * dcdWrx = dLdc * dcdZc * dZcdWrx = dLdc * dcdZc * dZcdr * drdWrx =
 | ||||||
|     dLdWhr.assign( mmul(h0T, dSigdr * dLdr) );                                                              //  [nUn,nUn]
 |     //        = dLdc * dcdZc * dZcdr * drdZr * dZrdWrx = dLdr * drdZr * dZrdWrx
 | ||||||
|  |     // dZrdWrx = xT × ...
 | ||||||
|  |     // finally dLdWrx = xT × (dLdr * drdZr)
 | ||||||
| 
 | 
 | ||||||
|     dLdWxu.assign( mmul(xT, dSigdu * dLdu) );                                                               //  [iS,nUn]
 |  | ||||||
|     dLdWhu.assign( mmul(h0T, dSigdu * dLdu) );                                                              //  [nUn,nUn]
 |  | ||||||
| 
 | 
 | ||||||
|     dLdWxn.assign( mmul(xT, dActdn * dLdn) );                                                               //  [iS,nUn]
 |     // dLdWrh = dLdh * dhdWrh = (dLdh * dhdc) * dcdWrh = dLdc * dcdZc * dZcdWrh = dLdc * dcdZc * dZcdr * drdWrh =
 | ||||||
|     dLdWhn.assign( mmul((r*(*h0)).transpose(), dActdn * dLdn) );                                               //  [nUn,nUn]
 |     //        = dLdc * dcdZc * dZcdr * drdZr * dZrdWrh = dLdr * drdZr * dZrdWrh
 | ||||||
|  |     // dZrdWrh = hLastT × ...
 | ||||||
|  |     // finally dLdWrh = hLastT × (dLdr * drdZr)
 | ||||||
| 
 | 
 | ||||||
|     dLdbr.assign( (dSigdr * dLdr).reduceAlongDims(reduce::Sum, {0}));                          // [nUn]
 |  | ||||||
|     dLdbu.assign( (dSigdu * dLdu).reduceAlongDims(reduce::Sum, {0}));                          // [nUn]
 |  | ||||||
|     dLdbn.assign( (dActdn * dLdn).reduceAlongDims(reduce::Sum, {0}));                          // [nUn]
 |  | ||||||
| 
 | 
 | ||||||
|     if(dLdWx0 != nullptr) |     // dLdWux = dLdh * dhdWux = (dLdh * dhdu) * dudWux = dLdu * dudZu * dZudWux
 | ||||||
|         *dLdWx += *dLdWx0; |     // dZudWux = xT × ...
 | ||||||
|  |     // dLdu * dudZu * dZudWux = xT × (dLdu * dudZu)
 | ||||||
| 
 | 
 | ||||||
|     if(dLdWh0 != nullptr) |  | ||||||
|         *dLdWh += *dLdWh0; |  | ||||||
| 
 | 
 | ||||||
|     if(dLdb0 != nullptr) |     // dLdWuh = dLdh * dhdWuh = (dLdh * dhdu) * dudWuh = dLdh * dhdu * dudZu * dZudWuh = dLdu * dudZu * dZudWuh
 | ||||||
|         *dLdb += *dLdb0; |     // dZudWuh = hLastT × ...
 | ||||||
|  |     // finally dLdWuh = hLastT × (dLdu * dudZu)
 | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
|  |     // dLdWcx = dLdh * dhdWcx = dLdh * dhdc * dcdWcx = (dLdh * dhdc) * dcdZc * dZcdWcx = dLdc * dcdZc * dZcdWcx
 | ||||||
|  |     // dZcdWcx = xT × ...
 | ||||||
|  |     // finally dLdWcx = xT × (dLdc * dcdZc)
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |     // dLdWch = dLdh * dhdWch = dLdh * dhdc * dcdWch = (dLdh * dhdc) * dcdZc * dZcdWch = dLdc * dcdZc * dZcdWch
 | ||||||
|  |     // dZcdWch = (r*hLast)^T × ...
 | ||||||
|  |     // finally dLdWch = (r*hLast)^T × (dLdc * dcdZc)
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |     // dLdbr = dLdh * dhdbr = (dLdh * dhdc) * dcdbr = dLdc * dcdbr = dLdc * dcdZc * dZcdbr = dLdc * dcdZc * dZcdr * drdbr =
 | ||||||
|  |     //       = dLdr * drdZr * dZrdbr
 | ||||||
|  |     // dZrdbr = 1
 | ||||||
|  |     // finally dLdbr = dLdr * drdZr
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |     // dLdbu = dLdh * dhdbu = (dLdh * dhdu) * dudbu = dLdu * dudZu * dZudbu
 | ||||||
|  |     // dZudbu = 1
 | ||||||
|  |     // finally dLdbu = dLdu * dudZu
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |     // dLdbc = dLdh * dhdbc = (dLdh * dhdc) * dcdbc = dLdc * dcdZc * dZcdbc
 | ||||||
|  |     // dZcdbc = 1
 | ||||||
|  |     // finally dLdbc = dLdc * dcdZc
 | ||||||
|  | 
 | ||||||
|  |     NDArray dhdc  = 1.f - u;           // [bS, nU]
 | ||||||
|  |     NDArray dhdu  = *hLast - c;        // [bS, nU]
 | ||||||
|  |     NDArray dudZu = u * dhdc;          // [bS, nU]
 | ||||||
|  |     NDArray drdZr = r * (1.f - r);     // [bS, nU]
 | ||||||
|  |     NDArray dcdZc = 1.f - c * c;       // [bS, nU]
 | ||||||
|  |     NDArray dLdZc = *dLdc * dcdZc;     // [bS, nU]
 | ||||||
|  |     NDArray dLdZu = *dLdu * dudZu;     // [bS, nU]
 | ||||||
|  |     NDArray dLdZr = *dLdr * drdZr;     // [bS, nU]
 | ||||||
|  | 
 | ||||||
|  |     // NDArray dLdc  = *dLdh * dhdc;                       // [bS, nU]
 | ||||||
|  |     // NDArray dLdu  = *dLdh * dhdu;                       // [bS, nU]
 | ||||||
|  |     // NDArray dLdr  = mmul(dLdc * dcdZc * *hLast, WchT);  // [bS, nU]
 | ||||||
|  | 
 | ||||||
|  |     dLdx->assign(mmul(dLdZu, WuxT) + mmul(dLdZc, WcxT) + mmul(dLdZr, WrxT));                        // [bS, iS]
 | ||||||
|  | 
 | ||||||
|  |     dLdhLast->assign(*dLdh * u + mmul(dLdZu, WuhT) + mmul(dLdZc * r, WchT) + mmul(dLdZr, WrhT));    // [bS, nU]
 | ||||||
|  | 
 | ||||||
|  |     dLdWrx.assign(mmul(xT,     dLdZr));     // [iS, bS] × [bS, nU] = [iS, nU]
 | ||||||
|  |     dLdWrh.assign(mmul(hLastT, dLdZr));     // [nU, bS] × [bS, nU] = [nU, nU]
 | ||||||
|  |     dLdWux.assign(mmul(xT,     dLdZu));     // [iS, bS] × [bS, nU] = [iS, nU]
 | ||||||
|  |     dLdWuh.assign(mmul(hLastT, dLdZu));     // [nU, bS] × [bS, nU] = [nU, nU]
 | ||||||
|  | 
 | ||||||
|  |     dLdWcx.assign(mmul(xT, dLdZc));                          // [iS, bS] × [bS, nU] = [iS, nU]
 | ||||||
|  |     dLdWch.assign(mmul((r * *hLast).transpose(), dLdZc));    // [nU, bS] × [bS, nU] = [nU, nU]
 | ||||||
|  | 
 | ||||||
|  |     dLdbr.assign(dLdZr.reduceAlongDims(reduce::Sum, {0}));  // [nU]
 | ||||||
|  |     dLdbu.assign(dLdZu.reduceAlongDims(reduce::Sum, {0}));  // [nU]
 | ||||||
|  | 
 | ||||||
|  |     dLdbc->assign(dLdZc.reduceAlongDims(reduce::Sum, {0})); // [nU]
 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // //////////////////////////////////////////////////////////////////////////
 | // //////////////////////////////////////////////////////////////////////////
 | ||||||
| @ -255,34 +364,34 @@ void gruCellBP(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h | |||||||
| // void gruTimeLoopBP(const std::vector<NDArray<T>*>& inArrs, const std::vector<NDArray<T>*>& outArrs) {
 | // void gruTimeLoopBP(const std::vector<NDArray<T>*>& inArrs, const std::vector<NDArray<T>*>& outArrs) {
 | ||||||
| 
 | 
 | ||||||
| //     NDArray<T>* x      = inArrs[0];                   // input [time, bS, iS]
 | //     NDArray<T>* x      = inArrs[0];                   // input [time, bS, iS]
 | ||||||
| //     NDArray<T>* hi     = inArrs[1];                   // previous/initial cell output [bS, nUn],  that is at previous time step t-1
 | //     NDArray<T>* hi     = inArrs[1];                   // previous/initial cell output [bS, nU],  that is at previous time step t-1
 | ||||||
| //     NDArray<T>* Wx     = inArrs[2];                   // input-to-hidden  weights, [iS, 3*nUn]
 | //     NDArray<T>* Wx     = inArrs[2];                   // input-to-hidden  weights, [iS, 3*nU]
 | ||||||
| //     NDArray<T>* Wh     = inArrs[3];                   // hidden-to-hidden weights, [nUn, 3*nUn]
 | //     NDArray<T>* Wh     = inArrs[3];                   // hidden-to-hidden weights, [nU, 3*nU]
 | ||||||
| //     NDArray<T>* b      = inArrs[4];                   // biases, [3*nUn]
 | //     NDArray<T>* b      = inArrs[4];                   // biases, [3*nU]
 | ||||||
| //     NDArray<T>* dLdh   = inArrs[5];                   // gradient wrt output, [time, bS, nUn], that is epsilon_next
 | //     NDArray<T>* dLdh   = inArrs[5];                   // gradient wrt output, [time, bS, nU], that is epsilon_next
 | ||||||
| 
 | 
 | ||||||
| //     NDArray<T>* dLdx   = outArrs[0];                  // gradient wrt x,  [time, bS, iS], that is epsilon
 | //     NDArray<T>* dLdx   = outArrs[0];                  // gradient wrt x,  [time, bS, iS], that is epsilon
 | ||||||
| //     NDArray<T>* dLdhi  = outArrs[1];                  // gradient wrt hi, [bS, nUn]
 | //     NDArray<T>* dLdhi  = outArrs[1];                  // gradient wrt hi, [bS, nU]
 | ||||||
| //     NDArray<T>* dLdWx  = outArrs[2];                  // gradient wrt Wx, [iS, 3*nUn]
 | //     NDArray<T>* dLdWx  = outArrs[2];                  // gradient wrt Wx, [iS, 3*nU]
 | ||||||
| //     NDArray<T>* dLdWh  = outArrs[3];                  // gradient wrt Wh, [nUn, 3*nUn]
 | //     NDArray<T>* dLdWh  = outArrs[3];                  // gradient wrt Wh, [nU, 3*nU]
 | ||||||
| //     NDArray<T>* dLdb   = outArrs[4];                  // gradient wrt b,  [3*nUn]
 | //     NDArray<T>* dLdb   = outArrs[4];                  // gradient wrt b,  [3*nU]
 | ||||||
| 
 | 
 | ||||||
| //     const Nd4jLong time = x->sizeAt(0);
 | //     const Nd4jLong time = x->sizeAt(0);
 | ||||||
| //     const Nd4jLong bS   = x->sizeAt(1);
 | //     const Nd4jLong bS   = x->sizeAt(1);
 | ||||||
| //     const Nd4jLong iS   = x->sizeAt(2);
 | //     const Nd4jLong iS   = x->sizeAt(2);
 | ||||||
| //     const Nd4jLong nUn   = hi->sizeAt(1);
 | //     const Nd4jLong nU   = hi->sizeAt(1);
 | ||||||
| 
 | 
 | ||||||
| //     NDArray<T> h(hi->ordering(), {time, bS, nUn});      // feed forward output
 | //     NDArray<T> h(hi->ordering(), {time, bS, nU});      // feed forward output
 | ||||||
| 
 | 
 | ||||||
| //     // first step, time = 0, feed forward
 | //     // first step, time = 0, feed forward
 | ||||||
| //     NDArray<T> x0 = (*x)({{0,1}, {}, {}});
 | //     NDArray<T> x0 = (*x)({{0,1}, {}, {}});
 | ||||||
| //     NDArray<T> h0 = h({{0,1}, {}, {}});
 | //     NDArray<T> hLast = h({{0,1}, {}, {}});
 | ||||||
| //     helpers::gruCell<T>({&x0, hi, Wx, Wh, b}, &h0);
 | //     helpers::gruCell<T>({&x0, hi, Wx, Wh, b}, &hLast);
 | ||||||
| 
 | 
 | ||||||
| //     // first step, time = 0, back prop
 | //     // first step, time = 0, back prop
 | ||||||
| //     NDArray<T> dLdx0 = (*dLdx)({{0,1}, {}, {}});
 | //     NDArray<T> dLdx0 = (*dLdx)({{0,1}, {}, {}});
 | ||||||
| //     NDArray<T> dLdh0 = (*dLdh)({{0,1}, {}, {}});
 | //     NDArray<T> dLdhLast = (*dLdh)({{0,1}, {}, {}});
 | ||||||
| //     helpers::gruCellBP<T>({&x0, hi, Wx, Wh, b, &dLdh0, nullptr, nullptr, nullptr}, {&dLdx0, dLdhi, dLdWx, dLdWh, dLdb});
 | //     helpers::gruCellBP<T>({&x0, hi, Wx, Wh, b, &dLdhLast, nullptr, nullptr, nullptr}, {&dLdx0, dLdhi, dLdWx, dLdWh, dLdb});
 | ||||||
| 
 | 
 | ||||||
| //     // loop through the rest time steps
 | //     // loop through the rest time steps
 | ||||||
| //     for (Nd4jLong t = time-1; t > 0; --t) {
 | //     for (Nd4jLong t = time-1; t > 0; --t) {
 | ||||||
| @ -310,4 +419,3 @@ void gruCellBP(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h | |||||||
| } | } | ||||||
| } | } | ||||||
| } | } | ||||||
| 
 |  | ||||||
|  | |||||||
| @ -20,6 +20,8 @@ | |||||||
| 
 | 
 | ||||||
| #include <ops/declarable/helpers/image_suppression.h> | #include <ops/declarable/helpers/image_suppression.h> | ||||||
| //#include <blas/NDArray.h>
 | //#include <blas/NDArray.h>
 | ||||||
|  | #include <algorithm> | ||||||
|  | #include <numeric> | ||||||
| 
 | 
 | ||||||
| namespace nd4j { | namespace nd4j { | ||||||
| namespace ops { | namespace ops { | ||||||
| @ -28,9 +30,8 @@ namespace helpers { | |||||||
|     template <typename T> |     template <typename T> | ||||||
|     static void nonMaxSuppressionV2_(NDArray* boxes, NDArray* scales, int maxSize, double threshold, NDArray* output) { |     static void nonMaxSuppressionV2_(NDArray* boxes, NDArray* scales, int maxSize, double threshold, NDArray* output) { | ||||||
|         std::vector<Nd4jLong> indices(scales->lengthOf()); |         std::vector<Nd4jLong> indices(scales->lengthOf()); | ||||||
|  |         std::iota(indices.begin(), indices.end(), 0); | ||||||
| 
 | 
 | ||||||
|         for (size_t i = 0; i < indices.size(); ++i) |  | ||||||
|             indices[i] = i; |  | ||||||
|         std::sort(indices.begin(), indices.end(), [scales](int i, int j) {return scales->e<T>(i) > scales->e<T>(j);}); |         std::sort(indices.begin(), indices.end(), [scales](int i, int j) {return scales->e<T>(i) > scales->e<T>(j);}); | ||||||
| 
 | 
 | ||||||
| //        std::vector<int> selected(output->lengthOf());
 | //        std::vector<int> selected(output->lengthOf());
 | ||||||
| @ -62,13 +63,15 @@ namespace helpers { | |||||||
|         }; |         }; | ||||||
| //        int numSelected = 0;
 | //        int numSelected = 0;
 | ||||||
|         int numBoxes = boxes->sizeAt(0); |         int numBoxes = boxes->sizeAt(0); | ||||||
|  |         int numSelected = 0; | ||||||
| 
 | 
 | ||||||
|         for (int i = 0, numSelected = 0; i < numBoxes && numSelected < output->lengthOf(); ++i) { |         for (int i = 0; i < numBoxes; ++i) { | ||||||
|             bool shouldSelect = true; |             bool shouldSelect = numSelected < output->lengthOf(); | ||||||
|  |             PRAGMA_OMP_PARALLEL_FOR //_ARGS(firstprivate(numSelected))
 | ||||||
|             for (int j = numSelected - 1; j >= 0; --j) { |             for (int j = numSelected - 1; j >= 0; --j) { | ||||||
|  |                 if (shouldSelect) | ||||||
|                 if (needToSuppressWithThreshold(*boxes, indices[i], indices[selectedIndices[j]], T(threshold))) { |                 if (needToSuppressWithThreshold(*boxes, indices[i], indices[selectedIndices[j]], T(threshold))) { | ||||||
|                     shouldSelect = false; |                     shouldSelect = false; | ||||||
|                     break; |  | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|             if (shouldSelect) { |             if (shouldSelect) { | ||||||
|  | |||||||
| @ -24,20 +24,20 @@ namespace nd4j { | |||||||
| namespace ops { | namespace ops { | ||||||
| namespace helpers { | namespace helpers { | ||||||
| 
 | 
 | ||||||
|     template <typename T> |     template <typename I, typename B> | ||||||
|     static void sequenceMask_(NDArray* input, NDArray* output, int maxIndex) { |     static void sequenceMask_(NDArray* input, NDArray* output, int maxIndex) { | ||||||
|         PRAGMA_OMP_PARALLEL_FOR_SIMD_COLLAPSE(2) |         PRAGMA_OMP_PARALLEL_FOR_SIMD_COLLAPSE(2) | ||||||
|         for (Nd4jLong i = 0; i < maxIndex; i++) |         for (Nd4jLong i = 0; i < maxIndex; i++) | ||||||
|             for(Nd4jLong k = 0; k < input->lengthOf(); k++) |             for(Nd4jLong k = 0; k < input->lengthOf(); k++) | ||||||
|                 if (i < input->e<int>(k)) |                 if (i < input->t<I>(k)) | ||||||
|                     output->p<T>(k * maxIndex + i,  T(1.0f)); |                     output->t<B>(k * maxIndex + i) = B(true); //,  T(1.0f));
 | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     void sequenceMask(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int maxIndex) { |     void sequenceMask(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int maxIndex) { | ||||||
|         BUILD_SINGLE_SELECTOR(input->dataType(), sequenceMask_, (input, output, maxIndex), LIBND4J_TYPES); |         BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sequenceMask_, (input, output, maxIndex), INTEGER_TYPES, BOOL_TYPES); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     BUILD_SINGLE_TEMPLATE(template void sequenceMask_, (NDArray* input, NDArray* output, int maxIndex), LIBND4J_TYPES); |     BUILD_DOUBLE_TEMPLATE(template void sequenceMask_, (NDArray* input, NDArray* output, int maxIndex), INTEGER_TYPES, BOOL_TYPES); | ||||||
| } | } | ||||||
| } | } | ||||||
| } | } | ||||||
| @ -27,6 +27,7 @@ | |||||||
| #include <helpers/TAD.h> | #include <helpers/TAD.h> | ||||||
| #include <helpers/ConstantTadHelper.h> | #include <helpers/ConstantTadHelper.h> | ||||||
| #include <Loops.h> | #include <Loops.h> | ||||||
|  | #include <graph/RandomGenerator.h> | ||||||
| 
 | 
 | ||||||
| namespace nd4j 	  { | namespace nd4j 	  { | ||||||
| namespace ops 	  { | namespace ops 	  { | ||||||
| @ -81,7 +82,7 @@ static void trace_(const NDArray& input, NDArray& output) { | |||||||
| 
 | 
 | ||||||
| //////////////////////////////////////////////////////////////////////////
 | //////////////////////////////////////////////////////////////////////////
 | ||||||
| template <typename T> | template <typename T> | ||||||
| void randomShuffle_(NDArray& input, NDArray& output, nd4j::random::RandomBuffer& rng, const bool isInplace) { | void randomShuffle_(NDArray& input, NDArray& output, nd4j::graph::RandomGenerator& rng, const bool isInplace) { | ||||||
| 
 | 
 | ||||||
|     // check edge cases first
 |     // check edge cases first
 | ||||||
|     int temp; |     int temp; | ||||||
| @ -95,16 +96,16 @@ void randomShuffle_(NDArray& input, NDArray& output, nd4j::random::RandomBuffer& | |||||||
| 
 | 
 | ||||||
|         // apply Fisher-Yates shuffle
 |         // apply Fisher-Yates shuffle
 | ||||||
|         if(isInplace) { |         if(isInplace) { | ||||||
|             PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->tadThreshold()) |             //PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->tadThreshold())
 | ||||||
|             for(int i = firstDim-1; i > 0; --i) { |             for(int i = firstDim-1; i > 0; --i) { | ||||||
|                 int r = rng.nextInt(0, i); |                 int r = rng.relativeInt(i) % i; | ||||||
|                 if(i == r) |                 if(i == r) | ||||||
|                     continue; |                     continue; | ||||||
|                 T _e0 = input.e<T>(i); |                 T t0 = input.t<T>(i); | ||||||
|                 T _e1 = input.e<T>(r); |                 T t1 = input.t<T>(r); | ||||||
|                 //math::nd4j_swap<T>(input(i), input(r));
 |                 //math::nd4j_swap<T>(input(i), input(r));
 | ||||||
|                 input.p<T>(i, _e1); |                 input.t<T>(i) = t1; | ||||||
|                 input.p<T>(r, _e0); |                 input.t<T>(r) = t0; | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|         else { |         else { | ||||||
| @ -113,12 +114,12 @@ void randomShuffle_(NDArray& input, NDArray& output, nd4j::random::RandomBuffer& | |||||||
|             output.p<T>(Nd4jLong(0), input.e<T>(0)); |             output.p<T>(Nd4jLong(0), input.e<T>(0)); | ||||||
|             PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->tadThreshold()) |             PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->tadThreshold()) | ||||||
|             for(int i = firstDim-1; i > 0; --i) { |             for(int i = firstDim-1; i > 0; --i) { | ||||||
|                 int r = rng.nextInt(0, i); |                 int r = rng.relativeInt(i) % i; | ||||||
|                 output.p(i, input.e<T>(indices[r])); |                 output.t<T>(i) = input.t<T>(indices[r]); | ||||||
|                 if(i == r) |                 if(i == r) | ||||||
|                     continue; |                     continue; | ||||||
| 
 | 
 | ||||||
|                 output.p(r, input.e<T>(indices[i])); |                 output.t<T>(r) = input.t<T>(indices[i]); | ||||||
|                 math::nd4j_swap<int>(indices[i], indices[r]); |                 math::nd4j_swap<int>(indices[i], indices[r]); | ||||||
|             } |             } | ||||||
|             rng.rewindH(firstDim-1); |             rng.rewindH(firstDim-1); | ||||||
| @ -132,9 +133,10 @@ void randomShuffle_(NDArray& input, NDArray& output, nd4j::random::RandomBuffer& | |||||||
| 
 | 
 | ||||||
|         // apply Fisher-Yates shuffle
 |         // apply Fisher-Yates shuffle
 | ||||||
|         if(isInplace) { |         if(isInplace) { | ||||||
|             PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->elementwiseThreshold()) |             //PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->elementwiseThreshold())
 | ||||||
|             for(int i = firstDim - 1; i > 0; --i) { |             for(int i = firstDim - 1; i > 0; --i) { | ||||||
|                 int r = rng.nextInt(0, i); |                 int r = rng.relativeInt(i) % i; | ||||||
|  | 
 | ||||||
|                 if(i == r) |                 if(i == r) | ||||||
|                     continue; |                     continue; | ||||||
|                 subArrsListIn->at(i)->swapUnsafe(*subArrsListIn->at(r)); |                 subArrsListIn->at(i)->swapUnsafe(*subArrsListIn->at(r)); | ||||||
| @ -146,9 +148,9 @@ void randomShuffle_(NDArray& input, NDArray& output, nd4j::random::RandomBuffer& | |||||||
|             std::vector<int> indices(firstDim); |             std::vector<int> indices(firstDim); | ||||||
|             std::iota(indices.begin(), indices.end(), 0); |             std::iota(indices.begin(), indices.end(), 0); | ||||||
|             bool isZeroShuffled = false; |             bool isZeroShuffled = false; | ||||||
|             PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->tadThreshold()) |             //PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->tadThreshold())
 | ||||||
|             for(int i = firstDim - 1; i > 0; --i) { |             for(int i = firstDim - 1; i > 0; --i) { | ||||||
|                 int r = rng.nextInt(0, i); |                 int r = rng.relativeInt(i) % i; | ||||||
|                 subArrsListOut->at(i)->assign(subArrsListIn->at(indices[r])); |                 subArrsListOut->at(i)->assign(subArrsListIn->at(indices[r])); | ||||||
|                 if(r == 0) |                 if(r == 0) | ||||||
|                     isZeroShuffled = true; |                     isZeroShuffled = true; | ||||||
| @ -167,11 +169,11 @@ void randomShuffle_(NDArray& input, NDArray& output, nd4j::random::RandomBuffer& | |||||||
| 
 | 
 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|     void randomShuffle(nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::random::RandomBuffer& rng, const bool isInplace) { |     void randomShuffle(nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::graph::RandomGenerator& rng, const bool isInplace) { | ||||||
|         BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (input, output, rng, isInplace), LIBND4J_TYPES); |         BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (input, output, rng, isInplace), LIBND4J_TYPES); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     BUILD_SINGLE_TEMPLATE(template void randomShuffle_, (NDArray& input, NDArray& output, nd4j::random::RandomBuffer& rng, const bool isInplace), LIBND4J_TYPES); |     BUILD_SINGLE_TEMPLATE(template void randomShuffle_, (NDArray& input, NDArray& output, nd4j::graph::RandomGenerator& rng, const bool isInplace), LIBND4J_TYPES); | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -16,15 +16,92 @@ | |||||||
| 
 | 
 | ||||||
| // | // | ||||||
| // @author raver119@gmail.com | // @author raver119@gmail.com | ||||||
|  | // @author Yurii Shyrma (iuriish@yahoo.com) | ||||||
| // | // | ||||||
| 
 | 
 | ||||||
| #include <ops/declarable/helpers/adjust_hue.h> | #include <ops/declarable/helpers/adjust_hue.h> | ||||||
| #include <helpers/ConstantTadHelper.h> | #include <helpers/ConstantTadHelper.h> | ||||||
|  | #include <PointersManager.h> | ||||||
| 
 | 
 | ||||||
| namespace nd4j    { | namespace nd4j    { | ||||||
| namespace ops     { | namespace ops     { | ||||||
| namespace helpers { | namespace helpers { | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
|  | /////////////////////////////////////////////////////////////////// | ||||||
|  | template <typename T> | ||||||
|  | static void _CUDA_G adjustHueCuda(const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, | ||||||
|  |                                         void* vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets, | ||||||
|  |                                         const Nd4jLong numOfTads, const T delta, const int dimC) { | ||||||
|  | 
 | ||||||
|  |     const T* x = reinterpret_cast<const T*>(vx); | ||||||
|  |           T* z = reinterpret_cast<T*>(vz); | ||||||
|  | 
 | ||||||
|  |     __shared__ int rank; | ||||||
|  |     __shared__ Nd4jLong xDimCstride, zDimCstride; | ||||||
|  | 
 | ||||||
|  |     if (threadIdx.x == 0) { | ||||||
|  |         rank = shape::rank(xShapeInfo); | ||||||
|  |         xDimCstride = shape::stride(xShapeInfo)[dimC]; | ||||||
|  |         zDimCstride = shape::stride(zShapeInfo)[dimC]; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     __syncthreads(); | ||||||
|  | 
 | ||||||
|  |     const auto tid = blockIdx.x * blockDim.x + threadIdx.x; | ||||||
|  | 
 | ||||||
|  |     for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) { | ||||||
|  | 
 | ||||||
|  |         const T* xTad = x + xTadOffsets[i]; | ||||||
|  |               T* zTad = z + zTadOffsets[i]; | ||||||
|  | 
 | ||||||
|  |         T h, s, v; | ||||||
|  | 
 | ||||||
|  |         rgbToHsv<T>(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], h, s, v); | ||||||
|  | 
 | ||||||
|  |         h += delta * 360; | ||||||
|  |         if(h > 360) | ||||||
|  |             h -= 360; | ||||||
|  |         else if(h < 0) | ||||||
|  |             h += 360; | ||||||
|  | 
 | ||||||
|  |         hsvToRgb<T>(h, s, v, zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | /////////////////////////////////////////////////////////////////// | ||||||
|  | template<typename T> | ||||||
|  | static _CUDA_H void adjustHueCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, | ||||||
|  |                                           const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, | ||||||
|  |                                                 void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets, | ||||||
|  |                                           const Nd4jLong numOfTads, const NDArray* deltaScalarArr, const int dimC) { | ||||||
|  | 
 | ||||||
|  |     adjustHueCuda<T><<<blocksPerGrid, threadsPerBlock, 256, *stream>>>(vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, deltaScalarArr->e<T>(0), dimC); | ||||||
|  | } | ||||||
|  | BUILD_SINGLE_TEMPLATE(template void adjustHueCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, const NDArray* deltaScalarArr, const int dimC), LIBND4J_TYPES); | ||||||
|  | 
 | ||||||
|  | //////////////////////////////////////////////////////////////////////// | ||||||
|  | void adjustHue(nd4j::LaunchContext* context, const NDArray *input, const NDArray* deltaScalarArr, NDArray *output, const int dimC) { | ||||||
|  | 
 | ||||||
|  |     auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(),  {dimC}); | ||||||
|  |     auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {dimC}); | ||||||
|  | 
 | ||||||
|  |     const Nd4jLong numOfTads = packX.numberOfTads(); | ||||||
|  | 
 | ||||||
|  |     const int threadsPerBlock = MAX_NUM_THREADS / 2; | ||||||
|  |     const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock; | ||||||
|  | 
 | ||||||
|  |     PointersManager manager(context, "adjustHue"); | ||||||
|  | 
 | ||||||
|  |     NDArray::prepareSpecialUse({output}, {input, deltaScalarArr}); | ||||||
|  |     BUILD_SINGLE_SELECTOR(input->dataType(), adjustHueCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformOffsets(), numOfTads, deltaScalarArr, dimC), LIBND4J_TYPES); | ||||||
|  |     NDArray::registerSpecialUse({output}, {input, deltaScalarArr}); | ||||||
|  | 
 | ||||||
|  |     manager.synchronize(); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | /* | ||||||
| template <typename T> | template <typename T> | ||||||
| static void _CUDA_G adjustHueSingleNHWCKernel(void *xBuffer, Nd4jLong *xShapeInfo,  void *zBuffer, Nd4jLong *zShapeInfo, Nd4jLong tuples, float delta) { | static void _CUDA_G adjustHueSingleNHWCKernel(void *xBuffer, Nd4jLong *xShapeInfo,  void *zBuffer, Nd4jLong *zShapeInfo, Nd4jLong tuples, float delta) { | ||||||
|     int numChannels = 3; |     int numChannels = 3; | ||||||
| @ -134,11 +211,13 @@ namespace helpers { | |||||||
| 
 | 
 | ||||||
|     float d = delta->e<float>(0); |     float d = delta->e<float>(0); | ||||||
|     if (array->rankOf() == 4) { |     if (array->rankOf() == 4) { | ||||||
|  |         BUILD_SINGLE_SELECTOR(xType, _adjust_hue_batch, (context, array, output, d, isNHWC);, FLOAT_TYPES); | ||||||
|     } else { |     } else { | ||||||
|         BUILD_SINGLE_SELECTOR(xType, _adjust_hue_single, (context, array, output, d, isNHWC);, FLOAT_TYPES); |         BUILD_SINGLE_SELECTOR(xType, _adjust_hue_single, (context, array, output, d, isNHWC);, FLOAT_TYPES); | ||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | */ | ||||||
| } | } | ||||||
| } | } | ||||||
| } | } | ||||||
|  | |||||||
| @ -16,16 +16,93 @@ | |||||||
| 
 | 
 | ||||||
| // | // | ||||||
| // @author raver119@gmail.com | // @author raver119@gmail.com | ||||||
|  | // @author Yurii Shyrma (iuriish@yahoo.com) | ||||||
| // | // | ||||||
| 
 | 
 | ||||||
| #include <ops/declarable/helpers/adjust_saturation.h> | #include <ops/declarable/helpers/adjust_saturation.h> | ||||||
|  | #include <ops/declarable/helpers/adjust_hue.h> | ||||||
| #include <helpers/ConstantTadHelper.h> | #include <helpers/ConstantTadHelper.h> | ||||||
|  | #include <PointersManager.h> | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| namespace nd4j    { | namespace nd4j    { | ||||||
| namespace ops     { | namespace ops     { | ||||||
| namespace helpers { | namespace helpers { | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
|  | /////////////////////////////////////////////////////////////////// | ||||||
|  | template <typename T> | ||||||
|  | static void _CUDA_G adjustSaturationCuda(const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, | ||||||
|  |                                                void* vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets, | ||||||
|  |                                         const Nd4jLong numOfTads, const T factor, const int dimC) { | ||||||
|  | 
 | ||||||
|  |     const T* x = reinterpret_cast<const T*>(vx); | ||||||
|  |           T* z = reinterpret_cast<T*>(vz); | ||||||
|  | 
 | ||||||
|  |     __shared__ int rank; | ||||||
|  |     __shared__ Nd4jLong xDimCstride, zDimCstride; | ||||||
|  | 
 | ||||||
|  |     if (threadIdx.x == 0) { | ||||||
|  |         rank = shape::rank(xShapeInfo); | ||||||
|  |         xDimCstride = shape::stride(xShapeInfo)[dimC]; | ||||||
|  |         zDimCstride = shape::stride(zShapeInfo)[dimC]; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     __syncthreads(); | ||||||
|  | 
 | ||||||
|  |     const auto tid = blockIdx.x * blockDim.x + threadIdx.x; | ||||||
|  | 
 | ||||||
|  |     for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) { | ||||||
|  | 
 | ||||||
|  |         const T* xTad = x + xTadOffsets[i]; | ||||||
|  |               T* zTad = z + zTadOffsets[i]; | ||||||
|  | 
 | ||||||
|  |         T h, s, v; | ||||||
|  | 
 | ||||||
|  |         rgbToHsv<T>(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], h, s, v); | ||||||
|  | 
 | ||||||
|  |         s *= factor; | ||||||
|  |         if(s > 1.f) | ||||||
|  |             s = 1.f; | ||||||
|  |         else if(s < 0.f) | ||||||
|  |             s = 0.f; | ||||||
|  | 
 | ||||||
|  |         hsvToRgb<T>(h, s, v, zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]); | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | /////////////////////////////////////////////////////////////////// | ||||||
|  | template<typename T> | ||||||
|  | static _CUDA_H void adjustSaturationCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, | ||||||
|  |                                           const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, | ||||||
|  |                                                 void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets, | ||||||
|  |                                           const Nd4jLong numOfTads, const NDArray* factorScalarArr, const int dimC) { | ||||||
|  | 
 | ||||||
|  |     adjustSaturationCuda<T><<<blocksPerGrid, threadsPerBlock, 256, *stream>>>(vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, factorScalarArr->e<T>(0), dimC); | ||||||
|  | } | ||||||
|  | BUILD_SINGLE_TEMPLATE(template void adjustSaturationCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, const NDArray* factorScalarArr, const int dimC), LIBND4J_TYPES); | ||||||
|  | 
 | ||||||
|  | //////////////////////////////////////////////////////////////////////// | ||||||
|  | void adjustSaturation(nd4j::LaunchContext* context, const NDArray *input, const NDArray* factorScalarArr, NDArray *output, const int dimC) { | ||||||
|  | 
 | ||||||
|  |     auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(),  {dimC}); | ||||||
|  |     auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {dimC}); | ||||||
|  | 
 | ||||||
|  |     const Nd4jLong numOfTads = packX.numberOfTads(); | ||||||
|  | 
 | ||||||
|  |     const int threadsPerBlock = MAX_NUM_THREADS / 2; | ||||||
|  |     const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock; | ||||||
|  | 
 | ||||||
|  |     PointersManager manager(context, "adjustSaturation"); | ||||||
|  | 
 | ||||||
|  |     NDArray::prepareSpecialUse({output}, {input, factorScalarArr}); | ||||||
|  |     BUILD_SINGLE_SELECTOR(input->dataType(), adjustSaturationCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformOffsets(), numOfTads, factorScalarArr, dimC), LIBND4J_TYPES); | ||||||
|  |     NDArray::registerSpecialUse({output}, {input, factorScalarArr}); | ||||||
|  | 
 | ||||||
|  |     manager.synchronize(); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | /* | ||||||
| template <typename T> | template <typename T> | ||||||
| static void _CUDA_G adjustSaturationSingleNHWCKernel(void *xBuffer, Nd4jLong *xShapeInfo,  void *zBuffer, Nd4jLong *zShapeInfo, Nd4jLong tuples, float delta) { | static void _CUDA_G adjustSaturationSingleNHWCKernel(void *xBuffer, Nd4jLong *xShapeInfo,  void *zBuffer, Nd4jLong *zShapeInfo, Nd4jLong tuples, float delta) { | ||||||
|     int numChannels = 3; |     int numChannels = 3; | ||||||
| @ -129,7 +206,7 @@ namespace helpers { | |||||||
|         BUILD_SINGLE_SELECTOR(xType, _adjust_saturation_single, (context, array, output, d, isNHWC);, FLOAT_TYPES); |         BUILD_SINGLE_SELECTOR(xType, _adjust_saturation_single, (context, array, output, d, isNHWC);, FLOAT_TYPES); | ||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | */ | ||||||
| 
 | 
 | ||||||
| } | } | ||||||
| } | } | ||||||
|  | |||||||
| @ -22,20 +22,99 @@ | |||||||
| #include <NativeOps.h> | #include <NativeOps.h> | ||||||
| #include <vector> | #include <vector> | ||||||
| #include <memory> | #include <memory> | ||||||
|  | #include <cuda_exception.h> | ||||||
| 
 | 
 | ||||||
| namespace nd4j { | namespace nd4j { | ||||||
| namespace ops { | namespace ops { | ||||||
| namespace helpers { | namespace helpers { | ||||||
| 
 | 
 | ||||||
|     template <typename T> |     template <typename T> | ||||||
|     static void dropoutSimple(NDArray const* input, NDArray* output, double probValue, int seed) { |     static __global__ void dropoutSimpleKernel(void const* inputBuf, Nd4jLong const* inputShape, void* outputBuf, Nd4jLong* outputShape, double probVal, int inLen, nd4j::graph::RandomGenerator* nodeRng) { | ||||||
|  |         auto tid = blockIdx.x * blockDim.x + threadIdx.x; | ||||||
|  |         auto step = blockDim.x * gridDim.x; | ||||||
|  |         __shared__ T const* input; | ||||||
|  |         __shared__ T* output; | ||||||
| 
 | 
 | ||||||
|  |         if (threadIdx.x == 0) { | ||||||
|  |             input = reinterpret_cast<T const*>(inputBuf); | ||||||
|  |             output = reinterpret_cast<T*>(outputBuf); | ||||||
|         } |         } | ||||||
|     BUILD_SINGLE_TEMPLATE(template void dropoutSimple, (NDArray const* input, NDArray* output, double probValue, int seed), FLOAT_TYPES); | 
 | ||||||
|  |         for (Nd4jLong e = 0; e < inLen; ++e) { | ||||||
|  |             T val = nodeRng->relativeT(e, T(0.f), T(1.f)); | ||||||
|  | 
 | ||||||
|  |             if (double(val) < probVal) | ||||||
|  |                 output[shape::getIndexOffset(e, outputShape, inLen)] = T(input[shape::getIndexOffset(e, inputShape, inLen)] / probVal); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     template <typename T> | ||||||
|  |     static void dropoutSimple(nd4j::LaunchContext* context, NDArray const* input, NDArray* output, double probValue, int seed) { | ||||||
|  |         nd4j::graph::RandomGenerator nodeRng(3019L, seed); | ||||||
|  |         int inLen = input->lengthOf(); | ||||||
|  |         nd4j::graph::RandomGenerator* dRandom; | ||||||
|  |         auto stream = context->getCudaStream(); | ||||||
|  |         NDArray::prepareSpecialUse({output}, {input}); | ||||||
|  | 
 | ||||||
|  |         auto err = cudaMalloc(&dRandom, sizeof(nd4j::graph::RandomGenerator)); | ||||||
|  |         if (err) { | ||||||
|  |             throw cuda_exception::build("helpers::dropoutSimple: Cannot allocate device memory for random generator.", err); | ||||||
|  |         } | ||||||
|  |         err = cudaMemcpy(dRandom, &nodeRng, sizeof(nd4j::graph::RandomGenerator), cudaMemcpyHostToDevice); | ||||||
|  |         if (err) { | ||||||
|  |             throw cuda_exception::build("helpers::dropoutSimple: Cannot set up device memory for random generator.", err); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         dropoutSimpleKernel<T><<<128, 256, 1024, *stream>>>(input->getSpecialBuffer(), input->getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), probValue, inLen, dRandom); | ||||||
|  |         err = cudaFree(dRandom); | ||||||
|  |         if (err) { | ||||||
|  |             throw cuda_exception::build("helpers::dropoutSimple: Cannot deallocate device memory for random generator.", err); | ||||||
|  |         } | ||||||
|  |         NDArray::registerSpecialUse({output}, {input}); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     BUILD_SINGLE_TEMPLATE(template void dropoutSimple, (nd4j::LaunchContext* context, NDArray const* input, NDArray* output, double probValue, int seed), FLOAT_TYPES); | ||||||
| 
 | 
 | ||||||
|     template <typename T> |     template <typename T> | ||||||
|     int _dropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue) { |     int _dropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue) { | ||||||
| 
 | 
 | ||||||
|  |         if (reduceShape == nullptr){ | ||||||
|  |             dropoutSimple<T>(context.launchContext(), input, output, probValue, seed); | ||||||
|  |         } | ||||||
|  |         else { | ||||||
|  |             REQUIRE_TRUE(reduceShape->lengthOf() <= input->rankOf(), 0, "dropout: Noise shape should be fittable to input"); | ||||||
|  | 
 | ||||||
|  |             std::vector<Nd4jLong> dims(reduceShape->lengthOf()); | ||||||
|  |             reduceShape->syncToHost(); // to ensure that follows are actual | ||||||
|  |             bool fit = true; | ||||||
|  | //            PRAGMA_OMP_PARALLEL_FOR_ARGS(firstprivate(fit)) | ||||||
|  |             for( int i = 0; i < dims.size(); i++ ) { | ||||||
|  |                 if (fit) { | ||||||
|  |                     dims[i] = reduceShape->e<Nd4jLong>(i); | ||||||
|  |                     for (int e = 0; e < input->rankOf(); ++e) | ||||||
|  |                         if (fit) | ||||||
|  |                             if (input->sizeAt(e) % dims[i]) { | ||||||
|  |                                 fit = false; | ||||||
|  |                             } | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             // check dims to fit input | ||||||
|  |             REQUIRE_TRUE(fit, 0, "dropout: Noise shape should fit to input rank."); | ||||||
|  |             std::unique_ptr<NDArray> chunk(new NDArray('c', dims, output->dataType(), context.launchContext())); | ||||||
|  |             chunk->assign(1.f); | ||||||
|  |             //chunk->applyRandom<randomOps::DropOutInverted<T>>(rng, nullptr, chunk.get(), &probValue); | ||||||
|  |             //NativeOpExecutioner::execRandom(random::DropOutInverted, rng, chunk->buffer(), chunk->shapeInfo(), chunk->buffer(), chunk->shapeInfo(), &prob); | ||||||
|  |             dropoutSimple<T>(context.launchContext(), chunk.get(), chunk.get(), probValue, seed); | ||||||
|  |             // broadcast chunk to full matrix | ||||||
|  |             std::unique_ptr<NDArray> dropOutMultiplier(new NDArray(*input)); | ||||||
|  |             dropOutMultiplier->assign(1.f); | ||||||
|  | 
 | ||||||
|  |             *dropOutMultiplier += *chunk; | ||||||
|  | 
 | ||||||
|  |             output->assign(*input * *dropOutMultiplier); //input->applyPairwiseTransform(pairwise::Multiply, dropOutMultiplier.get(), output, nullptr); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|         return Status::OK(); |         return Status::OK(); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| @ -48,14 +127,121 @@ namespace helpers { | |||||||
|     BUILD_SINGLE_TEMPLATE(template int _dropOutFunctor, (graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue);, FLOAT_TYPES); |     BUILD_SINGLE_TEMPLATE(template int _dropOutFunctor, (graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue);, FLOAT_TYPES); | ||||||
| 
 | 
 | ||||||
| /////////////////////////////////// backrpopagations /////////////////////////////////////////////// | /////////////////////////////////// backrpopagations /////////////////////////////////////////////// | ||||||
|  |     template <typename T> | ||||||
|  |     static __global__ void dropoutBPKernel(void* outputBuf, Nd4jLong* outputShape, void* gradOutBuf, Nd4jLong* gradOutShape, double probValue) { | ||||||
|  |         __shared__ T* output; | ||||||
|  |         __shared__ T* input; | ||||||
|  |         __shared__ int len; | ||||||
|  | 
 | ||||||
|  |         if (threadIdx.x == 0) { | ||||||
|  |             len = shape::length(outputShape); | ||||||
|  |             output = reinterpret_cast<T*>(outputBuf); | ||||||
|  |             input = reinterpret_cast<T*>(gradOutBuf); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         auto tid = blockIdx.x * blockDim.x + threadIdx.x; | ||||||
|  |         auto step = blockDim.x * gridDim.x; | ||||||
|  | 
 | ||||||
|  |         for (int e = tid; e < len; e += step) { | ||||||
|  |             if (output[shape::getIndexOffset(e, outputShape, len)] != T(0.)) | ||||||
|  |                 output[shape::getIndexOffset(e, outputShape, len)] = T(input[shape::getIndexOffset(e, gradOutShape, len)] / probValue); | ||||||
|  | 
 | ||||||
|  |         } | ||||||
|  |     } | ||||||
|     template <typename T> |     template <typename T> | ||||||
|     static int dropOutFunctorBP_(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue) { |     static int dropOutFunctorBP_(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue) { | ||||||
|         return Status::OK(); |         int res = dropOutFunctor(context, input, output, reduceShape, seed, probValue); | ||||||
|  |         auto stream = context.launchContext()->getCudaStream(); | ||||||
|  | 
 | ||||||
|  |         if (ND4J_STATUS_OK == res) | ||||||
|  |             dropoutBPKernel<T><<<128, 256, 1024, *stream>>>(output->specialBuffer(), output->specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), probValue); | ||||||
|  | 
 | ||||||
|  |         return res; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     template <typename T> | ||||||
|  |     static __global__ void alphaDropoutSimpleKernel(void const* inputBuf, Nd4jLong const* inputShape, void* outputBuf, Nd4jLong* outputShape, double probValue, double alpha, double alpha1, double beta, int inLen, nd4j::graph::RandomGenerator* nodeRng) { | ||||||
|  |         auto tid = blockIdx.x * blockDim.x + threadIdx.x; | ||||||
|  |         auto step = blockDim.x * gridDim.x; | ||||||
|  |         __shared__ T const* input; | ||||||
|  |         __shared__ T* output; | ||||||
|  | 
 | ||||||
|  |         if (threadIdx.x == 0) { | ||||||
|  |             input = reinterpret_cast<T const*>(inputBuf); | ||||||
|  |             output = reinterpret_cast<T*>(outputBuf); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         for (auto e = tid; e < inLen; e += step) { | ||||||
|  |             T val = nodeRng->relativeT(e, T(0.f), T(1.f)); | ||||||
|  |             T xVal = input[shape::getIndexOffset(e, inputShape, inLen)]; | ||||||
|  |             output[shape::getIndexOffset(e, outputShape, inLen)] = (val >= T(probValue) ? T(alpha * beta + alpha1) : T(alpha * (double)xVal + alpha1)); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     template <typename T> | ||||||
|  |     static void alphaDropoutSimple(nd4j::LaunchContext* context, NDArray const* input, NDArray* output, int seed, double probValue, double alpha, double alpha1, double beta) { | ||||||
|  |         nd4j::graph::RandomGenerator nodeRng(3019L, seed), *dRandom; | ||||||
|  |         auto stream = context->getCudaStream(); | ||||||
|  |         auto err = cudaMalloc(&dRandom, sizeof(nd4j::graph::RandomGenerator)); | ||||||
|  |         NDArray::prepareSpecialUse({output}, {input}); | ||||||
|  |         if (err) { | ||||||
|  |             throw cuda_exception::build("helpers::alphaDropoutSimple: Cannot allocate device memory for random generator.", err); | ||||||
|  |         } | ||||||
|  |         err = cudaMemcpy(dRandom, &nodeRng, sizeof(nd4j::graph::RandomGenerator), cudaMemcpyHostToDevice); | ||||||
|  |         if (err) { | ||||||
|  |             throw cuda_exception::build("helpers::alphaDropoutSimple: Cannot set up device memory for random generator.", err); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         alphaDropoutSimpleKernel<T><<<128, 256, 1024, *stream>>>(input->getSpecialBuffer(), input->getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), probValue, alpha, alpha1, beta, output->lengthOf(), dRandom); | ||||||
|  | 
 | ||||||
|  |         err = cudaFree(dRandom); | ||||||
|  |         if (err) { | ||||||
|  |             throw cuda_exception::build("helpers::alphaDropoutSimple: Cannot deallocate device memory for random generator.", err); | ||||||
|  |         } | ||||||
|  |         NDArray::registerSpecialUse({output}, {input}); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     template <typename T> |     template <typename T> | ||||||
|     static int alphaDropOutFunctor_(graph::Context& context, NDArray* input, NDArray* output, |     static int alphaDropOutFunctor_(graph::Context& context, NDArray* input, NDArray* output, | ||||||
|                             NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) { |                             NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) { | ||||||
|  | 
 | ||||||
|  |         if (reduceShape == nullptr){ | ||||||
|  |             alphaDropoutSimple<T>(context.launchContext(), input, output, seed, probValue, alpha, alpha1, beta); | ||||||
|  |         } | ||||||
|  |         else { | ||||||
|  |             REQUIRE_TRUE(reduceShape->lengthOf() <= input->rankOf(), 0, "dropout: Noise shape should be fittable to input"); | ||||||
|  | 
 | ||||||
|  |             std::vector<Nd4jLong> dims(reduceShape->lengthOf()); | ||||||
|  |             reduceShape->syncToHost(); // to ensure that follows are actual | ||||||
|  |             bool fit = true; | ||||||
|  | //            PRAGMA_OMP_PARALLEL_FOR_ARGS(firstprivate(fit)) | ||||||
|  |             for( int i = 0; i < dims.size(); i++ ) { | ||||||
|  |                 if (fit) { | ||||||
|  |                     dims[i] = reduceShape->e<Nd4jLong>(i); | ||||||
|  |                     for (int e = 0; e < input->rankOf(); ++e) | ||||||
|  |                         if (fit) | ||||||
|  |                             if (input->sizeAt(e) % dims[i]) { | ||||||
|  |                                 fit = false; | ||||||
|  |                             } | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             // check dims to fit input | ||||||
|  |             REQUIRE_TRUE(fit, 0, "alpha_dropout: Noise shape should fit to input rank."); | ||||||
|  |             std::unique_ptr<NDArray> chunk(new NDArray('c', dims, output->dataType(), context.launchContext())); | ||||||
|  |             chunk->assign(1.f); | ||||||
|  |             //chunk->applyRandom<randomOps::DropOutInverted<T>>(rng, nullptr, chunk.get(), &probValue); | ||||||
|  |             //NativeOpExecutioner::execRandom(random::DropOutInverted, rng, chunk->buffer(), chunk->shapeInfo(), chunk->buffer(), chunk->shapeInfo(), &prob); | ||||||
|  |             alphaDropoutSimple<T>(context.launchContext(), chunk.get(), chunk.get(), seed, probValue, alpha, alpha1, beta); | ||||||
|  |             // broadcast chunk to full matrix | ||||||
|  |             std::unique_ptr<NDArray> dropOutMultiplier(new NDArray(*input)); | ||||||
|  |             dropOutMultiplier->assign(1.f); | ||||||
|  | 
 | ||||||
|  |             *dropOutMultiplier += *chunk; | ||||||
|  | 
 | ||||||
|  |             output->assign(*input * *dropOutMultiplier); //input->applyPairwiseTransform(pairwise::Multiply, dropOutMultiplier.get(), output, nullptr); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|         return Status::OK(); |         return Status::OK(); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| @ -63,7 +249,12 @@ namespace helpers { | |||||||
|     int alphaDropOutFunctorBP_(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, |     int alphaDropOutFunctorBP_(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, | ||||||
|                               NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) { |                               NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) { | ||||||
| 
 | 
 | ||||||
|         return Status::OK(); |         int res = alphaDropOutFunctor(context, input, output, reduceShape, seed, probValue, alpha, alpha1, beta); | ||||||
|  |         if (res == ND4J_STATUS_OK) { | ||||||
|  |             (*output) *= alpha; | ||||||
|  |             (*output) *= (*gradOut); //->applyPairwiseTransform<transform::Multiply>(gradOut, output, nullptr); | ||||||
|  |         } | ||||||
|  |         return res; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     int dropOutFunctorBP(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue) { |     int dropOutFunctorBP(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue) { | ||||||
|  | |||||||
| @ -35,58 +35,88 @@ namespace helpers { | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| ////////////////////////////////////////////////////////////////////////// | ////////////////////////////////////////////////////////////////////////// | ||||||
| void gruCell(nd4j::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* Wru, const NDArray* Wc, | void gruCell(nd4j::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* W, const NDArray* Wc, | ||||||
|              const NDArray* bru, const NDArray* bc, |              const NDArray* b, const NDArray* bc, | ||||||
|              NDArray* r, NDArray* u, NDArray* c, NDArray* h) { |              NDArray* r, NDArray* u, NDArray* c, NDArray* h) { | ||||||
| 
 | 
 | ||||||
|     //Inputs: |     //Inputs: | ||||||
|     // x        input [bS x inSize] |     // x        input [bS, iS], iS - input size | ||||||
|     // hLast    previous cell output [bS x numUnits],  that is at previous time step t-1 |     // hLast    previous cell output [bS, nU],  that is at previous time step t-1, nU - number of units | ||||||
|     // Wru      RU weights - [bS, 2*numUnits] - reset and update gates |     // W        RU weights - [iS+nU, 2*nU] - reset and update gates | ||||||
|     // Wc       C weights - [bS, numUnits] - cell gate |     // Wc       C weights - [iS+nU, nU] - cell gate | ||||||
|     // bru      r and u biases, [2*numUnits] - reset and update gates |     // b        r and u biases, [2*nU] - reset and update gates | ||||||
|     // bc       c biases, [numUnits] - cell gate |     // bc       c biases, [nU] - cell gate | ||||||
| 
 | 
 | ||||||
|     //Outputs: |     //Outputs: | ||||||
|     // r        Reset gate output [bS, numUnits] |     // r        Reset gate output [bS, nU] | ||||||
|     // u        Update gate output [bS, numUnits] |     // u        Update gate output [bS, nU] | ||||||
|     // c        Cell gate output [bS, numUnits] |     // c        Cell gate output [bS, nU] | ||||||
|     // h        current cell output [bS, numUnits] |     // h        current cell output [bS, nU] | ||||||
| 
 | 
 | ||||||
|     const int nIn = x->sizeAt(1); |     /***************************************************************************************/ | ||||||
|     const int nU = hLast->sizeAt(1);                // number of units |     /************************ THIS IS NOT OPTIMAZED CODE ***********************************/ | ||||||
|  |     /** however it is more math-friendly and convenient for backprop formulas derivation) **/ | ||||||
| 
 | 
 | ||||||
|     //Concat inputs: [x, yt-1]: concat([bs,nIn],[bs,nOut]) -> [bs, (nIn+nOut)] |     const int bS  = x->sizeAt(0); | ||||||
|     nd4j::ops::concat concatOp; |     const int iS = x->sizeAt(1); | ||||||
|     std::vector<NDArray*> inputs; |     const int nU = hLast->sizeAt(1); | ||||||
|     std::vector<double> targs; |  | ||||||
|     std::vector<Nd4jLong> iargs({1});   //Axis = 1 |  | ||||||
|     std::vector<bool> bargs; |  | ||||||
|     inputs.emplace_back(const_cast<NDArray*>(x)); |  | ||||||
|     inputs.emplace_back(const_cast<NDArray*>(hLast)); |  | ||||||
| 
 | 
 | ||||||
|     auto result = concatOp.execute(inputs, targs, iargs, bargs); |     NDArray Wrx = (*W)({0,iS,     0,nU});       // [iS, nU] | ||||||
|     auto concatOut = result->at(0); |     NDArray Wux = (*W)({0,iS,     nU,2*nU});    // [iS, nU] | ||||||
|  |     NDArray Wrh = (*W)({iS,iS+nU, 0,nU});       // [nU, nU] | ||||||
|  |     NDArray Wuh = (*W)({iS,iS+nU, nU,2*nU});    // [nU, nU] | ||||||
| 
 | 
 | ||||||
|     //mmul/z for reset and update gates: (x * weight_ux + hLast * weight_xr + b_u) |     NDArray Wcx = (*Wc)({0,iS,     0,0});       // reset cell weights    [iS, nU] | ||||||
|     auto m = mmul(*concatOut, *Wru);    //mmul: [bs, (nIn+numUnits)]* [(inSize+numUnits), 2*numUnits] = [bs, 4*numUnits] |     NDArray Wch = (*Wc)({iS,iS+nU, 0,0});       // updates cell weights  [nU, nU] | ||||||
|     m += (*bru); |  | ||||||
| 
 | 
 | ||||||
|     sigmoidInplace(m);  //sigmoid(rz) and sigmoid(uz) |     NDArray br = (*b)({0,  nU});                // [nU] | ||||||
|     auto mr = m({0,0, 0, nU}); |     NDArray bu = (*b)({nU, 2*nU});              // [nU] | ||||||
|     auto mu = m({0,0, nU, 2*nU}); |  | ||||||
| 
 | 
 | ||||||
|     r->assign(&mr); |     // × means matrix multipication | ||||||
|     u->assign(&mu); |     // * means element-wise product or so called Hadamard product | ||||||
| 
 | 
 | ||||||
|     //Concatenated inputs: [x, yt-1 .* r] |     // reset gate | ||||||
|     auto yr = (*concatOut)({0,0, nIn, nIn+nU}); |     r->assign(mmul(*x, Wrx) + mmul(*hLast, Wrh) + br);         // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] | ||||||
|     yr *= (*r); |     r->applyTransform(transform::Sigmoid); | ||||||
| 
 | 
 | ||||||
|     //c = tanh(x * weight_cx + (hLast .* r) * weight_cr + b_c) |     // update gate | ||||||
|     MmulHelper::mmul(concatOut, const_cast<NDArray*>(Wc), c, 1.0, 0.0);       //c = 1.0 * concatOut * Wc + 0.0 * c |     u->assign(mmul(*x, Wux) + mmul(*hLast, Wuh) + bu);         // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] | ||||||
|  |     u->applyTransform(transform::Sigmoid); | ||||||
|  | 
 | ||||||
|  |     // cell gate c = activation(x × Wcx + (r * hlast) × Wch + bc) | ||||||
|  |     c->assign(mmul(*x, Wcx) + mmul(*r * *hLast, Wch) + *bc);    // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] | ||||||
|  |     c->applyTransform(transform::Tanh); | ||||||
|  | 
 | ||||||
|  |     NDArray temp = 1.f - *c * *c; | ||||||
|  | 
 | ||||||
|  |     // cell output | ||||||
|  |     h->assign(*u * *hLast + (1.f - *u) * *c); | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |     /***************************************************************************************/ | ||||||
|  |     /*************** THIS IS MORE OPTIMAZED CODE (should think about concat) ***************/ | ||||||
|  |     /***************************************************************************************/ | ||||||
|  | /* | ||||||
|  |     //Concat inputs: x + hLast : [bs, iS + nU] | ||||||
|  |     NDArray xhConcat(x->ordering(), {bS, iS + nU}, x->dataType(), context);  // concat([bs, iS], [bs, nU]) -> [bs, iS + nU] | ||||||
|  |     helpers::concat(context, {const_cast<NDArray*>(x), const_cast<NDArray*>(hLast)},  xhConcat, {1}); | ||||||
|  | 
 | ||||||
|  |     //mmul for reset and update gates: (x × weight_ux + hLast × weight_xr + b_u) | ||||||
|  |     auto m = mmul(xhConcat, *W) + *b ;    // [bs, iS+nU] * [iS+nU, 2*nU] = [bs, 2*nU] | ||||||
|  |     // m += *bru; | ||||||
|  | 
 | ||||||
|  |     m.applyTransform(transform::Sigmoid);  //sigmoid(rz) and sigmoid(uz) | ||||||
|  | 
 | ||||||
|  |     r->assign(m({0,0, 0, nU})); | ||||||
|  |     u->assign(m({0,0, nU, 2*nU})); | ||||||
|  | 
 | ||||||
|  |     // hLast = hLast * r | ||||||
|  |     xhConcat({0,0, iS, iS+nU}) *= *r; | ||||||
|  | 
 | ||||||
|  |     //c = tanh(x × weight_cx + (hLast * r) × weight_cr + b_c) | ||||||
|  |     MmulHelper::mmul(&xhConcat, Wc, c, 1.0, 0.0);       //c = 1.0 * xhConcat * Wc + 0.0 * c | ||||||
|     *c += *bc; |     *c += *bc; | ||||||
|     tanhInplace(*c); |     c->applyTransform(transform::Tanh); | ||||||
| 
 | 
 | ||||||
|     //Output: h = (1-u).*c + u .* hPrev |     //Output: h = (1-u).*c + u .* hPrev | ||||||
|     //auto hResult = (*u) * (*hLast) + (1.0f - *u) * (*c); const_cast<NDArray*>(h)->assign(&hResult); |     //auto hResult = (*u) * (*hLast) + (1.0f - *u) * (*c); const_cast<NDArray*>(h)->assign(&hResult); | ||||||
| @ -94,115 +124,238 @@ void gruCell(nd4j::LaunchContext * context, const NDArray* x, const NDArray* hLa | |||||||
|     auto temp = (1.0f - *u); |     auto temp = (1.0f - *u); | ||||||
|     temp *= (*c); |     temp *= (*c); | ||||||
|     (*h) += temp; |     (*h) += temp; | ||||||
| 
 | */ | ||||||
|     delete result; |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| ////////////////////////////////////////////////////////////////////////// | ////////////////////////////////////////////////////////////////////////// | ||||||
| void gruTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h) { | void gruTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h) { | ||||||
| 
 | 
 | ||||||
| } |     // x   input [time, bS, iS] | ||||||
| 
 |     // hLast  initial cell output (at time step = 0) [bS, nU] | ||||||
| ////////////////////////////////////////////////////////////////////////// |  | ||||||
| void gruCellBP(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh, const NDArray* dLdWx0, |  | ||||||
|                const NDArray* dLdWh0, const NDArray* dLdb0, NDArray* dLdx, NDArray* dLdh0, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb) { |  | ||||||
| 
 |  | ||||||
|     // x                        input [bS, iS] |  | ||||||
|     // h0                       previous cell output [bS, nU],  that is at previous time step t-1 |  | ||||||
|     // Wx  input-to-hidden  weights, [iS, 3*nU] |     // Wx  input-to-hidden  weights, [iS, 3*nU] | ||||||
|     // Wh  hidden-to-hidden weights, [nU, 3*nU] |     // Wh  hidden-to-hidden weights, [nU, 3*nU] | ||||||
|     // b   biases, [3*nU] |     // b   biases, [3*nU] | ||||||
|     // dLdh                     gradient wrt output, [bS,nU], that is epsilon_next |  | ||||||
|     // dLdWx0                   gradient wrt Wx at previous time step, [iS, 3*nU] |  | ||||||
|     // dLdWh0                   gradient wrt Wh at previous time step, [nU, 3*nU] |  | ||||||
|     // dLdb0                    gradient wrt b at previous time step,  [3*nU] |  | ||||||
| 
 | 
 | ||||||
|     // dLdx                   gradient wrt x,  [bS, iS], that is epsilon |     // h is cell outputs at each time step [time, bS, nU] | ||||||
|     // dLdh0                  gradient wrt h0, [bS, nU] |  | ||||||
|     // dLdWx                  gradient wrt Wx, [iS, 3*nU] |  | ||||||
|     // dLdWh                  gradient wrt Wh, [nU, 3*nU] |  | ||||||
|     // dLdb                   gradient wrt b at previous time step,  [3*nU] |  | ||||||
| 
 | 
 | ||||||
|     // h is current cell output [bS, nU], that is at current time step t |     const int time = x->sizeAt(0); | ||||||
|  | 
 | ||||||
|  |     NDArray ht_1(*hLast); | ||||||
|  | 
 | ||||||
|  |     // loop through time steps | ||||||
|  |     for (int t = 0; t < time; ++t) { | ||||||
|  | 
 | ||||||
|  |         auto xt = (*x)({t,t+1, 0,0, 0,0}); | ||||||
|  |         auto ht = (*h)({t,t+1, 0,0, 0,0}); | ||||||
|  | 
 | ||||||
|  |         // helpers::gruCell(&xt, &ht_1, Wx, Wh, b, &ht); | ||||||
|  |         // ht_1.assign(ht); | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | ////////////////////////////////////////////////////////////////////////// | ||||||
|  | void gruCellBP(nd4j::LaunchContext* context, | ||||||
|  |               const NDArray* x,    const NDArray* hLast, | ||||||
|  |               const NDArray* W,    const NDArray* Wc,        const NDArray* b,    const NDArray* bc, | ||||||
|  |               const NDArray* dLdr, const NDArray* dLdu,      const NDArray* dLdc, const NDArray* dLdh, | ||||||
|  |                     NDArray* dLdx,       NDArray* dLdhLast, | ||||||
|  |                     NDArray* dLdW,       NDArray* dLdWc, | ||||||
|  |                     NDArray* dLdb,       NDArray* dLdbc) { | ||||||
|  | 
 | ||||||
|  |     //Inputs: | ||||||
|  |     // x              input [bS, iS] | ||||||
|  |     // hLast          previous cell output [bS, nU],  that is at previous time step t-1 | ||||||
|  |     // W              weights - [iS+nU, 2*nU] - reset and update gates | ||||||
|  |     // Wc             C weights - [iS+nU, nU] - cell gate | ||||||
|  |     // b              r and u biases, [2*nU] - reset and update gates | ||||||
|  |     // bc             c biases, [nU] - cell gate | ||||||
|  |     // dLdr           gradient wrt reset gate, [bS, nU] | ||||||
|  |     // dLdu           gradient wrt update gate, [bS, nU] | ||||||
|  |     // dLdc           gradient wrt cell state, [bS, nU] | ||||||
|  |     // dLdh           gradient wrt current cell output, [bS, nU] | ||||||
|  | 
 | ||||||
|  |     //Outputs: | ||||||
|  |     // dLdx           gradient wrt x,  [bS, iS], | ||||||
|  |     // dLdhLast       gradient wrt hLast, [bS, nU] | ||||||
|  |     // dLdW           gradient wrt W,  [iS+nU, 2*nU] | ||||||
|  |     // dLdWc          gradient wrt Wc, [iS+nU, nU] | ||||||
|  |     // dLdb           gradient wrt bru [2*nU] | ||||||
|  |     // dLdbc          gradient wrt bc  [nU] | ||||||
|  | 
 | ||||||
|  |     // * means element-wise product or so called Hadamard product | ||||||
|  |     // × means matrix multiplication | ||||||
|  | 
 | ||||||
|  |     /************************************************************************************************/ | ||||||
|  |     /******************************* THIS IS NOT OPTIMAZED CODE *************************************/ | ||||||
|  |     /*** aim is to have math-readable code in order to keep track of backprop formulas derivation ***/ | ||||||
|  | 
 | ||||||
|  |     const int bS  = x->sizeAt(0); | ||||||
|  |     const int iS = x->sizeAt(1); | ||||||
|  |     const int nU = hLast->sizeAt(1); | ||||||
|  | 
 | ||||||
|  |     NDArray xT     = x->transpose();            // [iS, bS] | ||||||
|  |     NDArray hLastT = hLast->transpose();        // [nU, bS] | ||||||
|  | 
 | ||||||
|  |     NDArray Wrx = (*W)({0,iS,     0,nU});       // [iS, nU] | ||||||
|  |     NDArray Wux = (*W)({0,iS,     nU,2*nU});    // [iS, nU] | ||||||
|  |     NDArray Wrh = (*W)({iS,iS+nU, 0,nU});       // [nU, nU] | ||||||
|  |     NDArray Wuh = (*W)({iS,iS+nU, nU,2*nU});    // [nU, nU] | ||||||
|  | 
 | ||||||
|  |     NDArray Wcx = (*Wc)({0,iS,     0,0});       // reset cell weights    [iS, nU] | ||||||
|  |     NDArray Wch = (*Wc)({iS,iS+nU, 0,0});       // updates cell weights  [nU, nU] | ||||||
|  | 
 | ||||||
|  |     NDArray br = (*b)({0,  nU});                // [nU] | ||||||
|  |     NDArray bu = (*b)({nU, 2*nU});              // [nU] | ||||||
|  | 
 | ||||||
|  |     NDArray WrxT = Wrx.transpose();             // [nU, iS] | ||||||
|  |     NDArray WuxT = Wux.transpose();             // [nU, iS] | ||||||
|  |     NDArray WrhT = Wrh.transpose();             // [nU, nU] | ||||||
|  |     NDArray WuhT = Wuh.transpose();             // [nU, nU] | ||||||
|  | 
 | ||||||
|  |     NDArray WcxT = Wcx.transpose();             // [nU, iS] | ||||||
|  |     NDArray WchT = Wch.transpose();             // [nU, nU] | ||||||
|  | 
 | ||||||
|  |     NDArray dLdWrx = (*dLdW)({0,iS,     0,nU});     // [iS, nU] | ||||||
|  |     NDArray dLdWux = (*dLdW)({0,iS,     nU,2*nU});  // [iS, nU] | ||||||
|  |     NDArray dLdWrh = (*dLdW)({iS,iS+nU, 0,nU});     // [nU, nU] | ||||||
|  |     NDArray dLdWuh = (*dLdW)({iS,iS+nU, nU,2*nU});  // [nU, nU] | ||||||
|  | 
 | ||||||
|  |     NDArray dLdWcx = (*dLdWc)({0,iS,     0,0});     // [iS, nU] | ||||||
|  |     NDArray dLdWch = (*dLdWc)({iS,iS+nU, 0,0});     // [nU, nU] | ||||||
|  | 
 | ||||||
|  |     NDArray dLdbr = (*dLdb)({0,  nU});              // [nU] | ||||||
|  |     NDArray dLdbu = (*dLdb)({nU, 2*nU});            // [nU] | ||||||
| 
 | 
 | ||||||
|     const int nU = h0->sizeAt(1); |  | ||||||
| 
 | 
 | ||||||
|     // ***** feed forward step ***** // |     // ***** feed forward step ***** // | ||||||
|     // gates = sigmoid(x*Wx + h0*Wh + b) | 
 | ||||||
|     auto gates = sigmoid(mmul(*x, (*Wx)({0,0, 0,2*nU})) + mmul(*h0, (*Wh)({0,0, 0,2*nU})) + (*b)({0,2*nU}));       // [bS, 2*nU] + [bS, 2*nU] + [1, 2*nU] = [bS, 2*nU] |  | ||||||
|     // reset gate |     // reset gate | ||||||
|     auto r = gates({0,0, 0, nU});               // [bS, nU] |     NDArray r = mmul(*x, Wrx) + mmul(*hLast, Wrh) + br;         // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] | ||||||
|  |     r.applyTransform(transform::Sigmoid); | ||||||
|  | 
 | ||||||
|     // update gate |     // update gate | ||||||
|     auto u = gates({0,0, nU, 2*nU});            // [bS, nU] |     NDArray u = mmul(*x, Wux) + mmul(*hLast, Wuh) + bu;         // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] | ||||||
|     // ◦ means element-wise product or so called Hadamard product |     u.applyTransform(transform::Sigmoid); | ||||||
|     // n = tanh(x*Wx + (r◦h0)*Wh + b) | 
 | ||||||
|     auto n = tanh(mmul(*x, (*Wx)({0,0, 2*nU,3*nU})) + mmul((*h0)*r, (*Wh)({0,0, 2*nU,3*nU})) + (*b)({2*nU,3*nU}));     // [bS, nU] |     // cell gate c = activation(x×Wcx + (r*hlast)×Wcu + bc) | ||||||
|  |     NDArray c = mmul(*x, Wcx) + mmul(r * *hLast, Wch) + *bc;    // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] | ||||||
|  |     c.applyTransform(transform::Tanh); | ||||||
|  | 
 | ||||||
|  |     // h = (1 - u) * c + u * hPrev | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
|     // ***** back prop step ***** // |     // ***** back prop step ***** // | ||||||
|     auto Wxr  = (*Wx)({0,0, 0,   nU}); |  | ||||||
|     auto Wxu  = (*Wx)({0,0, nU,  2*nU}); |  | ||||||
|     auto Wxn  = (*Wx)({0,0, 2*nU,3*nU}); |  | ||||||
|     auto Whr  = (*Wh)({0,0, 0,   nU}); |  | ||||||
|     auto Whu  = (*Wh)({0,0, nU,  2*nU}); |  | ||||||
|     auto Whn  = (*Wh)({0,0, 2*nU,3*nU}); |  | ||||||
|     auto WxrT = Wxr.transpose(); |  | ||||||
|     auto WxuT = Wxu.transpose(); |  | ||||||
|     auto WxnT = Wxn.transpose(); |  | ||||||
|     auto WhrT = Whr.transpose(); |  | ||||||
|     auto WhuT = Whu.transpose(); |  | ||||||
|     auto WhnT = Whn.transpose(); |  | ||||||
|     auto xT   = x->transpose(); |  | ||||||
|     auto h0T  = h0->transpose(); |  | ||||||
| 
 | 
 | ||||||
|     auto dLdWxr = (*dLdWx)({0,0, 0,     nU}); |     // notations: | ||||||
|     auto dLdWxu = (*dLdWx)({0,0, nU,  2*nU}); |     // Zr = x × Wrx + hLast × Wrh + br | ||||||
|     auto dLdWxn = (*dLdWx)({0,0, 2*nU,3*nU}); |     // Zu = x × Wux + hLast × Wuh + bu | ||||||
|  |     // Sr = sigmoid(Zr) | ||||||
|  |     // Su = sigmoid(Zu) | ||||||
|  |     // Zc = x × Wcx + (r * hlast) × Wch + bc | ||||||
| 
 | 
 | ||||||
|     auto dLdWhr = (*dLdWh)({0,0, 0,     nU}); |  | ||||||
|     auto dLdWhu = (*dLdWh)({0,0, nU,  2*nU}); |  | ||||||
|     auto dLdWhn = (*dLdWh)({0,0, 2*nU,3*nU}); |  | ||||||
| 
 | 
 | ||||||
|     auto dLdbr = (*dLdb)({0,     nU}); |     // dLdx = dLdh * dhdx = dLdh * (dhdu * dudx + dhdc * dcdx) = (dLdh * dhdu) * dudx + (dLdh * dhdc) * dcdx = dLdu * dudx + dLdc * dcdx | ||||||
|     auto dLdbu = (*dLdb)({nU,  2*nU}); |     //      = dLdx_u + dLdx_c | ||||||
|     auto dLdbn = (*dLdb)({2*nU,3*nU}); |     // dLdx_u = dLdu * dudx = dLdu * dudZu * dZudx = |dZudx = ... × WuxT| = (dLdu * dudZu) × WuxT | ||||||
|  |     // dLdx_c = dLdc * dcdx = dLdc * dcdZc * (dZcdx + dZcdr * drdx) = dLdc * dcdZc * dZcdx + dLdc * dcdZc * dZcdr * drdx = dLdx_c0 + dLdx_c1 | ||||||
|  |     // dLdx_c0 = dLdc * dcdZc * dZcdx = |dZcdx = ... × WcxT| = (dLdc * dcdZc) × WcxT | ||||||
|  |     // dZcdr = (... * hLast) × WchT | ||||||
|  |     // dLdc * dcdZc * dZcdr = dLdr = (dLdc * dcdZc * hLast) × WchT | ||||||
|  |     // drdx = drdZr * dZrdx | ||||||
|  |     // dZrdx = ... × WrxT | ||||||
|  |     // dLdx_c1 = dLdc * dcdZc * dZcdr * drdx = dLdr * drdx = (dLdr * drdZr) × WrxT | ||||||
|  |     // finally dLdx = dLdx_u + dLdx_c0 + dLdx_c1 = (dLdu * dudZu) × WuxT + (dLdc * dcdZc) × WcxT + (dLdr * drdZr) × WrxT | ||||||
| 
 | 
 | ||||||
|     auto dhdu   = *h0  - n;              // [bS, nU] |  | ||||||
|     auto dhdn   = 1.f - u;               // [bS, nU] |  | ||||||
|     auto dSigdu = u * (1.f - u);         // [bS, nU] |  | ||||||
|     auto dSigdr = r * (1.f - r);         // [bS, nU] |  | ||||||
|     auto dActdn = 1.f - n * n;           // [bS, nU] |  | ||||||
|     auto dndr   = mmul(dActdn * (*h0), WhnT); |  | ||||||
|     auto drdh0  = mmul(dSigdr, WhrT); |  | ||||||
| 
 | 
 | ||||||
|     auto dLdn = (*dLdh) * dhdn; |     // dLdhLast    = dLdh * (dhdhLast + dhdu * dudhLast + dhdc * dcdhLast) = dLdh * dhdhLast + dLdu * dudhLast + dLdc * dcdhLast | ||||||
|     auto dLdu = (*dLdh) * dhdu; |     //             = dLdhLast_h + dLdhLast_u + dLdhLast_c | ||||||
|     auto dLdr = dLdn * dndr; |     // dLdhLast_h  = dLdh * dhdhLas = dLdh * u | ||||||
|  |     // dLdhLast_u  = dLdu * dudhLast = |dudhLast = dudZu * dZudhLast , dZudhLast = ... × WuhT| = (dLdu * dudZu) × WuhT | ||||||
|  |     // dLdhLast_c  = dLdc * dcdhLast  = dLdc * (dcdZc * dZcdhLast + dcdZc * dZcdr * drdhLast) = | ||||||
|  |     //             = dLdc * dcdZc * dZcdhLast + dLdc * dcdZc * dZcdr * drdhLast = | ||||||
|  |     //             = dLdc * dcdZc * dZcdhLast + dLdr * drdhLast = dLdhLast_c0 + dLdhLast_c1 | ||||||
|  |     // dLdhLast_c0 = dLdc * dcdZc * dZcdhLast = |dZcdhLast = (... * r) × WchT| = (dLdc * dcdZc * r) × WchT | ||||||
|  |     // dLdhLast_c1 = dLdr * drdhLast = |drdhLast  = drdZr * dZrdhLast, dZrdhLast = ... × WrhT| = (dLdr * drdZr) × WrhT | ||||||
|  |     // finally dLdhLast = dLdhLast_h + dLdhLast_u + dLdhLast_c0 + dLdhLast_c1 = | ||||||
|  |     //                  = dLdh * u + (dLdu * dudZu) × WuhT + (dLdc * dcdZc * r) × WchT + (dLdr * drdZr) × WrhT | ||||||
| 
 | 
 | ||||||
|     dLdx->assign( mmul(dLdu * dSigdu, WxuT) + mmul(dLdr * dSigdr, WxrT) + mmul(dLdn * dActdn, WxnT) );      // [bS,iS] |  | ||||||
|     dLdh0->assign( mmul(dLdu * dSigdu, WhuT) + mmul(dLdn * dActdn * (r + drdh0), WhnT) + (*dLdh)*u );       // [bS,nU] |  | ||||||
| 
 | 
 | ||||||
|     dLdWxr.assign( mmul(xT, dSigdr * dLdr) );                                                               //  [iS,nU] |     // dLdWrx = dLdh * dhdWrx = (dLdh * dhdc) * dcdWrx = dLdc * dcdZc * dZcdWrx = dLdc * dcdZc * dZcdr * drdWrx = | ||||||
|     dLdWhr.assign( mmul(h0T, dSigdr * dLdr) );                                                              //  [nU,nU] |     //        = dLdc * dcdZc * dZcdr * drdZr * dZrdWrx = dLdr * drdZr * dZrdWrx | ||||||
|  |     // dZrdWrx = xT × ... | ||||||
|  |     // finally dLdWrx = xT × (dLdr * drdZr) | ||||||
| 
 | 
 | ||||||
|     dLdWxu.assign( mmul(xT, dSigdu * dLdu) );                                                               //  [iS,nU] |  | ||||||
|     dLdWhu.assign( mmul(h0T, dSigdu * dLdu) );                                                              //  [nU,nU] |  | ||||||
| 
 | 
 | ||||||
|     dLdWxn.assign( mmul(xT, dActdn * dLdn) );                                                               //  [iS,nU] |     // dLdWrh = dLdh * dhdWrh = (dLdh * dhdc) * dcdWrh = dLdc * dcdZc * dZcdWrh = dLdc * dcdZc * dZcdr * drdWrh = | ||||||
|     dLdWhn.assign( mmul((r*(*h0)).transpose(), dActdn * dLdn) );                                               //  [nU,nU] |     //        = dLdc * dcdZc * dZcdr * drdZr * dZrdWrh = dLdr * drdZr * dZrdWrh | ||||||
|  |     // dZrdWrh = hLastT × ... | ||||||
|  |     // finally dLdWrh = hLastT × (dLdr * drdZr) | ||||||
| 
 | 
 | ||||||
|     dLdbr.assign( (dSigdr * dLdr).reduceAlongDims(reduce::Sum, {0}));                          // [nU] |  | ||||||
|     dLdbu.assign( (dSigdu * dLdu).reduceAlongDims(reduce::Sum, {0}));                          // [nU] |  | ||||||
|     dLdbn.assign( (dActdn * dLdn).reduceAlongDims(reduce::Sum, {0}));                          // [nU] |  | ||||||
| 
 | 
 | ||||||
|     if(dLdWx0 != nullptr) |     // dLdWux = dLdh * dhdWux = (dLdh * dhdu) * dudWux = dLdu * dudZu * dZudWux | ||||||
|         *dLdWx += *dLdWx0; |     // dZudWux = xT × ... | ||||||
|  |     // dLdu * dudZu * dZudWux = xT × (dLdu * dudZu) | ||||||
| 
 | 
 | ||||||
|     if(dLdWh0 != nullptr) |  | ||||||
|         *dLdWh += *dLdWh0; |  | ||||||
| 
 | 
 | ||||||
|     if(dLdb0 != nullptr) |     // dLdWuh = dLdh * dhdWuh = (dLdh * dhdu) * dudWuh = dLdh * dhdu * dudZu * dZudWuh = dLdu * dudZu * dZudWuh | ||||||
|         *dLdb += *dLdb0; |     // dZudWuh = hLastT × ... | ||||||
|  |     // finally dLdWuh = hLastT × (dLdu * dudZu) | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
|  |     // dLdWcx = dLdh * dhdWcx = dLdh * dhdc * dcdWcx = (dLdh * dhdc) * dcdZc * dZcdWcx = dLdc * dcdZc * dZcdWcx | ||||||
|  |     // dZcdWcx = xT × ... | ||||||
|  |     // finally dLdWcx = xT × (dLdc * dcdZc) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |     // dLdWch = dLdh * dhdWch = dLdh * dhdc * dcdWch = (dLdh * dhdc) * dcdZc * dZcdWch = dLdc * dcdZc * dZcdWch | ||||||
|  |     // dZcdWch = (r*hLast)^T × ... | ||||||
|  |     // finally dLdWch = (r*hLast)^T × (dLdc * dcdZc) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |     // dLdbr = dLdh * dhdbr = (dLdh * dhdc) * dcdbr = dLdc * dcdbr = dLdc * dcdZc * dZcdbr = dLdc * dcdZc * dZcdr * drdbr = | ||||||
|  |     //       = dLdr * drdZr * dZrdbr | ||||||
|  |     // dZrdbr = 1 | ||||||
|  |     // finally dLdbr = dLdr * drdZr | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |     // dLdbu = dLdh * dhdbu = (dLdh * dhdu) * dudbu = dLdu * dudZu * dZudbu | ||||||
|  |     // dZudbu = 1 | ||||||
|  |     // finally dLdbu = dLdu * dudZu | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |     // dLdbc = dLdh * dhdbc = (dLdh * dhdc) * dcdbc = dLdc * dcdZc * dZcdbc | ||||||
|  |     // dZcdbc = 1 | ||||||
|  |     // finally dLdbc = dLdc * dcdZc | ||||||
|  | 
 | ||||||
|  |     NDArray dhdc  = 1.f - u;           // [bS, nU] | ||||||
|  |     NDArray dhdu  = *hLast - c;        // [bS, nU] | ||||||
|  |     NDArray dudZu = u * dhdc;          // [bS, nU] | ||||||
|  |     NDArray drdZr = r * (1.f - r);     // [bS, nU] | ||||||
|  |     NDArray dcdZc = 1.f - c * c;       // [bS, nU] | ||||||
|  |     NDArray dLdZc = *dLdc * dcdZc;     // [bS, nU] | ||||||
|  |     NDArray dLdZu = *dLdu * dudZu;     // [bS, nU] | ||||||
|  |     NDArray dLdZr = *dLdr * drdZr;     // [bS, nU] | ||||||
|  | 
 | ||||||
|  |     // NDArray dLdc  = *dLdh * dhdc;                       // [bS, nU] | ||||||
|  |     // NDArray dLdu  = *dLdh * dhdu;                       // [bS, nU] | ||||||
|  |     // NDArray dLdr  = mmul(dLdc * dcdZc * *hLast, WchT);  // [bS, nU] | ||||||
|  | 
 | ||||||
|  |     dLdx->assign(mmul(dLdZu, WuxT) + mmul(dLdZc, WcxT) + mmul(dLdZr, WrxT));                        // [bS, iS] | ||||||
|  | 
 | ||||||
|  |     dLdhLast->assign(*dLdh * u + mmul(dLdZu, WuhT) + mmul(dLdZc * r, WchT) + mmul(dLdZr, WrhT));    // [bS, nU] | ||||||
|  | 
 | ||||||
|  |     dLdWrx.assign(mmul(xT,     dLdZr));     // [iS, bS] × [bS, nU] = [iS, nU] | ||||||
|  |     dLdWrh.assign(mmul(hLastT, dLdZr));     // [nU, bS] × [bS, nU] = [nU, nU] | ||||||
|  |     dLdWux.assign(mmul(xT,     dLdZu));     // [iS, bS] × [bS, nU] = [iS, nU] | ||||||
|  |     dLdWuh.assign(mmul(hLastT, dLdZu));     // [nU, bS] × [bS, nU] = [nU, nU] | ||||||
|  | 
 | ||||||
|  |     dLdWcx.assign(mmul(xT, dLdZc));                          // [iS, bS] × [bS, nU] = [iS, nU] | ||||||
|  |     dLdWch.assign(mmul((r * *hLast).transpose(), dLdZc));    // [nU, bS] × [bS, nU] = [nU, nU] | ||||||
|  | 
 | ||||||
|  |     dLdbr.assign(dLdZr.reduceAlongDims(reduce::Sum, {0}));  // [nU] | ||||||
|  |     dLdbu.assign(dLdZu.reduceAlongDims(reduce::Sum, {0}));  // [nU] | ||||||
|  | 
 | ||||||
|  |     dLdbc->assign(dLdZc.reduceAlongDims(reduce::Sum, {0})); // [nU] | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -20,12 +20,111 @@ | |||||||
| 
 | 
 | ||||||
| #include <ops/declarable/helpers/hashcode.h> | #include <ops/declarable/helpers/hashcode.h> | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| namespace nd4j { | namespace nd4j { | ||||||
|     namespace ops { |     namespace ops { | ||||||
|         namespace helpers { |         namespace helpers { | ||||||
|  |             template <typename T> | ||||||
|  |             static __global__ void splitBufferToChuncks(T* buffer, Nd4jLong* tempBuffer, Nd4jLong numBlocks, Nd4jLong blockSize, Nd4jLong length) { | ||||||
|  | 
 | ||||||
|  |                 for (int b = blockIdx.x; b < numBlocks; b += gridDim.x) { | ||||||
|  |                     auto blockBuffer = buffer + b * numBlocks; | ||||||
|  | 
 | ||||||
|  |                     Nd4jLong r = 1; | ||||||
|  |                     for (int e = threadIdx.x; e < blockSize && e + (b * numBlocks) < length; e += blockDim.x) { | ||||||
|  |                         auto v = longBytes<T>(blockBuffer[e]); | ||||||
|  |                         r = 31 * r + v; | ||||||
|  |                     } | ||||||
|  | 
 | ||||||
|  |                     tempBuffer[b] = r; | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             template <typename T> | ||||||
|  |             static __global__ void internalHash(Nd4jLong* tempBuffer, Nd4jLong* tempResult, Nd4jLong numBlocks, Nd4jLong blockSize, Nd4jLong lastLength) { | ||||||
|  | 
 | ||||||
|  |                 for (int b = blockIdx.x; b < numBlocks; b += gridDim.x) { | ||||||
|  |                     auto blockBuffer = tempBuffer + b * numBlocks; | ||||||
|  | 
 | ||||||
|  |                     Nd4jLong r = 1; | ||||||
|  |                     for (int e = threadIdx.x; e < blockSize && e + (b * numBlocks) < lastLength; e += blockDim.x) { | ||||||
|  |                         auto v = longBytes<T>(blockBuffer[e]); | ||||||
|  |                         r = 31 * r + v; | ||||||
|  |                     } | ||||||
|  | 
 | ||||||
|  |                     tempResult[b] = r; | ||||||
|  |                 } | ||||||
|  | 
 | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |             static __global__ void lastStep(Nd4jLong* resultBuf, Nd4jLong* tempBufferA, Nd4jLong* tempResult, Nd4jLong length, Nd4jLong blockSize) { | ||||||
|  |                 if (threadIdx.x == 0) { | ||||||
|  | 
 | ||||||
|  |                     if (length <= blockSize) | ||||||
|  |                         *resultBuf = *tempBufferA; | ||||||
|  |                     else | ||||||
|  |                         *resultBuf = *tempResult; | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             template <typename T> | ||||||
|  |             void hashCode_(LaunchContext *context, NDArray &array, NDArray &result) { | ||||||
|  |                 auto blockSize = 32; | ||||||
|  |                 auto stream = context->getCudaStream(); | ||||||
|  |                 array.syncToDevice(); | ||||||
|  | 
 | ||||||
|  |                 NDArray::prepareSpecialUse({&result}, {&array}); | ||||||
|  |                 auto length = array.lengthOf(); | ||||||
|  |                 int numBlocks = length / blockSize + ((length % blockSize == 0) ? 0 : 1); | ||||||
|  |                 auto tempA = NDArrayFactory::create<Nd4jLong>('c', {numBlocks}, context); | ||||||
|  |                 auto tempB = NDArrayFactory::create<Nd4jLong>('c', { numBlocks / blockSize + 1}, context); | ||||||
|  | 
 | ||||||
|  |                 auto buffer = reinterpret_cast<T*>(array.specialBuffer()); //bufferAsT<T>(); | ||||||
|  |                 auto tempBufferA = reinterpret_cast<Nd4jLong*>(tempA.specialBuffer()); //bufferAsT<Nd4jLong>(); | ||||||
|  |                 auto tempBufferB = reinterpret_cast<Nd4jLong*>(tempB.specialBuffer()); //bufferAsT<Nd4jLong>(); | ||||||
|  | 
 | ||||||
|  |                 // default buffer is the first one, because it might be the last one in case of small arrays (< blockSize) | ||||||
|  |                 auto tempBuffer = tempBufferA; | ||||||
|  |                 auto tempResult = tempBufferB; | ||||||
|  | 
 | ||||||
|  |                 // we divide array into 32 element chunks, and store intermediate results once | ||||||
|  |                 splitBufferToChuncks<T><<<numBlocks, length, 1024, *stream>>>(buffer, tempBuffer, numBlocks, blockSize, length); | ||||||
|  | 
 | ||||||
|  |                 // we replace pointer with intermediate one, and repeat only one chunk left | ||||||
|  |                 int iterationCount = 0; | ||||||
|  |                 while (numBlocks > 1) { | ||||||
|  |                     int lastLength = numBlocks; | ||||||
|  |                     numBlocks = lastLength / blockSize + ((lastLength % blockSize == 0) ? 0 : 1); | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |                     internalHash<Nd4jLong><<<numBlocks, lastLength, 1024, *stream>>>(tempBuffer, tempResult, numBlocks, blockSize, lastLength); | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |                     iterationCount++; | ||||||
|  |                     // swapping buffers | ||||||
|  |                     if (iterationCount % 2 == 0) { | ||||||
|  |                         tempBuffer = tempBufferA; | ||||||
|  |                         tempResult = tempBufferB; | ||||||
|  |                     } else { | ||||||
|  |                         tempBuffer = tempBufferB; | ||||||
|  |                         tempResult = tempBufferA; | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  | 
 | ||||||
|  |                 //lastStep<Nd4jLong><<<1,1,128, *stream>>>(result.specialBuffer(), tempBufferA, tempResult, length, blockSize); | ||||||
|  |                 tempA.syncToHost(); | ||||||
|  |                 tempB.syncToHost(); | ||||||
|  |                 result.assign((length <= blockSize?tempA.e(0) : tempB.e(0))); | ||||||
|  | 
 | ||||||
|  |                 NDArray::registerSpecialUse({&result}, {&array}); | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|             void hashCode(LaunchContext *context, NDArray &array, NDArray &result) { |             void hashCode(LaunchContext *context, NDArray &array, NDArray &result) { | ||||||
| 
 |                 BUILD_SINGLE_SELECTOR(array.dataType(), hashCode_, (context, array, result), LIBND4J_TYPES); | ||||||
|             } |             } | ||||||
|  | 
 | ||||||
|  |             BUILD_SINGLE_TEMPLATE(template void hashCode_, (LaunchContext* context, NDArray& array, NDArray& result), LIBND4J_TYPES); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  | |||||||
| @ -20,6 +20,8 @@ | |||||||
| 
 | 
 | ||||||
| #include <ops/declarable/helpers/image_suppression.h> | #include <ops/declarable/helpers/image_suppression.h> | ||||||
| #include <NDArrayFactory.h> | #include <NDArrayFactory.h> | ||||||
|  | #include <NativeOps.h> | ||||||
|  | #include <cuda_exception.h> | ||||||
| 
 | 
 | ||||||
| namespace nd4j { | namespace nd4j { | ||||||
| namespace ops { | namespace ops { | ||||||
| @ -35,15 +37,16 @@ namespace helpers { | |||||||
|         Nd4jLong next1[] = {nextIndex, 1}; |         Nd4jLong next1[] = {nextIndex, 1}; | ||||||
|         Nd4jLong next2[] = {nextIndex, 2}; |         Nd4jLong next2[] = {nextIndex, 2}; | ||||||
|         Nd4jLong next3[] = {nextIndex, 3}; |         Nd4jLong next3[] = {nextIndex, 3}; | ||||||
| 
 |         Nd4jLong* shapeOf = shape::shapeOf(boxesShape); | ||||||
|         T minYPrev = nd4j::math::nd4j_min(boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), previous0, 2)], boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), previous2, 2)]); |         Nd4jLong* strideOf = shape::stride(boxesShape); | ||||||
|         T minXPrev = nd4j::math::nd4j_min(boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), previous1, 2)], boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), previous3, 2)]); |         T minYPrev = nd4j::math::nd4j_min(boxes[shape::getOffset(0, shapeOf, strideOf, previous0, 2)], boxes[shape::getOffset(0, shapeOf, strideOf, previous2, 2)]); | ||||||
|         T maxYPrev = nd4j::math::nd4j_max(boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), previous0, 2)], boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), previous2, 2)]); |         T minXPrev = nd4j::math::nd4j_min(boxes[shape::getOffset(0, shapeOf, strideOf, previous1, 2)], boxes[shape::getOffset(0, shapeOf, strideOf, previous3, 2)]); | ||||||
|         T maxXPrev = nd4j::math::nd4j_max(boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), previous1, 2)], boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), previous3, 2)]); |         T maxYPrev = nd4j::math::nd4j_max(boxes[shape::getOffset(0, shapeOf, strideOf, previous0, 2)], boxes[shape::getOffset(0, shapeOf, strideOf, previous2, 2)]); | ||||||
|         T minYNext = nd4j::math::nd4j_min(boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), next0, 2)], boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), next2, 2)]); |         T maxXPrev = nd4j::math::nd4j_max(boxes[shape::getOffset(0, shapeOf, strideOf, previous1, 2)], boxes[shape::getOffset(0, shapeOf, strideOf, previous3, 2)]); | ||||||
|         T minXNext = nd4j::math::nd4j_min(boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), next1, 2)], boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), next3, 2)]); |         T minYNext = nd4j::math::nd4j_min(boxes[shape::getOffset(0, shapeOf, strideOf, next0, 2)],     boxes[shape::getOffset(0, shapeOf, strideOf, next2, 2)]); | ||||||
|         T maxYNext = nd4j::math::nd4j_max(boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), next0, 2)], boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), next2, 2)]); |         T minXNext = nd4j::math::nd4j_min(boxes[shape::getOffset(0, shapeOf, strideOf, next1, 2)],     boxes[shape::getOffset(0, shapeOf, strideOf, next3, 2)]); | ||||||
|         T maxXNext = nd4j::math::nd4j_max(boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), next1, 2)], boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), next3, 2)]); |         T maxYNext = nd4j::math::nd4j_max(boxes[shape::getOffset(0, shapeOf, strideOf, next0, 2)],     boxes[shape::getOffset(0, shapeOf, strideOf, next2, 2)]); | ||||||
|  |         T maxXNext = nd4j::math::nd4j_max(boxes[shape::getOffset(0, shapeOf, strideOf, next1, 2)],     boxes[shape::getOffset(0, shapeOf, strideOf, next3, 2)]); | ||||||
| 
 | 
 | ||||||
|         T areaPrev = (maxYPrev - minYPrev) * (maxXPrev - minXPrev); |         T areaPrev = (maxYPrev - minYPrev) * (maxXPrev - minXPrev); | ||||||
|         T areaNext = (maxYNext - minYNext) * (maxXNext - minXNext); |         T areaNext = (maxYNext - minYNext) * (maxXNext - minXNext); | ||||||
| @ -62,149 +65,101 @@ namespace helpers { | |||||||
|     }; |     }; | ||||||
| 
 | 
 | ||||||
|     template <typename T, typename I> |     template <typename T, typename I> | ||||||
|     static __global__ void nonMaxSuppressionKernel(T* boxes, Nd4jLong* boxesShape, I* indices, int* selectedIndices, Nd4jLong numBoxes, I* output, Nd4jLong* outputShape, T threshold) { |     static __global__ void shouldSelectKernel(T* boxesBuf, Nd4jLong* boxesShape, I* indexBuf, I* selectedIndicesData, double threshold, int numSelected, int i, bool* shouldSelect) { | ||||||
|         __shared__ Nd4jLong outputLen; |         auto tid = blockIdx.x * blockDim.x + threadIdx.x; | ||||||
| 
 |         auto step = gridDim.x * blockDim.x; | ||||||
|  |         __shared__ bool shouldSelectShared; | ||||||
|         if (threadIdx.x == 0) { |         if (threadIdx.x == 0) { | ||||||
|             outputLen = shape::length(outputShape); |             shouldSelectShared = shouldSelect[0]; | ||||||
|         } |         } | ||||||
|         __syncthreads(); |         __syncthreads(); | ||||||
|  |         for (int j = numSelected - 1 - tid; j >= 0; j -= step) { | ||||||
|  |             if (shouldSelectShared) { | ||||||
|  |                 if (needToSuppressWithThreshold(boxesBuf, boxesShape, indexBuf[i], | ||||||
|  |                                                                   indexBuf[selectedIndicesData[j]], T(threshold))) | ||||||
|  |                     shouldSelectShared = false; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |         __syncthreads(); | ||||||
|  |         if (threadIdx.x == 0) { | ||||||
|  |             *shouldSelect = shouldSelectShared; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     template <typename I> | ||||||
| 
 | 
 | ||||||
|         auto numSelected = blockIdx.x; |     static __global__ void copyIndices(void* indices,  void* indicesLong, Nd4jLong len) { | ||||||
|         auto start = blockIdx.x * blockDim.x + threadIdx.x; |         __shared__ I* indexBuf; | ||||||
|  |         __shared__ Nd4jLong* srcBuf; | ||||||
|  |         if (threadIdx.x == 0) { | ||||||
|  |             indexBuf = reinterpret_cast<I*>(indices); | ||||||
|  |             srcBuf = reinterpret_cast<Nd4jLong*>(indicesLong); | ||||||
|  |         } | ||||||
|  |         auto tid = threadIdx.x + blockIdx.x * blockDim.x; | ||||||
|         auto step = blockDim.x * gridDim.x; |         auto step = blockDim.x * gridDim.x; | ||||||
| //        for (int numSelected = blockIdx.x; numSelected < outputLen; numSelected += gridDim.x) { |  | ||||||
|         for (int i = start; i < numBoxes; i += step) { |  | ||||||
|                 bool shouldSelect = true; |  | ||||||
|                 for (int j = numSelected - 1; shouldSelect && j >= 0; --j) { |  | ||||||
|                     if (needToSuppressWithThreshold<T>(boxes, boxesShape, indices[i], indices[selectedIndices[j]], threshold)) { |  | ||||||
|                         shouldSelect = false; |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
| 
 | 
 | ||||||
|                 if (shouldSelect) { |         for (auto i = tid; i < len; i += step) | ||||||
|                     auto zPos = shape::getIndexOffset(numSelected, outputShape, outputLen); |             indexBuf[i] = (I)srcBuf[i]; | ||||||
|                     output[zPos] = indices[i]; |  | ||||||
|                     selectedIndices[numSelected] = i; |  | ||||||
|                 } |  | ||||||
| 
 |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     template <typename T, typename I> |  | ||||||
|     static __global__ void sortIndices(I* indices, Nd4jLong* indexShape, T* scores, Nd4jLong* scoreShape) { |  | ||||||
|         __shared__ Nd4jLong len; |  | ||||||
| //        __shared__ Nd4jLong* sortedPart; |  | ||||||
| //        __shared__ Nd4jLong part; |  | ||||||
| //        __shared__ Nd4jLong partSize; |  | ||||||
| 
 |  | ||||||
|         if (threadIdx.x == 0) { |  | ||||||
| //            blocksPerArr = (gridDim.x + numOfArrs - 1) / numOfArrs;     // ceil |  | ||||||
| //            part = blockIdx.x / blocksPerArr; |  | ||||||
| 
 |  | ||||||
|             len = shape::length(indexShape); |  | ||||||
| //            __shared__ Nd4jLong* shmem = shared[]; |  | ||||||
| //            sortedPart = shmem; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         for (int m = 0; m < len; m++) { |  | ||||||
|             if (m % 2 == 0) { |  | ||||||
|                 for (int tid = threadIdx.x; tid < len; tid += blockDim.x) { |  | ||||||
|                     auto top = 2 * tid + 1; |  | ||||||
|                     if (top < len) { |  | ||||||
|                         auto t0 = shape::getIndexOffset(top - 1, indexShape, len); |  | ||||||
|                         auto t1 = shape::getIndexOffset(top, indexShape, len); |  | ||||||
|                         auto z0 = shape::getIndexOffset(top - 1, scoreShape, len); |  | ||||||
|                         auto z1 = shape::getIndexOffset(top, scoreShape, len); |  | ||||||
| 
 |  | ||||||
|                         if (scores[t0] < scores[t1]) { |  | ||||||
|                             // swap indices first |  | ||||||
|                             Nd4jLong di0 = indices[t0]; |  | ||||||
|                             indices[t0] = indices[t1]; |  | ||||||
|                             indices[t1] = di0; |  | ||||||
| 
 |  | ||||||
|                             //swap scores next |  | ||||||
| //                            T dz0 = scores[z0]; |  | ||||||
| //                            scores[z0] = scores[z1]; |  | ||||||
| //                            scores[z1] = dz0; |  | ||||||
|                         } |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
|             } else { |  | ||||||
|                 for (int tid = threadIdx.x; tid < len; tid += blockDim.x) { |  | ||||||
|                     auto top = 2 * tid + 2; |  | ||||||
|                     if (top < len) { |  | ||||||
|                         auto t0 = shape::getIndexOffset(top - 1, indexShape, len); |  | ||||||
|                         auto t1 = shape::getIndexOffset(top, indexShape, len); |  | ||||||
|                         auto z0 = shape::getIndexOffset(top - 1, scoreShape, len); |  | ||||||
|                         auto z1 = shape::getIndexOffset(top, scoreShape, len); |  | ||||||
| 
 |  | ||||||
|                         if (scores[t0] < scores[t1]) { |  | ||||||
|                             // swap indices first |  | ||||||
|                             Nd4jLong di0 = indices[t0]; |  | ||||||
|                             indices[t0] = indices[t1]; |  | ||||||
|                             indices[t1] = di0; |  | ||||||
| 
 |  | ||||||
|                             //swap scores next |  | ||||||
| //                            T dz0 = scores[z0]; |  | ||||||
| //                            scores[z0] = scores[z1]; |  | ||||||
| //                            scores[z1] = dz0; |  | ||||||
|                         } |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|             __syncthreads(); |  | ||||||
|         } |  | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     template <typename T, typename I> |     template <typename T, typename I> | ||||||
|     static void nonMaxSuppressionV2_(nd4j::LaunchContext* context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, NDArray* output) { |     static void nonMaxSuppressionV2_(nd4j::LaunchContext* context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, NDArray* output) { | ||||||
|         auto stream = context->getCudaStream(); |         auto stream = context->getCudaStream(); | ||||||
|         NDArray::prepareSpecialUse({output}, {boxes, scales}); |         NDArray::prepareSpecialUse({output}, {boxes, scales}); | ||||||
|         NDArray* indices = NDArrayFactory::create_<I>('c', {scales->lengthOf()}); // - 1, scales->lengthOf()); //, scales->getContext()); |         std::unique_ptr<NDArray> indices(NDArrayFactory::create_<I>('c', {scales->lengthOf()})); // - 1, scales->lengthOf()); //, scales->getContext()); | ||||||
|         indices->linspace(0); |         indices->linspace(0); | ||||||
|  |         indices->syncToDevice(); // linspace only on CPU, so sync to Device as well | ||||||
|  | 
 | ||||||
|         NDArray scores(*scales); |         NDArray scores(*scales); | ||||||
|         indices->syncToHost(); //linspace(0); |         NativeOps nativeOps; | ||||||
|         I* indexBuf = reinterpret_cast<I*>(indices->specialBuffer()); | 
 | ||||||
|         T* scoreBuf = reinterpret_cast<T*>(scores.specialBuffer()); |         Nd4jPointer extras[2] = {nullptr, stream}; | ||||||
|         sortIndices<T, I><<<1, 32, 128, *stream>>>(indexBuf, indices->specialShapeInfo(), scoreBuf, scores.specialShapeInfo()); | 
 | ||||||
|  |         nativeOps.sortByValue(extras, indices->buffer(), indices->shapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), scores.buffer(), scores.shapeInfo(), scores.specialBuffer(), scores.specialShapeInfo(), true); | ||||||
|         // TO DO: sort indices using scales as value row |         // TO DO: sort indices using scales as value row | ||||||
|         //std::sort(indices.begin(), indices.end(), [scales](int i, int j) {return scales->e<T>(i) > scales->e<T>(j);}); |         //std::sort(indices.begin(), indices.end(), [scales](int i, int j) {return scales->e<T>(i) > scales->e<T>(j);}); | ||||||
|         indices->tickWriteDevice(); |         I* indexBuf = reinterpret_cast<I*>(indices->specialBuffer()); | ||||||
|         indices->syncToHost(); |  | ||||||
|         indices->printIndexedBuffer("AFTERSORT OUTPUT"); |  | ||||||
|         NDArray selected = NDArrayFactory::create<int>({output->lengthOf()}); |  | ||||||
| 
 | 
 | ||||||
|         NDArray selectedIndices = NDArrayFactory::create<int>({output->lengthOf()}); |         NDArray selectedIndices = NDArrayFactory::create<I>('c', {output->lengthOf()}); | ||||||
|         int numSelected = 0; |         int numSelected = 0; | ||||||
|         int numBoxes = boxes->sizeAt(0); |         int numBoxes = boxes->sizeAt(0); | ||||||
|         T* boxesBuf = reinterpret_cast<T*>(boxes->specialBuffer()); |         T* boxesBuf = reinterpret_cast<T*>(boxes->specialBuffer()); | ||||||
| //        Nd4jLong* indicesData = reinterpret_cast<Nd4jLong*>(indices->specialBuffer()); | 
 | ||||||
| //        int* selectedData = reinterpret_cast<int*>(selected.specialBuffer()); |         I* selectedIndicesData = reinterpret_cast<I*>(selectedIndices.specialBuffer()); | ||||||
|         int* selectedIndicesData = reinterpret_cast<int*>(selectedIndices.specialBuffer()); |  | ||||||
|         I* outputBuf = reinterpret_cast<I*>(output->specialBuffer()); |         I* outputBuf = reinterpret_cast<I*>(output->specialBuffer()); | ||||||
|         nonMaxSuppressionKernel<T, I><<<output->lengthOf(), 512, 1024, *stream>>>(boxesBuf, boxes->specialShapeInfo(), indexBuf, selectedIndicesData, numBoxes, outputBuf, output->specialShapeInfo(), T(threshold)); | 
 | ||||||
|         NDArray::registerSpecialUse({output}, {boxes, scales}); |         bool* shouldSelectD; | ||||||
| //        for (int i = 0; i < boxes->sizeAt(0); ++i) { |         auto err = cudaMalloc(&shouldSelectD, sizeof(bool)); | ||||||
| //            if (selected.size() >= output->lengthOf()) break; |         if (err) { | ||||||
| //            bool shouldSelect = true; |             throw cuda_exception::build("helpers::nonMaxSuppressionV2: Cannot allocate memory for bool flag", err); | ||||||
| //            // Overlapping boxes are likely to have similar scores, |         } | ||||||
| //            // therefore we iterate through the selected boxes backwards. |         for (I i = 0; i < boxes->sizeAt(0); ++i) { | ||||||
| //            for (int j = numSelected - 1; j >= 0; --j) { |             bool shouldSelect = numSelected < output->lengthOf(); | ||||||
| //                if (needToSuppressWithThreshold(*boxes, indices[i], indices[selectedIndices[j]], T(threshold)) { |             if (shouldSelect) { | ||||||
| //                    shouldSelect = false; |                 err = cudaMemcpy(shouldSelectD, &shouldSelect, sizeof(bool), cudaMemcpyHostToDevice); | ||||||
| //                    break; |                 if (err) { | ||||||
| //                } |                     throw cuda_exception::build("helpers::nonMaxSuppressionV2: Cannot set up bool flag to device", err); | ||||||
| //            } |                 } | ||||||
| //            if (shouldSelect) { | 
 | ||||||
| //                selected.push_back(indices[i]); |                 shouldSelectKernel<T> <<< 128, 256, 1024, *stream >>> | ||||||
| //                selectedIndices[numSelected++] = i; |                                                            (boxesBuf, boxes->specialShapeInfo(), indexBuf, selectedIndicesData, threshold, numSelected, i, shouldSelectD); | ||||||
| //            } |                 err = cudaMemcpy(&shouldSelect, shouldSelectD, sizeof(bool), cudaMemcpyDeviceToHost); | ||||||
| //        } |                 if (err) { | ||||||
| //        for (size_t e = 0; e < selected.size(); ++e) |                     throw cuda_exception::build("helpers::nonMaxSuppressionV2: Cannot set up bool flag to host", err); | ||||||
| //            output->p<int>(e, selected[e]); |                 } | ||||||
| // |             } | ||||||
|         delete indices; | 
 | ||||||
|  |             if (shouldSelect) { | ||||||
|  |                 cudaMemcpy(reinterpret_cast<I*>(output->specialBuffer()) + numSelected, indexBuf + i, sizeof(I), cudaMemcpyDeviceToDevice); | ||||||
|  |                 cudaMemcpy(selectedIndicesData + numSelected, &i, sizeof(I), cudaMemcpyHostToDevice); | ||||||
|  |                 numSelected++; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         err = cudaFree(shouldSelectD); | ||||||
|  |         if (err) { | ||||||
|  |             throw cuda_exception::build("helpers::nonMaxSuppressionV2: Cannot deallocate memory for bool flag", err); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     void nonMaxSuppressionV2(nd4j::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, NDArray* output) { |     void nonMaxSuppressionV2(nd4j::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, NDArray* output) { | ||||||
|  | |||||||
| @ -32,24 +32,24 @@ namespace nd4j { | |||||||
| namespace ops { | namespace ops { | ||||||
| namespace helpers { | namespace helpers { | ||||||
| 
 | 
 | ||||||
|     template <typename T> | //    template <typename T> | ||||||
|     static __device__ void swapRows_(T* matrix, Nd4jLong* shape, int theFirst, int theSecond, Nd4jLong N) { | //    static __device__ void swapRows_(T* matrix, Nd4jLong* shape, int theFirst, int theSecond, Nd4jLong N) { | ||||||
|         if (theFirst != theSecond) { | //        if (theFirst != theSecond) { | ||||||
|             auto start = threadIdx.x + blockIdx.x * blockDim.x; | //            auto start = threadIdx.x + blockIdx.x * blockDim.x; | ||||||
|             auto step = blockDim.x * gridDim.x; | //            auto step = blockDim.x * gridDim.x; | ||||||
|             for (auto i = start; i < N; i += step) { | //            for (auto i = start; i < N; i += step) { | ||||||
|                 Nd4jLong iCoord1[] = {theFirst, i}; | //                Nd4jLong iCoord1[] = {theFirst, i}; | ||||||
|                 Nd4jLong iCoord2[] = {theSecond, i}; | //                Nd4jLong iCoord2[] = {theSecond, i}; | ||||||
|                 auto iIndex1 = shape::getOffset(0, shape::shapeOf(shape), shape::stride(shape), iCoord1, 2); | //                auto iIndex1 = shape::getOffset(0, shape::shapeOf(shape), shape::stride(shape), iCoord1, 2); | ||||||
|                 auto iIndex2 = shape::getOffset(0, shape::shapeOf(shape), shape::stride(shape), iCoord2, 2); | //                auto iIndex2 = shape::getOffset(0, shape::shapeOf(shape), shape::stride(shape), iCoord2, 2); | ||||||
|                 //atomicExch(&matrix[iIndex1], matrix[iIndex2]); | //                //atomicExch(&matrix[iIndex1], matrix[iIndex2]); | ||||||
|                 T e0 = matrix[iIndex1]; | //                T e0 = matrix[iIndex1]; | ||||||
|                 T e1 = matrix[iIndex2]; | //                T e1 = matrix[iIndex2]; | ||||||
|                 matrix[iIndex1] = e0; | //                matrix[iIndex1] = e0; | ||||||
|                 matrix[iIndex2] = e1; | //                matrix[iIndex2] = e1; | ||||||
|             } | //            } | ||||||
|         } | //        } | ||||||
|     } | //    } | ||||||
| //    BUILD_SINGLE_TEMPLATE(template void swapRows_, (NDArray* matrix, int theFirst, int theSecond), FLOAT_TYPES); | //    BUILD_SINGLE_TEMPLATE(template void swapRows_, (NDArray* matrix, int theFirst, int theSecond), FLOAT_TYPES); | ||||||
| // | // | ||||||
| //    void swapRows(NDArray* matrix, int theFirst, int theSecond) { | //    void swapRows(NDArray* matrix, int theFirst, int theSecond) { | ||||||
| @ -71,9 +71,14 @@ namespace helpers { | |||||||
| 
 | 
 | ||||||
|         for (int i = start + 1; i < n; i += step) { |         for (int i = start + 1; i < n; i += step) { | ||||||
|             Nd4jLong pos[] = {i, i - 1}; |             Nd4jLong pos[] = {i, i - 1}; | ||||||
|  |             Nd4jLong posX[] = {i, i}; | ||||||
|  |             Nd4jLong posY[] = {i - 1, i - 1}; | ||||||
|             auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2); |             auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2); | ||||||
|  |             auto dxIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posX, 2); | ||||||
|  |             auto dyIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posY, 2); | ||||||
|             auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), pos, 2); |             auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), pos, 2); | ||||||
|             inverted[zIndex] = -input[xIndex]; |             inverted[zIndex] = -input[xIndex] / (input[dxIndex] * input[dyIndex]); | ||||||
|  | //            math::atomics::nd4j_atomicAdd(&inverted[zIndex], - input[xIndex] * inverted[iIndex] / input[dIndex]); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| @ -91,10 +96,11 @@ namespace helpers { | |||||||
|         auto start = threadIdx.x + blockIdx.x * blockDim.x; |         auto start = threadIdx.x + blockIdx.x * blockDim.x; | ||||||
|         auto step = blockDim.x * gridDim.x; |         auto step = blockDim.x * gridDim.x; | ||||||
| 
 | 
 | ||||||
|         for (int i = start + 1; i < n; i += step) { |         for (int i = start; i < n; i += step) { | ||||||
|             Nd4jLong pos[] = {i, i}; |             Nd4jLong pos[] = {i, i}; | ||||||
|             auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2); |             auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2); | ||||||
|             auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), pos, 2); |             auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), pos, 2); | ||||||
|  | //            math::atomics::nd4j_atomicDiv(&inverted[zIndex], input[xIndex]); | ||||||
|             inverted[zIndex] /= input[xIndex]; |             inverted[zIndex] /= input[xIndex]; | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| @ -113,16 +119,16 @@ namespace helpers { | |||||||
|         auto start = threadIdx.x + blockIdx.x * blockDim.x; |         auto start = threadIdx.x + blockIdx.x * blockDim.x; | ||||||
|         auto step = blockDim.x * gridDim.x; |         auto step = blockDim.x * gridDim.x; | ||||||
| 
 | 
 | ||||||
|         for (int i = start + 1; i < n - 1; i += step) { |         for (int i = start; i < n - 1; i += step) { | ||||||
|             Nd4jLong pos[] = {i, i + 1}; |             Nd4jLong pos[] = {i, i + 1}; | ||||||
|             Nd4jLong posY[] = {i, i}; |             //Nd4jLong posY[] = {i, i}; | ||||||
|             Nd4jLong posX[] = {i + 1, i + 1}; |             Nd4jLong posX[] = {i + 1, i + 1}; | ||||||
|             auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2); |             auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2); | ||||||
|             auto yIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2); | //            auto yIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posY, 2); | ||||||
| //            auto yIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2); | //            auto yIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2); | ||||||
|             auto iIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posX, 2); |             auto iIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posX, 2); | ||||||
|             auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), pos, 2); |             auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), pos, 2); | ||||||
|             inverted[zIndex] -= input[xIndex] * inverted[iIndex] / input[yIndex]; |             math::atomics::nd4j_atomicAdd(&inverted[zIndex], - input[xIndex] * inverted[iIndex]); // / input[yIndex]); | ||||||
|             //inputMatrix->t<T>(i, i + 1) * invertedMatrix->t<T>(i + 1, i + 1) / inputMatrix->t<T>(i, i) |             //inputMatrix->t<T>(i, i + 1) * invertedMatrix->t<T>(i + 1, i + 1) / inputMatrix->t<T>(i, i) | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| @ -142,16 +148,18 @@ namespace helpers { | |||||||
| //        auto step = blockDim.x * gridDim.x; | //        auto step = blockDim.x * gridDim.x; | ||||||
| 
 | 
 | ||||||
|         for (int i = blockIdx.x + 2; i < n; i += gridDim.x) { |         for (int i = blockIdx.x + 2; i < n; i += gridDim.x) { | ||||||
|             for (int j = i - 2; j > -1; --j) |             for (int j = i - 2; j >= 0; --j) | ||||||
|                 for (int k = threadIdx.x; k < i; k += blockDim.x) { |                 for (int k = threadIdx.x; k < i; k += blockDim.x) { | ||||||
|                     Nd4jLong posZ[] = {i, j}; |                     Nd4jLong posZ[] = {i, j}; | ||||||
|                     Nd4jLong posX[] = {k, j}; |                     Nd4jLong posY[] = {k, j}; | ||||||
|                     Nd4jLong posY[] = {i, k}; |                     Nd4jLong posX[] = {i, k}; | ||||||
|  |                     Nd4jLong posD[] = {i, i}; | ||||||
| 
 | 
 | ||||||
|                     auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posX, 2); |                     auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posX, 2); | ||||||
|                     auto yIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posY, 2); |                     auto yIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posY, 2); | ||||||
|  |                     auto dIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posD, 2); | ||||||
|                     auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posZ, 2); |                     auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posZ, 2); | ||||||
|                     inverted[zIndex] -= inverted[yIndex] * input[xIndex]; |                     math::atomics::nd4j_atomicAdd(&inverted[zIndex], - inverted[yIndex] * input[xIndex] / input[dIndex]); | ||||||
|                 } |                 } | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| @ -176,13 +184,13 @@ namespace helpers { | |||||||
|                     Nd4jLong posZ[] = {i, j}; |                     Nd4jLong posZ[] = {i, j}; | ||||||
|                     Nd4jLong posY[] = {k, j}; |                     Nd4jLong posY[] = {k, j}; | ||||||
|                     Nd4jLong posX[] = {i, k}; |                     Nd4jLong posX[] = {i, k}; | ||||||
|                     Nd4jLong posD[] = {i, i}; | //                    Nd4jLong posD[] = {i, i}; | ||||||
| 
 | 
 | ||||||
|                     auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posX, 2); |                     auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posX, 2); | ||||||
|                     auto yIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posY, 2); |                     auto yIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posY, 2); | ||||||
|                     auto dIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posD, 2); |   //                  auto dIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posD, 2); | ||||||
|                     auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posZ, 2); |                     auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posZ, 2); | ||||||
|                     inverted[zIndex] -= inverted[yIndex] * input[xIndex] / input[dIndex]; |                     math::atomics::nd4j_atomicAdd(&inverted[zIndex], - inverted[yIndex] * input[xIndex]);// / input[dIndex]); | ||||||
|                 } |                 } | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| @ -196,14 +204,18 @@ namespace helpers { | |||||||
|         LaunchContext* context = inputMatrix->getContext(); |         LaunchContext* context = inputMatrix->getContext(); | ||||||
|         auto stream = context->getCudaStream(); |         auto stream = context->getCudaStream(); | ||||||
| 
 | 
 | ||||||
|  |         // invert main diagonal | ||||||
|  |         upvertKernel<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); | ||||||
|  |         // invert the second diagonal | ||||||
|         invertKernelLow<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); |         invertKernelLow<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); | ||||||
|  | //        invertKernelLow<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); | ||||||
|         invertLowKernel<T><<<n, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); |         invertLowKernel<T><<<n, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     BUILD_SINGLE_TEMPLATE(template void invertLowerMatrix_, (NDArray* inputMatrix, NDArray* invertedMatrix);, FLOAT_TYPES); |     BUILD_SINGLE_TEMPLATE(template void invertLowerMatrix_, (NDArray* inputMatrix, NDArray* invertedMatrix);, FLOAT_NATIVE); | ||||||
| 
 | 
 | ||||||
|     void invertLowerMatrix(NDArray* inputMatrix, NDArray* invertedMatrix) { |     void invertLowerMatrix(NDArray* inputMatrix, NDArray* invertedMatrix) { | ||||||
|         BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertLowerMatrix_, (inputMatrix, invertedMatrix), FLOAT_TYPES); |         BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertLowerMatrix_, (inputMatrix, invertedMatrix), FLOAT_NATIVE); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     template <typename T> |     template <typename T> | ||||||
| @ -215,58 +227,58 @@ namespace helpers { | |||||||
|             return; |             return; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         upvertKernel<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); |         //upvertKernel<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); | ||||||
|         upvertKernelUp<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); |         upvertKernelUp<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); | ||||||
|         invertUpKernel<T><<<n, n, 256, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); |         invertUpKernel<T><<<n, n, 256, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     BUILD_SINGLE_TEMPLATE(template void invertUpperMatrix_, (NDArray* inputMatrix, NDArray* invertedMatrix);, FLOAT_TYPES); |     BUILD_SINGLE_TEMPLATE(template void invertUpperMatrix_, (NDArray* inputMatrix, NDArray* invertedMatrix);, FLOAT_NATIVE); | ||||||
| 
 | 
 | ||||||
|     void invertUpperMatrix(NDArray* inputMatrix, NDArray* invertedMatrix) { |     void invertUpperMatrix(NDArray* inputMatrix, NDArray* invertedMatrix) { | ||||||
|         BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertUpperMatrix_, (inputMatrix, invertedMatrix), FLOAT_TYPES); |         BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertUpperMatrix_, (inputMatrix, invertedMatrix), FLOAT_NATIVE); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     template <typename T> | //    template <typename T> | ||||||
|     static __global__ void lupKernel(T* compound, Nd4jLong* compoundShape, T* permutation, Nd4jLong* permutationShape, Nd4jLong rowNum) { | //    static __global__ void lupKernel(T* compound, Nd4jLong* compoundShape, T* permutation, Nd4jLong* permutationShape, Nd4jLong rowNum) { | ||||||
|         int swapCount = 0; | //        int swapCount = 0; | ||||||
|         for(int i = blockIdx.x; i < rowNum; i += gridDim.x ) { | //        for(int i = blockIdx.x; i < rowNum; i += gridDim.x ) { | ||||||
|             auto pivotValue = T(0.0); | //            auto pivotValue = T(0.0); | ||||||
|             auto pivot = -1; | //            auto pivot = -1; | ||||||
| 
 | // | ||||||
|             for(int rowCounter = i; rowCounter < rowNum; rowCounter++ ) { | //            for(int rowCounter = i; rowCounter < rowNum; rowCounter++ ) { | ||||||
|                 Nd4jLong rowCoord[] = {rowCounter, i}; | //                Nd4jLong rowCoord[] = {rowCounter, i}; | ||||||
|                 auto rowPos = shape::getOffset(0, shape::shapeOf(compoundShape), shape::stride(compoundShape), rowCoord, 2); | //                auto rowPos = shape::getOffset(0, shape::shapeOf(compoundShape), shape::stride(compoundShape), rowCoord, 2); | ||||||
|                 if(nd4j::math::nd4j_abs(compound[rowPos]) > pivotValue ) { | //                if(nd4j::math::nd4j_abs(compound[rowPos]) > pivotValue ) { | ||||||
|                     pivotValue = nd4j::math::nd4j_abs(compound[rowPos]); | //                    pivotValue = nd4j::math::nd4j_abs(compound[rowPos]); | ||||||
|                     pivot = rowCounter; | //                    pivot = rowCounter; | ||||||
|                 } | //                } | ||||||
|             } | //            } | ||||||
| 
 | // | ||||||
|             if( pivotValue != T(0.0) ) { | //            if( pivotValue != T(0.0) ) { | ||||||
|                 swapRows_<T>(compound, compoundShape, pivot, i, rowNum); | //                swapRows_<T>(compound, compoundShape, pivot, i, rowNum); | ||||||
|                 swapRows_<T>(permutation, permutationShape, pivot, i, rowNum); | //                swapRows_<T>(permutation, permutationShape, pivot, i, rowNum); | ||||||
|                 if (pivot != i) | //                if (pivot != i) | ||||||
|                     swapCount++; | //                    swapCount++; | ||||||
| 
 | // | ||||||
|                 for( int j = i + 1; j < rowNum; j++ ) { | //                for( int j = i + 1; j < rowNum; j++ ) { | ||||||
|                     Nd4jLong posJIbuf[] = {j, i}; | //                    Nd4jLong posJIbuf[] = {j, i}; | ||||||
|                     Nd4jLong posIIbuf[] = {i, i}; | //                    Nd4jLong posIIbuf[] = {i, i}; | ||||||
|                     auto posJI = shape::getOffset(0, shape::shapeOf(compoundShape), shape::stride(compoundShape), posJIbuf, 2); | //                    auto posJI = shape::getOffset(0, shape::shapeOf(compoundShape), shape::stride(compoundShape), posJIbuf, 2); | ||||||
|                     auto posII = shape::getOffset(0, shape::shapeOf(compoundShape), shape::stride(compoundShape), posIIbuf, 2); | //                    auto posII = shape::getOffset(0, shape::shapeOf(compoundShape), shape::stride(compoundShape), posIIbuf, 2); | ||||||
| 
 | // | ||||||
|                     compound[posJI] /= compound[posII]; | //                    compound[posJI] /= compound[posII]; | ||||||
|                     for( int k = i + 1; k < rowNum; k++ ) { | //                    for( int k = i + 1; k < rowNum; k++ ) { | ||||||
|                         Nd4jLong posJKbuf[] = {j, k}; | //                        Nd4jLong posJKbuf[] = {j, k}; | ||||||
|                         Nd4jLong posIKbuf[] = {i, k}; | //                        Nd4jLong posIKbuf[] = {i, k}; | ||||||
|                         auto posJK = shape::getOffset(0, shape::shapeOf(compoundShape), shape::stride(compoundShape), posJKbuf, 2); | //                        auto posJK = shape::getOffset(0, shape::shapeOf(compoundShape), shape::stride(compoundShape), posJKbuf, 2); | ||||||
|                         auto posIK = shape::getOffset(0, shape::shapeOf(compoundShape), shape::stride(compoundShape), posIKbuf, 2); | //                        auto posIK = shape::getOffset(0, shape::shapeOf(compoundShape), shape::stride(compoundShape), posIKbuf, 2); | ||||||
|                         T arg = compound[posJI] * compound[posIK]; | //                        T arg = compound[posJI] * compound[posIK]; | ||||||
|                         compound[posJK] -= arg; | //                        compound[posJK] -= arg; | ||||||
|                     } | //                    } | ||||||
|                 } | //                } | ||||||
|             } | //            } | ||||||
|         } | //        } | ||||||
|     } | //    } | ||||||
| 
 | 
 | ||||||
|     template <typename T, typename F> |     template <typename T, typename F> | ||||||
|     static __global__ void determinantKernel(T* compound, T* result, Nd4jLong len) { |     static __global__ void determinantKernel(T* compound, T* result, Nd4jLong len) { | ||||||
| @ -332,6 +344,30 @@ namespace helpers { | |||||||
|             matrix[j] = (F)inputBuf[xIndex]; |             matrix[j] = (F)inputBuf[xIndex]; | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  | 
 | ||||||
|  |     template <typename T, typename F> | ||||||
|  |     static __global__ void returnMatrix(void* output, Nd4jLong* outputShape, void* input, Nd4jLong* inputShape, Nd4jLong pos, Nd4jLong rowLen) { | ||||||
|  |         __shared__ F* matrix; | ||||||
|  |         __shared__ T* outputBuf; | ||||||
|  |         __shared__ Nd4jLong outputLen; | ||||||
|  |         __shared__ Nd4jLong n2; | ||||||
|  | 
 | ||||||
|  |         if (threadIdx.x == 0) { | ||||||
|  |             matrix = reinterpret_cast<F*>(input); | ||||||
|  |             outputBuf = reinterpret_cast<T*>(output); | ||||||
|  |             outputLen = shape::length(inputShape); | ||||||
|  |             n2 = rowLen * rowLen; | ||||||
|  |         } | ||||||
|  |         __syncthreads(); | ||||||
|  |         auto start = blockIdx.x * blockDim.x + threadIdx.x; | ||||||
|  |         auto step = blockDim.x * gridDim.x; | ||||||
|  | 
 | ||||||
|  |         for (int k = pos + start, j = start; j < n2; k += step, j += step) { | ||||||
|  |             auto zIndex = shape::getIndexOffset(k, outputShape, outputLen); | ||||||
|  |             outputBuf[zIndex] = (T)matrix[j]; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     template <typename F> |     template <typename F> | ||||||
|     static __global__ void fillUpPermutation(void* output, Nd4jLong* shape, int* source, int rowNum) { |     static __global__ void fillUpPermutation(void* output, Nd4jLong* shape, int* source, int rowNum) { | ||||||
|         __shared__ F* permutation; |         __shared__ F* permutation; | ||||||
| @ -462,7 +498,7 @@ namespace helpers { | |||||||
|                             d_work, |                             d_work, | ||||||
|                             permutationBuf, |                             permutationBuf, | ||||||
|                             d_info); |                             d_info); | ||||||
|                     fillUpPermutation<float><<<n, n, 128, *stream>>>(permutation->specialBuffer(), permutation->specialShapeInfo(), permutationBuf, n); |                     fillUpPermutation<T><<<n, n, 128, *stream>>>(permutation->specialBuffer(), permutation->specialShapeInfo(), permutationBuf, n); | ||||||
|                     permutation->tickWriteDevice(); |                     permutation->tickWriteDevice(); | ||||||
|                 } |                 } | ||||||
|                 err = cudaFree(d_work); |                 err = cudaFree(d_work); | ||||||
| @ -483,7 +519,7 @@ namespace helpers { | |||||||
| //        NDArray::registerSpecialUse({input}, {input}); | //        NDArray::registerSpecialUse({input}, {input}); | ||||||
|         input->tickWriteDevice(); |         input->tickWriteDevice(); | ||||||
|     } |     } | ||||||
|     BUILD_SINGLE_TEMPLATE(template void lup_, (LaunchContext* context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES); |     BUILD_SINGLE_TEMPLATE(template void lup_, (LaunchContext* context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_NATIVE); | ||||||
| 
 | 
 | ||||||
|     template <typename T> |     template <typename T> | ||||||
|     static int determinant_(nd4j::LaunchContext* context, NDArray* input, NDArray* output) { |     static int determinant_(nd4j::LaunchContext* context, NDArray* input, NDArray* output) { | ||||||
| @ -504,32 +540,32 @@ namespace helpers { | |||||||
|         output->assign(1.f); |         output->assign(1.f); | ||||||
|         for (int e = 0; e < output->lengthOf(); e++) { |         for (int e = 0; e < output->lengthOf(); e++) { | ||||||
|             Nd4jLong pos = e * n2; |             Nd4jLong pos = e * n2; | ||||||
|             if (matrix.dataType() == input->dataType()) | //            if (matrix.dataType() == input->dataType()) | ||||||
|                 fillMatrix<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); |                 fillMatrix<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); | ||||||
|             else | //            else | ||||||
|                 fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); | //                fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); | ||||||
| 
 | 
 | ||||||
|             if (matrix.dataType() == input->dataType()) | //            if (matrix.dataType() == input->dataType()) | ||||||
|               lup_<T>(context, &matrix, nullptr, nullptr); |               lup_<T>(context, &matrix, nullptr, nullptr); | ||||||
|             else | //            else | ||||||
|                 lup_<float>(context, &matrix, nullptr, nullptr); | //                lup_<float>(context, &matrix, nullptr, nullptr); | ||||||
|             auto offset = shape::getIndexOffset(e, output->shapeInfo(), output->lengthOf()); |             auto offset = shape::getIndexOffset(e, output->shapeInfo(), output->lengthOf()); | ||||||
|             auto inputBuf = reinterpret_cast<T*>(matrix.specialBuffer()); |             auto inputBuf = reinterpret_cast<T*>(matrix.specialBuffer()); | ||||||
|             auto outputBuf = reinterpret_cast<T*>(output->specialBuffer()) + offset; |             auto outputBuf = reinterpret_cast<T*>(output->specialBuffer()) + offset; | ||||||
|             if (matrix.dataType() == input->dataType()) | //            if (matrix.dataType() == input->dataType()) | ||||||
|             determinantKernel<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n); |             determinantKernel<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n); | ||||||
|             else | //            else | ||||||
|                 determinantKernel<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n); | //                determinantKernel<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n); | ||||||
|         } |         } | ||||||
|         NDArray::registerSpecialUse({output}, {input}); |         NDArray::registerSpecialUse({output}, {input}); | ||||||
| 
 | 
 | ||||||
|         return Status::OK(); |         return Status::OK(); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     BUILD_SINGLE_TEMPLATE(template int determinant_, (nd4j::LaunchContext* context, NDArray* input, NDArray* output), FLOAT_TYPES); |     BUILD_SINGLE_TEMPLATE(template int determinant_, (nd4j::LaunchContext* context, NDArray* input, NDArray* output), FLOAT_NATIVE); | ||||||
| 
 | 
 | ||||||
|     int determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { |     int determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { | ||||||
|         BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (context, input, output), FLOAT_TYPES); |         BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (context, input, output), FLOAT_NATIVE); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     template <typename T> |     template <typename T> | ||||||
| @ -552,22 +588,22 @@ namespace helpers { | |||||||
|         output->assign(1.f); |         output->assign(1.f); | ||||||
|         for (int e = 0; e < output->lengthOf(); e++) { |         for (int e = 0; e < output->lengthOf(); e++) { | ||||||
|             Nd4jLong pos = e * n2; |             Nd4jLong pos = e * n2; | ||||||
|             if (matrix.dataType() == input->dataType()) | //            if (matrix.dataType() == input->dataType()) | ||||||
|             fillMatrix<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); |             fillMatrix<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); | ||||||
|             else | //            else | ||||||
|                 fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); | //                fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); | ||||||
| 
 | 
 | ||||||
|             if (matrix.dataType() == input->dataType()) | //            if (matrix.dataType() == input->dataType()) | ||||||
|                 lup_<T>(context, &matrix, nullptr, nullptr); |                 lup_<T>(context, &matrix, nullptr, nullptr); | ||||||
|             else | //            else | ||||||
|                 lup_<float>(context, &matrix, nullptr, nullptr); | //                lup_<float>(context, &matrix, nullptr, nullptr); | ||||||
|             auto offset = shape::getIndexOffset(e, output->shapeInfo(), output->lengthOf()); |             auto offset = shape::getIndexOffset(e, output->shapeInfo(), output->lengthOf()); | ||||||
|             auto inputBuf = reinterpret_cast<T*>(matrix.specialBuffer()); |             auto inputBuf = reinterpret_cast<T*>(matrix.specialBuffer()); | ||||||
|             auto outputBuf = reinterpret_cast<T*>(output->specialBuffer()) + offset; |             auto outputBuf = reinterpret_cast<T*>(output->specialBuffer()) + offset; | ||||||
|             if (matrix.dataType() == input->dataType()) | //            if (matrix.dataType() == input->dataType()) | ||||||
|                 determinantLogKernel<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n); |                 determinantLogKernel<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n); | ||||||
|             else | //            else | ||||||
|                 determinantLogKernel<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n); | //                determinantLogKernel<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n); | ||||||
|         } |         } | ||||||
|         NDArray::registerSpecialUse({output}, {input}); |         NDArray::registerSpecialUse({output}, {input}); | ||||||
| 
 | 
 | ||||||
| @ -576,10 +612,10 @@ namespace helpers { | |||||||
|         return ND4J_STATUS_OK; |         return ND4J_STATUS_OK; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     BUILD_SINGLE_TEMPLATE(template int logAbsDeterminant_, (LaunchContext* context, NDArray* input, NDArray* output), FLOAT_TYPES); |     BUILD_SINGLE_TEMPLATE(template int logAbsDeterminant_, (LaunchContext* context, NDArray* input, NDArray* output), FLOAT_NATIVE); | ||||||
| 
 | 
 | ||||||
|     int logAbsDeterminant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { |     int logAbsDeterminant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { | ||||||
|         BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (context, input, output), FLOAT_TYPES); |         BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (context, input, output), FLOAT_NATIVE); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     template <typename T> |     template <typename T> | ||||||
| @ -597,10 +633,12 @@ namespace helpers { | |||||||
| 
 | 
 | ||||||
|         if (threadIdx.x == 0) { |         if (threadIdx.x == 0) { | ||||||
|             xShapeOf = shape::shapeOf(lowerShape); |             xShapeOf = shape::shapeOf(lowerShape); | ||||||
|             yShapeOf = shape::shapeOf(upperShape); |  | ||||||
|             zShapeOf = shape::shapeOf(matrixShape); |  | ||||||
|             xStrideOf = shape::stride(lowerShape); |             xStrideOf = shape::stride(lowerShape); | ||||||
|  | 
 | ||||||
|  |             yShapeOf = shape::shapeOf(upperShape); | ||||||
|             yStrideOf = shape::stride(upperShape); |             yStrideOf = shape::stride(upperShape); | ||||||
|  | 
 | ||||||
|  |             zShapeOf = shape::shapeOf(matrixShape); | ||||||
|             zStrideOf = shape::stride(matrixShape); |             zStrideOf = shape::stride(matrixShape); | ||||||
|             lowerMatrix = reinterpret_cast<T*>(lowerBuf); |             lowerMatrix = reinterpret_cast<T*>(lowerBuf); | ||||||
|             upperMatrix = reinterpret_cast<T*>(upperBuf); |             upperMatrix = reinterpret_cast<T*>(upperBuf); | ||||||
| @ -610,15 +648,16 @@ namespace helpers { | |||||||
| 
 | 
 | ||||||
|         for (int k = blockIdx.x; k < n; k += gridDim.x) {  // and then put all values under main diagonal on to it |         for (int k = blockIdx.x; k < n; k += gridDim.x) {  // and then put all values under main diagonal on to it | ||||||
|             for (int j = threadIdx.x; j < n; j += blockDim.x) { |             for (int j = threadIdx.x; j < n; j += blockDim.x) { | ||||||
|                 Nd4jLong posX[] = {j, k}; |                 Nd4jLong posX[] = {k, j}; | ||||||
| 
 |                 Nd4jLong posD[] = {j, j}; | ||||||
|                 auto xPos = shape::getOffset(0, xShapeOf, xStrideOf, posX, 2); |                 auto xPos = shape::getOffset(0, xShapeOf, xStrideOf, posX, 2); | ||||||
|                 auto yPos = shape::getOffset(0, yShapeOf, yStrideOf, posX, 2); |                 auto yPos = shape::getOffset(0, yShapeOf, yStrideOf, posX, 2); | ||||||
|                 auto pos =  shape::getOffset(0, zShapeOf, zStrideOf, posX, 2); |                 auto iPos = shape::getOffset(0, zShapeOf, zStrideOf, posX, 2); | ||||||
|                 if (k <= j) |                 auto dPos = shape::getOffset(0, zShapeOf, zStrideOf, posD, 2); | ||||||
|                     lowerMatrix[xPos] = matrix[pos];//(k, j); |                 if (k >= j) | ||||||
|  |                     lowerMatrix[xPos] = matrix[iPos];//(k, j); | ||||||
|                 else |                 else | ||||||
|                     upperMatrix[yPos] = matrix[pos]; //k, j); |                     upperMatrix[yPos] = matrix[iPos]; //k, j); | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| @ -639,38 +678,26 @@ namespace helpers { | |||||||
|         auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {output->rankOf() - 2, output->rankOf() - 1}); |         auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {output->rankOf() - 2, output->rankOf() - 1}); | ||||||
|         auto stream = context->getCudaStream(); |         auto stream = context->getCudaStream(); | ||||||
| 
 | 
 | ||||||
|  | //        PRAGMA_OMP_PARALLEL_FOR | ||||||
|         for (auto i = 0LL; i < packX.numberOfTads(); i++) { |         for (auto i = 0LL; i < packX.numberOfTads(); i++) { | ||||||
|             fillMatrix<T, float><<<1, n2, 128, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), i * n2, n); |             fillMatrix<T, T><<<1, n2, 128, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), i * n2, n); | ||||||
|             permutation.assign(0.f); |  | ||||||
|             lup_<float>(context, &matrix, &compound, &permutation); |  | ||||||
|             matrix.tickWriteDevice(); |             matrix.tickWriteDevice(); | ||||||
|             permutation.tickWriteDevice(); |             compound.assign(matrix); | ||||||
|             permutation.printIndexedBuffer("PERMUTE"); |             lup_<T>(context, &compound, nullptr, nullptr); | ||||||
|             lower.setIdentity(); // set up U to identity matrix |             fillLowerUpperKernel<T><<<n, n, 128>>>(lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), compound.specialBuffer(), compound.specialShapeInfo(), n); | ||||||
|             upper.setIdentity(); |             matrix.assign(0); | ||||||
|             fillLowerUpperKernel<float><<<1, n2, 128>>>(lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), matrix.specialBuffer(), matrix.specialShapeInfo(), n); |             invertUpperMatrix(&upper, &matrix); // U^{-1} | ||||||
|             lower.tickWriteDevice(); |             compound.assign(0); | ||||||
|             upper.tickWriteDevice(); |             invertLowerMatrix(&lower, &compound); // L{-1} | ||||||
|             invertUpperMatrix(&upper, &matrix); |  | ||||||
|             invertLowerMatrix(&lower, &upper); |  | ||||||
|             lower.tickWriteDevice(); |  | ||||||
|             upper.tickWriteDevice(); |  | ||||||
|             lower.printIndexedBuffer("LOWER"); |  | ||||||
|             upper.printIndexedBuffer("UPPER"); |  | ||||||
| 
 | 
 | ||||||
|             nd4j::MmulHelper::mmul(&matrix, &upper, &compound, 1.0, 0.0); |             nd4j::MmulHelper::mmul(&matrix, &compound, &upper, 1.0, 0.0); | ||||||
|             nd4j::MmulHelper::mmul(&compound, &permutation, &matrix, 1.0, 0.0); |             returnMatrix<T, T><<<1, n2, 128, *stream>>>(output->specialBuffer(), output->specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), i * n2, n); | ||||||
| //            for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) { |  | ||||||
| //                output->t<T>(k) = matrix.template t<T>(row++); |  | ||||||
| //            } |  | ||||||
|         } |         } | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         return Status::OK(); |         return Status::OK(); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     int inverse(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { |     int inverse(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { | ||||||
|         BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (context, input, output), FLOAT_TYPES); |         BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (context, input, output), FLOAT_NATIVE); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     bool checkCholeskyInput(nd4j::LaunchContext * context, NDArray const* input) { |     bool checkCholeskyInput(nd4j::LaunchContext * context, NDArray const* input) { | ||||||
| @ -803,7 +830,7 @@ namespace helpers { | |||||||
|         return cholesky_(context, input, output, inplace); |         return cholesky_(context, input, output, inplace); | ||||||
|     } |     } | ||||||
| //    BUILD_SINGLE_TEMPLATE(template int cholesky_, (LaunchContext* context, NDArray* input, NDArray* output, bool inplace), FLOAT_TYPES); | //    BUILD_SINGLE_TEMPLATE(template int cholesky_, (LaunchContext* context, NDArray* input, NDArray* output, bool inplace), FLOAT_TYPES); | ||||||
|     BUILD_SINGLE_TEMPLATE(template int inverse_, (nd4j::LaunchContext* context, NDArray* input, NDArray* output), FLOAT_TYPES); |     BUILD_SINGLE_TEMPLATE(template int inverse_, (nd4j::LaunchContext* context, NDArray* input, NDArray* output), FLOAT_NATIVE); | ||||||
| 
 | 
 | ||||||
|     __global__ void logDetKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong batchNum, Nd4jLong* tadShape, Nd4jLong* tadOffsets, void* outputBuf, Nd4jLong* outputShape) { |     __global__ void logDetKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong batchNum, Nd4jLong* tadShape, Nd4jLong* tadOffsets, void* outputBuf, Nd4jLong* outputShape) { | ||||||
|         __shared__ double* output; |         __shared__ double* output; | ||||||
|  | |||||||
| @ -143,7 +143,7 @@ namespace helpers { | |||||||
| 
 | 
 | ||||||
|     /////////////////////////////////////////////////////////////////// |     /////////////////////////////////////////////////////////////////// | ||||||
|     template <typename T> |     template <typename T> | ||||||
|     static void _reverseSequence(nd4j::LaunchContext * context, const NDArray* input, const NDArray* seqLengths, NDArray* output, int seqDim, const int batchDim){ |     static void reverseSequence_(nd4j::LaunchContext * context, const NDArray* input, const NDArray* seqLengths, NDArray* output, int seqDim, const int batchDim){ | ||||||
|         int posOfNonUnityDim = -1; |         int posOfNonUnityDim = -1; | ||||||
|         seqLengths->syncToHost(); |         seqLengths->syncToHost(); | ||||||
|         auto stream = context->getCudaStream(); |         auto stream = context->getCudaStream(); | ||||||
| @ -193,7 +193,7 @@ namespace helpers { | |||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     void reverseSequence(nd4j::LaunchContext * context, const NDArray* input, const NDArray* seqLengths, NDArray* output, int seqDim, const int batchDim) { |     void reverseSequence(nd4j::LaunchContext * context, const NDArray* input, const NDArray* seqLengths, NDArray* output, int seqDim, const int batchDim) { | ||||||
|         BUILD_SINGLE_SELECTOR(input->dataType(), _reverseSequence, (context, input, seqLengths, output, seqDim, batchDim), LIBND4J_TYPES); |         BUILD_SINGLE_SELECTOR(input->dataType(), reverseSequence_, (context, input, seqLengths, output, seqDim, batchDim), LIBND4J_TYPES); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     ////////////////////////////////////////////////////////////////////////// |     ////////////////////////////////////////////////////////////////////////// | ||||||
|  | |||||||
| @ -391,7 +391,7 @@ static void scatterCudaLauncher(const int blocksPerGrid, const int threadsPerBlo | |||||||
| /////////////////////////////////////////////////////////////////// | /////////////////////////////////////////////////////////////////// | ||||||
| void scatter(nd4j::LaunchContext  *context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock) { | void scatter(nd4j::LaunchContext  *context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock) { | ||||||
| 
 | 
 | ||||||
|     PointersManager manager(context, "scatterND"); |     PointersManager manager(context, "scatter"); | ||||||
| 
 | 
 | ||||||
|     NDArray::prepareSpecialUse({&output}, {&updates, &indices}); |     NDArray::prepareSpecialUse({&output}, {&updates, &indices}); | ||||||
| 
 | 
 | ||||||
|  | |||||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										427
									
								
								libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										427
									
								
								libnd4j/include/ops/declarable/helpers/cuda/segment_max.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,427 @@ | |||||||
|  | /******************************************************************************* | ||||||
|  |  * Copyright (c) 2015-2018 Skymind, Inc. | ||||||
|  |  * | ||||||
|  |  * This program and the accompanying materials are made available under the | ||||||
|  |  * terms of the Apache License, Version 2.0 which is available at | ||||||
|  |  * https://www.apache.org/licenses/LICENSE-2.0. | ||||||
|  |  * | ||||||
|  |  * Unless required by applicable law or agreed to in writing, software | ||||||
|  |  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||||
|  |  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||||
|  |  * License for the specific language governing permissions and limitations | ||||||
|  |  * under the License. | ||||||
|  |  * | ||||||
|  |  * SPDX-License-Identifier: Apache-2.0 | ||||||
|  |  ******************************************************************************/ | ||||||
|  | 
 | ||||||
|  | // | ||||||
|  | //  @author GS <sgazeos@gmail.com> | ||||||
|  | // | ||||||
|  | 
 | ||||||
|  | #include <ops/declarable/helpers/segment.h> | ||||||
|  | #include <ops/declarable/helpers/segment_common.h> | ||||||
|  | 
 | ||||||
|  | #include <NDArrayFactory.h> | ||||||
|  | #include <helpers/ShapeUtils.h> | ||||||
|  | #include <helpers/TAD.h> | ||||||
|  | #include <exceptions/cuda_exception.h> | ||||||
|  | #include <PointersManager.h> | ||||||
|  | #include <ConstantTadHelper.h> | ||||||
|  | 
 | ||||||
|  | namespace nd4j { | ||||||
|  |     namespace ops { | ||||||
|  |         namespace helpers { | ||||||
|  | 
 | ||||||
|  |             // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |             // Segment ops linear kernels | ||||||
|  |             // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  | 
 | ||||||
|  |             template<typename T, typename I> | ||||||
|  |             static __global__ void | ||||||
|  |             segmentMaxLinearKernel(void *input, Nd4jLong *inputShape, int *starts, int *lengths, Nd4jLong numOfClasses, | ||||||
|  |                                    void *output, Nd4jLong *outputShape) { | ||||||
|  |                 __shared__ | ||||||
|  |                 T *val; | ||||||
|  |                 __shared__ | ||||||
|  |                 Nd4jLong xLen, zLen, segment, zIndex; | ||||||
|  |                 __shared__ | ||||||
|  |                 T *x; | ||||||
|  |                 __shared__ | ||||||
|  |                 T *z; | ||||||
|  |                 __shared__ int threadsPerSegment, start, finish; | ||||||
|  | 
 | ||||||
|  |                 if (threadIdx.x == 0) { | ||||||
|  |                     threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses; | ||||||
|  |                     segment = blockIdx.x / threadsPerSegment; | ||||||
|  |                     x = reinterpret_cast<T *>(input); | ||||||
|  |                     z = reinterpret_cast<T *>(output); | ||||||
|  |                     extern __shared__ unsigned char shmem[]; | ||||||
|  |                     val = reinterpret_cast<T *>(shmem); | ||||||
|  |                     xLen = shape::length(inputShape); | ||||||
|  |                     zLen = shape::length(outputShape); | ||||||
|  | 
 | ||||||
|  |                     if (segment < numOfClasses) { | ||||||
|  |                         zIndex = shape::getIndexOffset(segment, outputShape, zLen); | ||||||
|  |                         start = starts[segment]; | ||||||
|  |                         finish = start + lengths[segment]; | ||||||
|  |                         z[zIndex] = x[shape::getIndexOffset(start, inputShape, xLen)]; | ||||||
|  |                         val[segment] = z[zIndex]; | ||||||
|  |                     } | ||||||
|  | 
 | ||||||
|  |                 } | ||||||
|  |                 __syncthreads(); | ||||||
|  | 
 | ||||||
|  |                 for (auto e = start + threadIdx.x + 1; e < finish; e += blockDim.x) { | ||||||
|  |                     auto xIndex = shape::getIndexOffset(e, inputShape, xLen); | ||||||
|  |                     nd4j::math::atomics::nd4j_atomicMax(&z[zIndex], x[xIndex]); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |             // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  | 
 | ||||||
|  |             template<typename T, typename I> | ||||||
|  |             static __global__ void | ||||||
|  |             unsortedSegmentMaxLinearKernel(void *input, Nd4jLong *inputShape, void *indices, Nd4jLong *indicesShape, | ||||||
|  |                                            int *starts, int *lengths, Nd4jLong numOfClasses, void *output, | ||||||
|  |                                            Nd4jLong *outputShape) { | ||||||
|  |                 __shared__ | ||||||
|  |                 T *val; | ||||||
|  |                 __shared__ | ||||||
|  |                 Nd4jLong xLen, zLen, segment, zIndex; | ||||||
|  |                 __shared__ | ||||||
|  |                 T *x; | ||||||
|  |                 __shared__ | ||||||
|  |                 T *z; | ||||||
|  |                 __shared__ | ||||||
|  |                 I *y; //int threadsPerSegment, start, finish; | ||||||
|  | 
 | ||||||
|  |                 if (threadIdx.x == 0) { | ||||||
|  |                     segment = blockIdx.x; | ||||||
|  |                     x = reinterpret_cast<T *>(input); | ||||||
|  |                     z = reinterpret_cast<T *>(output); | ||||||
|  |                     y = reinterpret_cast<I *>(indices); | ||||||
|  |                     xLen = shape::length(inputShape); | ||||||
|  |                     zLen = shape::length(outputShape); | ||||||
|  | 
 | ||||||
|  |                     zIndex = shape::getIndexOffset(segment, outputShape, zLen); | ||||||
|  |                     //start = starts[segment]; | ||||||
|  |                     //finish = start + lengths[segment]; | ||||||
|  |                     if (lengths[segment] > 0) | ||||||
|  |                         z[zIndex] = x[shape::getIndexOffset(starts[segment], inputShape, xLen)]; | ||||||
|  |                     else | ||||||
|  |                         z[zIndex] = -DataTypeUtils::max<T>(); | ||||||
|  |                 } | ||||||
|  |                 __syncthreads(); | ||||||
|  |                 if (lengths[segment] > 0) | ||||||
|  |                     for (auto e = threadIdx.x + 1; e < xLen; e += blockDim.x) { | ||||||
|  |                         auto xIndex = shape::getIndexOffset(e, inputShape, xLen); | ||||||
|  |                         auto yIndex = shape::getIndexOffset(e, indicesShape, xLen); | ||||||
|  |                         if (y[yIndex] == segment) { | ||||||
|  |                             nd4j::math::atomics::nd4j_atomicMax(&z[zIndex], x[xIndex]); | ||||||
|  |                         } | ||||||
|  |                     } | ||||||
|  |             } | ||||||
|  |             // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |             template <typename T, typename I> | ||||||
|  |             static __global__ void segmentMaxTadKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong* inputTads, | ||||||
|  |                                                        Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, | ||||||
|  |                                                        Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets, T filler = 0) { | ||||||
|  | 
 | ||||||
|  |                 __shared__ T* val; | ||||||
|  |                 __shared__ Nd4jLong len, segment, zIndex, total; | ||||||
|  |                 __shared__ T* z; | ||||||
|  |                 __shared__ int start, finish; | ||||||
|  | 
 | ||||||
|  |                 if (threadIdx.x == 0) { | ||||||
|  |                     segment = indices[blockIdx.x]; // / threadsPerSegment; | ||||||
|  |                     z = reinterpret_cast<T*>(outputBuf) + outputTadOffsets[segment]; | ||||||
|  |                     len = shape::length(inputTads); | ||||||
|  | 
 | ||||||
|  |                     start = starts[segment]; | ||||||
|  |                     finish = start + lengths[segment]; | ||||||
|  |                     total = shape::sizeAt(inputShape, 0); | ||||||
|  |                 } | ||||||
|  |                 __syncthreads(); | ||||||
|  | 
 | ||||||
|  |                 auto idx = blockIdx.x; | ||||||
|  |                 if (blockIdx.x <= total) { | ||||||
|  |                     auto x = reinterpret_cast<T *>(inputBuf) + inputTadOffsets[idx]; | ||||||
|  |                     if (blockIdx.x == start) { | ||||||
|  |                         for (auto e = threadIdx.x; e < len; e += blockDim.x) { | ||||||
|  |                             auto xIndex = shape::getIndexOffset(e, inputTads, len); | ||||||
|  |                             auto zIndex = shape::getIndexOffset(e, outputTads, len); | ||||||
|  |                             z[zIndex] = x[xIndex]; | ||||||
|  |                         } | ||||||
|  |                     } | ||||||
|  |                     else { | ||||||
|  |                         for (auto e = threadIdx.x; e < len; e += blockDim.x) { | ||||||
|  |                             auto xIndex = shape::getIndexOffset(e, inputTads, len); | ||||||
|  |                             auto zIndex = shape::getIndexOffset(e, outputTads, len); | ||||||
|  |                             nd4j::math::atomics::nd4j_atomicMax(&z[zIndex], x[xIndex]); | ||||||
|  |                         } | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |             // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  | 
 | ||||||
|  |             template <typename T, typename I> | ||||||
|  |             static void segmentMaxFunctor_(LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) { | ||||||
|  |                 //int numClasses = output->sizeAt(0); | ||||||
|  |                 // if input is a vector: (as if in doc sample) | ||||||
|  |                 //Nd4jLong idx = indices->e<Nd4jLong>(0); | ||||||
|  |                 auto stream = context->getCudaStream(); | ||||||
|  |                 indices->syncToHost(); | ||||||
|  |                 Nd4jLong numOfClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1; | ||||||
|  |                 NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}); | ||||||
|  |                 NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}); | ||||||
|  | 
 | ||||||
|  |                 classesRangesBegs.assign(indices->lengthOf()); | ||||||
|  |                 classesRangesLens.assign(0); | ||||||
|  |                 dim3 dims(256, 512, 256); | ||||||
|  |                 int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer()); | ||||||
|  |                 int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer()); | ||||||
|  |                 fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens); | ||||||
|  | 
 | ||||||
|  |                 NDArray::prepareSpecialUse({output}, {input, indices, &classesRangesBegs, &classesRangesLens}); | ||||||
|  | 
 | ||||||
|  |                 if (input->isVector()) { | ||||||
|  | 
 | ||||||
|  |                     segmentMaxLinearKernel<T,I><<<numOfClasses, input->lengthOf(), numOfClasses * 32 + 32, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo()); | ||||||
|  |                 } | ||||||
|  |                 else { | ||||||
|  |                     std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); | ||||||
|  |                     auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions); | ||||||
|  |                     auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions); | ||||||
|  |                     Nd4jLong* inputTads = packX.specialShapeInfo(); | ||||||
|  |                     Nd4jLong* inputTadOffsets = packX.specialOffsets(); | ||||||
|  |                     Nd4jLong* outputTads = packZ.specialShapeInfo(); | ||||||
|  |                     Nd4jLong* outputTadOffsets = packZ.specialOffsets(); | ||||||
|  |                     segmentMaxTadKernel<T,I><<<packX.numberOfTads(), 512, 2048, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets); | ||||||
|  |                 } | ||||||
|  |                 NDArray::registerSpecialUse({output}, {input, indices, &classesRangesBegs, &classesRangesLens}); | ||||||
|  |             } | ||||||
|  |             // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |             void segmentMaxFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) { | ||||||
|  |                 BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentMaxFunctor_, (context, input, indices, output), NUMERIC_TYPES, INTEGER_TYPES); | ||||||
|  |             } | ||||||
|  |             BUILD_DOUBLE_TEMPLATE(template void segmentMaxFunctor_, (LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES); | ||||||
|  |             // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  | 
 | ||||||
|  |             template <typename T, typename I> | ||||||
|  |             static void unsortedSegmentMaxFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { | ||||||
|  |                 auto stream = context->getCudaStream(); | ||||||
|  | //        NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2}); | ||||||
|  |                 NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}); | ||||||
|  |                 NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}); | ||||||
|  | //        NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); | ||||||
|  | //        classes.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), &row, &classes); | ||||||
|  |                 classesRangesBegs.assign(indices->lengthOf()); | ||||||
|  |                 classesRangesLens.assign(0); | ||||||
|  |                 dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32); | ||||||
|  | //        int* classesBuf = reinterpret_cast<int*>(classes.specialBuffer()); | ||||||
|  |                 fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens); | ||||||
|  |                 int* begins = reinterpret_cast<int*>(classesRangesBegs.getSpecialBuffer()); | ||||||
|  |                 int* lengths = reinterpret_cast<int*>(classesRangesLens.getSpecialBuffer()); | ||||||
|  | 
 | ||||||
|  |                 if (input->isVector()) { | ||||||
|  |                     unsortedSegmentMaxLinearKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo()); | ||||||
|  |                 } | ||||||
|  |                 else { | ||||||
|  |                     std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); | ||||||
|  |                     auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions); | ||||||
|  |                     auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions); | ||||||
|  |                     Nd4jLong* inputTads = packX.specialShapeInfo(); | ||||||
|  |                     Nd4jLong* inputTadOffsets = packX.specialOffsets(); | ||||||
|  |                     Nd4jLong* outputTads = packZ.specialShapeInfo(); | ||||||
|  |                     Nd4jLong* outputTadOffsets = packZ.specialOffsets(); | ||||||
|  |                     dims.x = input->sizeAt(0); | ||||||
|  |                     output->assign(-DataTypeUtils::max<T>()); | ||||||
|  |                     segmentMaxTadKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets); | ||||||
|  |                 } | ||||||
|  | 
 | ||||||
|  |             } | ||||||
|  |             // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |             void unsortedSegmentMaxFunctor(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { | ||||||
|  |                 BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMaxFunctor_, (context, input, indices, numOfClasses, output), NUMERIC_TYPES, INTEGER_TYPES); | ||||||
|  |             } | ||||||
|  |             // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |             BUILD_DOUBLE_TEMPLATE(template void unsortedSegmentMaxFunctor_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES); | ||||||
|  |             // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |             // segment max | ||||||
|  |             // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |             template <typename T, typename I> | ||||||
|  |             static __global__ void segmentMaxBPLinearKernel(void* inputBuf, Nd4jLong* inputShape, void* forwardOutput, | ||||||
|  |                                                             Nd4jLong* forwardShape, void* eps, Nd4jLong* epsShape, void* indicesBuf, Nd4jLong* indicesShape, | ||||||
|  |                                                             void* outputBuf, Nd4jLong* outputShape) { | ||||||
|  |                 __shared__ T* x; | ||||||
|  |                 __shared__ T* gradIn; | ||||||
|  |                 __shared__ T* gradOut; | ||||||
|  |                 __shared__ I* y; | ||||||
|  |                 __shared__ T* z; | ||||||
|  |                 __shared__ Nd4jLong xLen, gradLen; | ||||||
|  | 
 | ||||||
|  |                 if (threadIdx.x == 0) { | ||||||
|  |                     xLen = shape::length(inputShape); | ||||||
|  |                     x = reinterpret_cast<T*>(inputBuf); | ||||||
|  |                     y = reinterpret_cast<I*>(indicesBuf); | ||||||
|  |                     z = reinterpret_cast<T*>(outputBuf); | ||||||
|  |                     gradIn = reinterpret_cast<T*>(forwardOutput); | ||||||
|  |                     gradOut = reinterpret_cast<T*>(eps); | ||||||
|  |                     gradLen = shape::length(epsShape); | ||||||
|  |                 } | ||||||
|  | 
 | ||||||
|  |                 auto start = blockIdx.x * blockDim.x + threadIdx.x; | ||||||
|  |                 auto step = gridDim.x * blockDim.x; | ||||||
|  | 
 | ||||||
|  |                 for (auto e = start; e < xLen; e += step) { | ||||||
|  | 
 | ||||||
|  |                     auto zOffset = shape::getIndexOffset(e, outputShape, xLen); | ||||||
|  |                     auto xOffset = shape::getIndexOffset(e, inputShape, xLen); | ||||||
|  |                     auto yOffset = shape::getIndexOffset(e, indicesShape, xLen); | ||||||
|  |                     auto classIndex = y[yOffset]; | ||||||
|  |                     auto gradOffsetI = shape::getIndexOffset(classIndex, forwardShape, gradLen); | ||||||
|  |                     auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape, gradLen); | ||||||
|  | 
 | ||||||
|  |                     if (nd4j::math::nd4j_abs(gradIn[gradOffsetI] - x[xOffset]) <= T(1.e-6)) { | ||||||
|  |                         z[zOffset] = gradOut[gradOffsetO]; | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |             template <typename T, typename I> | ||||||
|  |             static __global__ void segmentMaxBPTadKernel(void* inputBuf, Nd4jLong* inputShape, void* forwardOutput, | ||||||
|  |                                                          Nd4jLong* forwardShape, void* eps, Nd4jLong* epsShape, void* indicesBuf, Nd4jLong* indicesShape, | ||||||
|  |                                                          void* outputBuf, Nd4jLong* outputShape,Nd4jLong* inputTad, | ||||||
|  |                                                          Nd4jLong* inputOffsets, Nd4jLong* gradInTad, Nd4jLong* gradInOffsets, | ||||||
|  |                                                          Nd4jLong* gradOutTad, Nd4jLong* gradOutOffsets, Nd4jLong* outTad, | ||||||
|  |                                                          Nd4jLong* outOffsets) { | ||||||
|  |                 __shared__ T* x; | ||||||
|  |                 __shared__ T* gradIn; | ||||||
|  |                 __shared__ T* gradOut; | ||||||
|  |                 __shared__ I* y; | ||||||
|  |                 __shared__ T* z; | ||||||
|  |                 __shared__ Nd4jLong xLen, yLen, gradLen, currentLen; | ||||||
|  | 
 | ||||||
|  |                 if (threadIdx.x == 0) { | ||||||
|  |                     xLen = shape::length(inputShape); | ||||||
|  |                     x = reinterpret_cast<T*>(inputBuf); | ||||||
|  |                     y = reinterpret_cast<I*>(indicesBuf); | ||||||
|  |                     z = reinterpret_cast<T*>(outputBuf); | ||||||
|  |                     yLen = shape::length(indicesShape); | ||||||
|  |                     gradOut = reinterpret_cast<T*>(eps); | ||||||
|  |                     gradIn = reinterpret_cast<T*>(forwardOutput); | ||||||
|  |                     gradLen = shape::length(epsShape); | ||||||
|  |                     currentLen = shape::length(outTad); | ||||||
|  |                 } | ||||||
|  | 
 | ||||||
|  |                 for (auto i = blockIdx.x; i < yLen; i += gridDim.x) { | ||||||
|  |                     auto yIndex = shape::getIndexOffset(i, indicesShape, yLen); | ||||||
|  |                     auto segment = y[yIndex]; | ||||||
|  |                     T* current = x + inputOffsets[i]; | ||||||
|  |                     T* currentOut = z + outOffsets[i]; | ||||||
|  |                     T* in = gradIn + gradInOffsets[segment]; | ||||||
|  |                     T* outGrad = gradOut + gradOutOffsets[segment]; | ||||||
|  | 
 | ||||||
|  |                     for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) { | ||||||
|  |                         if (nd4j::math::nd4j_abs(in[e] - current[e]) <= T(1.e-6)) | ||||||
|  |                             currentOut[e] = outGrad[e]; | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |             // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |             template <typename T, typename I> | ||||||
|  |             int segmentMaxFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { | ||||||
|  |                 //int numOfClasses = gradOut->sizeAt(0); | ||||||
|  |                 // if input is a vector: (as if in doc sample) | ||||||
|  |                 auto stream = context->getCudaStream(); | ||||||
|  |                 NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), DataTypeUtils::fromT<T>(), context);//->shapeInfo(), context); | ||||||
|  |                 segmentMaxFunctor_<T, I>(context, input, indices, &tempRes); | ||||||
|  |                 NDArray::prepareSpecialUse({output}, {input, indices, gradOut, &tempRes}); | ||||||
|  |                 if (input->isVector()) { | ||||||
|  |                     Nd4jLong loop_size = input->lengthOf(); | ||||||
|  |                     auto numOfClasses = gradOut->lengthOf(); //indices->e<Nd4jLong>(loop_size - 1); | ||||||
|  |                     segmentMaxBPLinearKernel<T,I><<<1 + gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), | ||||||
|  |                             tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), | ||||||
|  |                             indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo()); | ||||||
|  |                 } | ||||||
|  |                 else { | ||||||
|  |                     std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); | ||||||
|  |                     auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions); | ||||||
|  |                     auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions); | ||||||
|  |                     auto packGradIn = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.getShapeInfo(), dimensions); | ||||||
|  |                     auto packGradOut = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->getShapeInfo(), dimensions); | ||||||
|  |                     Nd4jLong* inputTads = packX.specialShapeInfo(); | ||||||
|  |                     Nd4jLong* inputTadOffsets = packX.specialOffsets(); | ||||||
|  |                     Nd4jLong* outputTads = packZ.specialShapeInfo(); | ||||||
|  |                     Nd4jLong* outputTadOffsets = packZ.specialOffsets(); | ||||||
|  |                     Nd4jLong* gradInTads = packGradIn.specialShapeInfo(); | ||||||
|  |                     Nd4jLong* gradInTadOffsets = packGradIn.specialOffsets(); | ||||||
|  |                     Nd4jLong* gradOutTads = packGradOut.specialShapeInfo(); | ||||||
|  |                     Nd4jLong* gradOutTadOffsets = packGradOut.specialOffsets(); | ||||||
|  | 
 | ||||||
|  |                     segmentMaxBPTadKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), | ||||||
|  |                             tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), | ||||||
|  |                             indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), | ||||||
|  |                             inputTads, inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, gradOutTadOffsets, | ||||||
|  |                             outputTads, outputTadOffsets); | ||||||
|  |                 } | ||||||
|  |                 NDArray::registerSpecialUse({output}, {input, indices, gradOut, &tempRes}); | ||||||
|  |                 return Status::OK(); | ||||||
|  |             } | ||||||
|  |             // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |             int segmentMaxFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { | ||||||
|  |                 BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentMaxFunctorBP_, (context, input, | ||||||
|  |                         indices, gradOut, output), NUMERIC_TYPES, INTEGER_TYPES); | ||||||
|  |             } | ||||||
|  |             // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |             BUILD_DOUBLE_TEMPLATE(template int segmentMaxFunctorBP_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES); | ||||||
|  | 
 | ||||||
|  |             // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |             template <typename T, typename I> | ||||||
|  |             static int unsortedSegmentMaxFunctorBP_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { | ||||||
|  |                 //int numOfClasses = gradOut->sizeAt(0); | ||||||
|  |                 // if input is a vector: (as if in doc sample) | ||||||
|  |                 auto stream = context->getCudaStream(); | ||||||
|  |                 NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), DataTypeUtils::fromT<T>(), context);//->shapeInfo(), context); | ||||||
|  |                 unsortedSegmentMaxFunctor_<T, I>(context, input, indices, numOfClasses, &tempRes); | ||||||
|  |                 NDArray::prepareSpecialUse({output}, {input, indices, gradOut, &tempRes}); | ||||||
|  |                 if (input->isVector()) { | ||||||
|  |                     Nd4jLong loop_size = input->lengthOf(); | ||||||
|  |                     auto numOfClasses = gradOut->lengthOf(); //indices->e<Nd4jLong>(loop_size - 1); | ||||||
|  |                     segmentMaxBPLinearKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), | ||||||
|  |                             tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), | ||||||
|  |                             indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo()); | ||||||
|  |                 } | ||||||
|  |                 else { | ||||||
|  |                     std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); | ||||||
|  |                     auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions); | ||||||
|  |                     auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions); | ||||||
|  |                     auto packGradIn = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.getShapeInfo(), dimensions); | ||||||
|  |                     auto packGradOut = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->getShapeInfo(), dimensions); | ||||||
|  |                     Nd4jLong* inputTads = packX.specialShapeInfo(); | ||||||
|  |                     Nd4jLong* inputTadOffsets = packX.specialOffsets(); | ||||||
|  |                     Nd4jLong* outputTads = packZ.specialShapeInfo(); | ||||||
|  |                     Nd4jLong* outputTadOffsets = packZ.specialOffsets(); | ||||||
|  |                     Nd4jLong* gradInTads = packGradIn.specialShapeInfo(); | ||||||
|  |                     Nd4jLong* gradInTadOffsets = packGradIn.specialOffsets(); | ||||||
|  |                     Nd4jLong* gradOutTads = packGradOut.specialShapeInfo(); | ||||||
|  |                     Nd4jLong* gradOutTadOffsets = packGradOut.specialOffsets(); | ||||||
|  | 
 | ||||||
|  |                     segmentMaxBPTadKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), | ||||||
|  |                             tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), | ||||||
|  |                             indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), | ||||||
|  |                             inputTads, inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, gradOutTadOffsets, | ||||||
|  |                             outputTads, outputTadOffsets); | ||||||
|  |                 } | ||||||
|  |                 NDArray::registerSpecialUse({output}, {input, indices, gradOut, &tempRes}); | ||||||
|  |                 return Status::OK(); | ||||||
|  |             } | ||||||
|  |             // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |             int unsortedSegmentMaxFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { | ||||||
|  |                 BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMaxFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), NUMERIC_TYPES, INTEGER_TYPES); | ||||||
|  |             } | ||||||
|  |             // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |             BUILD_DOUBLE_TEMPLATE(template int unsortedSegmentMaxFunctorBP_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
							
								
								
									
										414
									
								
								libnd4j/include/ops/declarable/helpers/cuda/segment_mean.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										414
									
								
								libnd4j/include/ops/declarable/helpers/cuda/segment_mean.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,414 @@ | |||||||
|  | /******************************************************************************* | ||||||
|  |  * Copyright (c) 2015-2018 Skymind, Inc. | ||||||
|  |  * | ||||||
|  |  * This program and the accompanying materials are made available under the | ||||||
|  |  * terms of the Apache License, Version 2.0 which is available at | ||||||
|  |  * https://www.apache.org/licenses/LICENSE-2.0. | ||||||
|  |  * | ||||||
|  |  * Unless required by applicable law or agreed to in writing, software | ||||||
|  |  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||||
|  |  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||||
|  |  * License for the specific language governing permissions and limitations | ||||||
|  |  * under the License. | ||||||
|  |  * | ||||||
|  |  * SPDX-License-Identifier: Apache-2.0 | ||||||
|  |  ******************************************************************************/ | ||||||
|  | 
 | ||||||
|  | // | ||||||
|  | //  @author GS <sgazeos@gmail.com> | ||||||
|  | // | ||||||
|  | 
 | ||||||
|  | #include <ops/declarable/helpers/segment.h> | ||||||
|  | #include <ops/declarable/helpers/segment_common.h> | ||||||
|  | #include <NDArrayFactory.h> | ||||||
|  | #include <helpers/ShapeUtils.h> | ||||||
|  | #include <helpers/TAD.h> | ||||||
|  | #include <exceptions/cuda_exception.h> | ||||||
|  | #include <PointersManager.h> | ||||||
|  | #include <ConstantTadHelper.h> | ||||||
|  | 
 | ||||||
|  | namespace nd4j { | ||||||
|  | namespace ops { | ||||||
|  | namespace helpers { | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     // Segment ops linear kernels | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     template <typename T, typename I> | ||||||
|  |     static __global__ void segmentMeanLinearKernel(void* input, Nd4jLong* inputShape, int* starts, int* lengths, Nd4jLong numOfClasses, void* output, Nd4jLong* outputShape) { | ||||||
|  |         __shared__ T* val; | ||||||
|  |         __shared__ Nd4jLong xLen, zLen, segment, zIndex; | ||||||
|  |         __shared__ T* x; | ||||||
|  |         __shared__ T* z; | ||||||
|  |         __shared__ int threadsPerSegment, start, finish; | ||||||
|  | 
 | ||||||
|  |         if (threadIdx.x == 0) { | ||||||
|  |             threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses; | ||||||
|  |             segment = blockIdx.x / threadsPerSegment; | ||||||
|  |             x = reinterpret_cast<T*>(input); | ||||||
|  |             z = reinterpret_cast<T*>(output); | ||||||
|  | //            extern __shared__ unsigned char shmem[]; | ||||||
|  | //            val = reinterpret_cast<T*>(shmem); | ||||||
|  |             xLen = shape::length(inputShape); | ||||||
|  |             zLen = shape::length(outputShape); | ||||||
|  | 
 | ||||||
|  |             //[zIndex] = | ||||||
|  |             if (segment < numOfClasses) { | ||||||
|  |                 zIndex = shape::getIndexOffset(segment, outputShape, zLen); | ||||||
|  |                 start = starts[segment]; | ||||||
|  |                 finish = start + lengths[segment]; | ||||||
|  |                 //val[segment] = ; | ||||||
|  |                 z[zIndex] = T(x[shape::getIndexOffset(start, inputShape, xLen)] / lengths[segment]); | ||||||
|  | //                val[segment] = z[zIndex]; | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |         } | ||||||
|  |         __syncthreads(); | ||||||
|  | 
 | ||||||
|  |         for (auto e = start + threadIdx.x + 1; e < finish; e += blockDim.x) { | ||||||
|  |             auto xIndex = shape::getIndexOffset(e, inputShape, xLen); | ||||||
|  |             if (lengths[segment]) | ||||||
|  |                 nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], T(x[xIndex] / lengths[segment])); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     template <typename T, typename I> | ||||||
|  |     static __global__ void unsortedSegmentMeanLinearKernel(void* input, Nd4jLong* inputShape, void* indices, Nd4jLong* indicesShape, int* starts, int* lengths, Nd4jLong numOfClasses, void* output, Nd4jLong* outputShape) { | ||||||
|  |         __shared__ T* val; | ||||||
|  |         __shared__ Nd4jLong xLen, zLen, segment, zIndex; | ||||||
|  |         __shared__ T* x; | ||||||
|  |         __shared__ T* z; | ||||||
|  |         __shared__ I* y; //int threadsPerSegment, start, finish; | ||||||
|  | 
 | ||||||
|  |         if (threadIdx.x == 0) { | ||||||
|  | //            threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses; | ||||||
|  |             segment = blockIdx.x;// / threadsPerSegment; | ||||||
|  |             x = reinterpret_cast<T*>(input); | ||||||
|  |             z = reinterpret_cast<T*>(output); | ||||||
|  |             y = reinterpret_cast<I*>(indices); | ||||||
|  | //            extern __shared__ unsigned char shmem[]; | ||||||
|  | //            val = reinterpret_cast<T*>(shmem); | ||||||
|  |             xLen = shape::length(inputShape); | ||||||
|  |             zLen = shape::length(outputShape); | ||||||
|  | 
 | ||||||
|  | //            if (segment < numOfClasses) { | ||||||
|  |             zIndex = shape::getIndexOffset(segment, outputShape, zLen); | ||||||
|  |             //start = starts[segment]; | ||||||
|  |             //finish = start + lengths[segment]; | ||||||
|  |             if (lengths[segment] > 0) | ||||||
|  |                 z[zIndex] = T(x[shape::getIndexOffset(starts[segment], inputShape, xLen)] / T(lengths[segment])); | ||||||
|  |             else | ||||||
|  |                 z[zIndex] = 0; //DataTypeUtils::max<T>(); | ||||||
|  | //                val[segment] = z[zIndex]; | ||||||
|  | //            } | ||||||
|  | 
 | ||||||
|  |         } | ||||||
|  |         __syncthreads(); | ||||||
|  |         if (lengths[segment] > 0) | ||||||
|  |             for (auto e = threadIdx.x; e < xLen; e += blockDim.x) { | ||||||
|  |                 auto xIndex = shape::getIndexOffset(e, inputShape, xLen); | ||||||
|  |                 auto yIndex = shape::getIndexOffset(e, indicesShape, xLen); | ||||||
|  |                 if (y[yIndex] == segment && e != starts[segment]) { | ||||||
|  |                     nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], T(x[xIndex]/T(lengths[segment]))); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     // SegmentMean kernel | ||||||
|  |     template <typename T, typename I> | ||||||
|  |     static __global__ void segmentMeanTadKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong* inputTads, Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets) { | ||||||
|  |         __shared__ T* val; | ||||||
|  |         __shared__ Nd4jLong len, segment, zIndex, total; | ||||||
|  |         __shared__ T* z; | ||||||
|  |         __shared__ int threadsPerSegment, start, finish; | ||||||
|  | 
 | ||||||
|  |         if (threadIdx.x == 0) { | ||||||
|  |             segment = indices[blockIdx.x]; // / threadsPerSegment; | ||||||
|  |             z = reinterpret_cast<T*>(outputBuf) + outputTadOffsets[segment]; | ||||||
|  |             len = shape::length(inputTads); | ||||||
|  |             start = starts[segment]; | ||||||
|  |             finish = start + lengths[segment]; | ||||||
|  |             total = shape::sizeAt(inputShape, 0); | ||||||
|  | 
 | ||||||
|  |         } | ||||||
|  |         __syncthreads(); | ||||||
|  | 
 | ||||||
|  |         auto idx = blockIdx.x; | ||||||
|  |         if (blockIdx.x <= total) { | ||||||
|  |             auto x = reinterpret_cast<T *>(inputBuf) + inputTadOffsets[idx]; | ||||||
|  |             if (blockIdx.x == start) { | ||||||
|  |                 for (auto e = threadIdx.x; e < len; e += blockDim.x) { | ||||||
|  |                     auto xIndex = shape::getIndexOffset(e, inputTads, len); | ||||||
|  |                     auto zIndex = shape::getIndexOffset(e, outputTads, len); | ||||||
|  |                     z[zIndex] = T(x[xIndex]/lengths[segment]); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |             else { | ||||||
|  |                 for (auto e = threadIdx.x; e < len; e += blockDim.x) { | ||||||
|  |                     auto xIndex = shape::getIndexOffset(e, inputTads, len); | ||||||
|  |                     auto zIndex = shape::getIndexOffset(e, outputTads, len); | ||||||
|  |                     if (lengths[segment]) | ||||||
|  |                         nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], T(x[xIndex]/lengths[segment])); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     // segmen mean | ||||||
|  |     template <typename T, typename I> | ||||||
|  |     static void segmentMeanFunctor_(LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) { | ||||||
|  |         auto stream = context->getCudaStream(); | ||||||
|  |         Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1; | ||||||
|  |         NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}); | ||||||
|  |         NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}); | ||||||
|  | 
 | ||||||
|  |         classesRangesBegs.assign(indices->lengthOf()); | ||||||
|  |         classesRangesLens.assign(0); | ||||||
|  | 
 | ||||||
|  |         dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32); | ||||||
|  |         int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer()); | ||||||
|  |         int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer()); | ||||||
|  |         fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens); | ||||||
|  | 
 | ||||||
|  |         if (input->isVector()) { | ||||||
|  |             segmentMeanLinearKernel<T,I><<<numClasses, input->lengthOf(), numClasses * 32 + 32, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo()); | ||||||
|  |         } | ||||||
|  |         else { | ||||||
|  |             std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); | ||||||
|  |             auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions); | ||||||
|  |             auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions); | ||||||
|  |             Nd4jLong* inputTads = packX.specialShapeInfo(); | ||||||
|  |             Nd4jLong* inputTadOffsets = packX.specialOffsets(); | ||||||
|  |             Nd4jLong* outputTads = packZ.specialShapeInfo(); | ||||||
|  |             Nd4jLong* outputTadOffsets = packZ.specialOffsets(); | ||||||
|  |             segmentMeanTadKernel<T,I><<<input->sizeAt(0), 512, 2048, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     void segmentMeanFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) { | ||||||
|  |         BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), segmentMeanFunctor_, (context, input, indices, output), FLOAT_TYPES, INTEGER_TYPES); | ||||||
|  |     } | ||||||
|  |     BUILD_DOUBLE_TEMPLATE(template void segmentMeanFunctor_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output), FLOAT_TYPES, INTEGER_TYPES); | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     template <typename T, typename I> | ||||||
|  |     static void unsortedSegmentMeanFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { | ||||||
|  |         auto stream = context->getCudaStream(); | ||||||
|  | //        NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2}); | ||||||
|  |         NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}); | ||||||
|  |         NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}); | ||||||
|  | //        NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); | ||||||
|  | //        classes.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), &row, &classes); | ||||||
|  |         classesRangesBegs.assign(indices->lengthOf()); | ||||||
|  |         classesRangesLens.assign(0); | ||||||
|  |         dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32); | ||||||
|  | //        int* classesBuf = reinterpret_cast<int*>(classes.specialBuffer()); | ||||||
|  |         fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens); | ||||||
|  |         int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer()); | ||||||
|  |         int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer()); | ||||||
|  | 
 | ||||||
|  |         if (input->isVector()) { | ||||||
|  |             unsortedSegmentMeanLinearKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo()); | ||||||
|  |         } | ||||||
|  |         else { | ||||||
|  |             output->assign(0); | ||||||
|  |             std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); | ||||||
|  |             auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions); | ||||||
|  |             auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions); | ||||||
|  |             Nd4jLong* inputTads = packX.specialShapeInfo(); | ||||||
|  |             Nd4jLong* inputTadOffsets = packX.specialOffsets(); | ||||||
|  |             Nd4jLong* outputTads = packZ.specialShapeInfo(); | ||||||
|  |             Nd4jLong* outputTadOffsets = packZ.specialOffsets(); | ||||||
|  |             dims.x = input->sizeAt(0); | ||||||
|  |             segmentMeanTadKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     void unsortedSegmentMeanFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { | ||||||
|  |         BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMeanFunctor_, (context, input, indices, numOfClasses, output), | ||||||
|  |                               FLOAT_TYPES, INTEGER_TYPES); | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     BUILD_DOUBLE_TEMPLATE(template void unsortedSegmentMeanFunctor_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), FLOAT_TYPES, INTEGER_TYPES); | ||||||
|  | 
 | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     template <typename T, typename I> | ||||||
|  |     static __global__ void segmentMeanBPLinearKernel(void* inputBuf, Nd4jLong* inputShape, void* eps, Nd4jLong* epsShape, void* indicesBuf, Nd4jLong* indicesShape, | ||||||
|  |                                                      int* lengths, void* outputBuf, Nd4jLong* outputShape) { | ||||||
|  |         __shared__ T* x; | ||||||
|  |         __shared__ T* gradIn; | ||||||
|  |         __shared__ T* gradOut; | ||||||
|  |         __shared__ I* y; | ||||||
|  |         __shared__ T* z; | ||||||
|  |         __shared__ Nd4jLong xLen, gradLen; | ||||||
|  | 
 | ||||||
|  |         if (threadIdx.x == 0) { | ||||||
|  |             xLen = shape::length(inputShape); | ||||||
|  |             x = reinterpret_cast<T*>(inputBuf); | ||||||
|  |             y = reinterpret_cast<I*>(indicesBuf); | ||||||
|  |             z = reinterpret_cast<T*>(outputBuf); | ||||||
|  |             gradOut = reinterpret_cast<T*>(eps); | ||||||
|  |             gradLen = shape::length(epsShape); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         auto start = blockIdx.x * blockDim.x + threadIdx.x; | ||||||
|  |         auto step = gridDim.x * blockDim.x; | ||||||
|  | 
 | ||||||
|  |         for (auto e = start; e < xLen; e += step) { | ||||||
|  | 
 | ||||||
|  |             auto zOffset = shape::getIndexOffset(e, outputShape, xLen); | ||||||
|  |             auto xOffset = shape::getIndexOffset(e, inputShape, xLen); | ||||||
|  |             auto yOffset = shape::getIndexOffset(e, indicesShape, xLen); | ||||||
|  |             auto classIndex = y[yOffset]; | ||||||
|  |             auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape, gradLen); | ||||||
|  | 
 | ||||||
|  |             z[zOffset] = T(gradOut[gradOffsetO] / float(lengths[classIndex])); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     template <typename T, typename I> | ||||||
|  |     static __global__ void segmentMeanBPTadKernel(void* inputBuf, Nd4jLong* inputShape, void* eps, Nd4jLong* epsShape, | ||||||
|  |                                                   void* indicesBuf, Nd4jLong* indicesShape, int* lengths, void* outputBuf, Nd4jLong* outputShape,Nd4jLong* inputTad, | ||||||
|  |                                                   Nd4jLong* inputOffsets, Nd4jLong* gradOutTad, Nd4jLong* gradOutOffsets, Nd4jLong* outTad, Nd4jLong* outOffsets) { | ||||||
|  |         __shared__ T* x; | ||||||
|  |         __shared__ T* gradOut; | ||||||
|  |         __shared__ I* y; | ||||||
|  |         __shared__ T* z; | ||||||
|  |         __shared__ Nd4jLong xLen, yLen, gradLen, currentLen; | ||||||
|  | 
 | ||||||
|  |         if (threadIdx.x == 0) { | ||||||
|  |             xLen = shape::length(inputShape); | ||||||
|  |             x = reinterpret_cast<T*>(inputBuf); | ||||||
|  |             y = reinterpret_cast<I*>(indicesBuf); | ||||||
|  |             z = reinterpret_cast<T*>(outputBuf); | ||||||
|  |             yLen = shape::length(indicesShape); | ||||||
|  |             gradOut = reinterpret_cast<T*>(eps); | ||||||
|  |             gradLen = shape::length(epsShape); | ||||||
|  |             currentLen = shape::length(outTad); | ||||||
|  |         } | ||||||
|  |         __syncthreads(); | ||||||
|  | 
 | ||||||
|  |         for (auto i = blockIdx.x; i < yLen; i += gridDim.x) { | ||||||
|  | //            auto yIndex = shape::getIndexOffset(i, indicesShape, yLen); | ||||||
|  |             auto segment = y[i]; //yIndex]; | ||||||
|  |             T* currentOut = z + outOffsets[i]; | ||||||
|  |             T* outGrad = gradOut + gradOutOffsets[segment]; | ||||||
|  | 
 | ||||||
|  |             for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) { | ||||||
|  |                 auto zIndex = shape::getIndexOffset(e, outTad, currentLen); | ||||||
|  |                 auto gradIndex = shape::getIndexOffset(e, gradOutTad, gradLen); | ||||||
|  |                 if (lengths[segment] > 0) | ||||||
|  |                     currentOut[zIndex] = T(outGrad[gradIndex] / float(lengths[segment])); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     // backrop for mean | ||||||
|  |     template <typename T, typename I> | ||||||
|  |     int segmentMeanFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { | ||||||
|  |         auto stream = context->getCudaStream(); | ||||||
|  |         NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); | ||||||
|  |         auto numClasses = indices->e<int>(indices->lengthOf() - 1) + 1; | ||||||
|  |         NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}); | ||||||
|  |         NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}); | ||||||
|  | 
 | ||||||
|  |         classesRangesBegs.assign(indices->lengthOf()); | ||||||
|  |         classesRangesLens.assign(0); | ||||||
|  |         dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32); | ||||||
|  |         fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens); | ||||||
|  |         int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer()); | ||||||
|  |         int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer()); | ||||||
|  | 
 | ||||||
|  |         if (input->isVector()) { | ||||||
|  |             Nd4jLong loop_size = input->lengthOf(); | ||||||
|  |             auto numOfClasses = gradOut->lengthOf(); //indices->e<Nd4jLong>(loop_size - 1); | ||||||
|  |             segmentMeanBPLinearKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), | ||||||
|  |                     input->specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), | ||||||
|  |                     indices->specialBuffer(), indices->specialShapeInfo(), lengths, output->specialBuffer(), output->specialShapeInfo()); | ||||||
|  |         } | ||||||
|  |         else { | ||||||
|  |             std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); | ||||||
|  |             auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions); | ||||||
|  |             auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions); | ||||||
|  | //            auto packGradIn = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.getShapeInfo(), dimensions); | ||||||
|  |             auto packGradOut = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->getShapeInfo(), dimensions); | ||||||
|  |             Nd4jLong* inputTads = packX.specialShapeInfo(); | ||||||
|  |             Nd4jLong* inputTadOffsets = packX.specialOffsets(); | ||||||
|  |             Nd4jLong* outputTads = packZ.specialShapeInfo(); | ||||||
|  |             Nd4jLong* outputTadOffsets = packZ.specialOffsets(); | ||||||
|  |             Nd4jLong* gradOutTads = packGradOut.specialShapeInfo(); | ||||||
|  |             Nd4jLong* gradOutTadOffsets = packGradOut.specialOffsets(); | ||||||
|  | 
 | ||||||
|  |             segmentMeanBPTadKernel<T,I><<<indices->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), | ||||||
|  |                     gradOut->specialBuffer(), gradOut->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), lengths, | ||||||
|  |                     output->specialBuffer(), output->specialShapeInfo(), inputTads, inputTadOffsets, gradOutTads, gradOutTadOffsets, | ||||||
|  |                     outputTads, outputTadOffsets); | ||||||
|  |         } | ||||||
|  |         NDArray::registerSpecialUse({output}, {input, indices, gradOut}); | ||||||
|  |         return Status::OK(); | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     // segmen mean bp main | ||||||
|  |     int segmentMeanFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { | ||||||
|  |         BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentMeanFunctorBP_, (context, input, | ||||||
|  |                 indices, gradOut, output), NUMERIC_TYPES, INTEGER_TYPES); | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     BUILD_DOUBLE_TEMPLATE(template int segmentMeanFunctorBP_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output), FLOAT_TYPES, INTEGER_TYPES); | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  | 
 | ||||||
|  |     template <typename T, typename I> | ||||||
|  |     static int unsortedSegmentMeanFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { | ||||||
|  |         auto stream = context->getCudaStream(); | ||||||
|  |         NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); | ||||||
|  |         auto numClasses = indices->e<int>(indices->lengthOf() - 1) + 1; | ||||||
|  |         NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}); | ||||||
|  |         NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}); | ||||||
|  | 
 | ||||||
|  |         classesRangesBegs.assign(indices->lengthOf()); | ||||||
|  |         classesRangesLens.assign(0); | ||||||
|  |         dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32); | ||||||
|  |         fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens); | ||||||
|  |         int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer()); | ||||||
|  |         int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer()); | ||||||
|  | 
 | ||||||
|  |         if (input->isVector()) { | ||||||
|  |             Nd4jLong loop_size = input->lengthOf(); | ||||||
|  |             auto numOfClasses = gradOut->lengthOf(); //indices->e<Nd4jLong>(loop_size - 1); | ||||||
|  |             segmentMeanBPLinearKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), | ||||||
|  |                     input->specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), | ||||||
|  |                     indices->specialBuffer(), indices->specialShapeInfo(), lengths, output->specialBuffer(), output->specialShapeInfo()); | ||||||
|  |         } | ||||||
|  |         else { | ||||||
|  |             std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); | ||||||
|  |             auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions); | ||||||
|  |             auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions); | ||||||
|  | //            auto packGradIn = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.getShapeInfo(), dimensions); | ||||||
|  |             auto packGradOut = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->getShapeInfo(), dimensions); | ||||||
|  |             Nd4jLong* inputTads = packX.specialShapeInfo(); | ||||||
|  |             Nd4jLong* inputTadOffsets = packX.specialOffsets(); | ||||||
|  |             Nd4jLong* outputTads = packZ.specialShapeInfo(); | ||||||
|  |             Nd4jLong* outputTadOffsets = packZ.specialOffsets(); | ||||||
|  |             Nd4jLong* gradOutTads = packGradOut.specialShapeInfo(); | ||||||
|  |             Nd4jLong* gradOutTadOffsets = packGradOut.specialOffsets(); | ||||||
|  | 
 | ||||||
|  |             segmentMeanBPTadKernel<T,I><<<indices->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), | ||||||
|  |                     gradOut->specialBuffer(), gradOut->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), lengths, | ||||||
|  |                     output->specialBuffer(), output->specialShapeInfo(), inputTads, inputTadOffsets, gradOutTads, gradOutTadOffsets, | ||||||
|  |                     outputTads, outputTadOffsets); | ||||||
|  |         } | ||||||
|  |         NDArray::registerSpecialUse({output}, {input, indices, gradOut}); | ||||||
|  |         return Status::OK(); | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     int unsortedSegmentMeanFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { | ||||||
|  |         BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMeanFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INTEGER_TYPES); | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  | 
 | ||||||
|  |     BUILD_DOUBLE_TEMPLATE(template int unsortedSegmentMeanFunctorBP_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), FLOAT_TYPES, INTEGER_TYPES); | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  | 
 | ||||||
|  | } | ||||||
|  | } | ||||||
|  | } | ||||||
							
								
								
									
										423
									
								
								libnd4j/include/ops/declarable/helpers/cuda/segment_min.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										423
									
								
								libnd4j/include/ops/declarable/helpers/cuda/segment_min.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,423 @@ | |||||||
|  | /******************************************************************************* | ||||||
|  |  * Copyright (c) 2015-2018 Skymind, Inc. | ||||||
|  |  * | ||||||
|  |  * This program and the accompanying materials are made available under the | ||||||
|  |  * terms of the Apache License, Version 2.0 which is available at | ||||||
|  |  * https://www.apache.org/licenses/LICENSE-2.0. | ||||||
|  |  * | ||||||
|  |  * Unless required by applicable law or agreed to in writing, software | ||||||
|  |  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||||
|  |  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||||
|  |  * License for the specific language governing permissions and limitations | ||||||
|  |  * under the License. | ||||||
|  |  * | ||||||
|  |  * SPDX-License-Identifier: Apache-2.0 | ||||||
|  |  ******************************************************************************/ | ||||||
|  | 
 | ||||||
|  | // | ||||||
|  | //  @author GS <sgazeos@gmail.com> | ||||||
|  | // | ||||||
|  | 
 | ||||||
|  | #include <ops/declarable/helpers/segment.h> | ||||||
|  | #include <ops/declarable/helpers/segment_common.h> | ||||||
|  | #include <NDArrayFactory.h> | ||||||
|  | #include <helpers/ShapeUtils.h> | ||||||
|  | #include <helpers/TAD.h> | ||||||
|  | #include <exceptions/cuda_exception.h> | ||||||
|  | #include <PointersManager.h> | ||||||
|  | #include <ConstantTadHelper.h> | ||||||
|  | 
 | ||||||
|  | namespace nd4j { | ||||||
|  | namespace ops { | ||||||
|  | namespace helpers { | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     // Segment ops linear kernels | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  | 
 | ||||||
|  |     template<typename T, typename I> | ||||||
|  |     static __global__ void | ||||||
|  |     segmentMinLinearKernel(void *input, Nd4jLong *inputShape, int *starts, int *lengths, Nd4jLong numOfClasses, | ||||||
|  |                            void *output, Nd4jLong *outputShape) { | ||||||
|  |         __shared__ | ||||||
|  |         T *val; | ||||||
|  |         __shared__ | ||||||
|  |         Nd4jLong xLen, zLen, segment, zIndex; | ||||||
|  |         __shared__ | ||||||
|  |         T *x; | ||||||
|  |         __shared__ | ||||||
|  |         T *z; | ||||||
|  |         __shared__ int threadsPerSegment, start, finish; | ||||||
|  | 
 | ||||||
|  |         if (threadIdx.x == 0) { | ||||||
|  |             threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses; | ||||||
|  |             segment = blockIdx.x / threadsPerSegment; | ||||||
|  |             x = reinterpret_cast<T *>(input); | ||||||
|  |             z = reinterpret_cast<T *>(output); | ||||||
|  |             extern __shared__ unsigned char shmem[]; | ||||||
|  |             val = reinterpret_cast<T *>(shmem); | ||||||
|  |             xLen = shape::length(inputShape); | ||||||
|  |             zLen = shape::length(outputShape); | ||||||
|  | 
 | ||||||
|  |             if (segment < numOfClasses) { | ||||||
|  |                 zIndex = shape::getIndexOffset(segment, outputShape, zLen); | ||||||
|  |                 start = starts[segment]; | ||||||
|  |                 finish = start + lengths[segment]; | ||||||
|  |                 z[zIndex] = x[shape::getIndexOffset(start, inputShape, xLen)]; | ||||||
|  |                 val[segment] = z[zIndex]; | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |         } | ||||||
|  |         __syncthreads(); | ||||||
|  | 
 | ||||||
|  |         for (auto e = start + threadIdx.x + 1; e < finish; e += blockDim.x) { | ||||||
|  |             auto xIndex = shape::getIndexOffset(e, inputShape, xLen); | ||||||
|  |             nd4j::math::atomics::nd4j_atomicMin(&z[zIndex], x[xIndex]); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  | 
 | ||||||
|  |     template<typename T, typename I> | ||||||
|  |     static __global__ void | ||||||
|  |     unsortedSegmentMinLinearKernel(void *input, Nd4jLong *inputShape, void *indices, Nd4jLong *indicesShape, | ||||||
|  |                                    int *starts, int *lengths, Nd4jLong numOfClasses, void *output, | ||||||
|  |                                    Nd4jLong *outputShape) { | ||||||
|  |         __shared__ | ||||||
|  |         T *val; | ||||||
|  |         __shared__ | ||||||
|  |         Nd4jLong xLen, zLen, segment, zIndex; | ||||||
|  |         __shared__ | ||||||
|  |         T *x; | ||||||
|  |         __shared__ | ||||||
|  |         T *z; | ||||||
|  |         __shared__ | ||||||
|  |         I *y; //int threadsPerSegment, start, finish; | ||||||
|  | 
 | ||||||
|  |         if (threadIdx.x == 0) { | ||||||
|  |             segment = blockIdx.x; | ||||||
|  |             x = reinterpret_cast<T *>(input); | ||||||
|  |             z = reinterpret_cast<T *>(output); | ||||||
|  |             y = reinterpret_cast<I *>(indices); | ||||||
|  |             xLen = shape::length(inputShape); | ||||||
|  |             zLen = shape::length(outputShape); | ||||||
|  | 
 | ||||||
|  |             zIndex = shape::getIndexOffset(segment, outputShape, zLen); | ||||||
|  |             if (lengths[segment] > 0) | ||||||
|  |                 z[zIndex] = x[shape::getIndexOffset(starts[segment], inputShape, xLen)]; | ||||||
|  |             else | ||||||
|  |                 z[zIndex] = DataTypeUtils::max<T>(); | ||||||
|  | 
 | ||||||
|  |         } | ||||||
|  |         __syncthreads(); | ||||||
|  |         if (lengths[segment] > 0) | ||||||
|  |             for (auto e = threadIdx.x + 1; e < xLen; e += blockDim.x) { | ||||||
|  |                 auto xIndex = shape::getIndexOffset(e, inputShape, xLen); | ||||||
|  |                 auto yIndex = shape::getIndexOffset(e, indicesShape, xLen); | ||||||
|  |                 if (y[yIndex] == segment) { | ||||||
|  |                     nd4j::math::atomics::nd4j_atomicMin(&z[zIndex], x[xIndex]); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  | // SegmentMin kernel | ||||||
|  |     template <typename T, typename I> | ||||||
|  |     static __global__ void segmentMinTadKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong* inputTads, Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets) { | ||||||
|  |         __shared__ T* val; | ||||||
|  |         __shared__ Nd4jLong len, segment, zIndex, total; | ||||||
|  |         __shared__ T* z; | ||||||
|  |         __shared__ int threadsPerSegment, start, finish; | ||||||
|  | 
 | ||||||
|  |         if (threadIdx.x == 0) { | ||||||
|  |             segment = indices[blockIdx.x]; // / threadsPerSegment; | ||||||
|  |             z = reinterpret_cast<T*>(outputBuf) + outputTadOffsets[segment]; | ||||||
|  |             len = shape::length(inputTads); | ||||||
|  |             start = starts[segment]; | ||||||
|  |             finish = start + lengths[segment]; | ||||||
|  |             total = shape::sizeAt(inputShape, 0); | ||||||
|  | 
 | ||||||
|  |         } | ||||||
|  |         __syncthreads(); | ||||||
|  | 
 | ||||||
|  |         auto idx = blockIdx.x; | ||||||
|  |         if (blockIdx.x <= total) { | ||||||
|  |             auto x = reinterpret_cast<T *>(inputBuf) + inputTadOffsets[idx]; | ||||||
|  |             if (blockIdx.x == start) { | ||||||
|  |                 for (auto e = threadIdx.x; e < len; e += blockDim.x) { | ||||||
|  |                     auto xIndex = shape::getIndexOffset(e, inputTads, len); | ||||||
|  |                     auto zIndex = shape::getIndexOffset(e, outputTads, len); | ||||||
|  |                     z[zIndex] = x[xIndex]; | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |             else { | ||||||
|  |                 for (auto e = threadIdx.x; e < len; e += blockDim.x) { | ||||||
|  |                     auto xIndex = shape::getIndexOffset(e, inputTads, len); | ||||||
|  |                     auto zIndex = shape::getIndexOffset(e, outputTads, len); | ||||||
|  |                     nd4j::math::atomics::nd4j_atomicMin(&z[zIndex], x[xIndex]); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     // segmen min | ||||||
|  |     template <typename T, typename I> | ||||||
|  |     static void segmentMinFunctor_(LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) { | ||||||
|  |         auto stream = context->getCudaStream(); | ||||||
|  |         Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1; | ||||||
|  |         NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}); | ||||||
|  |         NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}); | ||||||
|  | 
 | ||||||
|  |         classesRangesBegs.assign(indices->lengthOf()); | ||||||
|  |         classesRangesLens.assign(0); | ||||||
|  | 
 | ||||||
|  |         fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens); | ||||||
|  |         NDArray::prepareSpecialUse({output}, {input, indices, &classesRangesBegs, &classesRangesLens}); | ||||||
|  |         int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer()); | ||||||
|  |         int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer()); | ||||||
|  |         if (input->isVector()) { | ||||||
|  |             segmentMinLinearKernel<T,I><<<numClasses, input->lengthOf(), numClasses * 32 + 32, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo()); | ||||||
|  |         } | ||||||
|  |         else { | ||||||
|  |             std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); | ||||||
|  |             auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions); | ||||||
|  |             auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions); | ||||||
|  |             Nd4jLong* inputTads = packX.specialShapeInfo(); | ||||||
|  |             Nd4jLong* inputTadOffsets = packX.specialOffsets(); | ||||||
|  |             Nd4jLong* outputTads = packZ.specialShapeInfo(); | ||||||
|  |             Nd4jLong* outputTadOffsets = packZ.specialOffsets(); | ||||||
|  |             segmentMinTadKernel<T,I><<<input->sizeAt(0), 512, 2048, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets); | ||||||
|  | 
 | ||||||
|  |         } | ||||||
|  |         NDArray::registerSpecialUse({output}, {input, indices, &classesRangesBegs, &classesRangesLens}); | ||||||
|  | 
 | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     void segmentMinFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) { | ||||||
|  |         BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentMinFunctor_, (context, input, indices, output), NUMERIC_TYPES, INTEGER_TYPES); | ||||||
|  |     } | ||||||
|  |     BUILD_DOUBLE_TEMPLATE(template void segmentMinFunctor_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES); | ||||||
|  | 
 | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  | 
 | ||||||
|  |     template <typename T, typename I> | ||||||
|  |     static void unsortedSegmentMinFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { | ||||||
|  |         auto stream = context->getCudaStream(); | ||||||
|  | //        NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2}); | ||||||
|  |         NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}); | ||||||
|  |         NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}); | ||||||
|  | //        NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); | ||||||
|  | //        classes.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), &row, &classes); | ||||||
|  |         classesRangesBegs.assign(indices->lengthOf()); | ||||||
|  |         classesRangesLens.assign(0); | ||||||
|  |         dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32); | ||||||
|  | //        int* classesBuf = reinterpret_cast<int*>(classes.specialBuffer()); | ||||||
|  |         fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens); | ||||||
|  |         int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer()); | ||||||
|  |         int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer()); | ||||||
|  |         NDArray::prepareSpecialUse({output}, {input, indices}); | ||||||
|  |         if (input->isVector()) { | ||||||
|  |             unsortedSegmentMinLinearKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo()); | ||||||
|  |         } | ||||||
|  |         else { | ||||||
|  |             output->assign(DataTypeUtils::max<T>()); | ||||||
|  |             std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); | ||||||
|  |             auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions); | ||||||
|  |             auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions); | ||||||
|  |             Nd4jLong* inputTads = packX.specialShapeInfo(); | ||||||
|  |             Nd4jLong* inputTadOffsets = packX.specialOffsets(); | ||||||
|  |             Nd4jLong* outputTads = packZ.specialShapeInfo(); | ||||||
|  |             Nd4jLong* outputTadOffsets = packZ.specialOffsets(); | ||||||
|  |             dims.x = input->sizeAt(0); | ||||||
|  |             segmentMinTadKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets); | ||||||
|  |         } | ||||||
|  |         NDArray::registerSpecialUse({output}, {input, indices}); | ||||||
|  | 
 | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     void unsortedSegmentMinFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { | ||||||
|  |         BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMinFunctor_, (context, input, indices, numOfClasses, output), | ||||||
|  |                               NUMERIC_TYPES, INTEGER_TYPES); | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  | 
 | ||||||
|  |     BUILD_DOUBLE_TEMPLATE(template void unsortedSegmentMinFunctor_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES); | ||||||
|  |     template <typename T, typename I> | ||||||
|  |     static __global__ void segmentMinBPLinearKernel(void* inputBuf, Nd4jLong* inputShape, void* forwardOutput, | ||||||
|  |                                                     Nd4jLong* forwardShape, void* eps, Nd4jLong* epsShape, void* indicesBuf, Nd4jLong* indicesShape, | ||||||
|  |                                                     void* outputBuf, Nd4jLong* outputShape) { | ||||||
|  |         __shared__ T* x; | ||||||
|  |         __shared__ T* gradIn; | ||||||
|  |         __shared__ T* gradOut; | ||||||
|  |         __shared__ I* y; | ||||||
|  |         __shared__ T* z; | ||||||
|  |         __shared__ Nd4jLong xLen, gradLen; | ||||||
|  | 
 | ||||||
|  |         if (threadIdx.x == 0) { | ||||||
|  |             xLen = shape::length(inputShape); | ||||||
|  |             x = reinterpret_cast<T*>(inputBuf); | ||||||
|  |             y = reinterpret_cast<I*>(indicesBuf); | ||||||
|  |             z = reinterpret_cast<T*>(outputBuf); | ||||||
|  |             gradIn = reinterpret_cast<T*>(forwardOutput); | ||||||
|  |             gradOut = reinterpret_cast<T*>(eps); | ||||||
|  |             gradLen = shape::length(epsShape); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         auto start = blockIdx.x * blockDim.x + threadIdx.x; | ||||||
|  |         auto step = gridDim.x * blockDim.x; | ||||||
|  | 
 | ||||||
|  |         for (auto e = start; e < xLen; e += step) { | ||||||
|  | 
 | ||||||
|  |             auto zOffset = shape::getIndexOffset(e, outputShape, xLen); | ||||||
|  |             auto xOffset = shape::getIndexOffset(e, inputShape, xLen); | ||||||
|  |             auto yOffset = shape::getIndexOffset(e, indicesShape, xLen); | ||||||
|  |             auto classIndex = y[yOffset]; | ||||||
|  |             auto gradOffsetI = shape::getIndexOffset(classIndex, forwardShape, gradLen); | ||||||
|  |             auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape, gradLen); | ||||||
|  | 
 | ||||||
|  |             if (nd4j::math::nd4j_abs(gradIn[gradOffsetI] - x[xOffset]) <= T(1.e-6)) { | ||||||
|  |                 z[zOffset] = gradOut[gradOffsetO]; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     template <typename T, typename I> | ||||||
|  |     static __global__ void segmentMinBPTadKernel(void* inputBuf, Nd4jLong* inputShape, void* forwardOutput, | ||||||
|  |                                                  Nd4jLong* forwardShape, void* eps, Nd4jLong* epsShape, void* indicesBuf, Nd4jLong* indicesShape, | ||||||
|  |                                                  void* outputBuf, Nd4jLong* outputShape,Nd4jLong* inputTad, | ||||||
|  |                                                  Nd4jLong* inputOffsets, Nd4jLong* gradInTad, Nd4jLong* gradInOffsets, | ||||||
|  |                                                  Nd4jLong* gradOutTad, Nd4jLong* gradOutOffsets, Nd4jLong* outTad, | ||||||
|  |                                                  Nd4jLong* outOffsets) { | ||||||
|  |         __shared__ T* x; | ||||||
|  |         __shared__ T* gradIn; | ||||||
|  |         __shared__ T* gradOut; | ||||||
|  |         __shared__ I* y; | ||||||
|  |         __shared__ T* z; | ||||||
|  |         __shared__ Nd4jLong xLen, yLen, gradLen, currentLen; | ||||||
|  | 
 | ||||||
|  |         if (threadIdx.x == 0) { | ||||||
|  |             xLen = shape::length(inputShape); | ||||||
|  |             x = reinterpret_cast<T*>(inputBuf); | ||||||
|  |             y = reinterpret_cast<I*>(indicesBuf); | ||||||
|  |             z = reinterpret_cast<T*>(outputBuf); | ||||||
|  |             yLen = shape::length(indicesShape); | ||||||
|  |             gradOut = reinterpret_cast<T*>(eps); | ||||||
|  |             gradIn = reinterpret_cast<T*>(forwardOutput); | ||||||
|  |             gradLen = shape::length(epsShape); | ||||||
|  |             currentLen = shape::length(outTad); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         for (auto i = blockIdx.x; i < yLen; i += gridDim.x) { | ||||||
|  |             auto yIndex = shape::getIndexOffset(i, indicesShape, yLen); | ||||||
|  |             auto segment = y[yIndex]; | ||||||
|  |             T* current = x + inputOffsets[i]; | ||||||
|  |             T* currentOut = z + outOffsets[i]; | ||||||
|  |             T* in = gradIn + gradInOffsets[segment]; | ||||||
|  |             T* outGrad = gradOut + gradOutOffsets[segment]; | ||||||
|  | 
 | ||||||
|  |             for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) { | ||||||
|  |                 if (nd4j::math::nd4j_abs(in[e] - current[e]) <= T(1.e-6)) | ||||||
|  |                     currentOut[e] = outGrad[e]; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     template <typename T, typename I> | ||||||
|  |     int segmentMinFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { | ||||||
|  |         //int numOfClasses = gradOut->sizeAt(0); | ||||||
|  |         // if input is a vector: (as if in doc sample) | ||||||
|  |         auto stream = context->getCudaStream(); | ||||||
|  |         NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), DataTypeUtils::fromT<T>(), context);//->shapeInfo(), context); | ||||||
|  |         segmentMinFunctor_<T, I>(context, input, indices, &tempRes); | ||||||
|  |         NDArray::prepareSpecialUse({output}, {input, indices, gradOut, &tempRes}); | ||||||
|  |         if (input->isVector()) { | ||||||
|  |             Nd4jLong loop_size = input->lengthOf(); | ||||||
|  |             auto numOfClasses = gradOut->lengthOf(); //indices->e<Nd4jLong>(loop_size - 1); | ||||||
|  | 
 | ||||||
|  |             segmentMinBPLinearKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), | ||||||
|  |                     tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), | ||||||
|  |                     indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo()); | ||||||
|  |         } | ||||||
|  |         else { | ||||||
|  |             std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); | ||||||
|  |             auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions); | ||||||
|  |             auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions); | ||||||
|  |             auto packGradIn = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.getShapeInfo(), dimensions); | ||||||
|  |             auto packGradOut = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->getShapeInfo(), dimensions); | ||||||
|  |             Nd4jLong* inputTads = packX.specialShapeInfo(); | ||||||
|  |             Nd4jLong* inputTadOffsets = packX.specialOffsets(); | ||||||
|  |             Nd4jLong* outputTads = packZ.specialShapeInfo(); | ||||||
|  |             Nd4jLong* outputTadOffsets = packZ.specialOffsets(); | ||||||
|  |             Nd4jLong* gradInTads = packGradIn.specialShapeInfo(); | ||||||
|  |             Nd4jLong* gradInTadOffsets = packGradIn.specialOffsets(); | ||||||
|  |             Nd4jLong* gradOutTads = packGradOut.specialShapeInfo(); | ||||||
|  |             Nd4jLong* gradOutTadOffsets = packGradOut.specialOffsets(); | ||||||
|  | 
 | ||||||
|  |             segmentMinBPTadKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), | ||||||
|  |                     tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), | ||||||
|  |                     indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), | ||||||
|  |                     inputTads, inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, gradOutTadOffsets, | ||||||
|  |                     outputTads, outputTadOffsets); | ||||||
|  |         } | ||||||
|  |         NDArray::registerSpecialUse({output}, {input, indices, gradOut, &tempRes}); | ||||||
|  |         return Status::OK(); | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     // segmen min | ||||||
|  |     int segmentMinFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { | ||||||
|  |         BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentMinFunctorBP_, (context, input, | ||||||
|  |                 indices, gradOut, output), NUMERIC_TYPES, INTEGER_TYPES); | ||||||
|  |     } | ||||||
|  |     BUILD_DOUBLE_TEMPLATE(template int segmentMinFunctorBP_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES); | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  | 
 | ||||||
|  |     template <typename T, typename I> | ||||||
|  |     static int unsortedSegmentMinFunctorBP_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { | ||||||
|  |         //int numOfClasses = gradOut->sizeAt(0); | ||||||
|  |         // if input is a vector: (as if in doc sample) | ||||||
|  |         auto stream = context->getCudaStream(); | ||||||
|  |         NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), DataTypeUtils::fromT<T>(), context);//->shapeInfo(), context); | ||||||
|  |         unsortedSegmentMinFunctor_<T, I>(context, input, indices, numOfClasses, &tempRes); | ||||||
|  |         NDArray::prepareSpecialUse({output}, {input, indices, gradOut, &tempRes}); | ||||||
|  |         if (input->isVector()) { | ||||||
|  |             Nd4jLong loop_size = input->lengthOf(); | ||||||
|  |             auto numOfClasses = gradOut->lengthOf(); //indices->e<Nd4jLong>(loop_size - 1); | ||||||
|  |             segmentMinBPLinearKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), | ||||||
|  |                     tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), | ||||||
|  |                     indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo()); | ||||||
|  |         } | ||||||
|  |         else { | ||||||
|  |             std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); | ||||||
|  |             auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions); | ||||||
|  |             auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions); | ||||||
|  |             auto packGradIn = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.getShapeInfo(), dimensions); | ||||||
|  |             auto packGradOut = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->getShapeInfo(), dimensions); | ||||||
|  |             Nd4jLong* inputTads = packX.specialShapeInfo(); | ||||||
|  |             Nd4jLong* inputTadOffsets = packX.specialOffsets(); | ||||||
|  |             Nd4jLong* outputTads = packZ.specialShapeInfo(); | ||||||
|  |             Nd4jLong* outputTadOffsets = packZ.specialOffsets(); | ||||||
|  |             Nd4jLong* gradInTads = packGradIn.specialShapeInfo(); | ||||||
|  |             Nd4jLong* gradInTadOffsets = packGradIn.specialOffsets(); | ||||||
|  |             Nd4jLong* gradOutTads = packGradOut.specialShapeInfo(); | ||||||
|  |             Nd4jLong* gradOutTadOffsets = packGradOut.specialOffsets(); | ||||||
|  | 
 | ||||||
|  |             segmentMinBPTadKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), | ||||||
|  |                     tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), | ||||||
|  |                     indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), | ||||||
|  |                     inputTads, inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, gradOutTadOffsets, | ||||||
|  |                     outputTads, outputTadOffsets); | ||||||
|  |         } | ||||||
|  |         NDArray::registerSpecialUse({output}, {input, indices, gradOut, &tempRes}); | ||||||
|  |         return Status::OK(); | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     int unsortedSegmentMinFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { | ||||||
|  |         BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMinFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), NUMERIC_TYPES, INTEGER_TYPES); | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     BUILD_DOUBLE_TEMPLATE(template int unsortedSegmentMinFunctorBP_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES); | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  | 
 | ||||||
|  | } | ||||||
|  | } | ||||||
|  | } | ||||||
							
								
								
									
										419
									
								
								libnd4j/include/ops/declarable/helpers/cuda/segment_prod.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										419
									
								
								libnd4j/include/ops/declarable/helpers/cuda/segment_prod.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,419 @@ | |||||||
|  | /******************************************************************************* | ||||||
|  |  * Copyright (c) 2015-2018 Skymind, Inc. | ||||||
|  |  * | ||||||
|  |  * This program and the accompanying materials are made available under the | ||||||
|  |  * terms of the Apache License, Version 2.0 which is available at | ||||||
|  |  * https://www.apache.org/licenses/LICENSE-2.0. | ||||||
|  |  * | ||||||
|  |  * Unless required by applicable law or agreed to in writing, software | ||||||
|  |  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||||
|  |  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||||
|  |  * License for the specific language governing permissions and limitations | ||||||
|  |  * under the License. | ||||||
|  |  * | ||||||
|  |  * SPDX-License-Identifier: Apache-2.0 | ||||||
|  |  ******************************************************************************/ | ||||||
|  | 
 | ||||||
|  | // | ||||||
|  | //  @author GS <sgazeos@gmail.com> | ||||||
|  | // | ||||||
|  | 
 | ||||||
|  | #include <ops/declarable/helpers/segment.h> | ||||||
|  | #include <ops/declarable/helpers/segment_common.h> | ||||||
|  | #include <NDArrayFactory.h> | ||||||
|  | #include <helpers/ShapeUtils.h> | ||||||
|  | #include <helpers/TAD.h> | ||||||
|  | #include <exceptions/cuda_exception.h> | ||||||
|  | #include <PointersManager.h> | ||||||
|  | #include <ConstantTadHelper.h> | ||||||
|  | 
 | ||||||
|  | namespace nd4j { | ||||||
|  | namespace ops { | ||||||
|  | namespace helpers { | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     // Segment Prod ops linear kernels | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  | 
 | ||||||
|  |     template <typename T, typename I> | ||||||
|  |     static __global__ void segmentProdLinearKernel(void* input, Nd4jLong* inputShape, int* starts, int* lengths, Nd4jLong numOfClasses, void* output, Nd4jLong* outputShape) { | ||||||
|  |         __shared__ T* val; | ||||||
|  |         __shared__ Nd4jLong xLen, zLen, segment, zIndex; | ||||||
|  |         __shared__ T* x; | ||||||
|  |         __shared__ T* z; | ||||||
|  |         __shared__ int threadsPerSegment, start, finish; | ||||||
|  | 
 | ||||||
|  |         if (threadIdx.x == 0) { | ||||||
|  |             threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses; | ||||||
|  |             segment = blockIdx.x / threadsPerSegment; | ||||||
|  |             x = reinterpret_cast<T*>(input); | ||||||
|  |             z = reinterpret_cast<T*>(output); | ||||||
|  |             extern __shared__ unsigned char shmem[]; | ||||||
|  |             val = reinterpret_cast<T*>(shmem); | ||||||
|  |             xLen = shape::length(inputShape); | ||||||
|  |             zLen = shape::length(outputShape); | ||||||
|  | 
 | ||||||
|  |             if (segment < numOfClasses) { | ||||||
|  |                 zIndex = shape::getIndexOffset(segment, outputShape, zLen); | ||||||
|  |                 start = starts[segment]; | ||||||
|  |                 finish = start + lengths[segment]; | ||||||
|  |                 //val[segment] = ; | ||||||
|  |                 z[zIndex] = x[shape::getIndexOffset(start, inputShape, xLen)]; | ||||||
|  |                 val[segment] = z[zIndex]; | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |         } | ||||||
|  |         __syncthreads(); | ||||||
|  | //         auto tid = threadIdx.x + blockIdx.x * blockDim.x; | ||||||
|  | //         auto step = blockDim.x * gridDim.x; | ||||||
|  | 
 | ||||||
|  |         for (auto e = start + threadIdx.x + 1; e < finish; e += blockDim.x) { | ||||||
|  |             auto xIndex = shape::getIndexOffset(e, inputShape, xLen); | ||||||
|  |             nd4j::math::atomics::nd4j_atomicMul(&val[segment], x[xIndex]); | ||||||
|  |         } | ||||||
|  |         __syncthreads(); | ||||||
|  | 
 | ||||||
|  |         if (threadIdx.x == 0) { | ||||||
|  |             z[zIndex] = val[segment]; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     template <typename T, typename I> | ||||||
|  |     static __global__ void unsortedSegmentProdLinearKernel(void* input, Nd4jLong* inputShape, void* indices, Nd4jLong* indicesShape, int* starts, int* lengths, Nd4jLong numOfClasses, void* output, Nd4jLong* outputShape) { | ||||||
|  |         __shared__ T* val; | ||||||
|  |         __shared__ Nd4jLong xLen, zLen, segment, zIndex; | ||||||
|  |         __shared__ T* x; | ||||||
|  |         __shared__ T* z; | ||||||
|  |         __shared__ I* y; //int threadsPerSegment, start, finish; | ||||||
|  | 
 | ||||||
|  |         if (threadIdx.x == 0) { | ||||||
|  | //            threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses; | ||||||
|  |             segment = blockIdx.x;// / threadsPerSegment; | ||||||
|  |             x = reinterpret_cast<T*>(input); | ||||||
|  |             z = reinterpret_cast<T*>(output); | ||||||
|  |             y = reinterpret_cast<I*>(indices); | ||||||
|  | //            extern __shared__ unsigned char shmem[]; | ||||||
|  | //            val = reinterpret_cast<T*>(shmem); | ||||||
|  |             xLen = shape::length(inputShape); | ||||||
|  |             zLen = shape::length(outputShape); | ||||||
|  | 
 | ||||||
|  | //            if (segment < numOfClasses) { | ||||||
|  |             zIndex = shape::getIndexOffset(segment, outputShape, zLen); | ||||||
|  |             //start = starts[segment]; | ||||||
|  |             //finish = start + lengths[segment]; | ||||||
|  |             if (lengths[segment] > 0) | ||||||
|  |                 z[zIndex] = x[shape::getIndexOffset(starts[segment], inputShape, xLen)]; | ||||||
|  |             else | ||||||
|  |                 z[zIndex] = 0; //DataTypeUtils::max<T>(); | ||||||
|  | //                val[segment] = z[zIndex]; | ||||||
|  | //            } | ||||||
|  | 
 | ||||||
|  |         } | ||||||
|  |         __syncthreads(); | ||||||
|  |         if (lengths[segment] > 0) | ||||||
|  |             for (auto e = threadIdx.x; e < xLen; e += blockDim.x) { | ||||||
|  |                 auto xIndex = shape::getIndexOffset(e, inputShape, xLen); | ||||||
|  |                 auto yIndex = shape::getIndexOffset(e, indicesShape, xLen); | ||||||
|  |                 if (y[yIndex] == segment && e != starts[segment]) { | ||||||
|  |                     nd4j::math::atomics::nd4j_atomicMul(&z[zIndex], x[xIndex]); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     // SegmentProd kernel | ||||||
|  |     template <typename T, typename I> | ||||||
|  |     static __global__ void segmentProdTadKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong* inputTads, Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets) { | ||||||
|  |         __shared__ T* val; | ||||||
|  |         __shared__ Nd4jLong len, segment, zIndex, total; | ||||||
|  |         __shared__ T* z; | ||||||
|  |         __shared__ int threadsPerSegment, start, finish; | ||||||
|  | 
 | ||||||
|  |         if (threadIdx.x == 0) { | ||||||
|  |             segment = indices[blockIdx.x]; // / threadsPerSegment; | ||||||
|  |             z = reinterpret_cast<T*>(outputBuf) + outputTadOffsets[segment]; | ||||||
|  |             len = shape::length(inputTads); | ||||||
|  |             start = starts[segment]; | ||||||
|  |             finish = start + lengths[segment]; | ||||||
|  |             total = shape::sizeAt(inputShape, 0); | ||||||
|  | 
 | ||||||
|  |         } | ||||||
|  |         __syncthreads(); | ||||||
|  | 
 | ||||||
|  |         auto idx = blockIdx.x; | ||||||
|  |         if (blockIdx.x <= total) { | ||||||
|  |             auto x = reinterpret_cast<T *>(inputBuf) + inputTadOffsets[idx]; | ||||||
|  |             if (blockIdx.x == start) { | ||||||
|  |                 for (auto e = threadIdx.x; e < len; e += blockDim.x) { | ||||||
|  |                     auto xIndex = shape::getIndexOffset(e, inputTads, len); | ||||||
|  |                     auto zIndex = shape::getIndexOffset(e, outputTads, len); | ||||||
|  |                     z[zIndex] = x[xIndex]; | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |             else { | ||||||
|  |                 for (auto e = threadIdx.x; e < len; e += blockDim.x) { | ||||||
|  |                     auto xIndex = shape::getIndexOffset(e, inputTads, len); | ||||||
|  |                     auto zIndex = shape::getIndexOffset(e, outputTads, len); | ||||||
|  |                     nd4j::math::atomics::nd4j_atomicMul(&z[zIndex], x[xIndex]); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  | 
 | ||||||
|  |     template <typename T, typename I> | ||||||
|  |     static void segmentProdFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) { | ||||||
|  |         auto stream = context->getCudaStream(); | ||||||
|  |         Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1; | ||||||
|  |         NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}); | ||||||
|  |         NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}); | ||||||
|  | 
 | ||||||
|  |         classesRangesBegs.assign(indices->lengthOf()); | ||||||
|  |         classesRangesLens.assign(0); | ||||||
|  | 
 | ||||||
|  |         dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32); | ||||||
|  |         fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens); | ||||||
|  |         int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer()); | ||||||
|  |         int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer()); | ||||||
|  | 
 | ||||||
|  |         if (input->isVector()) { | ||||||
|  |             segmentProdLinearKernel<T,I><<<numClasses, input->lengthOf(), numClasses * 32 + 32, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo()); | ||||||
|  |         } | ||||||
|  |         else { | ||||||
|  |             std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); | ||||||
|  |             auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions); | ||||||
|  |             auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions); | ||||||
|  |             Nd4jLong* inputTads = packX.specialShapeInfo(); | ||||||
|  |             Nd4jLong* inputTadOffsets = packX.specialOffsets(); | ||||||
|  |             Nd4jLong* outputTads = packZ.specialShapeInfo(); | ||||||
|  |             Nd4jLong* outputTadOffsets = packZ.specialOffsets(); | ||||||
|  |             segmentProdTadKernel<T,I><<<input->sizeAt(0), 512, 2048, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     void segmentProdFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) { | ||||||
|  |         BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), segmentProdFunctor_, (context, input, indices, output), NUMERIC_TYPES, INTEGER_TYPES); | ||||||
|  |     } | ||||||
|  |     BUILD_DOUBLE_TEMPLATE(template void segmentProdFunctor_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output), FLOAT_TYPES, INTEGER_TYPES); | ||||||
|  | 
 | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     template <typename T, typename I> | ||||||
|  |     static void unsortedSegmentProdFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { | ||||||
|  |         auto stream = context->getCudaStream(); | ||||||
|  | //        NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2}); | ||||||
|  |         NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}); | ||||||
|  |         NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}); | ||||||
|  | //        NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); | ||||||
|  | //        classes.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), &row, &classes); | ||||||
|  |         classesRangesBegs.assign(indices->lengthOf()); | ||||||
|  |         classesRangesLens.assign(0); | ||||||
|  |         dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32); | ||||||
|  | //        int* classesBuf = reinterpret_cast<int*>(classes.specialBuffer()); | ||||||
|  |         fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens); | ||||||
|  |         int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer()); | ||||||
|  |         int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer()); | ||||||
|  | 
 | ||||||
|  |         if (input->isVector()) { | ||||||
|  |             unsortedSegmentProdLinearKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo()); | ||||||
|  |         } | ||||||
|  |         else { | ||||||
|  |             output->assign(1); | ||||||
|  |             std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); | ||||||
|  |             auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions); | ||||||
|  |             auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions); | ||||||
|  |             Nd4jLong* inputTads = packX.specialShapeInfo(); | ||||||
|  |             Nd4jLong* inputTadOffsets = packX.specialOffsets(); | ||||||
|  |             Nd4jLong* outputTads = packZ.specialShapeInfo(); | ||||||
|  |             Nd4jLong* outputTadOffsets = packZ.specialOffsets(); | ||||||
|  |             dims.x = input->sizeAt(0); | ||||||
|  |             segmentProdTadKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     void unsortedSegmentProdFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { | ||||||
|  |         BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentProdFunctor_, (context, input, indices, numOfClasses, output), | ||||||
|  |                               FLOAT_TYPES, INTEGER_TYPES); | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     BUILD_DOUBLE_TEMPLATE(template void unsortedSegmentProdFunctor_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), FLOAT_TYPES, INTEGER_TYPES); | ||||||
|  | 
 | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     template <typename T, typename I> | ||||||
|  |     static __global__ void segmentProdBPLinearKernel(void* inputBuf, Nd4jLong* inputShape, void* forwardOutput, | ||||||
|  |                                                      Nd4jLong* forwardShape, void* eps, Nd4jLong* epsShape, void* indicesBuf, Nd4jLong* indicesShape, | ||||||
|  |                                                      void* outputBuf, Nd4jLong* outputShape) { | ||||||
|  |         __shared__ T* x; | ||||||
|  |         __shared__ T* gradIn; | ||||||
|  |         __shared__ T* gradOut; | ||||||
|  |         __shared__ I* y; | ||||||
|  |         __shared__ T* z; | ||||||
|  |         __shared__ Nd4jLong xLen, gradLen; | ||||||
|  | 
 | ||||||
|  |         if (threadIdx.x == 0) { | ||||||
|  |             xLen = shape::length(inputShape); | ||||||
|  |             x = reinterpret_cast<T*>(inputBuf); | ||||||
|  |             y = reinterpret_cast<I*>(indicesBuf); | ||||||
|  |             z = reinterpret_cast<T*>(outputBuf); | ||||||
|  |             gradIn = reinterpret_cast<T*>(forwardOutput); | ||||||
|  |             gradOut = reinterpret_cast<T*>(eps); | ||||||
|  |             gradLen = shape::length(epsShape); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         auto start = blockIdx.x * blockDim.x + threadIdx.x; | ||||||
|  |         auto step = gridDim.x * blockDim.x; | ||||||
|  | 
 | ||||||
|  |         for (auto e = start; e < xLen; e += step) { | ||||||
|  | 
 | ||||||
|  |             auto zOffset = shape::getIndexOffset(e, outputShape, xLen); | ||||||
|  |             auto xOffset = shape::getIndexOffset(e, inputShape, xLen); | ||||||
|  |             auto yOffset = shape::getIndexOffset(e, indicesShape, xLen); | ||||||
|  |             auto classIndex = y[yOffset]; | ||||||
|  |             auto gradOffsetI = shape::getIndexOffset(classIndex, forwardShape, gradLen); | ||||||
|  |             auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape, gradLen); | ||||||
|  | 
 | ||||||
|  |             z[zOffset] = gradOut[gradOffsetO]  * gradIn[gradOffsetI] / x[xOffset]; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     template <typename T, typename I> | ||||||
|  |     static __global__ void segmentProdBPTadKernel(void* inputBuf, Nd4jLong* inputShape, void* forwardOutput, | ||||||
|  |                                                   Nd4jLong* forwardShape, void* eps, Nd4jLong* epsShape, void* indicesBuf, Nd4jLong* indicesShape, | ||||||
|  |                                                   void* outputBuf, Nd4jLong* outputShape,Nd4jLong* inputTad, | ||||||
|  |                                                   Nd4jLong* inputOffsets, Nd4jLong* gradInTad, Nd4jLong* gradInOffsets, | ||||||
|  |                                                   Nd4jLong* gradOutTad, Nd4jLong* gradOutOffsets, Nd4jLong* outTad, | ||||||
|  |                                                   Nd4jLong* outOffsets) { | ||||||
|  |         __shared__ T* x; | ||||||
|  |         __shared__ T* gradIn; | ||||||
|  |         __shared__ T* gradOut; | ||||||
|  |         __shared__ I* y; | ||||||
|  |         __shared__ T* z; | ||||||
|  |         __shared__ Nd4jLong xLen, yLen, gradLen, currentLen; | ||||||
|  | 
 | ||||||
|  |         if (threadIdx.x == 0) { | ||||||
|  |             xLen = shape::length(inputShape); | ||||||
|  |             x = reinterpret_cast<T*>(inputBuf); | ||||||
|  |             y = reinterpret_cast<I*>(indicesBuf); | ||||||
|  |             z = reinterpret_cast<T*>(outputBuf); | ||||||
|  |             yLen = shape::length(indicesShape); | ||||||
|  |             gradOut = reinterpret_cast<T*>(eps); | ||||||
|  |             gradIn = reinterpret_cast<T*>(forwardOutput); | ||||||
|  |             gradLen = shape::length(epsShape); | ||||||
|  |             currentLen = shape::length(outTad); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         for (auto i = blockIdx.x; i < yLen; i += gridDim.x) { | ||||||
|  |             auto yIndex = shape::getIndexOffset(i, indicesShape, yLen); | ||||||
|  |             auto segment = y[yIndex]; | ||||||
|  |             T* current = x + inputOffsets[i]; | ||||||
|  |             T* currentOut = z + outOffsets[i]; | ||||||
|  |             T* in = gradIn + gradInOffsets[segment]; | ||||||
|  |             T* outGrad = gradOut + gradOutOffsets[segment]; | ||||||
|  | 
 | ||||||
|  |             for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) { | ||||||
|  |                 currentOut[e] = outGrad[e] * in[e] / current[e]; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     template <typename T, typename I> | ||||||
|  |     int segmentProdFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { | ||||||
|  |         auto stream = context->getCudaStream(); | ||||||
|  |         NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), DataTypeUtils::fromT<T>(), context);//->shapeInfo(), context); | ||||||
|  |         segmentProdFunctor_<T, I>(context, input, indices, &tempRes); | ||||||
|  |         NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); | ||||||
|  |         if (input->isVector()) { | ||||||
|  |             Nd4jLong loopSize = input->lengthOf(); | ||||||
|  |             auto numOfClasses = gradOut->lengthOf(); //indices->e<Nd4jLong>(loop_size - 1); | ||||||
|  |             segmentProdBPLinearKernel<T,I><<<gradOut->lengthOf(), loopSize, 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), | ||||||
|  |                     tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), | ||||||
|  |                     indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo()); | ||||||
|  |         } | ||||||
|  |         else { | ||||||
|  |             std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); | ||||||
|  |             auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions); | ||||||
|  |             auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions); | ||||||
|  |             auto packGradIn = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.getShapeInfo(), dimensions); | ||||||
|  |             auto packGradOut = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->getShapeInfo(), dimensions); | ||||||
|  |             Nd4jLong* inputTads = packX.specialShapeInfo(); | ||||||
|  |             Nd4jLong* inputTadOffsets = packX.specialOffsets(); | ||||||
|  |             Nd4jLong* outputTads = packZ.specialShapeInfo(); | ||||||
|  |             Nd4jLong* outputTadOffsets = packZ.specialOffsets(); | ||||||
|  |             Nd4jLong* gradInTads = packGradIn.specialShapeInfo(); | ||||||
|  |             Nd4jLong* gradInTadOffsets = packGradIn.specialOffsets(); | ||||||
|  |             Nd4jLong* gradOutTads = packGradOut.specialShapeInfo(); | ||||||
|  |             Nd4jLong* gradOutTadOffsets = packGradOut.specialOffsets(); | ||||||
|  | 
 | ||||||
|  |             segmentProdBPTadKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), | ||||||
|  |                     tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), | ||||||
|  |                     indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), | ||||||
|  |                     inputTads, inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, gradOutTadOffsets, | ||||||
|  |                     outputTads, outputTadOffsets); | ||||||
|  |         } | ||||||
|  |         NDArray::registerSpecialUse({output}, {input, indices, gradOut}); | ||||||
|  |         return Status::OK(); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  | 
 | ||||||
|  |     int segmentProdFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { | ||||||
|  |         BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentProdFunctorBP_, (context, input, | ||||||
|  |                 indices, gradOut, output), FLOAT_TYPES, INTEGER_TYPES); | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  | 
 | ||||||
|  |     BUILD_DOUBLE_TEMPLATE(template int segmentProdFunctorBP_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output), FLOAT_TYPES, INTEGER_TYPES); | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  | 
 | ||||||
|  |     template <typename T, typename I> | ||||||
|  |     static int unsortedSegmentProdFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { | ||||||
|  |         auto stream = context->getCudaStream(); | ||||||
|  |         NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), DataTypeUtils::fromT<T>(), context);//->shapeInfo(), context); | ||||||
|  |         unsortedSegmentProdFunctor_<T, I>(context, input, indices, numOfClasses, &tempRes); | ||||||
|  |         NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); | ||||||
|  |         if (input->isVector()) { | ||||||
|  |             Nd4jLong loopSize = input->lengthOf(); | ||||||
|  |             auto numOfClasses = gradOut->lengthOf(); //indices->e<Nd4jLong>(loop_size - 1); | ||||||
|  |             segmentProdBPLinearKernel<T,I><<<gradOut->lengthOf(), loopSize, 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), | ||||||
|  |                     tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), | ||||||
|  |                     indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo()); | ||||||
|  |         } | ||||||
|  |         else { | ||||||
|  |             std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); | ||||||
|  |             auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions); | ||||||
|  |             auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions); | ||||||
|  |             auto packGradIn = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.getShapeInfo(), dimensions); | ||||||
|  |             auto packGradOut = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->getShapeInfo(), dimensions); | ||||||
|  |             Nd4jLong* inputTads = packX.specialShapeInfo(); | ||||||
|  |             Nd4jLong* inputTadOffsets = packX.specialOffsets(); | ||||||
|  |             Nd4jLong* outputTads = packZ.specialShapeInfo(); | ||||||
|  |             Nd4jLong* outputTadOffsets = packZ.specialOffsets(); | ||||||
|  |             Nd4jLong* gradInTads = packGradIn.specialShapeInfo(); | ||||||
|  |             Nd4jLong* gradInTadOffsets = packGradIn.specialOffsets(); | ||||||
|  |             Nd4jLong* gradOutTads = packGradOut.specialShapeInfo(); | ||||||
|  |             Nd4jLong* gradOutTadOffsets = packGradOut.specialOffsets(); | ||||||
|  | 
 | ||||||
|  |             segmentProdBPTadKernel<T,I><<<indices->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), | ||||||
|  |                     tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), | ||||||
|  |                     indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), | ||||||
|  |                     inputTads, inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, gradOutTadOffsets, | ||||||
|  |                     outputTads, outputTadOffsets); | ||||||
|  |         } | ||||||
|  |         NDArray::registerSpecialUse({output}, {input, indices, gradOut}); | ||||||
|  |         return Status::OK(); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     int unsortedSegmentProdFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { | ||||||
|  |         BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentProdFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INTEGER_TYPES); | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     BUILD_DOUBLE_TEMPLATE(template int unsortedSegmentProdFunctorBP_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), FLOAT_TYPES, INTEGER_TYPES); | ||||||
|  | 
 | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  | 
 | ||||||
|  | } | ||||||
|  | } | ||||||
|  | } | ||||||
							
								
								
									
										280
									
								
								libnd4j/include/ops/declarable/helpers/cuda/segment_sqrtn.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										280
									
								
								libnd4j/include/ops/declarable/helpers/cuda/segment_sqrtn.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,280 @@ | |||||||
|  | /******************************************************************************* | ||||||
|  |  * Copyright (c) 2015-2018 Skymind, Inc. | ||||||
|  |  * | ||||||
|  |  * This program and the accompanying materials are made available under the | ||||||
|  |  * terms of the Apache License, Version 2.0 which is available at | ||||||
|  |  * https://www.apache.org/licenses/LICENSE-2.0. | ||||||
|  |  * | ||||||
|  |  * Unless required by applicable law or agreed to in writing, software | ||||||
|  |  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||||
|  |  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||||
|  |  * License for the specific language governing permissions and limitations | ||||||
|  |  * under the License. | ||||||
|  |  * | ||||||
|  |  * SPDX-License-Identifier: Apache-2.0 | ||||||
|  |  ******************************************************************************/ | ||||||
|  | 
 | ||||||
|  | // | ||||||
|  | //  @author GS <sgazeos@gmail.com> | ||||||
|  | // | ||||||
|  | 
 | ||||||
|  | #include <ops/declarable/helpers/segment.h> | ||||||
|  | #include <ops/declarable/helpers/segment_common.h> | ||||||
|  | #include <NDArrayFactory.h> | ||||||
|  | #include <helpers/ShapeUtils.h> | ||||||
|  | #include <helpers/TAD.h> | ||||||
|  | #include <exceptions/cuda_exception.h> | ||||||
|  | #include <PointersManager.h> | ||||||
|  | #include <ConstantTadHelper.h> | ||||||
|  | 
 | ||||||
|  | namespace nd4j { | ||||||
|  | namespace ops { | ||||||
|  | namespace helpers { | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     template <typename T, typename I> | ||||||
|  |     static __global__ void unsortedSegmentSqrtNLinearKernel(void* input, Nd4jLong* inputShape, void* indices, Nd4jLong* indicesShape, int* starts, int* lengths, Nd4jLong numOfClasses, void* output, Nd4jLong* outputShape) { | ||||||
|  |         __shared__ T* val; | ||||||
|  |         __shared__ Nd4jLong xLen, zLen, segment, zIndex; | ||||||
|  |         __shared__ T* x; | ||||||
|  |         __shared__ T* z; | ||||||
|  |         __shared__ I* y; //int threadsPerSegment, start, finish; | ||||||
|  | 
 | ||||||
|  |         if (threadIdx.x == 0) { | ||||||
|  | //            threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses; | ||||||
|  |             segment = blockIdx.x;// / threadsPerSegment; | ||||||
|  |             x = reinterpret_cast<T*>(input); | ||||||
|  |             z = reinterpret_cast<T*>(output); | ||||||
|  |             y = reinterpret_cast<I*>(indices); | ||||||
|  | //            extern __shared__ unsigned char shmem[]; | ||||||
|  | //            val = reinterpret_cast<T*>(shmem); | ||||||
|  |             xLen = shape::length(inputShape); | ||||||
|  |             zLen = shape::length(outputShape); | ||||||
|  | 
 | ||||||
|  | //            if (segment < numOfClasses) { | ||||||
|  |             zIndex = shape::getIndexOffset(segment, outputShape, zLen); | ||||||
|  |             //start = starts[segment]; | ||||||
|  |             //finish = start + lengths[segment]; | ||||||
|  |             if (lengths[segment] > 0) | ||||||
|  |                 z[zIndex] = x[shape::getIndexOffset(starts[segment], inputShape, xLen)] / nd4j::math::nd4j_sqrt<int, T>(lengths[segment]); | ||||||
|  |             else | ||||||
|  |                 z[zIndex] = 0; //DataTypeUtils::max<T>(); | ||||||
|  | //                val[segment] = z[zIndex]; | ||||||
|  | //            } | ||||||
|  | 
 | ||||||
|  |         } | ||||||
|  |         __syncthreads(); | ||||||
|  |         if (lengths[segment] > 0) | ||||||
|  |             for (auto e = threadIdx.x + 1; e < xLen; e += blockDim.x) { | ||||||
|  |                 auto xIndex = shape::getIndexOffset(e, inputShape, xLen); | ||||||
|  |                 auto yIndex = shape::getIndexOffset(e, indicesShape, xLen); | ||||||
|  |                 if (y[yIndex] == segment && e != starts[segment]) { | ||||||
|  |                     nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex] / nd4j::math::nd4j_sqrt<int, T>(lengths[segment])); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     // SegmentSqrtN kernel | ||||||
|  |     template <typename T, typename I> | ||||||
|  |     static __global__ void segmentSqrtNTadKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong* inputTads, Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets) { | ||||||
|  |         __shared__ T* val; | ||||||
|  |         __shared__ Nd4jLong len, segment, zIndex, total; | ||||||
|  |         __shared__ T* z; | ||||||
|  |         __shared__ int threadsPerSegment, start, finish; | ||||||
|  | 
 | ||||||
|  |         if (threadIdx.x == 0) { | ||||||
|  |             segment = indices[blockIdx.x]; // / threadsPerSegment; | ||||||
|  |             z = reinterpret_cast<T*>(outputBuf) + outputTadOffsets[segment]; | ||||||
|  |             len = shape::length(inputTads); | ||||||
|  |             start = starts[segment]; | ||||||
|  |             finish = start + lengths[segment]; | ||||||
|  |             total = shape::sizeAt(inputShape, 0); | ||||||
|  | 
 | ||||||
|  |         } | ||||||
|  |         __syncthreads(); | ||||||
|  | 
 | ||||||
|  |         auto idx = blockIdx.x; | ||||||
|  |         if (blockIdx.x <= total) { | ||||||
|  |             auto x = reinterpret_cast<T *>(inputBuf) + inputTadOffsets[idx]; | ||||||
|  |             if (blockIdx.x == start) { | ||||||
|  |                 for (auto e = threadIdx.x; e < len; e += blockDim.x) { | ||||||
|  |                     auto xIndex = shape::getIndexOffset(e, inputTads, len); | ||||||
|  |                     auto zIndex = shape::getIndexOffset(e, outputTads, len); | ||||||
|  |                     z[zIndex] = x[xIndex] / nd4j::math::nd4j_sqrt<int, T>(lengths[segment]); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |             else { | ||||||
|  |                 for (auto e = threadIdx.x; e < len; e += blockDim.x) { | ||||||
|  |                     auto xIndex = shape::getIndexOffset(e, inputTads, len); | ||||||
|  |                     auto zIndex = shape::getIndexOffset(e, outputTads, len); | ||||||
|  |                     nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex] / nd4j::math::nd4j_sqrt<int, T>(lengths[segment])); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     template <typename T, typename I> | ||||||
|  |     static void unsortedSegmentSqrtNFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { | ||||||
|  |         auto stream = context->getCudaStream(); | ||||||
|  | //        NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2}); | ||||||
|  |         NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}); | ||||||
|  |         NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}); | ||||||
|  | //        NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); | ||||||
|  | //        classes.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), &row, &classes); | ||||||
|  |         classesRangesBegs.assign(indices->lengthOf()); | ||||||
|  |         classesRangesLens.assign(0); | ||||||
|  |         dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32); | ||||||
|  | //        int* classesBuf = reinterpret_cast<int*>(classes.specialBuffer()); | ||||||
|  |         fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens); | ||||||
|  |         int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer()); | ||||||
|  |         int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer()); | ||||||
|  | 
 | ||||||
|  |         if (input->isVector()) { | ||||||
|  |             unsortedSegmentSqrtNLinearKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo()); | ||||||
|  |         } | ||||||
|  |         else { | ||||||
|  |             output->assign(0); | ||||||
|  |             std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); | ||||||
|  |             auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions); | ||||||
|  |             auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions); | ||||||
|  |             Nd4jLong* inputTads = packX.specialShapeInfo(); | ||||||
|  |             Nd4jLong* inputTadOffsets = packX.specialOffsets(); | ||||||
|  |             Nd4jLong* outputTads = packZ.specialShapeInfo(); | ||||||
|  |             Nd4jLong* outputTadOffsets = packZ.specialOffsets(); | ||||||
|  |             dims.x = input->sizeAt(0); | ||||||
|  |             segmentSqrtNTadKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     void unsortedSegmentSqrtNFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { | ||||||
|  |         BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentSqrtNFunctor_, (context, input, indices, numOfClasses, output), | ||||||
|  |                               FLOAT_TYPES, INTEGER_TYPES); | ||||||
|  |     } | ||||||
|  |     BUILD_DOUBLE_TEMPLATE(template void unsortedSegmentSqrtNFunctor_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), FLOAT_TYPES, INTEGER_TYPES); | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     template <typename T, typename I> | ||||||
|  |     static __global__ void segmentSqrtNBPLinearKernel(void* inputBuf, Nd4jLong* inputShape, void* eps, Nd4jLong* epsShape, void* indicesBuf, Nd4jLong* indicesShape, | ||||||
|  |                                                       int* lengths, void* outputBuf, Nd4jLong* outputShape) { | ||||||
|  |         __shared__ T* x; | ||||||
|  |         __shared__ T* gradIn; | ||||||
|  |         __shared__ T* gradOut; | ||||||
|  |         __shared__ I* y; | ||||||
|  |         __shared__ T* z; | ||||||
|  |         __shared__ Nd4jLong xLen, gradLen; | ||||||
|  | 
 | ||||||
|  |         if (threadIdx.x == 0) { | ||||||
|  |             xLen = shape::length(inputShape); | ||||||
|  |             x = reinterpret_cast<T*>(inputBuf); | ||||||
|  |             y = reinterpret_cast<I*>(indicesBuf); | ||||||
|  |             z = reinterpret_cast<T*>(outputBuf); | ||||||
|  |             gradOut = reinterpret_cast<T*>(eps); | ||||||
|  |             gradLen = shape::length(epsShape); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         auto start = blockIdx.x * blockDim.x + threadIdx.x; | ||||||
|  |         auto step = gridDim.x * blockDim.x; | ||||||
|  | 
 | ||||||
|  |         for (auto e = start; e < xLen; e += step) { | ||||||
|  | 
 | ||||||
|  |             auto zOffset = shape::getIndexOffset(e, outputShape, xLen); | ||||||
|  |             auto xOffset = shape::getIndexOffset(e, inputShape, xLen); | ||||||
|  |             auto yOffset = shape::getIndexOffset(e, indicesShape, xLen); | ||||||
|  |             auto classIndex = y[yOffset]; | ||||||
|  |             auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape, gradLen); | ||||||
|  | 
 | ||||||
|  |             z[zOffset] = T(gradOut[gradOffsetO] / math::nd4j_sqrt<int, float>(lengths[classIndex])); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  | 
 | ||||||
|  |     template <typename T, typename I> | ||||||
|  |     static __global__ void segmentSqrtNBPTadKernel(void* inputBuf, Nd4jLong* inputShape, void* eps, Nd4jLong* epsShape, | ||||||
|  |                                                    void* indicesBuf, Nd4jLong* indicesShape, int* lengths, void* outputBuf, Nd4jLong* outputShape,Nd4jLong* inputTad, | ||||||
|  |                                                    Nd4jLong* inputOffsets, Nd4jLong* gradOutTad, Nd4jLong* gradOutOffsets, Nd4jLong* outTad, Nd4jLong* outOffsets) { | ||||||
|  |         __shared__ T* x; | ||||||
|  |         __shared__ T* gradOut; | ||||||
|  |         __shared__ I* y; | ||||||
|  |         __shared__ T* z; | ||||||
|  |         __shared__ Nd4jLong xLen, yLen, gradLen, currentLen; | ||||||
|  | 
 | ||||||
|  |         if (threadIdx.x == 0) { | ||||||
|  |             xLen = shape::length(inputShape); | ||||||
|  |             x = reinterpret_cast<T*>(inputBuf); | ||||||
|  |             y = reinterpret_cast<I*>(indicesBuf); | ||||||
|  |             z = reinterpret_cast<T*>(outputBuf); | ||||||
|  |             yLen = shape::length(indicesShape); | ||||||
|  |             gradOut = reinterpret_cast<T*>(eps); | ||||||
|  |             gradLen = shape::length(epsShape); | ||||||
|  |             currentLen = shape::length(outTad); | ||||||
|  |         } | ||||||
|  |         __syncthreads(); | ||||||
|  | 
 | ||||||
|  |         for (auto i = blockIdx.x; i < yLen; i += gridDim.x) { | ||||||
|  | //            auto yIndex = shape::getIndexOffset(i, indicesShape, yLen); | ||||||
|  |             auto segment = y[i]; //yIndex]; | ||||||
|  |             T* currentOut = z + outOffsets[i]; | ||||||
|  |             T* outGrad = gradOut + gradOutOffsets[segment]; | ||||||
|  | 
 | ||||||
|  |             for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) { | ||||||
|  |                 auto zIndex = shape::getIndexOffset(e, outTad, currentLen); | ||||||
|  |                 auto gradIndex = shape::getIndexOffset(e, gradOutTad, gradLen); | ||||||
|  |                 if (lengths[segment] > 0) | ||||||
|  |                     currentOut[zIndex] = T(outGrad[gradIndex] / math::nd4j_sqrt<int, float>(lengths[segment])); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  | 
 | ||||||
|  |     template <typename T, typename I> | ||||||
|  |     static int unsortedSegmentSqrtNFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { | ||||||
|  |         auto stream = context->getCudaStream(); | ||||||
|  |         NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); | ||||||
|  |         auto numClasses = indices->e<int>(indices->lengthOf() - 1) + 1; | ||||||
|  |         NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}); | ||||||
|  |         NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}); | ||||||
|  | 
 | ||||||
|  |         classesRangesBegs.assign(indices->lengthOf()); | ||||||
|  |         classesRangesLens.assign(0); | ||||||
|  |         dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32); | ||||||
|  |         fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens); | ||||||
|  |         int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer()); | ||||||
|  |         int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer()); | ||||||
|  | 
 | ||||||
|  |         if (input->isVector()) { | ||||||
|  |             Nd4jLong loop_size = input->lengthOf(); | ||||||
|  |             auto numOfClasses = gradOut->lengthOf(); //indices->e<Nd4jLong>(loop_size - 1); | ||||||
|  |             segmentSqrtNBPLinearKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), | ||||||
|  |                     input->specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), | ||||||
|  |                     indices->specialBuffer(), indices->specialShapeInfo(), lengths, output->specialBuffer(), output->specialShapeInfo()); | ||||||
|  |         } | ||||||
|  |         else { | ||||||
|  |             std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); | ||||||
|  |             auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions); | ||||||
|  |             auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions); | ||||||
|  | //            auto packGradIn = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.getShapeInfo(), dimensions); | ||||||
|  |             auto packGradOut = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->getShapeInfo(), dimensions); | ||||||
|  |             Nd4jLong* inputTads = packX.specialShapeInfo(); | ||||||
|  |             Nd4jLong* inputTadOffsets = packX.specialOffsets(); | ||||||
|  |             Nd4jLong* outputTads = packZ.specialShapeInfo(); | ||||||
|  |             Nd4jLong* outputTadOffsets = packZ.specialOffsets(); | ||||||
|  |             Nd4jLong* gradOutTads = packGradOut.specialShapeInfo(); | ||||||
|  |             Nd4jLong* gradOutTadOffsets = packGradOut.specialOffsets(); | ||||||
|  | 
 | ||||||
|  |             segmentSqrtNBPTadKernel<T,I><<<indices->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), | ||||||
|  |                     gradOut->specialBuffer(), gradOut->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), lengths, | ||||||
|  |                     output->specialBuffer(), output->specialShapeInfo(), inputTads, inputTadOffsets, gradOutTads, gradOutTadOffsets, | ||||||
|  |                     outputTads, outputTadOffsets); | ||||||
|  |         } | ||||||
|  |         NDArray::registerSpecialUse({output}, {input, indices, gradOut}); | ||||||
|  | 
 | ||||||
|  |         return Status::OK(); | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     int unsortedSegmentSqrtNFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { | ||||||
|  |         BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentSqrtNFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INTEGER_TYPES); | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     BUILD_DOUBLE_TEMPLATE(template int unsortedSegmentSqrtNFunctorBP_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), FLOAT_TYPES, INTEGER_TYPES); | ||||||
|  | 
 | ||||||
|  | } | ||||||
|  | } | ||||||
|  | } | ||||||
							
								
								
									
										393
									
								
								libnd4j/include/ops/declarable/helpers/cuda/segment_sum.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										393
									
								
								libnd4j/include/ops/declarable/helpers/cuda/segment_sum.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,393 @@ | |||||||
|  | /******************************************************************************* | ||||||
|  |  * Copyright (c) 2015-2018 Skymind, Inc. | ||||||
|  |  * | ||||||
|  |  * This program and the accompanying materials are made available under the | ||||||
|  |  * terms of the Apache License, Version 2.0 which is available at | ||||||
|  |  * https://www.apache.org/licenses/LICENSE-2.0. | ||||||
|  |  * | ||||||
|  |  * Unless required by applicable law or agreed to in writing, software | ||||||
|  |  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||||
|  |  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||||
|  |  * License for the specific language governing permissions and limitations | ||||||
|  |  * under the License. | ||||||
|  |  * | ||||||
|  |  * SPDX-License-Identifier: Apache-2.0 | ||||||
|  |  ******************************************************************************/ | ||||||
|  | 
 | ||||||
|  | // | ||||||
|  | //  @author GS <sgazeos@gmail.com> | ||||||
|  | // | ||||||
|  | 
 | ||||||
|  | #include <ops/declarable/helpers/segment.h> | ||||||
|  | #include <ops/declarable/helpers/segment_common.h> | ||||||
|  | #include <NDArrayFactory.h> | ||||||
|  | #include <helpers/ShapeUtils.h> | ||||||
|  | #include <helpers/TAD.h> | ||||||
|  | #include <exceptions/cuda_exception.h> | ||||||
|  | #include <PointersManager.h> | ||||||
|  | #include <ConstantTadHelper.h> | ||||||
|  | 
 | ||||||
|  | namespace nd4j { | ||||||
|  | namespace ops { | ||||||
|  | namespace helpers { | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     // Segment ops linear kernels | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     template<typename T, typename I> | ||||||
|  |     static __global__ void | ||||||
|  |     segmentSumLinearKernel(void *input, Nd4jLong *inputShape, int *starts, int *lengths, Nd4jLong numOfClasses, | ||||||
|  |                            void *output, Nd4jLong *outputShape) { | ||||||
|  |         __shared__ | ||||||
|  |         T *val; | ||||||
|  |         __shared__ | ||||||
|  |         Nd4jLong xLen, zLen, segment, zIndex; | ||||||
|  |         __shared__ | ||||||
|  |         T *x; | ||||||
|  |         __shared__ | ||||||
|  |         T *z; | ||||||
|  |         __shared__ int threadsPerSegment, start, finish; | ||||||
|  | 
 | ||||||
|  |         if (threadIdx.x == 0) { | ||||||
|  |             threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses; | ||||||
|  |             segment = blockIdx.x / threadsPerSegment; | ||||||
|  |             x = reinterpret_cast<T *>(input); | ||||||
|  |             z = reinterpret_cast<T *>(output); | ||||||
|  | 
 | ||||||
|  |             xLen = shape::length(inputShape); | ||||||
|  |             zLen = shape::length(outputShape); | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |             if (segment < numOfClasses) { | ||||||
|  |                 zIndex = shape::getIndexOffset(segment, outputShape, zLen); | ||||||
|  |                 start = starts[segment]; | ||||||
|  |                 finish = start + lengths[segment]; | ||||||
|  |                 //val[segment] = ; | ||||||
|  |                 z[zIndex] = x[shape::getIndexOffset(start, inputShape, xLen)]; | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |         } | ||||||
|  |         __syncthreads(); | ||||||
|  | 
 | ||||||
|  |         for (auto e = start + threadIdx.x + 1; e < finish; e += blockDim.x) { | ||||||
|  |             auto xIndex = shape::getIndexOffset(e, inputShape, xLen); | ||||||
|  |             nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex]); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  | 
 | ||||||
|  |     template<typename T, typename I> | ||||||
|  |     static __global__ void | ||||||
|  |     unsortedSegmentSumLinearKernel(void *input, Nd4jLong *inputShape, void *indices, Nd4jLong *indicesShape, | ||||||
|  |                                    int *starts, int *lengths, Nd4jLong numOfClasses, void *output, | ||||||
|  |                                    Nd4jLong *outputShape) { | ||||||
|  |         __shared__ | ||||||
|  |         T *val; | ||||||
|  |         __shared__ | ||||||
|  |         Nd4jLong xLen, zLen, segment, zIndex; | ||||||
|  |         __shared__ | ||||||
|  |         T *x; | ||||||
|  |         __shared__ | ||||||
|  |         T *z; | ||||||
|  |         __shared__ | ||||||
|  |         I *y; //int threadsPerSegment, start, finish; | ||||||
|  | 
 | ||||||
|  |         if (threadIdx.x == 0) { | ||||||
|  |             segment = blockIdx.x; | ||||||
|  |             x = reinterpret_cast<T *>(input); | ||||||
|  |             z = reinterpret_cast<T *>(output); | ||||||
|  |             y = reinterpret_cast<I *>(indices); | ||||||
|  |             xLen = shape::length(inputShape); | ||||||
|  |             zLen = shape::length(outputShape); | ||||||
|  | 
 | ||||||
|  |             zIndex = shape::getIndexOffset(segment, outputShape, zLen); | ||||||
|  |             if (lengths[segment] > 0) | ||||||
|  |                 z[zIndex] = x[shape::getIndexOffset(starts[segment], inputShape, xLen)]; | ||||||
|  |             else | ||||||
|  |                 z[zIndex] = 0; //DataTypeUtils::max<T>(); | ||||||
|  |         } | ||||||
|  |         __syncthreads(); | ||||||
|  | 
 | ||||||
|  |         if (lengths[segment] > 0) | ||||||
|  |             for (auto e = threadIdx.x; e < xLen; e += blockDim.x) { | ||||||
|  |                 auto xIndex = shape::getIndexOffset(e, inputShape, xLen); | ||||||
|  |                 auto yIndex = shape::getIndexOffset(e, indicesShape, xLen); | ||||||
|  |                 if (y[yIndex] == segment && e != starts[segment]) { | ||||||
|  |                     nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex]); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     // SegmentSum kernel | ||||||
|  |     template <typename T, typename I> | ||||||
|  |     static __global__ void segmentSumTadKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong* inputTads, Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets) { | ||||||
|  |         __shared__ T* val; | ||||||
|  |         __shared__ Nd4jLong len, segment, zIndex, total; | ||||||
|  |         __shared__ T* z; | ||||||
|  |         __shared__ int threadsPerSegment, start, finish; | ||||||
|  | 
 | ||||||
|  |         if (threadIdx.x == 0) { | ||||||
|  |             segment = indices[blockIdx.x]; // / threadsPerSegment; | ||||||
|  |             z = reinterpret_cast<T*>(outputBuf) + outputTadOffsets[segment]; | ||||||
|  |             len = shape::length(inputTads); | ||||||
|  |             start = starts[segment]; | ||||||
|  |             finish = start + lengths[segment]; | ||||||
|  |             total = shape::sizeAt(inputShape, 0); | ||||||
|  | 
 | ||||||
|  |         } | ||||||
|  |         __syncthreads(); | ||||||
|  | 
 | ||||||
|  |         auto idx = blockIdx.x; | ||||||
|  |         if (blockIdx.x <= total) { | ||||||
|  |             auto x = reinterpret_cast<T *>(inputBuf) + inputTadOffsets[idx]; | ||||||
|  |             if (blockIdx.x == start) { | ||||||
|  |                 for (auto e = threadIdx.x; e < len; e += blockDim.x) { | ||||||
|  |                     auto xIndex = shape::getIndexOffset(e, inputTads, len); | ||||||
|  |                     auto zIndex = shape::getIndexOffset(e, outputTads, len); | ||||||
|  |                     z[zIndex] = x[xIndex]; | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |             else { | ||||||
|  |                 for (auto e = threadIdx.x; e < len; e += blockDim.x) { | ||||||
|  |                     auto xIndex = shape::getIndexOffset(e, inputTads, len); | ||||||
|  |                     auto zIndex = shape::getIndexOffset(e, outputTads, len); | ||||||
|  |                     if (lengths[segment]) | ||||||
|  |                         nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex]); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  | 
 | ||||||
|  |     template <typename T, typename I> | ||||||
|  |     static void segmentSumFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) { | ||||||
|  |         auto stream = context->getCudaStream(); | ||||||
|  |         Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1; | ||||||
|  |         NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses}); | ||||||
|  |         NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses}); | ||||||
|  | 
 | ||||||
|  |         classesRangesBegs.assign(indices->lengthOf()); | ||||||
|  |         classesRangesLens.assign(0); | ||||||
|  | 
 | ||||||
|  |         dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32); | ||||||
|  |         fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens); | ||||||
|  |         int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer()); | ||||||
|  |         int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer()); | ||||||
|  | 
 | ||||||
|  |         if (input->isVector()) { | ||||||
|  |             segmentSumLinearKernel<T,I><<<numClasses, input->lengthOf(), numClasses * 32 + 32, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo()); | ||||||
|  |         } | ||||||
|  |         else { | ||||||
|  |             std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); | ||||||
|  |             auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions); | ||||||
|  |             auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions); | ||||||
|  |             Nd4jLong* inputTads = packX.specialShapeInfo(); | ||||||
|  |             Nd4jLong* inputTadOffsets = packX.specialOffsets(); | ||||||
|  |             Nd4jLong* outputTads = packZ.specialShapeInfo(); | ||||||
|  |             Nd4jLong* outputTadOffsets = packZ.specialOffsets(); | ||||||
|  |             segmentSumTadKernel<T,I><<<input->sizeAt(0), 512, 2048, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     void segmentSumFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) { | ||||||
|  |         BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentSumFunctor_, (context, input, indices, output), NUMERIC_TYPES, INTEGER_TYPES); | ||||||
|  |     } | ||||||
|  |     BUILD_DOUBLE_TEMPLATE(template void segmentSumFunctor_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES); | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     template <typename T, typename I> | ||||||
|  |     static void unsortedSegmentSumFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { | ||||||
|  |         auto stream = context->getCudaStream(); | ||||||
|  | //        NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2}); | ||||||
|  |         NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses}); | ||||||
|  |         NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses}); | ||||||
|  | //        NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0}); | ||||||
|  | //        classes.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), &row, &classes); | ||||||
|  |         classesRangesBegs.assign(indices->lengthOf()); | ||||||
|  |         classesRangesLens.assign(0); | ||||||
|  |         dim3 dims(numOfClasses, indices->lengthOf(), (numOfClasses + 1) * 64); | ||||||
|  | //        int* classesBuf = reinterpret_cast<int*>(classes.specialBuffer()); | ||||||
|  |         fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens); | ||||||
|  |         int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer()); | ||||||
|  |         int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer()); | ||||||
|  | 
 | ||||||
|  |         if (input->isVector()) { | ||||||
|  |             unsortedSegmentSumLinearKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo()); | ||||||
|  |         } | ||||||
|  |         else { | ||||||
|  |             output->assign(0); | ||||||
|  |             std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); | ||||||
|  |             auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions); | ||||||
|  |             auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions); | ||||||
|  |             Nd4jLong* inputTads = packX.specialShapeInfo(); | ||||||
|  |             Nd4jLong* inputTadOffsets = packX.specialOffsets(); | ||||||
|  |             Nd4jLong* outputTads = packZ.specialShapeInfo(); | ||||||
|  |             Nd4jLong* outputTadOffsets = packZ.specialOffsets(); | ||||||
|  |             dims.x = input->sizeAt(0); | ||||||
|  |             segmentSumTadKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     void unsortedSegmentSumFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { | ||||||
|  |         BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentSumFunctor_, (context, input, indices, numOfClasses, output), | ||||||
|  |                               NUMERIC_TYPES, INTEGER_TYPES); | ||||||
|  | 
 | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     BUILD_DOUBLE_TEMPLATE(template void unsortedSegmentSumFunctor_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES); | ||||||
|  | 
 | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     // Backpropagate ops | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     // Sorted sum backpropagate | ||||||
|  |     template <typename T, typename I> | ||||||
|  |     static __global__ void segmentSumBPLinearKernel(void* inputBuf, Nd4jLong* inputShape, void* eps, Nd4jLong* epsShape, | ||||||
|  |                                                     void* indicesBuf, Nd4jLong* indicesShape, void* outputBuf, Nd4jLong* outputShape) { | ||||||
|  |         __shared__ T* x; | ||||||
|  |         __shared__ T* gradIn; | ||||||
|  |         __shared__ T* gradOut; | ||||||
|  |         __shared__ I* y; | ||||||
|  |         __shared__ T* z; | ||||||
|  |         __shared__ Nd4jLong xLen, gradLen; | ||||||
|  | 
 | ||||||
|  |         if (threadIdx.x == 0) { | ||||||
|  |             xLen = shape::length(inputShape); | ||||||
|  |             x = reinterpret_cast<T*>(inputBuf); | ||||||
|  |             y = reinterpret_cast<I*>(indicesBuf); | ||||||
|  |             z = reinterpret_cast<T*>(outputBuf); | ||||||
|  |             gradOut = reinterpret_cast<T*>(eps); | ||||||
|  |             gradLen = shape::length(epsShape); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         auto start = blockIdx.x * blockDim.x + threadIdx.x; | ||||||
|  |         auto step = gridDim.x * blockDim.x; | ||||||
|  | 
 | ||||||
|  |         for (auto e = start; e < xLen; e += step) { | ||||||
|  | 
 | ||||||
|  |             auto zOffset = shape::getIndexOffset(e, outputShape, xLen); | ||||||
|  |             auto xOffset = shape::getIndexOffset(e, inputShape, xLen); | ||||||
|  |             auto yOffset = shape::getIndexOffset(e, indicesShape, xLen); | ||||||
|  |             auto classIndex = y[yOffset]; | ||||||
|  |             auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape, gradLen); | ||||||
|  | 
 | ||||||
|  |             z[zOffset] = gradOut[gradOffsetO]; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     template <typename T, typename I> | ||||||
|  |     static __global__ void segmentSumBPTadKernel(void* inputBuf, Nd4jLong* inputShape, void* eps, Nd4jLong* epsShape, | ||||||
|  |                                                  void* indicesBuf, Nd4jLong* indicesShape, void* outputBuf, Nd4jLong* outputShape, Nd4jLong* inputTad, | ||||||
|  |                                                  Nd4jLong* inputOffsets, Nd4jLong* gradOutTad, Nd4jLong* gradOutOffsets, Nd4jLong* outTad, Nd4jLong* outOffsets) { | ||||||
|  |         __shared__ T* x; | ||||||
|  |         __shared__ T* gradOut; | ||||||
|  |         __shared__ I* y; | ||||||
|  |         __shared__ T* z; | ||||||
|  |         __shared__ Nd4jLong xLen, yLen, gradLen, currentLen; | ||||||
|  | 
 | ||||||
|  |         if (threadIdx.x == 0) { | ||||||
|  |             xLen = shape::length(inputShape); | ||||||
|  |             x = reinterpret_cast<T*>(inputBuf); | ||||||
|  |             y = reinterpret_cast<I*>(indicesBuf); | ||||||
|  |             z = reinterpret_cast<T*>(outputBuf); | ||||||
|  |             yLen = shape::length(indicesShape); | ||||||
|  |             gradOut = reinterpret_cast<T*>(eps); | ||||||
|  |             gradLen = shape::length(epsShape); | ||||||
|  |             currentLen = shape::length(outTad); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         for (auto i = blockIdx.x; i < yLen; i += gridDim.x) { | ||||||
|  |             auto yIndex = shape::getIndexOffset(i, indicesShape, yLen); | ||||||
|  |             auto segment = y[yIndex]; | ||||||
|  |             T* currentOut = z + outOffsets[i]; | ||||||
|  |             T* outGrad = gradOut + gradOutOffsets[segment]; | ||||||
|  | 
 | ||||||
|  |             for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) { | ||||||
|  |                 currentOut[e] = outGrad[e]; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     template <typename T, typename I> | ||||||
|  |     int segmentSumFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { | ||||||
|  |         auto stream = context->getCudaStream(); | ||||||
|  |         NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); | ||||||
|  |         if (input->isVector()) { | ||||||
|  |             Nd4jLong loop_size = input->lengthOf(); | ||||||
|  |             auto numOfClasses = gradOut->lengthOf(); //indices->e<Nd4jLong>(loop_size - 1); | ||||||
|  |             segmentSumBPLinearKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), | ||||||
|  |                     input->specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), | ||||||
|  |                     indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo()); | ||||||
|  |         } | ||||||
|  |         else { | ||||||
|  |             std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); | ||||||
|  |             auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions); | ||||||
|  |             auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions); | ||||||
|  |             auto packGradOut = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->getShapeInfo(), dimensions); | ||||||
|  |             Nd4jLong* inputTads = packX.specialShapeInfo(); | ||||||
|  |             Nd4jLong* inputTadOffsets = packX.specialOffsets(); | ||||||
|  |             Nd4jLong* outputTads = packZ.specialShapeInfo(); | ||||||
|  |             Nd4jLong* outputTadOffsets = packZ.specialOffsets(); | ||||||
|  |             Nd4jLong* gradOutTads = packGradOut.specialShapeInfo(); | ||||||
|  |             Nd4jLong* gradOutTadOffsets = packGradOut.specialOffsets(); | ||||||
|  | 
 | ||||||
|  |             segmentSumBPTadKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), | ||||||
|  |                     gradOut->specialBuffer(), gradOut->specialShapeInfo(), | ||||||
|  |                     indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), | ||||||
|  |                     inputTads, inputTadOffsets, gradOutTads, gradOutTadOffsets, | ||||||
|  |                     outputTads, outputTadOffsets); | ||||||
|  |         } | ||||||
|  |         NDArray::registerSpecialUse({output}, {input, indices, gradOut}); | ||||||
|  |         return Status::OK(); | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  | 
 | ||||||
|  |     int segmentSumFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { | ||||||
|  |         BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentSumFunctorBP_, (context, input, | ||||||
|  |                 indices, gradOut, output), NUMERIC_TYPES, INTEGER_TYPES); | ||||||
|  |     } | ||||||
|  |     BUILD_DOUBLE_TEMPLATE(template int segmentSumFunctorBP_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES); | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  | 
 | ||||||
|  |     template <typename T, typename I> | ||||||
|  |     static int unsortedSegmentSumFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { | ||||||
|  |         auto stream = context->getCudaStream(); | ||||||
|  |         NDArray::prepareSpecialUse({output}, {input, indices, gradOut}); | ||||||
|  |         if (input->isVector()) { | ||||||
|  |             Nd4jLong loop_size = input->lengthOf(); | ||||||
|  |             auto numOfClasses = gradOut->lengthOf(); //indices->e<Nd4jLong>(loop_size - 1); | ||||||
|  |             segmentSumBPLinearKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), | ||||||
|  |                     input->specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), | ||||||
|  |                     indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo()); | ||||||
|  |         } | ||||||
|  |         else { | ||||||
|  |             std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); | ||||||
|  |             auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions); | ||||||
|  |             auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions); | ||||||
|  |             auto packGradOut = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->getShapeInfo(), dimensions); | ||||||
|  |             Nd4jLong* inputTads = packX.specialShapeInfo(); | ||||||
|  |             Nd4jLong* inputTadOffsets = packX.specialOffsets(); | ||||||
|  |             Nd4jLong* outputTads = packZ.specialShapeInfo(); | ||||||
|  |             Nd4jLong* outputTadOffsets = packZ.specialOffsets(); | ||||||
|  |             Nd4jLong* gradOutTads = packGradOut.specialShapeInfo(); | ||||||
|  |             Nd4jLong* gradOutTadOffsets = packGradOut.specialOffsets(); | ||||||
|  | 
 | ||||||
|  |             segmentSumBPTadKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), | ||||||
|  |                     gradOut->specialBuffer(), gradOut->specialShapeInfo(), | ||||||
|  |                     indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), | ||||||
|  |                     inputTads, inputTadOffsets, gradOutTads, gradOutTadOffsets, | ||||||
|  |                     outputTads, outputTadOffsets); | ||||||
|  |         } | ||||||
|  |         NDArray::registerSpecialUse({output}, {input, indices, gradOut}); | ||||||
|  |         return Status::OK(); | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     int unsortedSegmentSumFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { | ||||||
|  |         BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentSumFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), NUMERIC_TYPES, INTEGER_TYPES); | ||||||
|  |     } | ||||||
|  |     // -------------------------------------------------------------------------------------------------------------- // | ||||||
|  |     BUILD_DOUBLE_TEMPLATE(template int unsortedSegmentSumFunctorBP_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES); | ||||||
|  | 
 | ||||||
|  | } | ||||||
|  | } | ||||||
|  | } | ||||||
| @ -24,16 +24,40 @@ namespace nd4j { | |||||||
| namespace ops { | namespace ops { | ||||||
| namespace helpers { | namespace helpers { | ||||||
| 
 | 
 | ||||||
|     template <typename T> |     template <typename I, typename B> | ||||||
|     static void sequenceMask_(NDArray* input, NDArray* output, int maxIndex) { |     static __global__ void sequenceMaskKernel(void* inputBuf, Nd4jLong* inputShape, void* outputBuf, Nd4jLong* outputShape, int maxIndex) { | ||||||
|         // | 
 | ||||||
|  |         __shared__ I* input; | ||||||
|  |         __shared__ B* output; | ||||||
|  |         __shared__ Nd4jLong inputLen, outputLen; | ||||||
|  |         if (threadIdx.x == 0) { | ||||||
|  |             input = reinterpret_cast<I*>(inputBuf); | ||||||
|  |             output = reinterpret_cast<B*>(outputBuf); | ||||||
|  |             inputLen = shape::length(inputShape); | ||||||
|  |             outputLen = shape::length(outputShape); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         for (auto i = blockIdx.x; i < maxIndex; i += gridDim.x) | ||||||
|  |             for(auto k = threadIdx.x; k < inputLen; k += blockDim.x) | ||||||
|  |                 if (i < input[shape::getIndexOffset(k, inputShape, inputLen)]) | ||||||
|  |                     output[shape::getIndexOffset(k * maxIndex + i, outputShape, outputLen)] = B(true); | ||||||
|  | 
 | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     template <typename I, typename B> | ||||||
|  |     static void sequenceMask_(LaunchContext* context, NDArray* input, NDArray* output, int maxIndex) { | ||||||
|  |         dim3 launchDims(maxIndex, input->lengthOf(), 128); | ||||||
|  |         NDArray::prepareSpecialUse({output}, {input}); | ||||||
|  |         auto stream = context->getCudaStream(); | ||||||
|  |         sequenceMaskKernel<I, B><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), maxIndex); | ||||||
|  |         NDArray::registerSpecialUse({output}, {input}); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     void sequenceMask(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int maxIndex) { |     void sequenceMask(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int maxIndex) { | ||||||
|         BUILD_SINGLE_SELECTOR(input->dataType(), sequenceMask_, (input, output, maxIndex), LIBND4J_TYPES); |         BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sequenceMask_, (context, input, output, maxIndex), INTEGER_TYPES, BOOL_TYPES); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     BUILD_SINGLE_TEMPLATE(template void sequenceMask_, (NDArray* input, NDArray* output, int maxIndex), LIBND4J_TYPES); |     BUILD_DOUBLE_TEMPLATE(template void sequenceMask_, (nd4j::LaunchContext* context, NDArray* input, NDArray* output, int maxIndex), INTEGER_TYPES, BOOL_TYPES); | ||||||
| } | } | ||||||
| } | } | ||||||
| } | } | ||||||
| @ -456,27 +456,246 @@ void tileBP(nd4j::LaunchContext * context, const NDArray& gradO /*input*/, NDArr | |||||||
|     manager.synchronize(); |     manager.synchronize(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | /////////////////////////////////////////////////////////////////// | ||||||
|  | template<typename T> | ||||||
|  | __global__ static void scatterUpdateCuda(const int opCode, const int numOfInd, | ||||||
|  |                                               void* vx, const Nd4jLong *xShapeInfo, const Nd4jLong *xOffsets, | ||||||
|  |                                               void* vy, const Nd4jLong *yShapeInfo, const Nd4jLong *yOffsets, | ||||||
|  |                                               const int* indexes) { | ||||||
| 
 | 
 | ||||||
|  |     __shared__ T *x, *y; | ||||||
|  |     __shared__ Nd4jLong arrLenX, arrLenY; | ||||||
| 
 | 
 | ||||||
|  |     for (int e = 0; e < numOfInd; e++ ) { | ||||||
| 
 | 
 | ||||||
|  |         const auto xIndex = indexes[e]; | ||||||
|  |         const bool isOwner = xIndex < gridDim.x ? blockIdx.x == xIndex : blockIdx.x == xIndex % gridDim.x; | ||||||
| 
 | 
 | ||||||
|  |         if (!isOwner) | ||||||
|  |             continue; | ||||||
| 
 | 
 | ||||||
|  |         if (threadIdx.x == 0) { | ||||||
|  |             x = reinterpret_cast<T*>(vx) + xOffsets[xIndex]; | ||||||
|  |             y = reinterpret_cast<T*>(vy) + yOffsets[e]; | ||||||
|  |             arrLenX = shape::length(xShapeInfo); | ||||||
|  |             arrLenY = shape::length(yShapeInfo); | ||||||
|  |         } | ||||||
| 
 | 
 | ||||||
|  |         __syncthreads(); | ||||||
| 
 | 
 | ||||||
|  |         if (arrLenX != arrLenY) | ||||||
|  |             return; | ||||||
| 
 | 
 | ||||||
|  |         for (Nd4jLong i = threadIdx.x; i < arrLenX; i += blockDim.x) { | ||||||
|  | 
 | ||||||
|  |             const auto xOffset = shape::getIndexOffset(i, xShapeInfo, arrLenX); | ||||||
|  |             const auto yOffset = shape::getIndexOffset(i, yShapeInfo, arrLenY); | ||||||
|  | 
 | ||||||
|  |             switch (opCode) { | ||||||
|  |                 case 0: | ||||||
|  |                     x[xOffset] += y[yOffset]; | ||||||
|  |                     break; | ||||||
|  |                 case 1: | ||||||
|  |                     x[xOffset] -= y[yOffset]; | ||||||
|  |                     break; | ||||||
|  |                 case 2: | ||||||
|  |                     x[xOffset] *= y[yOffset]; | ||||||
|  |                     break; | ||||||
|  |                 case 3: | ||||||
|  |                     x[xOffset] /= y[yOffset]; | ||||||
|  |                     break; | ||||||
|  |                 case 4: | ||||||
|  |                     x[xOffset] = y[yOffset] - x[xOffset]; | ||||||
|  |                     break; | ||||||
|  |                 case 5: | ||||||
|  |                     x[xOffset] = y[yOffset] / x[xOffset]; | ||||||
|  |                     break; | ||||||
|  |                 case 6: | ||||||
|  |                     x[xOffset] = y[yOffset]; | ||||||
|  |                     break; | ||||||
|  |                 default: | ||||||
|  |                     continue; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |         __syncthreads(); | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | template<typename T> | ||||||
|  | __host__ static void scatterUpdateCudaLauncher(const cudaStream_t* stream, const int opCode, const int numOfInd, void* vx, const Nd4jLong *xShapeInfo, const Nd4jLong *xOffsets, void* vy, const Nd4jLong *yShapeInfo, const Nd4jLong *yOffsets, const int* indexes) { | ||||||
|  | 
 | ||||||
|  |     scatterUpdateCuda<T><<<512, 256, MAX_NUM_THREADS, *stream>>>(opCode, numOfInd, vx, xShapeInfo, xOffsets, vy, yShapeInfo, yOffsets, indexes); | ||||||
|  | } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| ////////////////////////////////////////////////////////////////////////// | ////////////////////////////////////////////////////////////////////////// | ||||||
|  | void scatterUpdate(nd4j::LaunchContext* context, NDArray& input, NDArray& updates, const std::vector<int>* intArgs) { | ||||||
|  | 
 | ||||||
|  |     const int opCode    = (*intArgs)[0]; | ||||||
|  |     const int numOfDims = (*intArgs)[1]; | ||||||
|  |     const int numOfInd  = (*intArgs)[2 + numOfDims]; | ||||||
|  | 
 | ||||||
|  |     std::vector<int> tadDimensions(numOfDims); | ||||||
|  |     for (int e = 2; e < 2 + numOfDims; e++) | ||||||
|  |         tadDimensions[e-2] = (*intArgs)[e]; | ||||||
|  | 
 | ||||||
|  |     auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), tadDimensions); | ||||||
|  |     auto packY = ConstantTadHelper::getInstance()->tadForDimensions(updates.getShapeInfo(), tadDimensions); | ||||||
|  | 
 | ||||||
|  |     NDArray indices(const_cast<int*>(intArgs->data()) + numOfDims + 3, 'c', {numOfInd}, nd4j::DataType::INT32, context); | ||||||
|  | 
 | ||||||
|  |     PointersManager manager(context, "scatterUpdate"); | ||||||
|  | 
 | ||||||
|  |     NDArray::prepareSpecialUse({&input}, {&input, &updates, &indices}); | ||||||
|  |     BUILD_SINGLE_SELECTOR(input.dataType(), scatterUpdateCudaLauncher, (context->getCudaStream(), opCode, numOfInd, input.specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), updates.specialBuffer(), packY.platformShapeInfo(), packY.platformOffsets(), reinterpret_cast<int*>(indices.getSpecialBuffer())), LIBND4J_TYPES); | ||||||
|  |     NDArray::registerSpecialUse({&input}, {&input, &updates, &indices}); | ||||||
|  | 
 | ||||||
|  |     manager.synchronize(); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|     template <typename T> |     template <typename T> | ||||||
|     void randomShuffle_(nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::random::RandomBuffer& rng, const bool isInplace) { |     static __global__ void swapShuffleKernel(T* input, Nd4jLong* shape, Nd4jLong firstDim, Nd4jLong len, nd4j::graph::RandomGenerator* rng) { | ||||||
|  |         auto tid = blockIdx.x * blockDim.x; | ||||||
|  |         auto step = blockDim.x * gridDim.x; | ||||||
|  | 
 | ||||||
|  |         for (int i = firstDim - 1 - tid - threadIdx.x; i > 0; i -= step) { | ||||||
|  |             int r = rng->relativeInt(i) % i; | ||||||
|  |             if (i != r) { | ||||||
|  |                 T e0 = input[shape::getIndexOffset(i, shape, len)]; | ||||||
|  |                 T e1 = input[shape::getIndexOffset(r, shape, len)]; | ||||||
|  |                 //math::nd4j_swap<T>(input(i), input(r)); | ||||||
|  |                 input[shape::getIndexOffset(i, shape, len)] = e1; | ||||||
|  |                 input[shape::getIndexOffset(r, shape, len)] = e0; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     template <typename T> | ||||||
|  |     static __global__ void fillShuffleKernel(T* input, Nd4jLong* inputShape, T* output, Nd4jLong* outputShape, Nd4jLong firstDim, Nd4jLong len, int* indices, nd4j::graph::RandomGenerator* rng) { | ||||||
|  | 
 | ||||||
|  | //        PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->tadThreshold()) | ||||||
|  |         auto tid = blockIdx.x * blockDim.x; | ||||||
|  |         auto step = blockDim.x * gridDim.x; | ||||||
|  | 
 | ||||||
|  |         for(int i = firstDim - 1 - tid - threadIdx.x; i > 0; i -= step) { | ||||||
|  |             int r = rng->relativeInt(i) % i; | ||||||
|  |             output[shape::getIndexOffset(i, outputShape, len)] = input[shape::getIndexOffset(indices[r], inputShape, len)]; | ||||||
|  |             if(i != r) { | ||||||
|  |                 output[shape::getIndexOffset(r, outputShape, len)] = input[shape::getIndexOffset(indices[i], inputShape, len)]; | ||||||
|  | //                output.p(r, input.e<T>(indices[i])); | ||||||
|  | //                math::nd4j_swap<int>(indices[i], indices[r]); | ||||||
|  |                 atomicExch(&indices[i], indices[r]); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |     } | ||||||
|  |     ////////////////////////////////////////////////////////////////////////// | ||||||
|  |     template <typename T> | ||||||
|  |     void randomShuffle_(nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::graph::RandomGenerator& rng, const bool isInplace) { | ||||||
|  | 
 | ||||||
|  |         // check edge cases first | ||||||
|  |         int temp; | ||||||
|  |         const int firstDim = input.sizeAt(0); | ||||||
|  |         auto stream = context->getCudaStream(); | ||||||
|  |         NDArray::prepareSpecialUse({&output}, {&input}); | ||||||
|  |         if(input.lengthOf() == 1 || firstDim == 1) { | ||||||
|  |             if(!isInplace) | ||||||
|  |                 output.assign(input); | ||||||
|  |         } | ||||||
|  |         else if (input.isVector() || shape::isLikeVector(input.getShapeInfo(), temp)) { | ||||||
|  | 
 | ||||||
|  |             // apply Fisher-Yates shuffle | ||||||
|  |             nd4j::graph::RandomGenerator* dRandom = nullptr; | ||||||
|  |             cudaMalloc(&dRandom, sizeof(nd4j::graph::RandomGenerator)); | ||||||
|  |             cudaMemcpy(dRandom, &rng, sizeof(nd4j::graph::RandomGenerator), cudaMemcpyHostToDevice); | ||||||
|  |             T* inputBuf = reinterpret_cast<T*>(input.specialBuffer()); | ||||||
|  |             if(isInplace) { | ||||||
|  |                 swapShuffleKernel<T><<<128, 256, 1024, *stream>>>(inputBuf, input.specialShapeInfo(), firstDim, input.lengthOf(), dRandom); | ||||||
|  |             } | ||||||
|  |             else { | ||||||
|  |                 std::vector<int> indices(firstDim); | ||||||
|  |                 std::iota(indices.begin(), indices.end(), 0); | ||||||
|  |                 cudaMemcpy(output.specialBuffer(), input.specialBuffer(), sizeof(T), cudaMemcpyDeviceToDevice); | ||||||
|  |                 //output.p<T>(Nd4jLong(0), input.e<T>(0)); | ||||||
|  |                 PointersManager pointersManager(context, "helper::randomShuffle_"); | ||||||
|  |                 int* indicesDev = reinterpret_cast<int*>(pointersManager.replicatePointer(indices.data(), indices.size() * sizeof(int))); | ||||||
|  |                 T* outputBuf = reinterpret_cast<T*>(output.specialBuffer()); | ||||||
|  |                 fillShuffleKernel<T><<<128, 256, 1024, *stream>>>(inputBuf, input.specialShapeInfo(), outputBuf, output.specialShapeInfo(), firstDim, input.lengthOf(), indicesDev, dRandom); | ||||||
|  |                 pointersManager.synchronize(); | ||||||
|  |             } | ||||||
|  | //            rng.rewindH(firstDim - 1); | ||||||
|  |             cudaFree(dRandom); | ||||||
|  |         } | ||||||
|  |         else { | ||||||
|  | 
 | ||||||
|  |             // evaluate sub-arrays list of input array through all dimensions excluding first one | ||||||
|  |             std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input.rankOf(), {0}); | ||||||
|  |             auto subArrsListIn = input.allTensorsAlongDimension(dimensions); | ||||||
|  | 
 | ||||||
|  |             // apply Fisher-Yates shuffle | ||||||
|  |             if(isInplace) { | ||||||
|  |                 PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->elementwiseThreshold()) | ||||||
|  |                 for(int i = firstDim - 1; i > 0; --i) { | ||||||
|  |                     int r = rng.relativeInt(i) % i; | ||||||
|  | 
 | ||||||
|  |                     if(i != r) | ||||||
|  |                         subArrsListIn->at(i)->swapUnsafe(*subArrsListIn->at(r)); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |             else { | ||||||
|  |                 // evaluate sub-arrays list of output array through all dimensions excluding first one | ||||||
|  |                 auto subArrsListOut = output.allTensorsAlongDimension(dimensions); | ||||||
|  |                 std::vector<int> indices(firstDim); | ||||||
|  |                 std::iota(indices.begin(), indices.end(), 0); | ||||||
|  |                 bool isZeroShuffled = false; | ||||||
|  |                 PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->tadThreshold()) | ||||||
|  |                 for(int i = firstDim - 1; i > 0; --i) { | ||||||
|  |                     int r = rng.relativeInt(i) % i; | ||||||
|  |                     subArrsListOut->at(i)->assign(subArrsListIn->at(indices[r])); | ||||||
|  |                     if(r == 0) | ||||||
|  |                         isZeroShuffled = true; | ||||||
|  | 
 | ||||||
|  |                     if(i != r) { | ||||||
|  |                         subArrsListOut->at(r)->assign(subArrsListIn->at(indices[i])); | ||||||
|  |                         math::nd4j_swap<int>(indices[i], indices[r]); | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |                 if(!isZeroShuffled) | ||||||
|  |                     subArrsListOut->at(0)->assign(subArrsListIn->at(0)); | ||||||
|  |                 delete subArrsListOut; | ||||||
|  |             } | ||||||
|  |             rng.rewindH(firstDim-1); | ||||||
|  |             delete subArrsListIn; | ||||||
|  |         } | ||||||
|  |         NDArray::registerSpecialUse({&output}, {&input}); | ||||||
| 
 | 
 | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     void randomShuffle(nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::random::RandomBuffer& rng, const bool isInplace) { |     void randomShuffle(nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::graph::RandomGenerator& rng, const bool isInplace) { | ||||||
|         BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (context, input, output, rng, isInplace), LIBND4J_TYPES); |         BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (context, input, output, rng, isInplace), LIBND4J_TYPES); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     BUILD_SINGLE_TEMPLATE(template void randomShuffle_, (nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::random::RandomBuffer& rng, const bool isInplace), LIBND4J_TYPES); |     BUILD_SINGLE_TEMPLATE(template void randomShuffle_, (nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::graph::RandomGenerator& rng, const bool isInplace), LIBND4J_TYPES); | ||||||
| 
 | 
 | ||||||
|     //////////////////////////////////////////////////////////////////////// |     //////////////////////////////////////////////////////////////////////// | ||||||
|     template<typename T> |     template<typename T> | ||||||
| @ -496,11 +715,6 @@ void tileBP(nd4j::LaunchContext * context, const NDArray& gradO /*input*/, NDArr | |||||||
| void eye(nd4j::LaunchContext * context, NDArray& output) { | void eye(nd4j::LaunchContext * context, NDArray& output) { | ||||||
| 
 | 
 | ||||||
|     output.setIdentity(); |     output.setIdentity(); | ||||||
| } |  | ||||||
| 
 |  | ||||||
|     ////////////////////////////////////////////////////////////////////////// |  | ||||||
|     void scatterUpdate(nd4j::LaunchContext * context, NDArray& operand, NDArray& updates, const std::vector<int>* intArgs) { |  | ||||||
| 
 |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|     ////////////////////////////////////////////////////////////////////////// |     ////////////////////////////////////////////////////////////////////////// | ||||||
|  | |||||||
| @ -33,27 +33,7 @@ namespace helpers { | |||||||
| 
 | 
 | ||||||
| 	void gruTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h); | 	void gruTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h); | ||||||
| 
 | 
 | ||||||
| 	void gruCellBP(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh, const NDArray* dLdWx0, | 	void gruCellBP(nd4j::LaunchContext* context, const NDArray* x, const NDArray* hLast, const NDArray* W, const NDArray* Wc, const NDArray* b, const NDArray* bc, const NDArray* dLdr, const NDArray* dLdu, const NDArray* dLdc, const NDArray* dLdh, NDArray* dLdx, NDArray* dLdhLast, NDArray* dLdW, NDArray* dLdWc, NDArray* dLdb, NDArray* dLdbc); | ||||||
|                   const NDArray* dLdWh0, const NDArray* dLdb0, NDArray* dLdx, NDArray* dLdh0, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb); |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| //////////////////////////////////////////////////////////////////////////
 |  | ||||||
| FORCEINLINE NDArray sigmoid(const NDArray& arr) { |  | ||||||
|     return (const_cast<NDArray&>(arr)).transform(transform::Sigmoid); |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| FORCEINLINE void sigmoidInplace(const NDArray& arr) { |  | ||||||
|     (const_cast<NDArray&>(arr)).applyTransform(transform::Sigmoid); |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| //////////////////////////////////////////////////////////////////////////
 |  | ||||||
| FORCEINLINE NDArray tanh(const NDArray& arr) { |  | ||||||
|     return (const_cast<NDArray&>(arr)).transform(transform::Tanh); |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| FORCEINLINE void tanhInplace(const NDArray& arr) { |  | ||||||
|     (const_cast<NDArray&>(arr)).applyTransform(transform::Tanh); |  | ||||||
| } |  | ||||||
| 
 | 
 | ||||||
| } | } | ||||||
| } | } | ||||||
|  | |||||||
| @ -27,37 +27,37 @@ namespace nd4j { | |||||||
|     namespace ops { |     namespace ops { | ||||||
|         namespace helpers { |         namespace helpers { | ||||||
|             template <typename T> |             template <typename T> | ||||||
|             FORCEINLINE Nd4jLong longBytes(T value); |             FORCEINLINE _CUDA_HD Nd4jLong longBytes(T value); | ||||||
| 
 | 
 | ||||||
|             template <> |             template <> | ||||||
|             FORCEINLINE Nd4jLong longBytes(float value) { |             FORCEINLINE _CUDA_HD Nd4jLong longBytes(float value) { | ||||||
|                 int intie = *(int *)&value; |                 int intie = *(int *)&value; | ||||||
|                 return static_cast<Nd4jLong>(intie); |                 return static_cast<Nd4jLong>(intie); | ||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|             template <> |             template <> | ||||||
|             FORCEINLINE Nd4jLong longBytes(double value) { |             FORCEINLINE _CUDA_HD Nd4jLong longBytes(double value) { | ||||||
|                 Nd4jLong longie = *(Nd4jLong *)&value; |                 Nd4jLong longie = *(Nd4jLong *)&value; | ||||||
|                 return longie; |                 return longie; | ||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|             template <> |             template <> | ||||||
|             FORCEINLINE Nd4jLong longBytes(float16 value) { |             FORCEINLINE _CUDA_HD Nd4jLong longBytes(float16 value) { | ||||||
|                 return longBytes<float>((float) value); |                 return longBytes<float>((float) value); | ||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|             template <> |             template <> | ||||||
|             FORCEINLINE Nd4jLong longBytes(Nd4jLong value) { |             FORCEINLINE _CUDA_HD Nd4jLong longBytes(Nd4jLong value) { | ||||||
|                 return value; |                 return value; | ||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|             template <> |             template <> | ||||||
|             FORCEINLINE Nd4jLong longBytes(bfloat16 value) { |             FORCEINLINE _CUDA_HD Nd4jLong longBytes(bfloat16 value) { | ||||||
|                 return longBytes<float>((float) value); |                 return longBytes<float>((float) value); | ||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|             template <typename T> |             template <typename T> | ||||||
|             FORCEINLINE Nd4jLong longBytes(T value) { |             FORCEINLINE _CUDA_HD Nd4jLong longBytes(T value) { | ||||||
|                 return longBytes<Nd4jLong>((Nd4jLong) value); |                 return longBytes<Nd4jLong>((Nd4jLong) value); | ||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -30,37 +30,30 @@ namespace helpers { | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| //////////////////////////////////////////////////////////////////////////
 | //////////////////////////////////////////////////////////////////////////
 | ||||||
| static FORCEINLINE NDArray activation(const NDArray& arr) { | void rnnCell(nd4j::LaunchContext * context, const NDArray* xt, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* hPrev, NDArray* ht) { | ||||||
| 
 | 
 | ||||||
|     return (const_cast<NDArray&>(arr)).transform(transform::Tanh); |     // xt    input [bS x iS]
 | ||||||
|  |     // Wx    input-to-hidden weights, [iS  x nU]
 | ||||||
|  |     // Wh    hidden-to-hidden weights, [nU x nU]
 | ||||||
|  |     // b     biases, [2*nU]: {0, nU} are input-to-hidden biases and {nU, 2*nU} are hidden-to-hidden biases
 | ||||||
|  |     // hPrev previous cell output [bS x nU],  that is at previous time step t-1, in case of projection=false -> nU=nU!!!
 | ||||||
|  | 
 | ||||||
|  |     const int nU = hPrev->sizeAt(1); | ||||||
|  | 
 | ||||||
|  |     // ht is current cell output [bS x nU], that is at current time step t
 | ||||||
|  |     ht->assign(mmul(*xt, *Wx) + (*b)({{0, nU}})  +  mmul(*hPrev, *Wh) + (*b)({{nU, 2*nU}}));     // [bS x nU] + [nU]  +  [bS x nU] + [nU] = [bS x nU]
 | ||||||
|  |     ht->applyTransform(transform::Tanh); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| //////////////////////////////////////////////////////////////////////////
 |  | ||||||
| void rnnCell(nd4j::LaunchContext * context, const NDArray* xt, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* ht_1, NDArray* ht) { |  | ||||||
| 
 |  | ||||||
|     // xt   input [bS x inSize]
 |  | ||||||
|     // Wx   input-to-hidden weights, [inSize  x numUnits]
 |  | ||||||
|     // Wh   hidden-to-hidden weights, [numUnits x numUnits]
 |  | ||||||
|     // b    biases, [2*numUnits]: {0, numUnits} are input-to-hidden biases and {numUnits, 2*numUnits} are hidden-to-hidden biases
 |  | ||||||
|     // ht_1 previous cell output [bS x numUnits],  that is at previous time step t-1, in case of projection=false -> numUnits=numUnits!!!
 |  | ||||||
| 
 |  | ||||||
|     const int numUnits  = ht_1->sizeAt(1); |  | ||||||
|      |  | ||||||
|     // ht is current cell output [bS x numUnits], that is at current time step t                
 |  | ||||||
|     ht->assign(activation(mmul(*xt, *Wx) + (*b)({{0, numUnits}})  +  mmul(*ht_1, *Wh) + (*b)({{numUnits, 2*numUnits}})));     // [bS x numUnits] + [numUnits]  +  [bS x numUnits] + [numUnits] = [bS x numUnits]
 |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| //////////////////////////////////////////////////////////////////////////
 | //////////////////////////////////////////////////////////////////////////
 | ||||||
| void rnnTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* h0, const NDArray* maxTimeStep, NDArray* h, NDArray* hFinal) { | void rnnTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* h0, const NDArray* maxTimeStep, NDArray* h, NDArray* hFinal) { | ||||||
| 
 | 
 | ||||||
|     // x   input [time x bS x inSize]
 |     // x   input [time x bS x iS]
 | ||||||
| 	// Wx  input-to-hidden  weights, [inSize  x numUnits]
 | 	// Wx  input-to-hidden  weights, [iS  x nU]
 | ||||||
|     // Wh  hidden-to-hidden weights, [numUnits x numUnits]
 |     // Wh  hidden-to-hidden weights, [nU x nU]
 | ||||||
| 	// b   biases for, [2*numUnits]
 | 	// b   biases for, [2*nU]
 | ||||||
| 
 | 
 | ||||||
| 	// h0          initial cell output (at time step = 0) [bS x numUnits]
 | 	// h0          initial cell output (at time step = 0) [bS x nU]
 | ||||||
| 	// maxTimeStep vector [bS] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this means there are no calculations for time >= maxTimeStep
 | 	// maxTimeStep vector [bS] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this means there are no calculations for time >= maxTimeStep
 | ||||||
| 
 | 
 | ||||||
|     const int time = x->sizeAt(0); |     const int time = x->sizeAt(0); | ||||||
| @ -82,16 +75,16 @@ void rnnTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray* | |||||||
| 
 | 
 | ||||||
|             auto xt   = (*x)({t,t+1, e,e+1, 0,0}, true); |             auto xt   = (*x)({t,t+1, e,e+1, 0,0}, true); | ||||||
|             auto ht   = (*h)({t,t+1, e,e+1, 0,0}, true); |             auto ht   = (*h)({t,t+1, e,e+1, 0,0}, true); | ||||||
|             auto ht_1 = (*hFinal)({e,e+1, 0,0}, true);                       // previous state
 |             auto hPrev = (*hFinal)({e,e+1, 0,0}, true);                       // previous state
 | ||||||
| 
 | 
 | ||||||
|             if(t >= maxStep) { |             if(t >= maxStep) { | ||||||
|                 ht = 0.; |                 ht = 0.; | ||||||
|                 if(maxStep != 0) |                 if(maxStep != 0) | ||||||
|                     ht_1.assign((*h)({maxStep-1,maxStep, e,e+1, 0,0})); |                     hPrev.assign((*h)({maxStep-1,maxStep, e,e+1, 0,0})); | ||||||
|             } |             } | ||||||
|             else { |             else { | ||||||
|                 helpers::rnnCell(context, &xt, Wx, Wh, b, &ht_1, &ht); |                 helpers::rnnCell(context, &xt, Wx, Wh, b, &hPrev, &ht); | ||||||
|                 ht_1.assign(ht); |                 hPrev.assign(ht); | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  | |||||||
							
								
								
									
										36
									
								
								libnd4j/include/ops/declarable/helpers/segment_common.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								libnd4j/include/ops/declarable/helpers/segment_common.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,36 @@ | |||||||
|  | /*******************************************************************************
 | ||||||
|  |  * Copyright (c) 2015-2018 Skymind, Inc. | ||||||
|  |  * | ||||||
|  |  * This program and the accompanying materials are made available under the | ||||||
|  |  * terms of the Apache License, Version 2.0 which is available at | ||||||
|  |  * https://www.apache.org/licenses/LICENSE-2.0.
 | ||||||
|  |  * | ||||||
|  |  * Unless required by applicable law or agreed to in writing, software | ||||||
|  |  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||||
|  |  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||||
|  |  * License for the specific language governing permissions and limitations | ||||||
|  |  * under the License. | ||||||
|  |  * | ||||||
|  |  * SPDX-License-Identifier: Apache-2.0 | ||||||
|  |  ******************************************************************************/ | ||||||
|  | 
 | ||||||
|  | //
 | ||||||
|  | //  @author sgazeos@gmail.com
 | ||||||
|  | //  @brief helpers common fuctions for segment_* ops (segment_max, segment_min, etc.)
 | ||||||
|  | //  @brief helpers common fuctions for unsorted_segment_* ops (unsorted_segment_max, etc.)
 | ||||||
|  | //
 | ||||||
|  | #ifndef __SEGMENT_COMMON_HELPERS__ | ||||||
|  | #define __SEGMENT_COMMON_HELPERS__ | ||||||
|  | #include <op_boilerplate.h> | ||||||
|  | #include <NDArray.h> | ||||||
|  | 
 | ||||||
|  | namespace nd4j { | ||||||
|  | namespace ops { | ||||||
|  | namespace helpers { | ||||||
|  |     void fillUpSegments(NDArray* indices, Nd4jLong numClasses, NDArray& classesRangesBegs, NDArray& classesRangesLens); | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | } | ||||||
|  | } | ||||||
|  | } | ||||||
|  | #endif | ||||||
| @ -23,6 +23,7 @@ | |||||||
| 
 | 
 | ||||||
| #include <ops/declarable/helpers/helpers.h> | #include <ops/declarable/helpers/helpers.h> | ||||||
| #include <helpers/helper_random.h> | #include <helpers/helper_random.h> | ||||||
|  | #include <graph/RandomGenerator.h> | ||||||
| 
 | 
 | ||||||
| namespace nd4j    { | namespace nd4j    { | ||||||
| namespace ops     { | namespace ops     { | ||||||
| @ -32,7 +33,7 @@ namespace helpers { | |||||||
| 
 | 
 | ||||||
| 	void trace(nd4j::LaunchContext * context, const NDArray& input, NDArray& output); | 	void trace(nd4j::LaunchContext * context, const NDArray& input, NDArray& output); | ||||||
| 
 | 
 | ||||||
| 	void randomShuffle(nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::random::RandomBuffer& rng, const bool isInplace); | 	void randomShuffle(nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::graph::RandomGenerator& rng, const bool isInplace); | ||||||
| 
 | 
 | ||||||
|     // auxiliary function which serves for recursion purpose and is used in pad operation
 |     // auxiliary function which serves for recursion purpose and is used in pad operation
 | ||||||
| 	// void recursiveLoopForPad(const int mode, NDArray& input, const NDArray& paddings, NDArray& output, std::vector<int> dimensions, int dim, int inIdx, int outIdx, NDArray& padValue);
 | 	// void recursiveLoopForPad(const int mode, NDArray& input, const NDArray& paddings, NDArray& output, std::vector<int> dimensions, int dim, int inIdx, int outIdx, NDArray& padValue);
 | ||||||
|  | |||||||
| @ -1126,15 +1126,7 @@ inline __device__ bool nd4j_atomicAdd<bool>(bool* address, bool val)  { | |||||||
| 
 | 
 | ||||||
| template <> | template <> | ||||||
| inline __device__ double nd4j_atomicSub<double>(double* address, double val)  { | inline __device__ double nd4j_atomicSub<double>(double* address, double val)  { | ||||||
| 	unsigned long long int* address_as_ull = | 	return nd4j_atomicAdd<double>(address, -val); | ||||||
| 			(unsigned long long int *) address; |  | ||||||
| 	unsigned long long int old = *address_as_ull, assumed; |  | ||||||
| 	do { |  | ||||||
| 		assumed = old; |  | ||||||
| 		old = atomicCAS(address_as_ull, assumed,__double_as_longlong(val - |  | ||||||
| 				__longlong_as_double(assumed))); |  | ||||||
| 	} while (assumed != old); |  | ||||||
| 	return __longlong_as_double(old); |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| template <> | template <> | ||||||
| @ -1152,15 +1144,7 @@ inline __device__ double nd4j_atomicMul<double>(double* address, double val)  { | |||||||
| 
 | 
 | ||||||
| template <> | template <> | ||||||
| inline __device__ double nd4j_atomicDiv<double>(double* address, double val)  { | inline __device__ double nd4j_atomicDiv<double>(double* address, double val)  { | ||||||
| 	unsigned long long int* address_as_ull = | 	return nd4j_atomicMul<double>(address, 1./val); | ||||||
| 			(unsigned long long int*) address; |  | ||||||
| 	unsigned long long int old = *address_as_ull, assumed; |  | ||||||
| 	do { |  | ||||||
| 		assumed = old; |  | ||||||
| 		old = atomicCAS(address_as_ull, assumed,__double_as_longlong(val / |  | ||||||
| 				__longlong_as_double(assumed))); |  | ||||||
| 	} while (assumed != old); |  | ||||||
| 	return __longlong_as_double(old); |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| template <> | template <> | ||||||
| @ -1179,14 +1163,16 @@ inline __device__ int32_t nd4j_atomicAdd<int32_t>(int32_t* address, int32_t val) | |||||||
| 
 | 
 | ||||||
| template <> | template <> | ||||||
| inline __device__ float nd4j_atomicSub<float>(float* address, float val) { | inline __device__ float nd4j_atomicSub<float>(float* address, float val) { | ||||||
| 	int* address_as_ull = (int*) address; | 	return nd4j_atomicAdd<float>(address, -val); | ||||||
| 	int old = *address_as_ull, assumed; | } | ||||||
| 	do { | 
 | ||||||
| 		assumed = old; | template <> | ||||||
| 		old = atomicCAS(address_as_ull, assumed, __float_as_int(val - | inline __device__ float16 nd4j_atomicSub<float16>(float16* address, float16 val) { | ||||||
| 				__float_as_int(assumed))); | 	return nd4j_atomicAdd<float16>(address, -val); | ||||||
| 	} while (assumed != old); | } | ||||||
| 	return __int_as_float(old); | template <> | ||||||
|  | inline __device__ bfloat16 nd4j_atomicSub<bfloat16>(bfloat16* address, bfloat16 val) { | ||||||
|  | 	return nd4j_atomicAdd<bfloat16>(address, -val); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| template <> | template <> | ||||||
| @ -1415,6 +1401,30 @@ inline __device__ float16 nd4j_atomicMul<float16>(float16* address, float16 val) | |||||||
| 
 | 
 | ||||||
| template <> | template <> | ||||||
| inline __device__ float nd4j_atomicDiv<float>(float* address, float val) { | inline __device__ float nd4j_atomicDiv<float>(float* address, float val) { | ||||||
|  | 	int* address_as_ull = | ||||||
|  | 			(int*)address; | ||||||
|  | 	int old = *address_as_ull, assumed; | ||||||
|  | 	do { | ||||||
|  | 		assumed = old; | ||||||
|  | 		old = atomicCAS(address_as_ull, assumed, __float_as_int(__int_as_float(assumed) / val )); | ||||||
|  | 	} while (assumed != old); | ||||||
|  | 	return __int_as_float(old); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | template <> | ||||||
|  | inline __device__ float16 nd4j_atomicDiv<float16>(float16* address, float16 val) { | ||||||
|  | 	int* address_as_ull = | ||||||
|  | 			(int*)address; | ||||||
|  | 	int old = *address_as_ull, assumed; | ||||||
|  | 	do { | ||||||
|  | 		assumed = old; | ||||||
|  | 		old = atomicCAS(address_as_ull, assumed, __float_as_int(val * | ||||||
|  | 				__float_as_int(assumed))); | ||||||
|  | 	} while (assumed != old); | ||||||
|  | 	return __int_as_float(old); | ||||||
|  | } | ||||||
|  | template <> | ||||||
|  | inline __device__ bfloat16 nd4j_atomicDiv<bfloat16>(bfloat16* address, bfloat16 val) { | ||||||
| 	int* address_as_ull = | 	int* address_as_ull = | ||||||
| 			(int*)address; | 			(int*)address; | ||||||
| 	int old = *address_as_ull, assumed; | 	int old = *address_as_ull, assumed; | ||||||
|  | |||||||
| @ -76,6 +76,9 @@ | |||||||
|         (nd4j::DataType::FLOAT32, float), \ |         (nd4j::DataType::FLOAT32, float), \ | ||||||
|         (nd4j::DataType::DOUBLE, double) |         (nd4j::DataType::DOUBLE, double) | ||||||
| 
 | 
 | ||||||
|  | #define FLOAT_NATIVE \ | ||||||
|  |         (nd4j::DataType::FLOAT32, float), \ | ||||||
|  |         (nd4j::DataType::DOUBLE, double) | ||||||
| 
 | 
 | ||||||
| #define FLOAT_TYPES_0 \ | #define FLOAT_TYPES_0 \ | ||||||
|         (nd4j::DataType::HALF, float16) |         (nd4j::DataType::HALF, float16) | ||||||
|  | |||||||
| @ -1870,7 +1870,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_1) { | |||||||
|     ASSERT_EQ(ND4J_STATUS_OK, results->status()); |     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||||
| 
 | 
 | ||||||
|     NDArray* result = results->at(0); |     NDArray* result = results->at(0); | ||||||
|     result->printIndexedBuffer("OOOOUUUUTTT"); | //    result->printIndexedBuffer("OOOOUUUUTTT");
 | ||||||
| 
 | 
 | ||||||
|     ASSERT_TRUE(expected.isSameShapeStrict(result)); |     ASSERT_TRUE(expected.isSameShapeStrict(result)); | ||||||
|     ASSERT_TRUE(expected.equalsTo(result)); |     ASSERT_TRUE(expected.equalsTo(result)); | ||||||
| @ -1881,9 +1881,9 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_1) { | |||||||
| ////////////////////////////////////////////////////////////////////
 | ////////////////////////////////////////////////////////////////////
 | ||||||
| TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_2) { | TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_2) { | ||||||
| 
 | 
 | ||||||
|     NDArray boxes    = NDArrayFactory::create<float>('c', {6,4}, {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1.f, 0.9f, |     NDArray boxes    = NDArrayFactory::create<double>('c', {6,4}, {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1.f, 0.9f, | ||||||
|                                          0, 10, 1, 11, 0, 10.1f, 1.f, 11.1f, 0, 100, 1, 101}); |                                          0, 10, 1, 11, 0, 10.1f, 1.f, 11.1f, 0, 100, 1, 101}); | ||||||
|     NDArray scales = NDArrayFactory::create<float>('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); //3, 0, 1, 2, 4, 5
 |     NDArray scales = NDArrayFactory::create<double>('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); //3, 0, 1, 2, 4, 5
 | ||||||
|     NDArray expected = NDArrayFactory::create<int>('c', {3}, {3,0,5}); |     NDArray expected = NDArrayFactory::create<int>('c', {3}, {3,0,5}); | ||||||
| 
 | 
 | ||||||
|     nd4j::ops::non_max_suppression op; |     nd4j::ops::non_max_suppression op; | ||||||
| @ -1892,7 +1892,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_2) { | |||||||
|     ASSERT_EQ(ND4J_STATUS_OK, results->status()); |     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||||
| 
 | 
 | ||||||
|     NDArray* result = results->at(0); |     NDArray* result = results->at(0); | ||||||
|     result->printBuffer("NonMaxSuppression OUtput2"); | //    result->printBuffer("NonMaxSuppression OUtput2");
 | ||||||
|     ASSERT_TRUE(expected.isSameShapeStrict(result)); |     ASSERT_TRUE(expected.isSameShapeStrict(result)); | ||||||
|     ASSERT_TRUE(expected.equalsTo(result)); |     ASSERT_TRUE(expected.equalsTo(result)); | ||||||
| 
 | 
 | ||||||
| @ -1970,6 +1970,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_3) { | |||||||
| 
 | 
 | ||||||
|     delete results; |     delete results; | ||||||
| } | } | ||||||
|  | 
 | ||||||
| ////////////////////////////////////////////////////////////////////
 | ////////////////////////////////////////////////////////////////////
 | ||||||
| TEST_F(DeclarableOpsTests10, Image_CropAndResize_4) { | TEST_F(DeclarableOpsTests10, Image_CropAndResize_4) { | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -421,3 +421,200 @@ ASSERT_TRUE(result->at(0)->e<bool>(0)); | |||||||
|     //ASSERT_TRUE(exp.equalsTo(result->at(0)));
 |     //ASSERT_TRUE(exp.equalsTo(result->at(0)));
 | ||||||
|     delete result; |     delete result; | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | ////////////////////////////////////////////////////////////////////
 | ||||||
|  | TEST_F(DeclarableOpsTests13, adjustHue_1) { | ||||||
|  | 
 | ||||||
|  |     NDArray input('c', {2,2,3}, {0,100,56, 17,220,5,  150,97,230, 255,2,13}, nd4j::DataType::FLOAT32); | ||||||
|  |     NDArray exp  ('c', {2,2,3}, {100,0,44, 208,5,220, 177,230,97,  2,255,244}, nd4j::DataType::FLOAT32); | ||||||
|  | 
 | ||||||
|  |     nd4j::ops::adjust_hue op; | ||||||
|  |     auto results = op.execute({&input}, {0.5}, {2}); | ||||||
|  | 
 | ||||||
|  |     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||||
|  | 
 | ||||||
|  |     auto result = results->at(0); | ||||||
|  |     // result->printIndexedBuffer();
 | ||||||
|  | 
 | ||||||
|  |     ASSERT_TRUE(exp.isSameShape(result)); | ||||||
|  |     ASSERT_TRUE(exp.equalsTo(result)); | ||||||
|  | 
 | ||||||
|  |     delete results; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | ////////////////////////////////////////////////////////////////////
 | ||||||
|  | TEST_F(DeclarableOpsTests13, adjustHue_2) { | ||||||
|  | 
 | ||||||
|  |     NDArray input('c', {2,2,3}, {0,100,56, 17,220,5,  150,97,230,   255,2,13}, nd4j::DataType::FLOAT32); | ||||||
|  |     NDArray exp  ('c', {2,2,3}, {4,100,0,  146,220,5, 97,123.8,230, 255,2,164.8}, nd4j::DataType::FLOAT32); | ||||||
|  | 
 | ||||||
|  |     nd4j::ops::adjust_hue op; | ||||||
|  |     auto results = op.execute({&input}, {0.9}, {2}); | ||||||
|  | 
 | ||||||
|  |     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||||
|  | 
 | ||||||
|  |     auto result = results->at(0); | ||||||
|  | 
 | ||||||
|  |     ASSERT_TRUE(exp.isSameShape(result)); | ||||||
|  |     ASSERT_TRUE(exp.equalsTo(result)); | ||||||
|  | 
 | ||||||
|  |     delete results; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | ////////////////////////////////////////////////////////////////////
 | ||||||
|  | TEST_F(DeclarableOpsTests13, adjustHue_3) { | ||||||
|  | 
 | ||||||
|  |     NDArray input('c', {2,2,3}, {0,100,56,    17,220,5,          150,97,230,     255,2,13}, nd4j::DataType::FLOAT32); | ||||||
|  |     NDArray exp  ('c', {2,2,3}, {0.,84.,100., 5.,220.,122.0001,  229.8,97.,230., 255.,142.8002,2.}, nd4j::DataType::FLOAT32); | ||||||
|  | 
 | ||||||
|  |     nd4j::ops::adjust_hue op; | ||||||
|  |     auto results = op.execute({&input}, {-0.9}, {2}); | ||||||
|  | 
 | ||||||
|  |     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||||
|  | 
 | ||||||
|  |     auto result = results->at(0); | ||||||
|  | 
 | ||||||
|  |     ASSERT_TRUE(exp.isSameShape(result)); | ||||||
|  |     ASSERT_TRUE(exp.equalsTo(result)); | ||||||
|  | 
 | ||||||
|  |     delete results; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | ////////////////////////////////////////////////////////////////////
 | ||||||
|  | TEST_F(DeclarableOpsTests13, adjustHue_4) { | ||||||
|  | 
 | ||||||
|  |     NDArray input('c', {2,3,2}, {0,17,   100,220, 56,5,   150,255, 97,2,   230,13}, nd4j::DataType::FLOAT32); | ||||||
|  |     NDArray exp  ('c', {2,3,2}, {100,208, 0,5,   44,220,  177,2,   230,255, 97,244}, nd4j::DataType::FLOAT32); | ||||||
|  | 
 | ||||||
|  |     nd4j::ops::adjust_hue op; | ||||||
|  |     auto results = op.execute({&input}, {0.5}, {1}); | ||||||
|  | 
 | ||||||
|  |     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||||
|  | 
 | ||||||
|  |     auto result = results->at(0); | ||||||
|  | 
 | ||||||
|  |     ASSERT_TRUE(exp.isSameShape(result)); | ||||||
|  |     ASSERT_TRUE(exp.equalsTo(result)); | ||||||
|  | 
 | ||||||
|  |     delete results; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | ////////////////////////////////////////////////////////////////////
 | ||||||
|  | TEST_F(DeclarableOpsTests13, adjustHue_5) { | ||||||
|  | 
 | ||||||
|  |     NDArray input('c', {3,2,2}, {0,17, 150,255,   100,220, 97,2,  56,5, 230,13}, nd4j::DataType::FLOAT32); | ||||||
|  |     NDArray exp  ('c', {3,2,2}, {100,208, 177,2,  0,5, 230,255,   44,220, 97,244}, nd4j::DataType::FLOAT32); | ||||||
|  | 
 | ||||||
|  |     nd4j::ops::adjust_hue op; | ||||||
|  |     auto results = op.execute({&input}, {0.5}, {0}); | ||||||
|  | 
 | ||||||
|  |     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||||
|  | 
 | ||||||
|  |     auto result = results->at(0); | ||||||
|  | 
 | ||||||
|  |     ASSERT_TRUE(exp.isSameShape(result)); | ||||||
|  |     ASSERT_TRUE(exp.equalsTo(result)); | ||||||
|  | 
 | ||||||
|  |     delete results; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | ////////////////////////////////////////////////////////////////////
 | ||||||
|  | TEST_F(DeclarableOpsTests13, adjustSaturation_1) { | ||||||
|  | 
 | ||||||
|  |     NDArray input('c', {2,2,3}, {0,100,56,  17,220,5,         150,97,230,    255,2,13}, nd4j::DataType::FLOAT32); | ||||||
|  |     NDArray exp  ('c', {2,2,3}, {50,100,78, 118.5,220,112.5,  190,163.5,230, 255,128.5,134}, nd4j::DataType::FLOAT32); | ||||||
|  | 
 | ||||||
|  |     nd4j::ops::adjust_saturation op; | ||||||
|  |     auto results = op.execute({&input}, {0.5}, {2}); | ||||||
|  | 
 | ||||||
|  |     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||||
|  | 
 | ||||||
|  |     auto result = results->at(0); | ||||||
|  |     // result->printIndexedBuffer();
 | ||||||
|  | 
 | ||||||
|  |     ASSERT_TRUE(exp.isSameShape(result)); | ||||||
|  |     ASSERT_TRUE(exp.equalsTo(result)); | ||||||
|  | 
 | ||||||
|  |     delete results; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | ////////////////////////////////////////////////////////////////////
 | ||||||
|  | TEST_F(DeclarableOpsTests13, adjustSaturation_2) { | ||||||
|  | 
 | ||||||
|  |     NDArray input('c', {2,2,3}, {0,100,56,    17,220,5,          150,97,230,        255,2,13}, nd4j::DataType::FLOAT32); | ||||||
|  |     NDArray exp  ('c', {2,2,3}, {0.,100.,56., 12.279087,220.,0., 91.654228,0.,230., 255.,0.,11.087015}, nd4j::DataType::FLOAT32); | ||||||
|  | 
 | ||||||
|  |     nd4j::ops::adjust_saturation op; | ||||||
|  |     auto results = op.execute({&input}, {10}, {2}); | ||||||
|  | 
 | ||||||
|  |     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||||
|  | 
 | ||||||
|  |     auto result = results->at(0); | ||||||
|  |     // result->printIndexedBuffer();
 | ||||||
|  | 
 | ||||||
|  |     ASSERT_TRUE(exp.isSameShape(result)); | ||||||
|  |     ASSERT_TRUE(exp.equalsTo(result)); | ||||||
|  | 
 | ||||||
|  |     delete results; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | ////////////////////////////////////////////////////////////////////
 | ||||||
|  | TEST_F(DeclarableOpsTests13, adjustSaturation_3) { | ||||||
|  | 
 | ||||||
|  |     NDArray input('c', {2,2,3}, {0,100,56,       17,220,5,       150,97,230,     255,2,13}, nd4j::DataType::FLOAT32); | ||||||
|  |     NDArray exp  ('c', {2,2,3}, {100.,100.,100., 220.,220.,220., 230.,230.,230., 255., 255., 255.}, nd4j::DataType::FLOAT32); | ||||||
|  | 
 | ||||||
|  |     nd4j::ops::adjust_saturation op; | ||||||
|  |     auto results = op.execute({&input}, {-10}, {2}); | ||||||
|  | 
 | ||||||
|  |     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||||
|  | 
 | ||||||
|  |     auto result = results->at(0); | ||||||
|  | 
 | ||||||
|  |     ASSERT_TRUE(exp.isSameShape(result)); | ||||||
|  |     ASSERT_TRUE(exp.equalsTo(result)); | ||||||
|  | 
 | ||||||
|  |     delete results; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | ////////////////////////////////////////////////////////////////////
 | ||||||
|  | TEST_F(DeclarableOpsTests13, adjustSaturation_4) { | ||||||
|  | 
 | ||||||
|  |     NDArray input('c', {2,3,2}, {0,17,   100,220,  56,5,   150,255,  97,2,   230,13}, nd4j::DataType::FLOAT32); | ||||||
|  |     NDArray exp  ('c', {2,3,2}, {50,118.5, 100,220, 78,112.5,  190,255, 163.5,128.5, 230,134}, nd4j::DataType::FLOAT32); | ||||||
|  | 
 | ||||||
|  |     nd4j::ops::adjust_saturation op; | ||||||
|  |     auto results = op.execute({&input}, {0.5}, {1}); | ||||||
|  | 
 | ||||||
|  |     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||||
|  | 
 | ||||||
|  |     auto result = results->at(0); | ||||||
|  |     // result->printIndexedBuffer();
 | ||||||
|  | 
 | ||||||
|  |     ASSERT_TRUE(exp.isSameShape(result)); | ||||||
|  |     ASSERT_TRUE(exp.equalsTo(result)); | ||||||
|  | 
 | ||||||
|  |     delete results; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | ////////////////////////////////////////////////////////////////////
 | ||||||
|  | TEST_F(DeclarableOpsTests13, adjustSaturation_5) { | ||||||
|  | 
 | ||||||
|  |     NDArray input('c', {3,2,2}, {0,17,     150,255,  100,220,  97,2,        56,5,     230,13}, nd4j::DataType::FLOAT32); | ||||||
|  |     NDArray exp  ('c', {3,2,2}, {50,118.5, 190,255,  100,220,  163.5,128.5, 78,112.5, 230,134}, nd4j::DataType::FLOAT32); | ||||||
|  | 
 | ||||||
|  |     nd4j::ops::adjust_saturation op; | ||||||
|  |     auto results = op.execute({&input}, {0.5}, {0}); | ||||||
|  | 
 | ||||||
|  |     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||||
|  | 
 | ||||||
|  |     auto result = results->at(0); | ||||||
|  | 
 | ||||||
|  |     ASSERT_TRUE(exp.isSameShape(result)); | ||||||
|  |     ASSERT_TRUE(exp.equalsTo(result)); | ||||||
|  | 
 | ||||||
|  |     delete results; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | |||||||
| @ -1479,6 +1479,27 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test3) { | |||||||
| 
 | 
 | ||||||
|     delete results; |     delete results; | ||||||
| } | } | ||||||
|  | //////////////////////////////////////////////////////////////////////
 | ||||||
|  | TEST_F(DeclarableOpsTests5, random_shuffle_test04) { | ||||||
|  |     auto input = NDArrayFactory::create<double>('c', {4}); | ||||||
|  |     input.linspace(1); | ||||||
|  | 
 | ||||||
|  |     nd4j::ops::random_shuffle op; | ||||||
|  |     //NDArray* output;
 | ||||||
|  |     auto results = op.execute({&input}, {},  {},  {}, true, nd4j::DataType::DOUBLE); | ||||||
|  |     ASSERT_EQ(Status::OK(), results->status()); | ||||||
|  |     auto output = &input; //results->at(0);
 | ||||||
|  |     bool haveZeros = false; | ||||||
|  |     for(int i = 0; i < output->lengthOf(); ++i) | ||||||
|  |         if(output->e<float>(i) == (float)0.) | ||||||
|  |             haveZeros = true; | ||||||
|  | 
 | ||||||
|  |     ASSERT_TRUE(input.isSameShape(output)); | ||||||
|  |     //ASSERT_TRUE(!input.equalsTo(output));
 | ||||||
|  |     ASSERT_TRUE(!haveZeros); | ||||||
|  | 
 | ||||||
|  |     delete results; | ||||||
|  | } | ||||||
| 
 | 
 | ||||||
| //////////////////////////////////////////////////////////////////////
 | //////////////////////////////////////////////////////////////////////
 | ||||||
| TEST_F(DeclarableOpsTests5, random_shuffle_test4) { | TEST_F(DeclarableOpsTests5, random_shuffle_test4) { | ||||||
| @ -1486,17 +1507,17 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test4) { | |||||||
|     input.linspace(1); |     input.linspace(1); | ||||||
| 
 | 
 | ||||||
|     nd4j::ops::random_shuffle op; |     nd4j::ops::random_shuffle op; | ||||||
|  |     //NDArray* output;
 | ||||||
|     auto results = op.execute({&input}, {},  {},  {}, false, nd4j::DataType::DOUBLE); |     auto results = op.execute({&input}, {},  {},  {}, false, nd4j::DataType::DOUBLE); | ||||||
|  |     ASSERT_EQ(Status::OK(), results->status()); | ||||||
|     auto output = results->at(0); |     auto output = results->at(0); | ||||||
| 
 |  | ||||||
|     bool haveZeros = false; |     bool haveZeros = false; | ||||||
|     for(int i = 0; i < output->lengthOf(); ++i) |     for(int i = 0; i < output->lengthOf(); ++i) | ||||||
|         if(output->e<float>(i) == (float)0.) |         if(output->e<float>(i) == (float)0.) | ||||||
|             haveZeros = true; |             haveZeros = true; | ||||||
| 
 | 
 | ||||||
|     ASSERT_EQ(Status::OK(), results->status()); |  | ||||||
|     ASSERT_TRUE(input.isSameShape(output)); |     ASSERT_TRUE(input.isSameShape(output)); | ||||||
|     ASSERT_TRUE(!input.equalsTo(output)); |     //ASSERT_TRUE(!input.equalsTo(output));
 | ||||||
|     ASSERT_TRUE(!haveZeros); |     ASSERT_TRUE(!haveZeros); | ||||||
| 
 | 
 | ||||||
|     delete results; |     delete results; | ||||||
|  | |||||||
| @ -1601,8 +1601,8 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_1) { | |||||||
|     ASSERT_EQ(ND4J_STATUS_OK, result->status()); |     ASSERT_EQ(ND4J_STATUS_OK, result->status()); | ||||||
| 
 | 
 | ||||||
|     auto z = result->at(0); |     auto z = result->at(0); | ||||||
| //    z->printIndexedBuffer("Output ");
 |     z->printIndexedBuffer("Output "); | ||||||
| //    exp.printIndexedBuffer("Expected ");
 |     exp.printIndexedBuffer("Expected "); | ||||||
| 
 | 
 | ||||||
|     ASSERT_TRUE(exp.isSameShape(z)); |     ASSERT_TRUE(exp.isSameShape(z)); | ||||||
|     ASSERT_TRUE(exp.equalsTo(z)); |     ASSERT_TRUE(exp.equalsTo(z)); | ||||||
| @ -1610,6 +1610,75 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_1) { | |||||||
|     delete result; |     delete result; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | ////////////////////////////////////////////////////////////////////////////////
 | ||||||
|  | TEST_F(DeclarableOpsTests6, MatrixInverse_01) { | ||||||
|  | 
 | ||||||
|  |     auto x = NDArrayFactory::create<float>('c', {1, 5, 5}, { | ||||||
|  |             2.,  4., 60.,  8., 10., | ||||||
|  |             0.,  1.,  2.,  3.,  4., | ||||||
|  |             0.,  0.,  2.,  4.,  6., | ||||||
|  |             0.,  0.,  0.,  1.,  2., | ||||||
|  |             0.,  0.,  0.,  0.,  4. | ||||||
|  | 
 | ||||||
|  |     }); | ||||||
|  | 
 | ||||||
|  |     auto exp = NDArrayFactory::create<float>('c', {1, 5, 5}, { | ||||||
|  |             0.5, -2.0, -13.0, 54.0, -6.75, | ||||||
|  |             0.0,  1.0,  -1.0,  1.0,   0.0, | ||||||
|  |             0,    0,   0.5, -2.0,  0.25, | ||||||
|  |             0,    0,     0,  1.0,  -0.5, | ||||||
|  |             0,    0,     0,    0,  0.25 | ||||||
|  | 
 | ||||||
|  |     }); | ||||||
|  | 
 | ||||||
|  |     nd4j::ops::matrix_inverse op; | ||||||
|  |     auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32); | ||||||
|  | 
 | ||||||
|  |     ASSERT_EQ(ND4J_STATUS_OK, result->status()); | ||||||
|  | 
 | ||||||
|  |     auto z = result->at(0); | ||||||
|  |     z->printIndexedBuffer("Output "); | ||||||
|  |     exp.printIndexedBuffer("Expected "); | ||||||
|  | 
 | ||||||
|  |     ASSERT_TRUE(exp.isSameShape(z)); | ||||||
|  |     ASSERT_TRUE(exp.equalsTo(z)); | ||||||
|  | 
 | ||||||
|  |     delete result; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | ////////////////////////////////////////////////////////////////////////////////
 | ||||||
|  | TEST_F(DeclarableOpsTests6, MatrixInverse_02) { | ||||||
|  | 
 | ||||||
|  |     auto x = NDArrayFactory::create<float>('c', {1, 5, 5}, { | ||||||
|  |             1.,  0.,  0.,  0.,  0., | ||||||
|  |             2.,  1.,  0.,  0.,  0., | ||||||
|  |             30.,  2.,  1.,  0.,  0., | ||||||
|  |             4.,  3.,  2.,  1.,  0., | ||||||
|  |             5.,  4.,  3.,  2.,  1. | ||||||
|  |     }); | ||||||
|  | 
 | ||||||
|  |     auto exp = NDArrayFactory::create<float>('c', {1, 5, 5}, { | ||||||
|  |             1.0,  0.0,  0.0,  0.0, 0., | ||||||
|  |             -2.0,  1.0,   0.,   0., 0., | ||||||
|  |             -26.0, -2.0,    1,    0, 0., | ||||||
|  |             54.0,  1.0, -2.0,    1, 0., | ||||||
|  |             -27.0,  0.0,  1.0, -2.0, 1. | ||||||
|  |     }); | ||||||
|  | 
 | ||||||
|  |     nd4j::ops::matrix_inverse op; | ||||||
|  |     auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32); | ||||||
|  | 
 | ||||||
|  |     ASSERT_EQ(ND4J_STATUS_OK, result->status()); | ||||||
|  | 
 | ||||||
|  |     auto z = result->at(0); | ||||||
|  |     z->printIndexedBuffer("Output "); | ||||||
|  |     exp.printIndexedBuffer("Expected "); | ||||||
|  | 
 | ||||||
|  |     ASSERT_TRUE(exp.isSameShape(z)); | ||||||
|  |     ASSERT_TRUE(exp.equalsTo(z)); | ||||||
|  | 
 | ||||||
|  |     delete result; | ||||||
|  | } | ||||||
| 
 | 
 | ||||||
| ////////////////////////////////////////////////////////////////////////////////
 | ////////////////////////////////////////////////////////////////////////////////
 | ||||||
| /*
 | /*
 | ||||||
| @ -1658,6 +1727,39 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_2) { | |||||||
|     delete result; |     delete result; | ||||||
| } | } | ||||||
| */ | */ | ||||||
|  | TEST_F(DeclarableOpsTests6, MatrixInverse_03) { | ||||||
|  | 
 | ||||||
|  |     auto x = NDArrayFactory::create<float>('c', {5, 5}, { | ||||||
|  |             4.,   0.,  0.,  0.,  0., | ||||||
|  |             4.,   2.,  0.,  0.,  0., | ||||||
|  |            30.,   2.,  1.,  0.,  0., | ||||||
|  |             8.,   6.,  4.,  2.,  0., | ||||||
|  |            15.,  12.,  9.,  6.,  3., | ||||||
|  |     }); | ||||||
|  | 
 | ||||||
|  |     auto exp = NDArrayFactory::create<float>('c', {5, 5}, { | ||||||
|  |             0.25,  0.0,    0.0,   0.0,   0.0, | ||||||
|  |             -0.50,  0.5,    0.0,   0.0,   0.0, | ||||||
|  |             -6.50, -1.0,    1.0,   0.0,   0.0, | ||||||
|  |             13.50,  0.5,   -2.0,   0.5,   0.0, | ||||||
|  |             -6.75,  0.0,    1.0,  -1.0,   0.33333333 | ||||||
|  |     }); | ||||||
|  | 
 | ||||||
|  |     nd4j::ops::matrix_inverse op; | ||||||
|  |     auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32); | ||||||
|  | 
 | ||||||
|  |     ASSERT_EQ(ND4J_STATUS_OK, result->status()); | ||||||
|  | 
 | ||||||
|  |     auto z = result->at(0); | ||||||
|  |     z->printIndexedBuffer("Output "); | ||||||
|  |     exp.printIndexedBuffer("Expected "); | ||||||
|  | 
 | ||||||
|  |     ASSERT_TRUE(exp.isSameShape(z)); | ||||||
|  |     ASSERT_TRUE(exp.equalsTo(z)); | ||||||
|  | 
 | ||||||
|  |     delete result; | ||||||
|  | } | ||||||
|  | 
 | ||||||
| ////////////////////////////////////////////////////////////////////////////////
 | ////////////////////////////////////////////////////////////////////////////////
 | ||||||
| TEST_F(DeclarableOpsTests6, MatrixInverse_3) { | TEST_F(DeclarableOpsTests6, MatrixInverse_3) { | ||||||
| 
 | 
 | ||||||
| @ -1695,7 +1797,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_3) { | |||||||
| ////////////////////////////////////////////////////////////////////////////////
 | ////////////////////////////////////////////////////////////////////////////////
 | ||||||
| TEST_F(DeclarableOpsTests6, MatrixInverse_4) { | TEST_F(DeclarableOpsTests6, MatrixInverse_4) { | ||||||
| 
 | 
 | ||||||
|     auto x = NDArrayFactory::create<double>('c', {5, 5}, { |     auto x = NDArrayFactory::create<float>('c', {5, 5}, { | ||||||
|                     1.,  2., 30.,  4.,  5., |                     1.,  2., 30.,  4.,  5., | ||||||
|                     0.,  1.,  2.,  3.,  4., |                     0.,  1.,  2.,  3.,  4., | ||||||
|                     0.,  0.,  1.,  2.,  3., |                     0.,  0.,  1.,  2.,  3., | ||||||
| @ -1703,7 +1805,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_4) { | |||||||
|                     0.,  0.,  0.,  0.,  1. |                     0.,  0.,  0.,  0.,  1. | ||||||
|     }); |     }); | ||||||
| 
 | 
 | ||||||
|     auto exp = NDArrayFactory::create<double>('c', {5, 5}, { |     auto exp = NDArrayFactory::create<float>('c', {5, 5}, { | ||||||
|      1.0,  -2.0,  -26.0,  54.0, -27.0, |      1.0,  -2.0,  -26.0,  54.0, -27.0, | ||||||
|      0.0,   1.0,  -2.0,    1.0,   0.0, |      0.0,   1.0,  -2.0,    1.0,   0.0, | ||||||
|      0.0,   0.0,   1.0,   -2.0,   1.0, |      0.0,   0.0,   1.0,   -2.0,   1.0, | ||||||
| @ -1712,13 +1814,13 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_4) { | |||||||
|     }); |     }); | ||||||
| 
 | 
 | ||||||
|     nd4j::ops::matrix_inverse op; |     nd4j::ops::matrix_inverse op; | ||||||
|     auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::DOUBLE); |     auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32); | ||||||
| 
 | 
 | ||||||
|     ASSERT_EQ(ND4J_STATUS_OK, result->status()); |     ASSERT_EQ(ND4J_STATUS_OK, result->status()); | ||||||
| 
 | 
 | ||||||
|     auto z = result->at(0); |     auto z = result->at(0); | ||||||
|     //z->printIndexedBuffer("Output ");
 |     z->printIndexedBuffer("Output "); | ||||||
|     //exp.printIndexedBuffer("Expected ");
 |     exp.printIndexedBuffer("Expected "); | ||||||
| 
 | 
 | ||||||
|     ASSERT_TRUE(exp.isSameShape(z)); |     ASSERT_TRUE(exp.isSameShape(z)); | ||||||
|     ASSERT_TRUE(exp.equalsTo(z)); |     ASSERT_TRUE(exp.equalsTo(z)); | ||||||
|  | |||||||
| @ -763,15 +763,15 @@ TEST_F(DeclarableOpsTests7, Test_Dynamic_Partition_119_2) { | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| TEST_F(DeclarableOpsTests7, Test_SequenceMask_1) { | TEST_F(DeclarableOpsTests7, Test_SequenceMask_1) { | ||||||
|     auto input = NDArrayFactory::create<double>('c', {4, 4},   {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f}); |     auto input = NDArrayFactory::create<int>('c', {4, 4},   {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); | ||||||
|     auto exp = NDArrayFactory::create<double>('c', {4, 4, 16}, {1.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,  0.f,  0.f,  0.f,  0.f,  0.f,  0.f, 0.f,1.f, 1.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,  0.f,  0.f,  0.f,  0.f,  0.f,  0.f, 0.f, |     auto exp = NDArrayFactory::create<bool>('c', {4, 4, 16}, {1, 0, 0, 0, 0, 0, 0, 0, 0,  0,  0,  0,  0,  0,  0, 0,1, 1, 0, 0, 0, 0, 0, 0, 0,  0,  0,  0,  0,  0,  0, 0, | ||||||
|                                         1.f, 1.f, 1.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,  0.f,  0.f,  0.f,  0.f,  0.f,  0.f, 0.f,1.f, 1.f, 1.f, 1.f, 0.f, 0.f, 0.f, 0.f, 0.f,  0.f,  0.f,  0.f,  0.f,  0.f,  0.f, 0.f, |                                         1, 1, 1, 0, 0, 0, 0, 0, 0,  0,  0,  0,  0,  0,  0, 0,1, 1, 1, 1, 0, 0, 0, 0, 0,  0,  0,  0,  0,  0,  0, 0, | ||||||
|                                         1.f, 1.f, 1.f, 1.f, 1.f, 0.f, 0.f, 0.f, 0.f,  0.f,  0.f,  0.f,  0.f,  0.f,  0.f, 0.f,1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.f, 0.f, 0.f,  0.f,  0.f,  0.f,  0.f,  0.f,  0.f, 0.f, |                                         1, 1, 1, 1, 1, 0, 0, 0, 0,  0,  0,  0,  0,  0,  0, 0,1, 1, 1, 1, 1, 1, 0, 0, 0,  0,  0,  0,  0,  0,  0, 0, | ||||||
|                                         1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.f, 0.f,  0.f,  0.f,  0.f,  0.f,  0.f,  0.f, 0.f,1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.f,  0.f,  0.f,  0.f,  0.f,  0.f,  0.f, 0.f, |                                         1, 1, 1, 1, 1, 1, 1, 0, 0,  0,  0,  0,  0,  0,  0, 0,1, 1, 1, 1, 1, 1, 1, 1, 0,  0,  0,  0,  0,  0,  0, 0, | ||||||
|                                         1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f,  0.f,  0.f,  0.f,  0.f,  0.f,  0.f, 0.f,1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f,  1.f,  0.f,  0.f,  0.f,  0.f,  0.f, 0.f, |                                         1, 1, 1, 1, 1, 1, 1, 1, 1,  0,  0,  0,  0,  0,  0, 0,1, 1, 1, 1, 1, 1, 1, 1, 1,  1,  0,  0,  0,  0,  0, 0, | ||||||
|                                         1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f,  1.f,  1.f,  0.f,  0.f,  0.f,  0.f, 0.f,1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f,  1.f,  1.f,  1.f,  0.f,  0.f,  0.f, 0.f, |                                         1, 1, 1, 1, 1, 1, 1, 1, 1,  1,  1,  0,  0,  0,  0, 0,1, 1, 1, 1, 1, 1, 1, 1, 1,  1,  1,  1,  0,  0,  0, 0, | ||||||
|                                         1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f,  1.f,  1.f,  1.f,  1.f,  0.f,  0.f, 0.f,1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f,  1.f,  1.f,  1.f,  1.f,  1.f,  0.f, 0.f, |                                         1, 1, 1, 1, 1, 1, 1, 1, 1,  1,  1,  1,  1,  0,  0, 0,1, 1, 1, 1, 1, 1, 1, 1, 1,  1,  1,  1,  1,  1,  0, 0, | ||||||
|                                         1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f,  1.f,  1.f,  1.f,  1.f,  1.f,  1.f, 0.f,1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f,  1.f,  1.f,  1.f,  1.f,  1.f,  1.f, 1.f }); |                                         1, 1, 1, 1, 1, 1, 1, 1, 1,  1,  1,  1,  1,  1,  1, 0,1, 1, 1, 1, 1, 1, 1, 1, 1,  1,  1,  1,  1,  1,  1, 1 }); | ||||||
| 
 | 
 | ||||||
|     nd4j::ops::sequence_mask op; |     nd4j::ops::sequence_mask op; | ||||||
|     auto result = op.execute({&input}, {}, {}); |     auto result = op.execute({&input}, {}, {}); | ||||||
| @ -788,19 +788,19 @@ TEST_F(DeclarableOpsTests7, Test_SequenceMask_1) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| TEST_F(DeclarableOpsTests7, Test_SequenceMask_2) { | TEST_F(DeclarableOpsTests7, Test_SequenceMask_2) { | ||||||
|     auto input = NDArrayFactory::create<double>('c', {2, 2, 2},   {10., 20., 30., 4., 0., 6., 7., 8.}); |     auto input = NDArrayFactory::create<int>('c', {2, 2, 2},   {10, 20, 30, 4, 0, 6, 7, 8}); | ||||||
|     auto exp = NDArrayFactory::create<double>('c', {2, 2, 2, 30}, {   1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., |     auto exp = NDArrayFactory::create<bool>('c', {2, 2, 2, 30}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | ||||||
|     1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,    1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., |              1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | ||||||
|     1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,    0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., |              1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | ||||||
|     1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,    1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., |              1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | ||||||
|     1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}); |              1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); | ||||||
| 
 | 
 | ||||||
|     nd4j::ops::sequence_mask op; |     nd4j::ops::sequence_mask op; | ||||||
|     auto result = op.execute({&input}, {}, {}); |     auto result = op.execute({&input}, {}, {}); | ||||||
|     ASSERT_EQ(Status::OK(), result->status()); |     ASSERT_EQ(Status::OK(), result->status()); | ||||||
| 
 | 
 | ||||||
|     auto z = result->at(0); |     auto z = result->at(0); | ||||||
| //    z->printIndexedBuffer("Output");
 | //    z->printBuffer("Output");
 | ||||||
| //    z->printShapeInfo("Shape");
 | //    z->printShapeInfo("Shape");
 | ||||||
|     ASSERT_TRUE(exp.isSameShape(z)); |     ASSERT_TRUE(exp.isSameShape(z)); | ||||||
|     ASSERT_TRUE(exp.equalsTo(z)); |     ASSERT_TRUE(exp.equalsTo(z)); | ||||||
|  | |||||||
| @ -2770,9 +2770,8 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test3) { | |||||||
|     ASSERT_TRUE(isGradCorrect); |     ASSERT_TRUE(isGradCorrect); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| ////////////////////////////////////////////////////////////////////
 |  | ||||||
| /*
 | /*
 | ||||||
| //2019/02/23 AB - GRU backprop tests disabled pending update of GRU backprop op after rewriting forward pass
 | ////////////////////////////////////////////////////////////////////
 | ||||||
| TEST_F(DeclarableOpsTests9, gru_cell_bp_test1) { | TEST_F(DeclarableOpsTests9, gru_cell_bp_test1) { | ||||||
| 
 | 
 | ||||||
|     const int bS = 2; |     const int bS = 2; | ||||||
| @ -2780,160 +2779,58 @@ TEST_F(DeclarableOpsTests9, gru_cell_bp_test1) { | |||||||
|     const int nU = 4; |     const int nU = 4; | ||||||
| 
 | 
 | ||||||
|     NDArray x('c', {bS, iS}, nd4j::DataType::DOUBLE); |     NDArray x('c', {bS, iS}, nd4j::DataType::DOUBLE); | ||||||
|     NDArray h0('c', {bS, nU}, nd4j::DataType::DOUBLE); |     NDArray hi('c', {bS, nU}, nd4j::DataType::DOUBLE); | ||||||
|     NDArray Wx('c', {iS, 3*nU}, nd4j::DataType::DOUBLE); |     NDArray W('c', {iS+nU, 2*nU}, nd4j::DataType::DOUBLE); | ||||||
|     NDArray Wh('c', {nU, 3*nU}, nd4j::DataType::DOUBLE); |     NDArray Wc('c', {iS+nU, nU}, nd4j::DataType::DOUBLE); | ||||||
|     NDArray b('c', {3*nU}, nd4j::DataType::DOUBLE); |     NDArray b('c', {2*nU}, nd4j::DataType::DOUBLE); | ||||||
|  |     NDArray bc('c', {nU}, nd4j::DataType::DOUBLE); | ||||||
|  |     NDArray dLdr('c', {bS, nU}, nd4j::DataType::DOUBLE); | ||||||
|  |     NDArray dLdu('c', {bS, nU}, nd4j::DataType::DOUBLE); | ||||||
|  |     NDArray dLdc('c', {bS, nU}, nd4j::DataType::DOUBLE); | ||||||
|     NDArray dLdh('c', {bS, nU}, nd4j::DataType::DOUBLE); |     NDArray dLdh('c', {bS, nU}, nd4j::DataType::DOUBLE); | ||||||
| 
 | 
 | ||||||
|     x.linspace(0.5, 0.5); |     x.linspace(-5, 0.5); | ||||||
|     h0 = 1.; |     hi   = 1.; | ||||||
|     Wx = 0.003; |     W    = 0.003; | ||||||
|     Wh = 0.006; |     Wc   = 0.006; | ||||||
|     b    = 0.5; |     b    = 0.5; | ||||||
|  |     bc   = 0.35; | ||||||
| 
 | 
 | ||||||
|     const OpArgsHolder argsHolderFF({&x, &h0, &Wx, &Wh, &b}, {}, {}); | 
 | ||||||
|     const OpArgsHolder argsHolderBP({&x, &h0, &Wx, &Wh, &b, &dLdh}, {}, {}); |     const OpArgsHolder argsHolderFF({&x, &hi, &W, &Wc, &b, &bc}, {}, {}); | ||||||
|  |     nd4j::ops::gruCell op; | ||||||
|  |     auto results = op.execute(argsHolderFF); | ||||||
|  | 
 | ||||||
|  |     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||||
|  | 
 | ||||||
|  |     auto u = results->at(1);    // [bS, nU]
 | ||||||
|  |     auto c = results->at(2);    // [bS, nU]
 | ||||||
|  |     auto h = results->at(3);    // [bS, nU]
 | ||||||
|  | 
 | ||||||
|  |     dLdh = 1.; // SUM loss
 | ||||||
|  | 
 | ||||||
|  |     NDArray Wch = Wc({iS,iS+nU, 0,0}); // [nU, nU]
 | ||||||
|  |     NDArray dhdc  = 1. - *u; | ||||||
|  |     NDArray dhdu  = hi - *c; | ||||||
|  |     NDArray dcdZc = 1. - *c * *c; | ||||||
|  |     dLdc.assign(dLdh * dhdc); | ||||||
|  |     dLdu.assign(dLdh * dhdu); | ||||||
|  |     dLdr.assign(mmul(dLdc * dcdZc * hi, Wch.transpose())); | ||||||
|  | 
 | ||||||
|  |     delete results; | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |     const OpArgsHolder argsHolderBP({&x, &hi, &W, &Wc, &b, &bc, &dLdr, &dLdu, &dLdc, &dLdh}, {}, {}); | ||||||
| 
 | 
 | ||||||
|     nd4j::ops::gruCell opFF; |     nd4j::ops::gruCell opFF; | ||||||
|     nd4j::ops::gruCell_bp opBP; |     nd4j::ops::gruCell_bp opBP; | ||||||
| 
 | 
 | ||||||
|     const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); |     const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 1, 1, 1 , 1, 1}, {0., 1.}, nd4j::GradCheck::LossFunc::SUM, true); | ||||||
| 
 |  | ||||||
|     ASSERT_TRUE(isGradCorrect); |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| ////////////////////////////////////////////////////////////////////
 |  | ||||||
| TEST_F(DeclarableOpsTests9, gru_cell_bp_test2) { |  | ||||||
| 
 |  | ||||||
|     const int bS = 2; |  | ||||||
|     const int iS = 3; |  | ||||||
|     const int nU = 4; |  | ||||||
| 
 |  | ||||||
|     NDArray x('c', {bS, iS}, nd4j::DataType::DOUBLE); |  | ||||||
|     NDArray h0('c', {bS, nU}, nd4j::DataType::DOUBLE); |  | ||||||
|     NDArray Wx('c', {iS, 3*nU}, nd4j::DataType::DOUBLE); |  | ||||||
|     NDArray Wh('c', {nU, 3*nU}, nd4j::DataType::DOUBLE); |  | ||||||
|     NDArray b('c', {3*nU}, nd4j::DataType::DOUBLE); |  | ||||||
|     NDArray dLdh('c', {bS, nU}, nd4j::DataType::DOUBLE); |  | ||||||
| 
 |  | ||||||
|     x.linspace(0.5, 0.5); |  | ||||||
|     h0 = 1.; |  | ||||||
|     Wx = 0.003; |  | ||||||
|     Wh = 0.006; |  | ||||||
|     b  = 0.; |  | ||||||
| 
 |  | ||||||
|     const OpArgsHolder argsHolderFF({&x, &h0, &Wx, &Wh, &b}, {}, {}); |  | ||||||
|     const OpArgsHolder argsHolderBP({&x, &h0, &Wx, &Wh, &b, &dLdh}, {}, {}); |  | ||||||
| 
 |  | ||||||
|     nd4j::ops::gruCell opFF; |  | ||||||
|     nd4j::ops::gruCell_bp opBP; |  | ||||||
| 
 |  | ||||||
|     const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); |  | ||||||
| 
 |  | ||||||
|     ASSERT_TRUE(isGradCorrect); |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| ////////////////////////////////////////////////////////////////////
 |  | ||||||
| TEST_F(DeclarableOpsTests9, gru_cell_bp_test3) { |  | ||||||
| 
 |  | ||||||
|     const int bS = 2; |  | ||||||
|     const int iS = 3; |  | ||||||
|     const int nU = 4; |  | ||||||
| 
 |  | ||||||
|     NDArray x('c', {bS, iS}, nd4j::DataType::DOUBLE); |  | ||||||
|     NDArray h0('c', {bS, nU}, nd4j::DataType::DOUBLE); |  | ||||||
|     NDArray Wx('c', {iS, 3*nU}, nd4j::DataType::DOUBLE); |  | ||||||
|     NDArray Wh('c', {nU, 3*nU}, nd4j::DataType::DOUBLE); |  | ||||||
|     NDArray b('c', {3*nU}, nd4j::DataType::DOUBLE); |  | ||||||
|     NDArray dLdh('c', {bS, nU}, nd4j::DataType::DOUBLE); |  | ||||||
|     // NDArray<double> dLdWx0('c', {iS, 3*nU});
 |  | ||||||
|     // NDArray<double> dLdWh0('c', {nU, 3*nU});
 |  | ||||||
|     // NDArray<double> dLdb0 ('c', {3*nU});
 |  | ||||||
| 
 |  | ||||||
|     x = 1.; |  | ||||||
|     h0 = 0.0; |  | ||||||
|     Wx = 0.0; |  | ||||||
|     Wh = 0.0; |  | ||||||
|     b  = 0.5; |  | ||||||
| 
 |  | ||||||
|     const OpArgsHolder argsHolderFF({&x, &h0, &Wx, &Wh, &b}, {}, {}); |  | ||||||
|     const OpArgsHolder argsHolderBP({&x, &h0, &Wx, &Wh, &b, &dLdh}, {}, {}); |  | ||||||
| 
 |  | ||||||
|     nd4j::ops::gruCell opFF; |  | ||||||
|     nd4j::ops::gruCell_bp opBP; |  | ||||||
| 
 |  | ||||||
|     const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); |  | ||||||
| 
 |  | ||||||
|     ASSERT_TRUE(isGradCorrect); |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| ////////////////////////////////////////////////////////////////////
 |  | ||||||
| // TEST_F(DeclarableOpsTests9, gru_bp_test1) {
 |  | ||||||
| 
 |  | ||||||
| //     const int time = 5;
 |  | ||||||
| //     const int bS   = 2;
 |  | ||||||
| //     const int iS   = 3;
 |  | ||||||
| //     const int nU   = 4;
 |  | ||||||
| 
 |  | ||||||
| //     NDArray<double> x     ('c', {time, bS, iS});
 |  | ||||||
| //     NDArray<double> h0    ('c', {bS, nU});
 |  | ||||||
| //     NDArray<double> Wx    ('c', {iS, 3*nU});
 |  | ||||||
| //     NDArray<double> Wh    ('c', {nU, 3*nU});
 |  | ||||||
| //     NDArray<double> b     ('c', {3*nU});
 |  | ||||||
| //     NDArray<double> dLdh  ('c', {time, bS, nU});
 |  | ||||||
| 
 |  | ||||||
| //     x.linspace(0.5, 0.5);
 |  | ||||||
| //     h0 = 1.;
 |  | ||||||
| //     Wx = 0.003;
 |  | ||||||
| //     Wh = 0.006;
 |  | ||||||
| //     b  = 0.5;
 |  | ||||||
| 
 |  | ||||||
| //     const OpArgsHolder<double> argsHolderFF({&x, &h0, &Wx, &Wh, &b}, {}, {});
 |  | ||||||
| //     const OpArgsHolder<double> argsHolderBP({&x, &h0, &Wx, &Wh, &b, &dLdh}, {}, {});
 |  | ||||||
| 
 |  | ||||||
| //     nd4j::ops::gru<double> opFF;
 |  | ||||||
| //     nd4j::ops::gru_bp<double> opBP;
 |  | ||||||
| 
 |  | ||||||
| //     const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
 |  | ||||||
| 
 |  | ||||||
| //     ASSERT_TRUE(isGradCorrect);
 |  | ||||||
| // }
 |  | ||||||
| 
 |  | ||||||
| ////////////////////////////////////////////////////////////////////
 |  | ||||||
| TEST_F(DeclarableOpsTests9, gru_cell_bp_test3_1) { |  | ||||||
| 
 |  | ||||||
|     const int bS = 2; |  | ||||||
|     const int iS = 3; |  | ||||||
|     const int nU = 4; |  | ||||||
| 
 |  | ||||||
|     auto x  = NDArrayFactory::create<double>('c', {bS, iS}); |  | ||||||
|     auto h0  = NDArrayFactory::create<double>('c', {bS, nU}); |  | ||||||
|     auto Wx  = NDArrayFactory::create<double>('c', {iS, 3*nU}); |  | ||||||
|     auto Wh  = NDArrayFactory::create<double>('c', {nU, 3*nU}); |  | ||||||
|     auto b  = NDArrayFactory::create<double>('c', {3*nU}); |  | ||||||
|     auto dLdh  = NDArrayFactory::create<double>('c', {bS, nU}); |  | ||||||
|     // NDArray<double> dLdWx0('c', {iS, 3*nU});
 |  | ||||||
|     // NDArray<double> dLdWh0('c', {nU, 3*nU});
 |  | ||||||
|     // NDArray<double> dLdb0 ('c', {3*nU});
 |  | ||||||
| 
 |  | ||||||
|     x = 1.; |  | ||||||
|     h0 = 0.0; |  | ||||||
|     Wx = 0.0; |  | ||||||
|     Wh = 0.0; |  | ||||||
|     b  = 0.5; |  | ||||||
| 
 |  | ||||||
|     const OpArgsHolder argsHolderFF({&x, &h0, &Wx, &Wh, &b}, {}, {}); |  | ||||||
|     const OpArgsHolder argsHolderBP({&x, &h0, &Wx, &Wh, &b, &dLdh}, {}, {}); |  | ||||||
| 
 |  | ||||||
|     nd4j::ops::gruCell opFF; |  | ||||||
|     nd4j::ops::gruCell_bp opBP; |  | ||||||
| 
 |  | ||||||
|     const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); |  | ||||||
| 
 | 
 | ||||||
|     ASSERT_TRUE(isGradCorrect); |     ASSERT_TRUE(isGradCorrect); | ||||||
| } | } | ||||||
| */ | */ | ||||||
|  | 
 | ||||||
| ////////////////////////////////////////////////////////////////////
 | ////////////////////////////////////////////////////////////////////
 | ||||||
| TEST_F(DeclarableOpsTests9, Cholesky_Test_1) { | TEST_F(DeclarableOpsTests9, Cholesky_Test_1) { | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -719,6 +719,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_1) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| TEST_F(ParityOpsTests, Test_Scatter_Add_2) { | TEST_F(ParityOpsTests, Test_Scatter_Add_2) { | ||||||
|  | 
 | ||||||
|     auto vec = NDArrayFactory::create<float>('c', {4}, {1, 2, 3, 4}); |     auto vec = NDArrayFactory::create<float>('c', {4}, {1, 2, 3, 4}); | ||||||
|     NDArray idc('c', {1, 4}, {0, 1, 2, 3}, nd4j::DataType::INT64); |     NDArray idc('c', {1, 4}, {0, 1, 2, 3}, nd4j::DataType::INT64); | ||||||
|     auto updates = NDArrayFactory::create<float>('c', {1, 4}, {1, 1, 1, 1}); |     auto updates = NDArrayFactory::create<float>('c', {1, 4}, {1, 1, 1, 1}); | ||||||
| @ -1588,36 +1589,79 @@ TEST_F(ParityOpsTests, scatterND_update_test5) { | |||||||
|     delete result; |     delete result; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | //////////////////////////////////////////////////////////////////////
 | ||||||
| TEST_F(ParityOpsTests, scatter_update_1) { | TEST_F(ParityOpsTests, scatter_update_1) { | ||||||
|     auto matrix  = NDArrayFactory::create_<float>('c', {3, 2}); |  | ||||||
|     auto updates = NDArrayFactory::create_<float>('c', {2, 2}); |  | ||||||
|     updates->assign(1.0); |  | ||||||
| 
 | 
 | ||||||
|     //updates.printBuffer("Updates");
 |     NDArray x('c', {2,2}, {1,2,3,4}, nd4j::DataType::INT32); | ||||||
|  |     NDArray updates('c', {2,2}, {10,20,30,40}, nd4j::DataType::INT32); | ||||||
| 
 | 
 | ||||||
|     auto variableSpace = new VariableSpace(); |     NDArray exp('c', {2,2}, {30,40,10,20}, nd4j::DataType::INT32); | ||||||
|     variableSpace->putVariable(-1, matrix); |  | ||||||
|     variableSpace->putVariable(-2, updates); |  | ||||||
|     variableSpace->putVariable(1, new Variable(&matrix)); |  | ||||||
| 
 |  | ||||||
|     auto block = new Context(1, variableSpace, false); |  | ||||||
|     block->fillInputs({-1, -2}); |  | ||||||
| 
 |  | ||||||
|     std::vector<int>* arguments = block->getIArguments(); |  | ||||||
|     arguments->push_back(0); |  | ||||||
|     arguments->push_back(1); |  | ||||||
|     arguments->push_back(1); |  | ||||||
|     arguments->push_back(2); |  | ||||||
|     arguments->push_back(1); |  | ||||||
|     arguments->push_back(2); |  | ||||||
| 
 | 
 | ||||||
|     nd4j::ops::scatter_update op; |     nd4j::ops::scatter_update op; | ||||||
|  |     auto results = op.execute({&x, &updates}, {}, {6,   1,1,  2,1,0}); | ||||||
| 
 | 
 | ||||||
|  |     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||||
|  |     // x.printBuffer();
 | ||||||
| 
 | 
 | ||||||
|     Nd4jStatus result = op.execute(block); |     ASSERT_TRUE(exp.isSameShape(x)); | ||||||
|     ASSERT_EQ(ND4J_STATUS_OK, result); |     ASSERT_TRUE(exp.equalsTo(x)); | ||||||
| 
 | 
 | ||||||
|     delete block; |     delete results; | ||||||
|     delete variableSpace; |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | //////////////////////////////////////////////////////////////////////
 | ||||||
|  | TEST_F(ParityOpsTests, scatter_update_2) { | ||||||
|  | 
 | ||||||
|  |     NDArray x('c', {2,2}, {1,2,3,4}, nd4j::DataType::INT32); | ||||||
|  |     NDArray updates('c', {2,2}, {10,20,30,40}, nd4j::DataType::INT32); | ||||||
|  | 
 | ||||||
|  |     NDArray exp('c', {2,2}, {20,10,40,30}, nd4j::DataType::INT32); | ||||||
|  | 
 | ||||||
|  |     nd4j::ops::scatter_update op; | ||||||
|  |     auto results = op.execute({&x, &updates}, {}, {6,   1,0,  2,1,0}); | ||||||
|  | 
 | ||||||
|  |     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||||
|  | 
 | ||||||
|  |     ASSERT_TRUE(exp.isSameShape(x)); | ||||||
|  |     ASSERT_TRUE(exp.equalsTo(x)); | ||||||
|  | 
 | ||||||
|  |     delete results; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | //////////////////////////////////////////////////////////////////////
 | ||||||
|  | TEST_F(ParityOpsTests, scatter_update_3) { | ||||||
|  | 
 | ||||||
|  |     NDArray x('c', {2,2,2}, {1,2,3,4,5,6,7,8}, nd4j::DataType::INT32); | ||||||
|  |     NDArray updates('c', {2,2,2}, {10,20,30,40,50,60,70,80}, nd4j::DataType::INT32); | ||||||
|  | 
 | ||||||
|  |     NDArray exp('c', {2,2,2}, {50,60,70,80,10,20,30,40}, nd4j::DataType::INT32); | ||||||
|  | 
 | ||||||
|  |     nd4j::ops::scatter_update op; | ||||||
|  |     auto results = op.execute({&x, &updates}, {}, {6,  2,1,2,  2,1,0}); | ||||||
|  | 
 | ||||||
|  |     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||||
|  | 
 | ||||||
|  |     ASSERT_TRUE(exp.isSameShape(x)); | ||||||
|  |     ASSERT_TRUE(exp.equalsTo(x)); | ||||||
|  | 
 | ||||||
|  |     delete results; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | //////////////////////////////////////////////////////////////////////
 | ||||||
|  | TEST_F(ParityOpsTests, scatter_update_4) { | ||||||
|  | 
 | ||||||
|  |     NDArray x('c', {2,2,2}, {1,2,3,4,5,6,7,8}, nd4j::DataType::INT32); | ||||||
|  |     NDArray updates('c', {2,2,2}, {10,20,30,40,50,60,70,80}, nd4j::DataType::INT32); | ||||||
|  | 
 | ||||||
|  |     NDArray exp('c', {2,2,2}, {20,2,3,10,60,6,7,50}, nd4j::DataType::INT32); | ||||||
|  | 
 | ||||||
|  |     nd4j::ops::scatter_update op; | ||||||
|  |     auto results = op.execute({&x, &updates}, {}, {6,  1,0,  2,3,0}); | ||||||
|  | 
 | ||||||
|  |     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||||
|  | 
 | ||||||
|  |     ASSERT_TRUE(exp.isSameShape(x)); | ||||||
|  |     ASSERT_TRUE(exp.equalsTo(x)); | ||||||
|  | 
 | ||||||
|  |     delete results; | ||||||
|  | } | ||||||
|  | |||||||
| @ -278,8 +278,8 @@ TEST_F(RNGTests, Test_Gaussian_22) { | |||||||
|     auto x0 = NDArrayFactory::create<float>('c', {10000, 1000}); |     auto x0 = NDArrayFactory::create<float>('c', {10000, 1000}); | ||||||
|     auto x1 = NDArrayFactory::create<float>('c', {10000, 1000}); |     auto x1 = NDArrayFactory::create<float>('c', {10000, 1000}); | ||||||
| 
 | 
 | ||||||
|     RandomLauncher::fillGaussian(_rngA, &x0, 0.0f, 1.0f); |     RandomLauncher::fillGaussian(nd4j::LaunchContext::defaultContext(), _rngA, &x0, 0.0f, 1.0f); | ||||||
|     RandomLauncher::fillGaussian(_rngB, &x1, 0.0f, 1.0f); |     RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngB, &x1, 0.0f, 1.0f); | ||||||
| 
 | 
 | ||||||
|     //x0.printIndexedBuffer("x0");
 |     //x0.printIndexedBuffer("x0");
 | ||||||
|     //x1.printIndexedBuffer("x1");
 |     //x1.printIndexedBuffer("x1");
 | ||||||
| @ -306,7 +306,7 @@ TEST_F(RNGTests, Test_Gaussian_22) { | |||||||
| TEST_F(RNGTests, Test_Gaussian_3) { | TEST_F(RNGTests, Test_Gaussian_3) { | ||||||
|     auto x0 = NDArrayFactory::create<double>('c', {10000000}); |     auto x0 = NDArrayFactory::create<double>('c', {10000000}); | ||||||
| 
 | 
 | ||||||
|     RandomLauncher::fillGaussian(_rngA, &x0, 0.0, 1.0); |     RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngA, &x0, 0.0, 1.0); | ||||||
| 
 | 
 | ||||||
|     auto mean = x0.meanNumber().e<double>(0); |     auto mean = x0.meanNumber().e<double>(0); | ||||||
|     auto stdev = x0.varianceNumber(nd4j::variance::SummaryStatsStandardDeviation, false).e<double>(0); |     auto stdev = x0.varianceNumber(nd4j::variance::SummaryStatsStandardDeviation, false).e<double>(0); | ||||||
| @ -319,8 +319,8 @@ TEST_F(RNGTests, Test_LogNormal_1) { | |||||||
|     auto x0 = NDArrayFactory::create<float>('c', {10, 10}); |     auto x0 = NDArrayFactory::create<float>('c', {10, 10}); | ||||||
|     auto x1 = NDArrayFactory::create<float>('c', {10, 10}); |     auto x1 = NDArrayFactory::create<float>('c', {10, 10}); | ||||||
| 
 | 
 | ||||||
|     RandomLauncher::fillLogNormal(_rngA, &x0, 1.0f, 2.0f); |     RandomLauncher::fillLogNormal(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f); | ||||||
|     RandomLauncher::fillLogNormal(_rngB, &x1, 1.0f, 2.0f); |     RandomLauncher::fillLogNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); | ||||||
| 
 | 
 | ||||||
|     ASSERT_TRUE(x0.equalsTo(&x1)); |     ASSERT_TRUE(x0.equalsTo(&x1)); | ||||||
| 
 | 
 | ||||||
| @ -333,8 +333,8 @@ TEST_F(RNGTests, Test_Truncated_1) { | |||||||
|     auto x0 = NDArrayFactory::create<float>('c', {10, 10}); |     auto x0 = NDArrayFactory::create<float>('c', {10, 10}); | ||||||
|     auto x1 = NDArrayFactory::create<float>('c', {10, 10}); |     auto x1 = NDArrayFactory::create<float>('c', {10, 10}); | ||||||
| 
 | 
 | ||||||
|     RandomLauncher::fillTruncatedNormal(_rngA, &x0, 1.0f, 2.0f); |     RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f); | ||||||
|     RandomLauncher::fillTruncatedNormal(_rngB, &x1, 1.0f, 2.0f); |     RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); | ||||||
| 
 | 
 | ||||||
|     ASSERT_TRUE(x0.equalsTo(&x1)); |     ASSERT_TRUE(x0.equalsTo(&x1)); | ||||||
| 
 | 
 | ||||||
| @ -357,8 +357,8 @@ TEST_F(RNGTests, Test_Truncated_2) { | |||||||
|     auto x0 = NDArrayFactory::create<float>('c', {1000, 1000}); |     auto x0 = NDArrayFactory::create<float>('c', {1000, 1000}); | ||||||
|     auto x1 = NDArrayFactory::create<float>('c', {1000, 1000}); |     auto x1 = NDArrayFactory::create<float>('c', {1000, 1000}); | ||||||
| 
 | 
 | ||||||
|     RandomLauncher::fillTruncatedNormal(_rngA, &x0, 1.0f, 2.0f); |     RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f); | ||||||
|     RandomLauncher::fillTruncatedNormal(_rngB, &x1, 1.0f, 2.0f); |     RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); | ||||||
| 
 | 
 | ||||||
|     ASSERT_TRUE(x0.equalsTo(&x1)); |     ASSERT_TRUE(x0.equalsTo(&x1)); | ||||||
| 
 | 
 | ||||||
| @ -383,8 +383,8 @@ TEST_F(RNGTests, Test_Truncated_21) { | |||||||
|     auto x0 = NDArrayFactory::create<float>('c', {1000, 1000}); |     auto x0 = NDArrayFactory::create<float>('c', {1000, 1000}); | ||||||
|     auto x1 = NDArrayFactory::create<float>('c', {1000, 1000}); |     auto x1 = NDArrayFactory::create<float>('c', {1000, 1000}); | ||||||
| 
 | 
 | ||||||
|     RandomLauncher::fillTruncatedNormal(_rngA, &x0, 1.0f, 2.0f); |     RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f); | ||||||
|     RandomLauncher::fillTruncatedNormal(_rngB, &x1, 1.0f, 2.0f); |     RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); | ||||||
| 
 | 
 | ||||||
|     ASSERT_TRUE(x0.equalsTo(&x1)); |     ASSERT_TRUE(x0.equalsTo(&x1)); | ||||||
| 
 | 
 | ||||||
| @ -430,8 +430,8 @@ TEST_F(RNGTests, Test_Truncated_22) { | |||||||
|     auto x0 = NDArrayFactory::create<float>('c', {1000, 1000}); |     auto x0 = NDArrayFactory::create<float>('c', {1000, 1000}); | ||||||
|     auto x1 = NDArrayFactory::create<float>('c', {1000, 1000}); |     auto x1 = NDArrayFactory::create<float>('c', {1000, 1000}); | ||||||
| 
 | 
 | ||||||
|     RandomLauncher::fillTruncatedNormal(_rngA, &x0, 2.0f, 4.0f); |     RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 2.0f, 4.0f); | ||||||
|     RandomLauncher::fillTruncatedNormal(_rngB, &x1, 2.0f, 4.0f); |     RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 2.0f, 4.0f); | ||||||
| 
 | 
 | ||||||
|     ASSERT_TRUE(x0.equalsTo(&x1)); |     ASSERT_TRUE(x0.equalsTo(&x1)); | ||||||
| 
 | 
 | ||||||
| @ -477,8 +477,8 @@ TEST_F(RNGTests, Test_Truncated_23) { | |||||||
|     auto x0 = NDArrayFactory::create<float>('c', {1000, 1000}); |     auto x0 = NDArrayFactory::create<float>('c', {1000, 1000}); | ||||||
|     auto x1 = NDArrayFactory::create<float>('c', {1000, 1000}); |     auto x1 = NDArrayFactory::create<float>('c', {1000, 1000}); | ||||||
| 
 | 
 | ||||||
|     RandomLauncher::fillTruncatedNormal(_rngA, &x0, 0.0f, 1.0f); |     RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 0.0f, 1.0f); | ||||||
|     RandomLauncher::fillTruncatedNormal(_rngB, &x1, 0.0f, 1.0f); |     RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 0.0f, 1.0f); | ||||||
| 
 | 
 | ||||||
|     ASSERT_TRUE(x0.equalsTo(&x1)); |     ASSERT_TRUE(x0.equalsTo(&x1)); | ||||||
| 
 | 
 | ||||||
| @ -524,8 +524,8 @@ TEST_F(RNGTests, Test_Truncated_3) { | |||||||
|     auto x0 = NDArrayFactory::create<float>('c', {10000, 1000}); |     auto x0 = NDArrayFactory::create<float>('c', {10000, 1000}); | ||||||
|     auto x1 = NDArrayFactory::create<float>('c', {10000, 1000}); |     auto x1 = NDArrayFactory::create<float>('c', {10000, 1000}); | ||||||
| 
 | 
 | ||||||
|     RandomLauncher::fillTruncatedNormal(_rngA, &x0, 1.0f, 2.0f); |     RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f); | ||||||
|     RandomLauncher::fillTruncatedNormal(_rngB, &x1, 1.0f, 2.0f); |     RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); | ||||||
| 
 | 
 | ||||||
|     ASSERT_TRUE(x0.equalsTo(&x1)); |     ASSERT_TRUE(x0.equalsTo(&x1)); | ||||||
| 
 | 
 | ||||||
| @ -964,7 +964,7 @@ TEST_F(RNGTests, Test_Reproducibility_2) { | |||||||
| TEST_F(RNGTests, Test_Uniform_4) { | TEST_F(RNGTests, Test_Uniform_4) { | ||||||
|     auto x1 = NDArrayFactory::create<double>('c', {1000000}); |     auto x1 = NDArrayFactory::create<double>('c', {1000000}); | ||||||
| 
 | 
 | ||||||
|     RandomLauncher::fillUniform(_rngB, &x1, 1.0, 2.0); |     RandomLauncher::fillUniform(LaunchContext::defaultContext(), _rngB, &x1, 1.0, 2.0); | ||||||
| 
 | 
 | ||||||
|     /* Check up distribution */ |     /* Check up distribution */ | ||||||
|     auto mean = x1.reduceNumber(reduce::Mean); |     auto mean = x1.reduceNumber(reduce::Mean); | ||||||
|  | |||||||
| @ -69,6 +69,24 @@ TEST_F(SortCudaTests, test_linear_sort_by_val_1) { | |||||||
|     ASSERT_EQ(ev, v); |     ASSERT_EQ(ev, v); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | TEST_F(SortCudaTests, test_linear_sort_by_val_2) { | ||||||
|  |     auto k = NDArrayFactory::create<int>('c', {6}, {0, 1, 2, 3, 4, 5}); | ||||||
|  | //    auto v = NDArrayFactory::create<double>('c', {6}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); | ||||||
|  |     NDArray v = NDArrayFactory::create<double>('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); | ||||||
|  |     auto ek = NDArrayFactory::create<int>('c', {6}, {3, 0, 1, 2, 4, 5}); | ||||||
|  |     auto ev = NDArrayFactory::create<double>('c', {6}, {0.95, 0.9, 0.75, 0.6, 0.5, 0.3}); | ||||||
|  | 
 | ||||||
|  |     Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()}; | ||||||
|  | 
 | ||||||
|  |     NativeOps nativeOps; | ||||||
|  |     nativeOps.sortByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), true); | ||||||
|  |     k.tickWriteDevice(); | ||||||
|  |     v.tickWriteDevice(); | ||||||
|  |     k.printIndexedBuffer("KEYS"); | ||||||
|  |     ASSERT_EQ(ek, k); | ||||||
|  |     ASSERT_EQ(ev, v); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| TEST_F(SortCudaTests, test_tad_sort_by_key_1) { | TEST_F(SortCudaTests, test_tad_sort_by_key_1) { | ||||||
|     auto k = NDArrayFactory::create<Nd4jLong>('c', {2, 10}, {1, 3, 5, 9, 0, 2, 4, 6, 7, 8,   1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); |     auto k = NDArrayFactory::create<Nd4jLong>('c', {2, 10}, {1, 3, 5, 9, 0, 2, 4, 6, 7, 8,   1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); | ||||||
|     auto v = NDArrayFactory::create<double>('c', {2, 10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5,   1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); |     auto v = NDArrayFactory::create<double>('c', {2, 10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5,   1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user