[WIP] More of CUDA operations (#69)

* initial commit

Signed-off-by: raver119 <raver119@gmail.com>

* - gruCell_bp further

Signed-off-by: Yurii <yurii@skymind.io>

* - further work on gruCell_bp

Signed-off-by: Yurii <yurii@skymind.io>

* Inverse matrix cublas implementation. Partial working revision.

* Separation of segment ops helpers. Max separation.

* Separated segment_min ops.

* Separation of segment_mean/sum/prod/sqrtN ops heleprs.

* Fixed diagonal processing with LUP decomposition.

* Modified inversion approach using current state of LU decomposition.

* Implementation of matrix_inverse op with cuda kernels. Working revision.

* Implemented sequence_mask cuda helper. Eliminated waste printf with matrix_inverse implementation. Added proper tests.

* - further work on gruCell_bp (ff/cuda)

Signed-off-by: Yurii <yurii@skymind.io>

* comment one test for gruCell_bp

Signed-off-by: Yurii <yurii@skymind.io>

* - provide cuda static_rnn

Signed-off-by: Yurii <yurii@skymind.io>

* Refactored random_shuffle op to use new random generator.

* Refactored random_shuffle op helper.

* Fixed debug tests with random ops tests.

* Implement random_shuffle op cuda kernel helper and tests.

* - provide cuda scatter_update

Signed-off-by: Yurii <yurii@skymind.io>

* Implementation of random_shuffle for linear case with cuda kernels and tests.

* Implemented random_shuffle with cuda kernels. Final revision.

* - finally gruCell_bp is completed

Signed-off-by: Yurii <yurii@skymind.io>

* Dropout op cuda helper implementation.

* Implemented dropout_bp cuda helper.

* Implemented alpha_dropout_bp with cuda kernel helpers.

* Refactored helper.

* Implementation of suppresion helper with cuda kernels.

* - provide cpu code fot hsvToRgb, rgbToHsv, adjustHue

Signed-off-by: Yurii <yurii@skymind.io>

* Using sort by value method.

* Implementation of image.non_max_suppression op cuda-based helper.

* - correcting and testing adjust_hue, adjust_saturation cpu/cuda code

Signed-off-by: Yurii <yurii@skymind.io>

* Added cuda device prefixes to declarations.

* Implementation of hashcode op with cuda helper. Initital revision.

* rnn cu impl removed

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-07-20 08:58:44 +03:00 committed by AlexDBlack
parent 06e4f5f96e
commit 763a225c6a
62 changed files with 5615 additions and 3848 deletions

View File

@ -208,9 +208,9 @@ namespace nd4j {
NDArray(const char order, const std::vector<Nd4jLong> &shape, const std::vector<double>& data, nd4j::DataType dtype = DOUBLE, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext()); NDArray(const char order, const std::vector<Nd4jLong> &shape, const std::vector<double>& data, nd4j::DataType dtype = DOUBLE, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
/** /**
* this constructor creates new array using given buffer (without memory allocating) and shape information stored in shape * this constructor creates new array using given buffer (without memory allocation) and shape information stored in shape
*/ */
NDArray(void *buffer, const char order, const std::vector<Nd4jLong> &shape, nd4j::DataType dtype, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext()); NDArray(void *buffer, const char order, const std::vector<Nd4jLong> &shape, nd4j::DataType dtype, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext(), const bool isBuffAlloc = false);
/** /**
* this constructor creates new NDArray with shape matching "other" array, * this constructor creates new NDArray with shape matching "other" array,

View File

@ -132,9 +132,8 @@ NDArray::NDArray(const NDArray *other, const bool copyStrides, nd4j::LaunchConte
_buffer = std::make_shared<DataBuffer>(lengthOf() * sizeOfT(), dataType(), getContext()->getWorkspace()); _buffer = std::make_shared<DataBuffer>(lengthOf() * sizeOfT(), dataType(), getContext()->getWorkspace());
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
NDArray::NDArray(void* buffer, const char order, const std::vector<Nd4jLong> &shape, nd4j::DataType dtype, nd4j::LaunchContext * context) { NDArray::NDArray(void* buffer, const char order, const std::vector<Nd4jLong> &shape, nd4j::DataType dtype, nd4j::LaunchContext * context, const bool isBuffAlloc) {
if (shape.empty()) if (shape.empty())
throw std::runtime_error("NDArray constructor: input shape is empty !"); throw std::runtime_error("NDArray constructor: input shape is empty !");
@ -148,7 +147,7 @@ NDArray::NDArray(void* buffer, const char order, const std::vector<Nd4jLong> &sh
setShapeInfo(ShapeDescriptor(dtype, order, shape)); setShapeInfo(ShapeDescriptor(dtype, order, shape));
_buffer = std::make_shared<DataBuffer>(buffer, lengthOf() * sizeOfT(), dataType(), true, getContext()->getWorkspace()); _buffer = std::make_shared<DataBuffer>(buffer, lengthOf() * sizeOfT(), dataType(), isBuffAlloc, getContext()->getWorkspace());
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////

View File

@ -1498,16 +1498,6 @@ void NativeOps::specialConcat(
* This method saves * This method saves
*/ */
nd4j::TadPack* NativeOps::tadOnlyShapeInfo(Nd4jLong *dXShapeInfo, int *dimension, int dimensionLength) { nd4j::TadPack* NativeOps::tadOnlyShapeInfo(Nd4jLong *dXShapeInfo, int *dimension, int dimensionLength) {
/*shape::TAD tad;
tad.init(dXShapeInfo, dimension, dimensionLength);
//tad->setOutputBuffer(target);
tad.createTadOnlyShapeInfo();
tad.createOffsets();
std::memcpy(reinterpret_cast<void *>(target), tad.tadOnlyShapeInfo, shape::shapeInfoByteLength(tad.tadOnlyShapeInfo));
std::memcpy(reinterpret_cast<void *>(offsets), tad.tadOffsets, tad.numTads * sizeof(Nd4jLong));
*/
auto pack = new TadPack(); auto pack = new TadPack();
*pack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(dXShapeInfo, dimension, dimensionLength); *pack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(dXShapeInfo, dimension, dimensionLength);
return pack; return pack;

View File

@ -45,9 +45,9 @@ namespace nd4j {
static ConstantTadHelper* getInstance(); static ConstantTadHelper* getInstance();
TadPack& tadForDimensions(Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape = false); TadPack& tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape = false);
TadPack& tadForDimensions(Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape = false); TadPack& tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape = false);
TadPack& tadForDimensions(Nd4jLong *originalShape, int dimensions, const bool keepUnitiesInShape = false); TadPack& tadForDimensions(const Nd4jLong *originalShape, int dimensions, const bool keepUnitiesInShape = false);
TadPack& tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape = false); TadPack& tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape = false);
TadPack& tadForDimensions(TadDescriptor &descriptor); TadPack& tadForDimensions(TadDescriptor &descriptor);
}; };

View File

@ -38,15 +38,15 @@ namespace nd4j {
return _INSTANCE; return _INSTANCE;
} }
TadPack& ConstantTadHelper::tadForDimensions(Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) { TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) {
return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape); return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape);
} }
TadPack& ConstantTadHelper::tadForDimensions(Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape) { TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape) {
return tadForDimensions(originalShape, const_cast<int *>(dimensions.data()), dimensions.size(), keepUnitiesInShape); return tadForDimensions(originalShape, const_cast<int *>(dimensions.data()), dimensions.size(), keepUnitiesInShape);
} }
TadPack& ConstantTadHelper::tadForDimensions(Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) { TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) {
TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, keepUnitiesInShape); TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, keepUnitiesInShape);
return tadForDimensions(tadDescriptor); return tadForDimensions(tadDescriptor);
} }

View File

@ -42,15 +42,15 @@ namespace nd4j {
return _INSTANCE; return _INSTANCE;
} }
TadPack& ConstantTadHelper::tadForDimensions(Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) { TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) {
return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape); return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape);
} }
TadPack& ConstantTadHelper::tadForDimensions(Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape) { TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape) {
return tadForDimensions(originalShape, const_cast<int *>(dimensions.data()), dimensions.size(), keepUnitiesInShape); return tadForDimensions(originalShape, const_cast<int *>(dimensions.data()), dimensions.size(), keepUnitiesInShape);
} }
TadPack& ConstantTadHelper::tadForDimensions(Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) { TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) {
TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, keepUnitiesInShape); TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, keepUnitiesInShape);
return tadForDimensions(tadDescriptor); return tadForDimensions(tadDescriptor);
} }

View File

@ -58,7 +58,7 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
const std::vector<NDArray*>& inArrsFF = argsHolderFF.getInArrs(); const std::vector<NDArray*>& inArrsFF = argsHolderFF.getInArrs();
const std::vector<NDArray*>& inArrsBP = argsHolderBP.getInArrs(); const std::vector<NDArray*>& inArrsBP = argsHolderBP.getInArrs();
// fill input gradient arrays in accordance to type of loss function // fill input gradient arrays in accordance to kind of loss function
fillGradArrays(loss, std::vector<NDArray*>(&inArrsBP[numInArrsFF], &inArrsBP[numInArrsFF + numInGradArrsBP])); fillGradArrays(loss, std::vector<NDArray*>(&inArrsBP[numInArrsFF], &inArrsBP[numInArrsFF + numInGradArrsBP]));
// beck prop pass // beck prop pass

View File

@ -987,9 +987,10 @@ namespace shape {
// dimsToExclude - should be sorted in increasing order // dimsToExclude - should be sorted in increasing order
ND4J_EXPORT _CUDA_HD int outerArrayIndexes(Nd4jLong* maxIdxs, const Nd4jLong minIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, const int* dimsToExclude = nullptr); ND4J_EXPORT _CUDA_HD int outerArrayIndexes(Nd4jLong* maxIdxs, const Nd4jLong minIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, const int* dimsToExclude = nullptr);
// calculate offsets of max-array, these output offsets correspond to one minIdx index of min-array which is sub-array of max-array // calculate offsets of max-array, these offsets correspond to one minIdx index of min-array which is sub-array of max-array
// maxOffsets - will contain calculated offsets of max-array, buffer for maxOffsets should be allocated beforehand
// dimsToExclude - should be sorted in increasing order // dimsToExclude - should be sorted in increasing order
// memBuff - auxiliary memory buffer (size = 2 * max_rank) for coordinates and increments storing, should be passed from outside // memBuff - auxiliary memory buffer (size = 2 * max_rank) for coordinates and increments storing, should be allocated beforehand
ND4J_EXPORT _CUDA_HD int outerArrayOffsets(Nd4jLong* maxOffsets, const Nd4jLong minIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, Nd4jLong* memBuff, const int* dimsToExclude = nullptr); ND4J_EXPORT _CUDA_HD int outerArrayOffsets(Nd4jLong* maxOffsets, const Nd4jLong minIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, Nd4jLong* memBuff, const int* dimsToExclude = nullptr);
// calculates offsets for entities (elements or sub-arrays), shape in context of sub-array means dimensions excluded from outer array // calculates offsets for entities (elements or sub-arrays), shape in context of sub-array means dimensions excluded from outer array

View File

@ -16,6 +16,7 @@
// //
// @author raver119@gmail.com // @author raver119@gmail.com
// @author Yurii Shyrma (iuriish@yahoo.com)
// //
#include <op_boilerplate.h> #include <op_boilerplate.h>
@ -28,46 +29,35 @@
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
DECLARE_TYPES(adjust_hue) {
getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY)
->setSameMode(true);
}
CONFIGURABLE_OP_IMPL(adjust_hue, 1, 1, true, -2, -2) { CONFIGURABLE_OP_IMPL(adjust_hue, 1, 1, true, 1, -2) {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
REQUIRE_TRUE(input->rankOf() == 3 || input->rankOf() == 4, 0, "AdjustHue: op expects either 3D or 4D input, but got %i instead", input->rankOf()); const int rank = input->rankOf();
const int dimC = block.getIArguments()->size() > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1;
const double delta = T_ARG(0);
REQUIRE_TRUE(rank >= 3, 0, "ADJUST_HUE: op expects rank of input array to be >= 3, but got %i instead", rank);
REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "ADJUST_HUE: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC));
REQUIRE_TRUE(-1. <= delta && delta <= 1., 0, "ADJUST_HUE: parameter delta must be within [-1, 1] interval, but got %f instead", delta);
double delta = 0; NDArray deltaScalarArr = NDArrayFactory::create<double>(delta, block.launchContext());
if (block.numT() > 0)
delta = T_ARG(0);
else if (block.width() > 1) {
auto _d = INPUT_VARIABLE(1);
if (!_d->isScalar()) {
auto str = ShapeUtils::shapeAsString(_d);
REQUIRE_TRUE(_d->isScalar(), 0, "AdjustHue: delta should be scalar NDArray, but got %s instead", str.c_str());
}
delta = _d->e<double>(0);
}
helpers::adjustHue(block.launchContext(), input, &deltaScalarArr, output, dimC);
bool isNHWC = false;
if (block.numI() > 0)
isNHWC = INT_ARG(0) == 1;
int numChannels = isNHWC ? input->sizeAt(-1) : input->sizeAt(-3);
REQUIRE_TRUE(numChannels == 3, 0, "AdjustHue: this operation expects image with 3 channels (R, G, B), but got % instead", numChannels);
auto ts = NDArrayFactory::create(delta, block.launchContext());
// FIXME: delta should be NDArray scalar
helpers::_adjust_hue(block.launchContext(), input, output, &ts, isNHWC);
return Status::OK(); return Status::OK();
} }
DECLARE_TYPES(adjust_hue) {
getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY)
->setSameMode(true);
}
} }
} }

View File

@ -27,45 +27,33 @@
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
DECLARE_TYPES(adjust_saturation) {
getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY)
->setSameMode(true);
}
CONFIGURABLE_OP_IMPL(adjust_saturation, 1, 1, true, -2, -2) { CONFIGURABLE_OP_IMPL(adjust_saturation, 1, 1, true, 1, -2) {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
REQUIRE_TRUE(input->rankOf() == 3 || input->rankOf() == 4, 0, "AdjustSaturation: op expects either 3D or 4D input, but got %i instead", input->rankOf()); const int rank = input->rankOf();
const int dimC = block.getIArguments()->size() > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1;
const double factor = T_ARG(0);
double delta = 0; REQUIRE_TRUE(rank >= 3, 0, "ADJUST_SATURATION: op expects rank of input array to be >= 3, but got %i instead", rank);
if (block.numT() > 0) REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "ADJUST_SATURATION: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC));
delta = T_ARG(0);
else if (block.width() > 1) {
auto _d = INPUT_VARIABLE(1);
if (!_d->isScalar()) {
auto str = ShapeUtils::shapeAsString(_d);
REQUIRE_TRUE(_d->isScalar(), 0, "AdjustSaturation: delta should be scalar NDArray, but got %s instead", str.c_str());
}
delta = _d->e<double>(0); NDArray factorScalarArr = NDArrayFactory::create<double>(factor, block.launchContext());
}
bool isNHWC = false; helpers::adjustSaturation(block.launchContext(), input, &factorScalarArr, output, dimC);
if (block.numI() > 0)
isNHWC = INT_ARG(0) == 1;
int numChannels = isNHWC ? input->sizeAt(-1) : input->sizeAt(-3);
REQUIRE_TRUE(numChannels == 3, 0, "AdjustSaturation: this operation expects image with 3 channels (R, G, B), but got % instead", numChannels);
auto ts = NDArrayFactory::create(delta, block.launchContext());
// FIXME: delta should be NDArray scalar
helpers::adjust_saturation(block.launchContext(), input, output, &ts, isNHWC);
return Status::OK(); return Status::OK();
} }
DECLARE_TYPES(adjust_saturation) {
getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY)
->setSameMode(true);
}
} }
} }

View File

@ -27,6 +27,7 @@
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
OP_IMPL(scatter_add, 3, 1, true) { OP_IMPL(scatter_add, 3, 1, true) {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto indices = INPUT_VARIABLE(1); auto indices = INPUT_VARIABLE(1);
@ -74,8 +75,8 @@ namespace nd4j {
return Status::OK(); return Status::OK();
} }
DECLARE_SYN(ScatterAdd, scatter_add); DECLARE_SYN(ScatterAdd, scatter_add);
}
DECLARE_TYPES(scatter_add) { DECLARE_TYPES(scatter_add) {
getOpDescriptor() getOpDescriptor()
@ -84,6 +85,8 @@ namespace nd4j {
->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS}) ->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS})
->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}); ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS});
} }
}
} }
#endif #endif

View File

@ -57,16 +57,26 @@ namespace nd4j {
auto in = inputShape->at(0); auto in = inputShape->at(0);
int outRank = shape::rank(in) + 1; int outRank = shape::rank(in) + 1;
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto dtype = DataType::BOOL;
Nd4jLong maxInd = input->argMax(); Nd4jLong maxInd = input->argMax();
float max = input->e<float>(maxInd); Nd4jLong max = input->e<Nd4jLong>(maxInd);
if (block.getIArguments()->size() > 0) { if (block.getIArguments()->size() > 0) {
if (block.width() < 2) {
maxInd = INT_ARG(0); maxInd = INT_ARG(0);
if (maxInd < max) if (maxInd < max)
maxInd = static_cast<Nd4jLong>(max); maxInd = static_cast<Nd4jLong>(max);
if (block.getIArguments()->size() > 1)
dtype = (DataType)INT_ARG(1);
} }
else if (block.width() > 1) { else {
dtype = (DataType)INT_ARG(0);
}
}
if (block.width() > 1) {
auto maxlen = INPUT_VARIABLE(1); auto maxlen = INPUT_VARIABLE(1);
float tmaxlen = maxlen->e<float>(0); Nd4jLong tmaxlen = maxlen->e<Nd4jLong>(0);
if (tmaxlen > max) if (tmaxlen > max)
maxInd = static_cast<Nd4jLong>(tmaxlen); maxInd = static_cast<Nd4jLong>(tmaxlen);
} }
@ -80,14 +90,14 @@ namespace nd4j {
outShapeInfo[i + 1] = shape::sizeAt(in, i); outShapeInfo[i + 1] = shape::sizeAt(in, i);
outShapeInfo[outRank] = lastDimension; outShapeInfo[outRank] = lastDimension;
ShapeUtils::updateStridesAndType(outShapeInfo, in, shape::order(in)); ShapeUtils::updateStridesAndType(outShapeInfo, dtype, shape::order(in));
return SHAPELIST(CONSTANT(outShapeInfo)); return SHAPELIST(CONSTANT(outShapeInfo));
} }
DECLARE_TYPES(sequence_mask) { DECLARE_TYPES(sequence_mask) {
getOpDescriptor() getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY) ->setAllowedInputTypes({ALL_INTS})
->setAllowedOutputTypes(nd4j::DataType::ANY); ->setAllowedOutputTypes(nd4j::DataType::ANY);
} }
} }

View File

@ -33,11 +33,11 @@ OP_IMPL(random_shuffle, 1, 1, true) {
const bool isInplace = block.isInplace(); const bool isInplace = block.isInplace();
auto output = isInplace ? nullptr : OUTPUT_VARIABLE(0); auto output = isInplace ? nullptr : OUTPUT_VARIABLE(0);
nd4j::random::RandomBuffer* rng = block.getRNG(); // nd4j::random::RandomBuffer* rng = block.getRNG();
nd4j::graph::RandomGenerator rng = block.randomGenerator();
// REQUIRE_TRUE(rng != nullptr, 0, "RANDOM_SHUFFLE op: RNG should be defined in Graph !");
REQUIRE_TRUE(rng != nullptr, 0, "RANDOM_SHUFFLE op: RNG should be defined in Graph !"); helpers::randomShuffle(block.launchContext(), *input, *output, rng, isInplace);
helpers::randomShuffle(block.launchContext(), *input, *output, *rng, isInplace);
return Status::OK(); return Status::OK();
} }

View File

@ -31,6 +31,7 @@ namespace ops {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(gruCell, 6, 4, false, 0, 0) { CUSTOM_OP_IMPL(gruCell, 6, 4, false, 0, 0) {
auto x = INPUT_VARIABLE(0); // input [bS, nIn], nIn - input size auto x = INPUT_VARIABLE(0); // input [bS, nIn], nIn - input size
auto hLast = INPUT_VARIABLE(1); // previous cell output [bS, nU], that is at previous time step t-1, nU - number of units auto hLast = INPUT_VARIABLE(1); // previous cell output [bS, nU], that is at previous time step t-1, nU - number of units
auto Wru = INPUT_VARIABLE(2); // RU weights - [nIn+nU, 2*nU] - reset and update gates (input/recurrent weights) auto Wru = INPUT_VARIABLE(2); // RU weights - [nIn+nU, 2*nU] - reset and update gates (input/recurrent weights)
@ -118,65 +119,58 @@ DECLARE_SHAPE_FN(gruCell) {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(gruCell_bp, 6, 5, false, 0, 0) { CUSTOM_OP_IMPL(gruCell_bp, 10, 6, false, 0, 0) {
auto x = INPUT_VARIABLE(0); // input [bS x iS] auto x = INPUT_VARIABLE(0); // input [bS x iS]
auto hi = INPUT_VARIABLE(1); // previous cell output [bS x nU] auto hi = INPUT_VARIABLE(1); // previous cell output [bS x nU]
auto Wx = INPUT_VARIABLE(2); // input-to-hidden weights, [iS x 3*nU] auto W = INPUT_VARIABLE(2); // weights, [iS+nU x 2*nU]
auto Wh = INPUT_VARIABLE(3); // hidden-to-hidden weights, [nU x 3*nU] auto Wc = INPUT_VARIABLE(3); // c weights, [iS+nU x nU]
auto b = INPUT_VARIABLE(4); // biases, [3*nU] auto b = INPUT_VARIABLE(4); // biases, [2*nU]
auto dLdh = INPUT_VARIABLE(5); // gradient wrt output, [bS,nU], that is epsilon_next auto bc = INPUT_VARIABLE(5); // biases, [nU]
auto dLdWxi = block.width() > 6 ? INPUT_VARIABLE(6) : nullptr; // gradient wrt Wx at previous time step, [iS, 3*nU] auto dLdr = INPUT_VARIABLE(6); // gradient wrt reset gate, [bS, nU]
auto dLdWhi = block.width() > 7 ? INPUT_VARIABLE(7) : nullptr; // gradient wrt Wh at previous time step, [nU, 3*nU] auto dLdu = INPUT_VARIABLE(7); // gradient wrt update gate, [bS, nU]
auto dLdbi = block.width() > 8 ? INPUT_VARIABLE(8) : nullptr; // gradient wrt b at previous time step, [3*nU] auto dLdc = INPUT_VARIABLE(8); // gradient wrt cell state, [bS, nU]
auto dLdh = INPUT_VARIABLE(9); // gradient wrt current cell output, [bS, nU]
auto dLdx = OUTPUT_VARIABLE(0); // gradient wrt x, [bS, iS], that is epsilon auto dLdx = OUTPUT_VARIABLE(0); // gradient wrt x, [bS, iS]
auto dLdhi = OUTPUT_VARIABLE(1); // gradient wrt hi, [bS, nU] auto dLdhi = OUTPUT_VARIABLE(1); // gradient wrt hi, [bS, nU]
auto dLdWx = OUTPUT_VARIABLE(2); // gradient wrt Wx, [iS, 3*nU] auto dLdW = OUTPUT_VARIABLE(2); // gradient wrt W, [iS+nU x 2*nU]
auto dLdWh = OUTPUT_VARIABLE(3); // gradient wrt Wh, [nU, 3*nU] auto dLdWc = OUTPUT_VARIABLE(3); // gradient wrt Wc, [iS+nU x nU]
auto dLdb = OUTPUT_VARIABLE(4); // gradient wrt biases, [3*nU] auto dLdb = OUTPUT_VARIABLE(4); // gradient wrt biases, [2*nU]
auto dLdbc = OUTPUT_VARIABLE(5); // gradient wrt c biases, [nU]
const int rank = x->rankOf(); // = 2
const Nd4jLong bS = x->sizeAt(0); const Nd4jLong bS = x->sizeAt(0);
const Nd4jLong iS = x->sizeAt(1); const Nd4jLong iS = x->sizeAt(1);
const Nd4jLong nU = hi->sizeAt(1); const Nd4jLong nU = hi->sizeAt(1);
REQUIRE_TRUE(x->rankOf() == 2, 0, "GRU_CELL_BP: rank of input array x must be 2, but got %i instead", x->rankOf());
const std::string hiShape = ShapeUtils::shapeAsString(hi); const std::string hiShape = ShapeUtils::shapeAsString(hi);
const std::string hiCorrectShape = ShapeUtils::shapeAsString({bS, nU}); const std::string hiCorrectShape = ShapeUtils::shapeAsString({bS, nU});
const std::string wxShape = ShapeUtils::shapeAsString(Wx); const std::string wShape = ShapeUtils::shapeAsString(W);
const std::string wxCorrectShape = ShapeUtils::shapeAsString({iS, 3*nU}); const std::string wCorrectShape = ShapeUtils::shapeAsString({iS+nU, 2*nU});
const std::string whShape = ShapeUtils::shapeAsString(Wh); const std::string wcShape = ShapeUtils::shapeAsString(Wc);
const std::string whCorrectShape = ShapeUtils::shapeAsString({nU, 3*nU}); const std::string wcCorrectShape = ShapeUtils::shapeAsString({iS+nU, nU});
const std::string bShape = ShapeUtils::shapeAsString(b); const std::string bShape = ShapeUtils::shapeAsString(b);
const std::string bCorrectShape = ShapeUtils::shapeAsString({3*nU}); const std::string bCorrectShape = ShapeUtils::shapeAsString({2*nU});
const std::string bcShape = ShapeUtils::shapeAsString(bc);
const std::string bcCorrectShape = ShapeUtils::shapeAsString({nU});
const std::string dLdrShape = ShapeUtils::shapeAsString(dLdr);
const std::string dLduShape = ShapeUtils::shapeAsString(dLdu);
const std::string dLdcShape = ShapeUtils::shapeAsString(dLdc);
const std::string dLdhShape = ShapeUtils::shapeAsString(dLdh); const std::string dLdhShape = ShapeUtils::shapeAsString(dLdh);
const std::string dLdhCorrectShape = ShapeUtils::shapeAsString({bS, nU});
REQUIRE_TRUE(hiShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of previous cell output array, expected is %s, but got %s instead !", hiCorrectShape.c_str(), hiShape.c_str()); REQUIRE_TRUE(hiShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of previous cell output array, expected is %s, but got %s instead !", hiCorrectShape.c_str(), hiShape.c_str());
REQUIRE_TRUE(wxShape == wxCorrectShape, 0, "GRU_CELL_BP op: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", wxCorrectShape.c_str(), wxShape.c_str()); REQUIRE_TRUE(wShape == wCorrectShape, 0, "GRU_CELL_BP op: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
REQUIRE_TRUE(whShape == whCorrectShape, 0, "GRU_CELL_BP op: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", whCorrectShape.c_str(), whShape.c_str()); REQUIRE_TRUE(wcShape == wcCorrectShape, 0, "GRU_CELL_BP op: wrong shape of c weights array, expected is %s, but got %s instead !", wcCorrectShape.c_str(), wcShape.c_str());
REQUIRE_TRUE(bShape == bCorrectShape, 0, "GRU_CELL_BP op: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str()); REQUIRE_TRUE(bShape == bCorrectShape, 0, "GRU_CELL_BP op: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
REQUIRE_TRUE(dLdhShape == dLdhCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdh array (epsilon_next), expected is %s, but got %s instead !", dLdhCorrectShape.c_str(), dLdhShape.c_str()); REQUIRE_TRUE(bcShape == bcCorrectShape, 0, "GRU_CELL_BP op: wrong shape of c biases array, expected is %s, but got %s instead !", bcCorrectShape.c_str(), bcShape.c_str());
REQUIRE_TRUE(dLdrShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdr array (gradient wrt reset gate), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLdrShape.c_str());
REQUIRE_TRUE(dLduShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdu array (gradient wrt update gate), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLduShape.c_str());
REQUIRE_TRUE(dLdcShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdc array (gradient wrt cell state), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLdcShape.c_str());
REQUIRE_TRUE(dLdhShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdh array (gradient wrt current cell output), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLdhShape.c_str());
if(dLdWxi != nullptr) { helpers::gruCellBP(block.launchContext(), x, hi, W, Wc, b, bc, dLdr, dLdu, dLdc, dLdh, dLdx, dLdhi, dLdW, dLdWc, dLdb, dLdbc);
const std::string dLdWxiShape = ShapeUtils::shapeAsString(dLdWxi);
const std::string dLdWxiCorrectShape = ShapeUtils::shapeAsString({iS, 3*nU});
REQUIRE_TRUE(dLdWxiShape == dLdWxiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdWxi array (gradient wrt Wx at previous time step), expected is %s, but got %s instead !", dLdWxiCorrectShape.c_str(), dLdWxiShape.c_str());
}
if(dLdWhi != nullptr) {
const std::string dLdWhiShape = ShapeUtils::shapeAsString(dLdWhi);
const std::string dLdWhiCorrectShape = ShapeUtils::shapeAsString({nU, 3*nU});
REQUIRE_TRUE(dLdWhiShape == dLdWhiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdWhi array (gradient wrt Wh at previous time step), expected is %s, but got %s instead !", dLdWhiCorrectShape.c_str(), dLdWhiShape.c_str());
}
if(dLdbi != nullptr) {
const std::string dLdbiShape = ShapeUtils::shapeAsString(dLdbi);
const std::string dLdbiCorrectShape = ShapeUtils::shapeAsString({3*nU});
REQUIRE_TRUE(dLdbiShape == dLdbiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdbi array (gradient wrt biases at previous time step), expected is %s, but got %s instead !", dLdbiCorrectShape.c_str(), dLdbiShape.c_str());
}
helpers::gruCellBP(block.launchContext(), x, hi, Wx, Wh, b, dLdh, dLdWxi, dLdWhi, dLdbi, dLdx, dLdhi, dLdWx, dLdWh, dLdb);
return Status::OK(); return Status::OK();
} }
@ -192,6 +186,7 @@ DECLARE_TYPES(gruCell_bp) {
->setAllowedInputTypes(6, {ALL_FLOATS}) ->setAllowedInputTypes(6, {ALL_FLOATS})
->setAllowedInputTypes(7, {ALL_FLOATS}) ->setAllowedInputTypes(7, {ALL_FLOATS})
->setAllowedInputTypes(8, {ALL_FLOATS}) ->setAllowedInputTypes(8, {ALL_FLOATS})
->setAllowedInputTypes(9, {ALL_FLOATS})
->setAllowedOutputTypes({ALL_FLOATS}); ->setAllowedOutputTypes({ALL_FLOATS});
} }
@ -199,53 +194,46 @@ DECLARE_SHAPE_FN(gruCell_bp) {
auto xShapeInfo = inputShape->at(0); // [bS x iS] auto xShapeInfo = inputShape->at(0); // [bS x iS]
auto hiShapeInfo = inputShape->at(1); // [bS x nU] auto hiShapeInfo = inputShape->at(1); // [bS x nU]
auto wxShapeInfo = inputShape->at(2); // [iS x 3*nU] auto wShapeInfo = inputShape->at(2); // [iS+nU x 2*nU]
auto whShapeInfo = inputShape->at(3); // [nU x 3*nU] auto wcShapeInfo = inputShape->at(3); // [iS+nU x nU]
auto bShapeInfo = inputShape->at(4); // [3*nU] auto bShapeInfo = inputShape->at(4); // [2*nU]
auto dLdhShapeInfo = inputShape->at(5); // [bS x nU] auto bcShapeInfo = inputShape->at(5); // [nU]
auto dLdrShapeInfo = inputShape->at(6); // [bS, nU]
auto dLduShapeInfo = inputShape->at(7); // [bS, nU]
auto dLdcShapeInfo = inputShape->at(8); // [bS, nU]
auto dLdhShapeInfo = inputShape->at(9); // [bS, nU]
const int rank = xShapeInfo[0]; // = 2 const int rank = xShapeInfo[0]; // = 2
const Nd4jLong bS = xShapeInfo[1]; const Nd4jLong bS = xShapeInfo[1];
const Nd4jLong iS = xShapeInfo[2]; const Nd4jLong iS = xShapeInfo[2];
const Nd4jLong nU = hiShapeInfo[2]; const Nd4jLong nU = hiShapeInfo[2];
REQUIRE_TRUE(xShapeInfo[0] == 2, 0, "GRU_CELL_BP: rank of input array x must be 2, but got %i instead", xShapeInfo[0]);
const std::string hiShape = ShapeUtils::shapeAsString(hiShapeInfo); const std::string hiShape = ShapeUtils::shapeAsString(hiShapeInfo);
const std::string hiCorrectShape = ShapeUtils::shapeAsString({bS, nU}); const std::string hiCorrectShape = ShapeUtils::shapeAsString({bS, nU});
const std::string wxShape = ShapeUtils::shapeAsString(wxShapeInfo); const std::string wShape = ShapeUtils::shapeAsString(wShapeInfo);
const std::string wxCorrectShape = ShapeUtils::shapeAsString({iS, 3*nU}); const std::string wCorrectShape = ShapeUtils::shapeAsString({iS+nU, 2*nU});
const std::string whShape = ShapeUtils::shapeAsString(whShapeInfo); const std::string wcShape = ShapeUtils::shapeAsString(wcShapeInfo);
const std::string whCorrectShape = ShapeUtils::shapeAsString({nU, 3*nU}); const std::string wcCorrectShape = ShapeUtils::shapeAsString({iS+nU, nU});
const std::string bShape = ShapeUtils::shapeAsString(bShapeInfo); const std::string bShape = ShapeUtils::shapeAsString(bShapeInfo);
const std::string bCorrectShape = ShapeUtils::shapeAsString({3*nU}); const std::string bCorrectShape = ShapeUtils::shapeAsString({2*nU});
const std::string bcShape = ShapeUtils::shapeAsString(bcShapeInfo);
const std::string bcCorrectShape = ShapeUtils::shapeAsString({nU});
const std::string dLdrShape = ShapeUtils::shapeAsString(dLdrShapeInfo);
const std::string dLduShape = ShapeUtils::shapeAsString(dLduShapeInfo);
const std::string dLdcShape = ShapeUtils::shapeAsString(dLdcShapeInfo);
const std::string dLdhShape = ShapeUtils::shapeAsString(dLdhShapeInfo); const std::string dLdhShape = ShapeUtils::shapeAsString(dLdhShapeInfo);
const std::string dLdhCorrectShape = ShapeUtils::shapeAsString({bS, nU});
REQUIRE_TRUE(hiShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of previous cell output array, expected is %s, but got %s instead !", hiCorrectShape.c_str(), hiShape.c_str()); REQUIRE_TRUE(hiShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of previous cell output array, expected is %s, but got %s instead !", hiCorrectShape.c_str(), hiShape.c_str());
REQUIRE_TRUE(wxShape == wxCorrectShape, 0, "GRU_CELL_BP op: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", wxCorrectShape.c_str(), wxShape.c_str()); REQUIRE_TRUE(wShape == wCorrectShape, 0, "GRU_CELL_BP op: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
REQUIRE_TRUE(whShape == whCorrectShape, 0, "GRU_CELL_BP op: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", whCorrectShape.c_str(), whShape.c_str()); REQUIRE_TRUE(wcShape == wcCorrectShape, 0, "GRU_CELL_BP op: wrong shape of c weights array, expected is %s, but got %s instead !", wcCorrectShape.c_str(), wcShape.c_str());
REQUIRE_TRUE(bShape == bCorrectShape, 0, "GRU_CELL_BP op: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str()); REQUIRE_TRUE(bShape == bCorrectShape, 0, "GRU_CELL_BP op: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
REQUIRE_TRUE(dLdhShape == dLdhCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdh array (epsilon_next), expected is %s, but got %s instead !", dLdhCorrectShape.c_str(), dLdhShape.c_str()); REQUIRE_TRUE(bcShape == bcCorrectShape, 0, "GRU_CELL_BP op: wrong shape of c biases array, expected is %s, but got %s instead !", bcCorrectShape.c_str(), bcShape.c_str());
REQUIRE_TRUE(dLdrShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdr array (gradient wrt reset gate), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLdrShape.c_str());
if(block.width() > 6) { REQUIRE_TRUE(dLduShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdu array (gradient wrt update gate), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLduShape.c_str());
Nd4jLong* dLdWxiShapeInfo = inputShape->at(6); // [iS x 3*nU] REQUIRE_TRUE(dLdcShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdc array (gradient wrt cell state), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLdcShape.c_str());
const std::string dLdWxiShape = ShapeUtils::shapeAsString(dLdWxiShapeInfo); REQUIRE_TRUE(dLdhShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdh array (gradient wrt current cell output), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLdhShape.c_str());
const std::string dLdWxiCorrectShape = ShapeUtils::shapeAsString({iS, 3*nU});
REQUIRE_TRUE(dLdWxiShape == dLdWxiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdWxi array (gradient wrt Wx at previous time step), expected is %s, but got %s instead !", dLdWxiCorrectShape.c_str(), dLdWxiShape.c_str());
}
if(block.width() > 7) {
Nd4jLong* dLdWhiShapeInfo = inputShape->at(7); // [nU x 3*nU]
const std::string dLdWhiShape = ShapeUtils::shapeAsString(dLdWhiShapeInfo);
const std::string dLdWhiCorrectShape = ShapeUtils::shapeAsString({nU, 3*nU});
REQUIRE_TRUE(dLdWhiShape == dLdWhiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdWhi array (gradient wrt Wh at previous time step), expected is %s, but got %s instead !", dLdWhiCorrectShape.c_str(), dLdWhiShape.c_str());
}
if(block.width() > 8) {
Nd4jLong* dLdbiShapeInfo = inputShape->at(8); // [3*nU]
const std::string dLdbiShape = ShapeUtils::shapeAsString(dLdbiShapeInfo);
const std::string dLdbiCorrectShape = ShapeUtils::shapeAsString({3*nU});
REQUIRE_TRUE(dLdbiShape == dLdbiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdbi array (gradient wrt biases at previous time step), expected is %s, but got %s instead !", dLdbiCorrectShape.c_str(), dLdbiShape.c_str());
}
Nd4jLong *dLdxShapeInfo = nullptr; Nd4jLong *dLdxShapeInfo = nullptr;
COPY_SHAPE(xShapeInfo, dLdxShapeInfo); COPY_SHAPE(xShapeInfo, dLdxShapeInfo);
@ -253,17 +241,19 @@ DECLARE_SHAPE_FN(gruCell_bp) {
Nd4jLong *dLdhiShapeInfo = nullptr; Nd4jLong *dLdhiShapeInfo = nullptr;
COPY_SHAPE(hiShapeInfo, dLdhiShapeInfo); COPY_SHAPE(hiShapeInfo, dLdhiShapeInfo);
Nd4jLong *dLdWxShapeInfo = nullptr; Nd4jLong *dLdWShapeInfo = nullptr;
COPY_SHAPE(wxShapeInfo, dLdWxShapeInfo); COPY_SHAPE(wShapeInfo, dLdWShapeInfo);
Nd4jLong *dLdWhShapeInfo = nullptr; Nd4jLong *dLdWcShapeInfo = nullptr;
COPY_SHAPE(whShapeInfo, dLdWhShapeInfo); COPY_SHAPE(wcShapeInfo, dLdWcShapeInfo);
Nd4jLong *dLdbShapeInfo = nullptr; Nd4jLong *dLdbShapeInfo = nullptr;
COPY_SHAPE(bShapeInfo, dLdbShapeInfo); COPY_SHAPE(bShapeInfo, dLdbShapeInfo);
return SHAPELIST(dLdxShapeInfo, dLdhiShapeInfo, dLdWxShapeInfo, dLdWhShapeInfo, dLdbShapeInfo); Nd4jLong *dLdbcShapeInfo = nullptr;
COPY_SHAPE(bcShapeInfo, dLdbcShapeInfo);
return SHAPELIST(dLdxShapeInfo, dLdhiShapeInfo, dLdWShapeInfo, dLdWcShapeInfo, dLdbShapeInfo, dLdbcShapeInfo);
} }

View File

@ -553,33 +553,31 @@ namespace nd4j {
/** /**
* This operation adjusts image hue by delta * This operation adjusts image hue by delta
* Input arrays: * Input arrays:
* 0 - 1D or 3D input array, must have 3 channels. * 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels.
* 1 - optional scalar, delta value
* *
* T arguments: * T arguments:
* 0 - optional delta value * 0 - delta value
* *
* Int arguments: * Int arguments:
* 0 - optional argument, isNHWC. false by default. * 0 - optional argument, corresponds to dimension with 3 channels
*/ */
#if NOT_EXCLUDED(OP_adjust_hue) #if NOT_EXCLUDED(OP_adjust_hue)
DECLARE_CONFIGURABLE_OP(adjust_hue, 1, 1, true, -2, -2); DECLARE_CONFIGURABLE_OP(adjust_hue, 1, 1, true, 1, -2);
#endif #endif
/** /**
* This operation adjusts image saturation by delta * This operation adjusts image saturation by delta
* Input arrays: * Input arrays:
* 0 - 1D or 3D input array, must have 3 channels. * 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels.
* 1 - optional scalar, delta value
* *
* T arguments: * T arguments:
* 0 - optional delta value * 0 - saturation factor
* *
* Int arguments: * Int arguments:
* 0 - optional argument, isNHWC. false by default. * 0 - optional argument, corresponds to dimension with 3 channels
*/ */
#if NOT_EXCLUDED(OP_adjust_saturation) #if NOT_EXCLUDED(OP_adjust_saturation)
DECLARE_CONFIGURABLE_OP(adjust_saturation, 1, 1, true, -2, -2); DECLARE_CONFIGURABLE_OP(adjust_saturation, 1, 1, true, 1, -2);
#endif #endif

View File

@ -259,8 +259,8 @@ namespace ops {
* Input arrays: * Input arrays:
* 0: input with shape [batchSize x inSize], batchSize - batch size, inSize - number of features * 0: input with shape [batchSize x inSize], batchSize - batch size, inSize - number of features
* 1: previous cell output [batchSize x numUnits], that is at previous time step t-1 * 1: previous cell output [batchSize x numUnits], that is at previous time step t-1
* 2: RU weights - [(nIn+nOut), 2*numUnits] - reset and update gates (input/recurrent weights) * 2: RU weights - [(inSize+numUnits), 2*numUnits] - reset and update gates (input/recurrent weights)
* 3: C weights - [(nIn+nOut), numUnits] - cell gate (input/recurrent weights) * 3: C weights - [(inSize+numUnits), numUnits] - cell gate (input/recurrent weights)
* 4: reset and update biases, [2*numUnits] - reset and update gates * 4: reset and update biases, [2*numUnits] - reset and update gates
* 5: cell biases, [numUnits] * 5: cell biases, [numUnits]
* *
@ -275,7 +275,7 @@ namespace ops {
#endif #endif
#if NOT_EXCLUDED(OP_gruCell) #if NOT_EXCLUDED(OP_gruCell)
DECLARE_CUSTOM_OP(gruCell_bp, 6, 5, false, 0, 0); DECLARE_CUSTOM_OP(gruCell_bp, 10, 6, false, 0, 0);
#endif #endif
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////

View File

@ -16,6 +16,7 @@
// //
// @author raver119@gmail.com // @author raver119@gmail.com
// @author Yurii Shyrma (iuriish@yahoo.com)
// //
#include <op_boilerplate.h> #include <op_boilerplate.h>
@ -24,6 +25,88 @@
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
void adjustHue(nd4j::LaunchContext* context, const NDArray *input, const NDArray* deltaScalarArr, NDArray *output, const int dimC);
////////////////////////////////////////////////////////////////////////////////
template <typename T>
FORCEINLINE _CUDA_HD void rgbToHsv(const T& r, const T& g, const T& b, T& h, T& s, T& v) {
// h values are in range [0, 360)
// s and v values are in range [0, 1]
const T max = nd4j::math::nd4j_max<T>(r, nd4j::math::nd4j_max<T>(g, b));
const T min = nd4j::math::nd4j_min<T>(r, nd4j::math::nd4j_min<T>(g, b));
const T c = max - min;
// calculate h
if(c == 0) {
h = 0;
}
else if(max == r) {
h = 60.f * ((g - b) / c) + (g >= b ? 0 : 360);
}
else if(max == g) {
h = 60.f * ((b - r) / c) + 120;
}
else { // max == b
h = 60.f * ((r - g) / c) + 240;
}
// calculate s
s = max == (T)0 ? (T)0 : c / max;
// calculate v
v = max / 255.f;
}
////////////////////////////////////////////////////////////////////////////////
template <typename T>
FORCEINLINE _CUDA_HD void hsvToRgb(const T& h, const T& s, const T& v, T& r, T& g, T& b) {
const float sector = h / 60.f;
const T c = v * s;
if(0.f <= sector && sector < 1.f) {
r = v;
g = v - c * (1 - sector);
b = v - c;
}
else if(1.f <= sector && sector < 2.f) {
r = v - c * (sector - 1);
g = v;
b = v - c;
}
else if(2.f <= sector && sector < 3.f) {
r = v - c;
g = v;
b = v - c * (3 - sector);
}
else if(3.f <= sector && sector < 4.f) {
r = v - c;
g = v - c * (sector - 3);
b = v;
}
else if(4.f <= sector && sector < 5.f) {
r = v - c * (5 - sector);
g = v - c;
b = v;
}
else { // 5.f <= sector < 6.f
r = v;
g = v - c;
b = v - c * (sector - 5);
}
r *= 255;
g *= 255;
b *= 255;
}
/*////////////////////////////////////////////////////////////////////////////////
template <typename T> template <typename T>
static FORCEINLINE _CUDA_HD void rgb_to_hv(T r, T g, T b, T* h, T* v_min, T* v_max) { static FORCEINLINE _CUDA_HD void rgb_to_hv(T r, T g, T b, T* h, T* v_min, T* v_max) {
T v_mid; T v_mid;
@ -83,6 +166,7 @@ namespace helpers {
*h = h_category + (increase ? ratio : (1 - ratio)); *h = h_category + (increase ? ratio : (1 - ratio));
} }
////////////////////////////////////////////////////////////////////////////////
template <typename T> template <typename T>
static FORCEINLINE _CUDA_HD void hv_to_rgb(T h, T v_min, T v_max, T* r, T* g, T* b) { static FORCEINLINE _CUDA_HD void hv_to_rgb(T h, T v_min, T v_max, T* r, T* g, T* b) {
int h_category = static_cast<int>(h); int h_category = static_cast<int>(h);
@ -128,7 +212,7 @@ namespace helpers {
} }
} }
void _adjust_hue(nd4j::LaunchContext * context, NDArray *input, NDArray *output, NDArray *delta, bool isNHWC); */
} }
} }
} }

View File

@ -16,6 +16,7 @@
// //
// @author raver119@gmail.com // @author raver119@gmail.com
// @author Yurii Shyrma (iuriish@yahoo.com)
// //
#include <op_boilerplate.h> #include <op_boilerplate.h>
@ -25,6 +26,10 @@
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
void adjustSaturation(nd4j::LaunchContext* context, const NDArray *input, const NDArray* factorScalarArr, NDArray *output, const int dimC);
/*
template <typename T> template <typename T>
static FORCEINLINE _CUDA_HD void rgb_to_hsv(T r, T g, T b, T* h, T* s, T* v) { static FORCEINLINE _CUDA_HD void rgb_to_hsv(T r, T g, T b, T* h, T* s, T* v) {
T vv = nd4j::math::nd4j_max<T>(r, nd4j::math::nd4j_max<T>(g, b)); T vv = nd4j::math::nd4j_max<T>(r, nd4j::math::nd4j_max<T>(g, b));
@ -109,8 +114,8 @@ namespace helpers {
*g = gg + m; *g = gg + m;
*b = bb + m; *b = bb + m;
} }
*/
void adjust_saturation(nd4j::LaunchContext * context, NDArray *input, NDArray *output, NDArray *delta, bool isNHWC);
} }
} }
} }

View File

@ -16,16 +16,84 @@
// //
// @author raver119@gmail.com // @author raver119@gmail.com
// @author Yurii Shyrma (iuriish@yahoo.com)
// //
#include <ops/declarable/helpers/adjust_hue.h> #include <ops/declarable/helpers/adjust_hue.h>
#include <helpers/ConstantTadHelper.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
template <typename T> template <typename T>
static void _adjust_hue_single(nd4j::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC) { static void adjustHue_(const NDArray *input, const NDArray* deltaScalarArr, NDArray *output, const int dimC) {
const T delta = deltaScalarArr->e<T>(0);
const int rank = input->rankOf();
const T* x = input->bufferAsT<T>();
T* z = output->bufferAsT<T>();
if(dimC == rank - 1 && input->ews() == 1 && output->ews() == 1 && input->ordering() == 'c' && output->ordering() == 'c') {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (Nd4jLong i = 0; i < input->lengthOf(); i += 3) {
T h, s, v;
rgbToHsv<T>(x[i], x[i+1], x[i+2], h, s, v);
h += delta * 360;
if(h > 360)
h -= 360;
else if(h < 0)
h += 360;
hsvToRgb<T>(h, s, v, z[i], z[i+1], z[i+2]);
}
}
else {
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {dimC});
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {dimC});
const Nd4jLong numOfTads = packX.numberOfTads();
const Nd4jLong xDimCstride = input->stridesOf()[dimC];
const Nd4jLong zDimCstride = output->stridesOf()[dimC];
PRAGMA_OMP_PARALLEL_FOR_SIMD
for(Nd4jLong i = 0; i < numOfTads; ++i) {
const T* xTad = x + packX.platformOffsets()[i];
T* zTad = z + packZ.platformOffsets()[i];
T h, s, v;
rgbToHsv<T>(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], h, s, v);
h += delta * 360;
if(h > 360)
h -= 360;
else if(h < 0)
h += 360;
hsvToRgb<T>(h, s, v, zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]);
}
}
}
void adjustHue(nd4j::LaunchContext* context, const NDArray *input, const NDArray* deltaScalarArr, NDArray *output, const int dimC) {
BUILD_SINGLE_SELECTOR(input->dataType(), adjustHue_, (input, deltaScalarArr, output, dimC), LIBND4J_TYPES);
}
/*
template <typename T>
static void adjust_hue_single_(nd4j::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC) {
// we're 100% sure it's 3 // we're 100% sure it's 3
const int numChannels = 3; const int numChannels = 3;
int tuples = array->lengthOf() / numChannels; int tuples = array->lengthOf() / numChannels;
@ -93,7 +161,7 @@ namespace helpers {
} }
} }
void _adjust_hue(nd4j::LaunchContext * context, NDArray *array, NDArray *output, NDArray* delta, bool isNHWC) { void adjust_hue_(nd4j::LaunchContext * context, NDArray *array, NDArray *output, NDArray* delta, bool isNHWC) {
auto xType = array->dataType(); auto xType = array->dataType();
float d = delta->e<float>(0); float d = delta->e<float>(0);
@ -104,18 +172,20 @@ namespace helpers {
// FIXME: template selector should be moved out of loop // FIXME: template selector should be moved out of loop
PRAGMA_OMP_PARALLEL_FOR PRAGMA_OMP_PARALLEL_FOR
for (int e = 0; e < tSize; e++) { for (int e = 0; e < tSize; e++) {
BUILD_SINGLE_SELECTOR(xType, _adjust_hue_single, (context, tadsIn->at(e), tadsOut->at(e), d, isNHWC);, FLOAT_TYPES); BUILD_SINGLE_SELECTOR(xType, adjust_hue_single_, (context, tadsIn->at(e), tadsOut->at(e), d, isNHWC);, FLOAT_TYPES);
} }
delete tadsIn; delete tadsIn;
delete tadsOut; delete tadsOut;
} else { } else {
BUILD_SINGLE_SELECTOR(xType, _adjust_hue_single, (context, array, output, d, isNHWC);, FLOAT_TYPES); BUILD_SINGLE_SELECTOR(xType, adjust_hue_single_, (context, array, output, d, isNHWC);, FLOAT_TYPES);
} }
} }
BUILD_SINGLE_TEMPLATE(template void _adjust_hue_single, (nd4j::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC);, FLOAT_TYPES); BUILD_SINGLE_TEMPLATE(template void adjust_hue_single_, (nd4j::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC);, FLOAT_TYPES);
*/
} }
} }

View File

@ -16,15 +16,83 @@
// //
// @author raver119@gmail.com // @author raver119@gmail.com
// @author Yurii Shyrma (iuriish@yahoo.com)
// //
#include <ops/declarable/helpers/adjust_saturation.h> #include <ops/declarable/helpers/adjust_saturation.h>
#include <ops/declarable/helpers/adjust_hue.h>
#include <helpers/ConstantTadHelper.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
template <typename T>
static void adjustSaturation_(const NDArray *input, const NDArray* factorScalarArr, NDArray *output, const int dimC) {
const T factor = factorScalarArr->e<T>(0);
const int rank = input->rankOf();
const T* x = input->bufferAsT<T>();
T* z = output->bufferAsT<T>();
if(dimC == rank - 1 && input->ews() == 1 && output->ews() == 1 && input->ordering() == 'c' && output->ordering() == 'c') {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (Nd4jLong i = 0; i < input->lengthOf(); i += 3) {
T h, s, v;
rgbToHsv<T>(x[i], x[i+1], x[i+2], h, s, v);
s *= factor;
if(s > 1.f)
s = 1.f;
else if(s < 0.f)
s = 0.f;
hsvToRgb<T>(h, s, v, z[i], z[i+1], z[i+2]);
}
}
else {
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {dimC});
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {dimC});
const Nd4jLong numOfTads = packX.numberOfTads();
const Nd4jLong xDimCstride = input->stridesOf()[dimC];
const Nd4jLong zDimCstride = output->stridesOf()[dimC];
PRAGMA_OMP_PARALLEL_FOR_SIMD
for(Nd4jLong i = 0; i < numOfTads; ++i) {
const T* xTad = x + packX.platformOffsets()[i];
T* zTad = z + packZ.platformOffsets()[i];
T h, s, v;
rgbToHsv<T>(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], h, s, v);
s *= factor;
if(s > 1.f)
s = 1.f;
else if(s < 0.f)
s = 0.f;
hsvToRgb<T>(h, s, v, zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]);
}
}
}
void adjustSaturation(nd4j::LaunchContext* context, const NDArray *input, const NDArray* factorScalarArr, NDArray *output, const int dimC) {
BUILD_SINGLE_SELECTOR(input->dataType(), adjustSaturation_, (input, factorScalarArr, output, dimC), LIBND4J_TYPES);
}
/*
template <typename T> template <typename T>
static void adjust_saturation_single_(nd4j::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC) { static void adjust_saturation_single_(nd4j::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC) {
// we're 100% sure it's 3 // we're 100% sure it's 3
@ -108,6 +176,7 @@ namespace helpers {
} }
BUILD_SINGLE_TEMPLATE(template void adjust_saturation_single_, (nd4j::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC), FLOAT_TYPES); BUILD_SINGLE_TEMPLATE(template void adjust_saturation_single_, (nd4j::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC), FLOAT_TYPES);
*/
} }
} }

View File

@ -59,14 +59,17 @@ namespace helpers {
std::vector<Nd4jLong> dims(reduceShape->lengthOf()); std::vector<Nd4jLong> dims(reduceShape->lengthOf());
bool fit = true; bool fit = true;
PRAGMA_OMP_PARALLEL_FOR_ARGS(firstprivate(fit))
for( int i = 0; fit && (i < dims.size()); i++ ) { for( int i = 0; i < dims.size(); i++ ) {
if (fit) {
dims[i] = reduceShape->e<Nd4jLong>(i); dims[i] = reduceShape->e<Nd4jLong>(i);
for (int e = 0; fit && (e < input->rankOf()); ++e) for (int e = 0; e < input->rankOf(); ++e)
if (fit)
if (input->sizeAt(e) % dims[i]) { if (input->sizeAt(e) % dims[i]) {
fit = false; fit = false;
} }
} }
}
// check dims to fit input // check dims to fit input
REQUIRE_TRUE(fit, 0, "dropout: Noise shape should fit to input rank."); REQUIRE_TRUE(fit, 0, "dropout: Noise shape should fit to input rank.");

View File

@ -35,82 +35,88 @@ namespace helpers {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void gruCell(nd4j::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* Wru, const NDArray* Wc, void gruCell(nd4j::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* W, const NDArray* Wc,
const NDArray* bru, const NDArray* bc, const NDArray* b, const NDArray* bc,
NDArray* r, NDArray* u, NDArray* c, NDArray* h) { NDArray* r, NDArray* u, NDArray* c, NDArray* h) {
//Inputs: //Inputs:
// x input [bS, nIn], nIn - input size // x input [bS, iS], iS - input size
// hLast previous cell output [bS, nUn], that is at previous time step t-1, nUn - number of units // hLast previous cell output [bS, nU], that is at previous time step t-1, nU - number of units
// Wru RU weights - [nIn+nUn, 2*nUn] - reset and update gates // W RU weights - [iS+nU, 2*nU] - reset and update gates
// Wc C weights - [nIn+nUn, nUn] - cell gate // Wc C weights - [iS+nU, nU] - cell gate
// bru r and u biases, [2*nUn] - reset and update gates // b r and u biases, [2*nU] - reset and update gates
// bc c biases, [nUn] - cell gate // bc c biases, [nU] - cell gate
//Outputs: //Outputs:
// r Reset gate output [bS, nUn] // r Reset gate output [bS, nU]
// u Update gate output [bS, nUn] // u Update gate output [bS, nU]
// c Cell gate output [bS, nUn] // c Cell gate output [bS, nU]
// h current cell output [bS, nUn] // h current cell output [bS, nU]
/***************************************************************************************/ /***************************************************************************************/
/************************ THIS IS NOT OPTIMAZED CODE ***********************************/ /************************ THIS IS NOT OPTIMAZED CODE ***********************************/
/** however it is more math-friendly and convenient for backprop formulas derivation) **/ /** however it is more math-friendly and convenient for backprop formulas derivation) **/
const int bS = x->sizeAt(0); const int bS = x->sizeAt(0);
const int nIn = x->sizeAt(1); const int iS = x->sizeAt(1);
const int nUn = hLast->sizeAt(1); const int nU = hLast->sizeAt(1);
NDArray Wr = (*Wru)({0,nIn, 0,0}); // reset gates weights [nIn, 2*nUn] NDArray Wrx = (*W)({0,iS, 0,nU}); // [iS, nU]
NDArray Wu = (*Wru)({nIn,nIn+nUn, 0,0}); // updates gates weights [nUn, 2*nUn] NDArray Wux = (*W)({0,iS, nU,2*nU}); // [iS, nU]
NDArray Wrh = (*W)({iS,iS+nU, 0,nU}); // [nU, nU]
NDArray Wuh = (*W)({iS,iS+nU, nU,2*nU}); // [nU, nU]
NDArray Wcr = (*Wc)({0,nIn, 0,0}); // reset cell weights [nIn, nUn] NDArray Wcx = (*Wc)({0,iS, 0,0}); // reset cell weights [iS, nU]
NDArray Wcu = (*Wc)({nIn,nIn+nUn, 0,0}); // updates cell weights [nUn, nUn] NDArray Wch = (*Wc)({iS,iS+nU, 0,0}); // updates cell weights [nU, nU]
// gates = sigmoid(x*Wr + hLast*Wu + br + bu) NDArray br = (*b)({0, nU}); // [nU]
NDArray gates = mmul(*x, Wr) + mmul(*hLast, Wu) + *bru; // [bS, nIn] * [nIn, 2*nUn] + [bS, nUn] * [nUn, 2*nUn] + [2*nUn] = [bS, 2*nUn] NDArray bu = (*b)({nU, 2*nU}); // [nU]
gates.applyTransform(transform::Sigmoid);
// × means matrix multipication
// * means element-wise product or so called Hadamard product
// reset gate // reset gate
r->assign(gates({0,0, 0,nUn})); // [bS, nUn] r->assign(mmul(*x, Wrx) + mmul(*hLast, Wrh) + br); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
r->applyTransform(transform::Sigmoid);
// update gate // update gate
u->assign(gates({0,0, nUn,2*nUn})); // [bS, nUn] u->assign(mmul(*x, Wux) + mmul(*hLast, Wuh) + bu); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
u->applyTransform(transform::Sigmoid);
// cell gate c = activation(x*Wcr + (r◦hlast)*Wcu + bc) // cell gate c = activation(x × Wcx + (r * hlast) × Wch + bc)
c->assign(mmul(*x, Wcr) + mmul(*r * *hLast, Wcu) + *bc); // [bS, nIn] * [nIn, nUn] + [bS, nUn] * [nUn, nUn] + [nUn] = [bS, nUn] c->assign(mmul(*x, Wcx) + mmul(*r * *hLast, Wch) + *bc); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
c->applyTransform(transform::Tanh); c->applyTransform(transform::Tanh);
NDArray temp = 1.f - *c * *c;
// cell output // cell output
h->assign(*u * *hLast + (1.f - *u) * *c); h->assign(*u * *hLast + (1.f - *u) * *c);
/***************************************************************************************/ /***************************************************************************************/
/********************** THIS MORE OPTIMAZED CODE (except concat ) **********************/ /*************** THIS IS MORE OPTIMAZED CODE (should think about concat) ***************/
/***************************************************************************************/ /***************************************************************************************/
/* /*
//Concat inputs: x + hLast : [bs, nIn + nUn] //Concat inputs: x + hLast : [bs, iS + nU]
NDArray xhConcat(x->ordering(), {bS, nIn + nUn}, x->dataType(), context); // concat([bs, nIn], [bs, nUn]) -> [bs, nIn + nUn] NDArray xhConcat(x->ordering(), {bS, iS + nU}, x->dataType(), context); // concat([bs, iS], [bs, nU]) -> [bs, iS + nU]
helpers::concat(context, {const_cast<NDArray*>(x), const_cast<NDArray*>(hLast)}, xhConcat, {1}); helpers::concat(context, {const_cast<NDArray*>(x), const_cast<NDArray*>(hLast)}, xhConcat, {1});
//mmul for reset and update gates: (x * weight_ux + hLast * weight_xr + b_u) //mmul for reset and update gates: (x × weight_ux + hLast × weight_xr + b_u)
auto m = mmul(xhConcat, *Wru) + *bru ; // [bs, nIn+nUn] * [nIn+nUn, 2*nUn] = [bs, 2*nUn] auto m = mmul(xhConcat, *W) + *b ; // [bs, iS+nU] * [iS+nU, 2*nU] = [bs, 2*nU]
// m += *bru; // m += *bru;
sigmoidInplace(m); //sigmoid(rz) and sigmoid(uz) m.applyTransform(transform::Sigmoid); //sigmoid(rz) and sigmoid(uz)
r->assign(m({0,0, 0, nUn})); r->assign(m({0,0, 0, nU}));
u->assign(m({0,0, nUn, 2*nUn})); u->assign(m({0,0, nU, 2*nU}));
// hLast = hLast * r // hLast = hLast * r
xhConcat({0,0, nIn, nIn+nUn}) *= *r; xhConcat({0,0, iS, iS+nU}) *= *r;
//c = tanh(x * weight_cx + (hLast .* r) * weight_cr + b_c) //c = tanh(x × weight_cx + (hLast * r) × weight_cr + b_c)
MmulHelper::mmul(&xhConcat, Wc, c, 1.0, 0.0); //c = 1.0 * xhConcat * Wc + 0.0 * c MmulHelper::mmul(&xhConcat, Wc, c, 1.0, 0.0); //c = 1.0 * xhConcat * Wc + 0.0 * c
*c += *bc; *c += *bc;
tanhInplace(*c); c->applyTransform(transform::Tanh);
//Output: h = (1-u).*c + u .* hPrev //Output: h = (1-u).*c + u .* hPrev
//auto hResult = (*u) * (*hLast) + (1.0f - *u) * (*c); const_cast<NDArray*>(h)->assign(&hResult); //auto hResult = (*u) * (*hLast) + (1.0f - *u) * (*c); const_cast<NDArray*>(h)->assign(&hResult);
@ -122,19 +128,19 @@ void gruCell(nd4j::LaunchContext * context, const NDArray* x, const NDArray* hLa
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void gruTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h) { void gruTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h) {
// x input [time, bS, iS] // x input [time, bS, iS]
// h0 initial cell output (at time step = 0) [bS, nUn] // hLast initial cell output (at time step = 0) [bS, nU]
// Wx input-to-hidden weights, [iS, 3*nUn] // Wx input-to-hidden weights, [iS, 3*nU]
// Wh hidden-to-hidden weights, [nUn, 3*nUn] // Wh hidden-to-hidden weights, [nU, 3*nU]
// b biases, [3*nUn] // b biases, [3*nU]
// h is cell outputs at each time step [time, bS, nUn] // h is cell outputs at each time step [time, bS, nU]
const int time = x->sizeAt(0); const int time = x->sizeAt(0);
NDArray ht_1(*h0); NDArray ht_1(*hLast);
// loop through time steps // loop through time steps
for (int t = 0; t < time; ++t) { for (int t = 0; t < time; ++t) {
@ -148,105 +154,208 @@ void gruTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray*
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void gruCellBP(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh, const NDArray* dLdWx0, void gruCellBP(nd4j::LaunchContext* context,
const NDArray* dLdWh0, const NDArray* dLdb0, NDArray* dLdx, NDArray* dLdh0, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb) { const NDArray* x, const NDArray* hLast,
const NDArray* W, const NDArray* Wc, const NDArray* b, const NDArray* bc,
const NDArray* dLdr, const NDArray* dLdu, const NDArray* dLdc, const NDArray* dLdh,
NDArray* dLdx, NDArray* dLdhLast,
NDArray* dLdW, NDArray* dLdWc,
NDArray* dLdb, NDArray* dLdbc) {
//Inputs:
// x input [bS, iS] // x input [bS, iS]
// h0 previous cell output [bS, nUn], that is at previous time step t-1 // hLast previous cell output [bS, nU], that is at previous time step t-1
// Wx input-to-hidden weights, [iS, 3*nUn] // W weights - [iS+nU, 2*nU] - reset and update gates
// Wh hidden-to-hidden weights, [nUn, 3*nUn] // Wc C weights - [iS+nU, nU] - cell gate
// b biases, [3*nUn] // b r and u biases, [2*nU] - reset and update gates
// dLdh gradient wrt output, [bS,nUn], that is epsilon_next // bc c biases, [nU] - cell gate
// dLdWx0 gradient wrt Wx at previous time step, [iS, 3*nUn] // dLdr gradient wrt reset gate, [bS, nU]
// dLdWh0 gradient wrt Wh at previous time step, [nUn, 3*nUn] // dLdu gradient wrt update gate, [bS, nU]
// dLdb0 gradient wrt b at previous time step, [3*nUn] // dLdc gradient wrt cell state, [bS, nU]
// dLdh gradient wrt current cell output, [bS, nU]
// dLdx gradient wrt x, [bS, iS], that is epsilon //Outputs:
// dLdh0 gradient wrt h0, [bS, nUn] // dLdx gradient wrt x, [bS, iS],
// dLdWx gradient wrt Wx, [iS, 3*nUn] // dLdhLast gradient wrt hLast, [bS, nU]
// dLdWh gradient wrt Wh, [nUn, 3*nUn] // dLdW gradient wrt W, [iS+nU, 2*nU]
// dLdb gradient wrt b at previous time step, [3*nUn] // dLdWc gradient wrt Wc, [iS+nU, nU]
// dLdb gradient wrt bru [2*nU]
// dLdbc gradient wrt bc [nU]
// h is current cell output [bS, nUn], that is at current time step t // * means element-wise product or so called Hadamard product
// × means matrix multiplication
/************************************************************************************************/
/******************************* THIS IS NOT OPTIMAZED CODE *************************************/
/*** aim is to have math-readable code in order to keep track of backprop formulas derivation ***/
const int bS = x->sizeAt(0);
const int iS = x->sizeAt(1);
const int nU = hLast->sizeAt(1);
NDArray xT = x->transpose(); // [iS, bS]
NDArray hLastT = hLast->transpose(); // [nU, bS]
NDArray Wrx = (*W)({0,iS, 0,nU}); // [iS, nU]
NDArray Wux = (*W)({0,iS, nU,2*nU}); // [iS, nU]
NDArray Wrh = (*W)({iS,iS+nU, 0,nU}); // [nU, nU]
NDArray Wuh = (*W)({iS,iS+nU, nU,2*nU}); // [nU, nU]
NDArray Wcx = (*Wc)({0,iS, 0,0}); // reset cell weights [iS, nU]
NDArray Wch = (*Wc)({iS,iS+nU, 0,0}); // updates cell weights [nU, nU]
NDArray br = (*b)({0, nU}); // [nU]
NDArray bu = (*b)({nU, 2*nU}); // [nU]
NDArray WrxT = Wrx.transpose(); // [nU, iS]
NDArray WuxT = Wux.transpose(); // [nU, iS]
NDArray WrhT = Wrh.transpose(); // [nU, nU]
NDArray WuhT = Wuh.transpose(); // [nU, nU]
NDArray WcxT = Wcx.transpose(); // [nU, iS]
NDArray WchT = Wch.transpose(); // [nU, nU]
NDArray dLdWrx = (*dLdW)({0,iS, 0,nU}); // [iS, nU]
NDArray dLdWux = (*dLdW)({0,iS, nU,2*nU}); // [iS, nU]
NDArray dLdWrh = (*dLdW)({iS,iS+nU, 0,nU}); // [nU, nU]
NDArray dLdWuh = (*dLdW)({iS,iS+nU, nU,2*nU}); // [nU, nU]
NDArray dLdWcx = (*dLdWc)({0,iS, 0,0}); // [iS, nU]
NDArray dLdWch = (*dLdWc)({iS,iS+nU, 0,0}); // [nU, nU]
NDArray dLdbr = (*dLdb)({0, nU}); // [nU]
NDArray dLdbu = (*dLdb)({nU, 2*nU}); // [nU]
const int nUn = h0->sizeAt(1);
// ***** feed forward step ***** // // ***** feed forward step ***** //
// gates = sigmoid(x*Wx + h0*Wh + b)
auto gates = sigmoid(mmul(*x, (*Wx)({0,0, 0,2*nUn})) + mmul(*h0, (*Wh)({0,0, 0,2*nUn})) + (*b)({0,2*nUn})); // [bS, 2*nUn] + [bS, 2*nUn] + [1, 2*nUn] = [bS, 2*nUn]
// reset gate // reset gate
auto r = gates({0,0, 0, nUn}); // [bS, nUn] NDArray r = mmul(*x, Wrx) + mmul(*hLast, Wrh) + br; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
r.applyTransform(transform::Sigmoid);
// update gate // update gate
auto u = gates({0,0, nUn, 2*nUn}); // [bS, nUn] NDArray u = mmul(*x, Wux) + mmul(*hLast, Wuh) + bu; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
// ◦ means element-wise product or so called Hadamard product u.applyTransform(transform::Sigmoid);
// n = tanh(x*Wx + (r◦h0)*Wh + b)
auto n = tanh(mmul(*x, (*Wx)({0,0, 2*nUn,3*nUn})) + mmul((*h0)*r, (*Wh)({0,0, 2*nUn,3*nUn})) + (*b)({2*nUn,3*nUn})); // [bS, nUn] // cell gate c = activation(x×Wcx + (r*hlast)×Wcu + bc)
NDArray c = mmul(*x, Wcx) + mmul(r * *hLast, Wch) + *bc; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
c.applyTransform(transform::Tanh);
// h = (1 - u) * c + u * hPrev
// ***** back prop step ***** // // ***** back prop step ***** //
auto Wxr = (*Wx)({0,0, 0, nUn});
auto Wxu = (*Wx)({0,0, nUn, 2*nUn});
auto Wxn = (*Wx)({0,0, 2*nUn,3*nUn});
auto Whr = (*Wh)({0,0, 0, nUn});
auto Whu = (*Wh)({0,0, nUn, 2*nUn});
auto Whn = (*Wh)({0,0, 2*nUn,3*nUn});
auto WxrT = Wxr.transpose();
auto WxuT = Wxu.transpose();
auto WxnT = Wxn.transpose();
auto WhrT = Whr.transpose();
auto WhuT = Whu.transpose();
auto WhnT = Whn.transpose();
auto xT = x->transpose();
auto h0T = h0->transpose();
auto dLdWxr = (*dLdWx)({0,0, 0, nUn}); // notations:
auto dLdWxu = (*dLdWx)({0,0, nUn, 2*nUn}); // Zr = x × Wrx + hLast × Wrh + br
auto dLdWxn = (*dLdWx)({0,0, 2*nUn,3*nUn}); // Zu = x × Wux + hLast × Wuh + bu
// Sr = sigmoid(Zr)
// Su = sigmoid(Zu)
// Zc = x × Wcx + (r * hlast) × Wch + bc
auto dLdWhr = (*dLdWh)({0,0, 0, nUn});
auto dLdWhu = (*dLdWh)({0,0, nUn, 2*nUn});
auto dLdWhn = (*dLdWh)({0,0, 2*nUn,3*nUn});
auto dLdbr = (*dLdb)({0, nUn}); // dLdx = dLdh * dhdx = dLdh * (dhdu * dudx + dhdc * dcdx) = (dLdh * dhdu) * dudx + (dLdh * dhdc) * dcdx = dLdu * dudx + dLdc * dcdx
auto dLdbu = (*dLdb)({nUn, 2*nUn}); // = dLdx_u + dLdx_c
auto dLdbn = (*dLdb)({2*nUn,3*nUn}); // dLdx_u = dLdu * dudx = dLdu * dudZu * dZudx = |dZudx = ... × WuxT| = (dLdu * dudZu) × WuxT
// dLdx_c = dLdc * dcdx = dLdc * dcdZc * (dZcdx + dZcdr * drdx) = dLdc * dcdZc * dZcdx + dLdc * dcdZc * dZcdr * drdx = dLdx_c0 + dLdx_c1
// dLdx_c0 = dLdc * dcdZc * dZcdx = |dZcdx = ... × WcxT| = (dLdc * dcdZc) × WcxT
// dZcdr = (... * hLast) × WchT
// dLdc * dcdZc * dZcdr = dLdr = (dLdc * dcdZc * hLast) × WchT
// drdx = drdZr * dZrdx
// dZrdx = ... × WrxT
// dLdx_c1 = dLdc * dcdZc * dZcdr * drdx = dLdr * drdx = (dLdr * drdZr) × WrxT
// finally dLdx = dLdx_u + dLdx_c0 + dLdx_c1 = (dLdu * dudZu) × WuxT + (dLdc * dcdZc) × WcxT + (dLdr * drdZr) × WrxT
auto dhdu = *h0 - n; // [bS, nUn]
auto dhdn = 1.f - u; // [bS, nUn]
auto dSigdu = u * (1.f - u); // [bS, nUn]
auto dSigdr = r * (1.f - r); // [bS, nUn]
auto dActdn = 1.f - n * n; // [bS, nUn]
auto dndr = mmul(dActdn * (*h0), WhnT);
auto drdh0 = mmul(dSigdr, WhrT);
auto dLdn = (*dLdh) * dhdn; // dLdhLast = dLdh * (dhdhLast + dhdu * dudhLast + dhdc * dcdhLast) = dLdh * dhdhLast + dLdu * dudhLast + dLdc * dcdhLast
auto dLdu = (*dLdh) * dhdu; // = dLdhLast_h + dLdhLast_u + dLdhLast_c
auto dLdr = dLdn * dndr; // dLdhLast_h = dLdh * dhdhLas = dLdh * u
// dLdhLast_u = dLdu * dudhLast = |dudhLast = dudZu * dZudhLast , dZudhLast = ... × WuhT| = (dLdu * dudZu) × WuhT
// dLdhLast_c = dLdc * dcdhLast = dLdc * (dcdZc * dZcdhLast + dcdZc * dZcdr * drdhLast) =
// = dLdc * dcdZc * dZcdhLast + dLdc * dcdZc * dZcdr * drdhLast =
// = dLdc * dcdZc * dZcdhLast + dLdr * drdhLast = dLdhLast_c0 + dLdhLast_c1
// dLdhLast_c0 = dLdc * dcdZc * dZcdhLast = |dZcdhLast = (... * r) × WchT| = (dLdc * dcdZc * r) × WchT
// dLdhLast_c1 = dLdr * drdhLast = |drdhLast = drdZr * dZrdhLast, dZrdhLast = ... × WrhT| = (dLdr * drdZr) × WrhT
// finally dLdhLast = dLdhLast_h + dLdhLast_u + dLdhLast_c0 + dLdhLast_c1 =
// = dLdh * u + (dLdu * dudZu) × WuhT + (dLdc * dcdZc * r) × WchT + (dLdr * drdZr) × WrhT
dLdx->assign( mmul(dLdu * dSigdu, WxuT) + mmul(dLdr * dSigdr, WxrT) + mmul(dLdn * dActdn, WxnT) ); // [bS,iS]
dLdh0->assign( mmul(dLdu * dSigdu, WhuT) + mmul(dLdn * dActdn * (r + drdh0), WhnT) + (*dLdh)*u ); // [bS,nUn]
dLdWxr.assign( mmul(xT, dSigdr * dLdr) ); // [iS,nUn] // dLdWrx = dLdh * dhdWrx = (dLdh * dhdc) * dcdWrx = dLdc * dcdZc * dZcdWrx = dLdc * dcdZc * dZcdr * drdWrx =
dLdWhr.assign( mmul(h0T, dSigdr * dLdr) ); // [nUn,nUn] // = dLdc * dcdZc * dZcdr * drdZr * dZrdWrx = dLdr * drdZr * dZrdWrx
// dZrdWrx = xT × ...
// finally dLdWrx = xT × (dLdr * drdZr)
dLdWxu.assign( mmul(xT, dSigdu * dLdu) ); // [iS,nUn]
dLdWhu.assign( mmul(h0T, dSigdu * dLdu) ); // [nUn,nUn]
dLdWxn.assign( mmul(xT, dActdn * dLdn) ); // [iS,nUn] // dLdWrh = dLdh * dhdWrh = (dLdh * dhdc) * dcdWrh = dLdc * dcdZc * dZcdWrh = dLdc * dcdZc * dZcdr * drdWrh =
dLdWhn.assign( mmul((r*(*h0)).transpose(), dActdn * dLdn) ); // [nUn,nUn] // = dLdc * dcdZc * dZcdr * drdZr * dZrdWrh = dLdr * drdZr * dZrdWrh
// dZrdWrh = hLastT × ...
// finally dLdWrh = hLastT × (dLdr * drdZr)
dLdbr.assign( (dSigdr * dLdr).reduceAlongDims(reduce::Sum, {0})); // [nUn]
dLdbu.assign( (dSigdu * dLdu).reduceAlongDims(reduce::Sum, {0})); // [nUn]
dLdbn.assign( (dActdn * dLdn).reduceAlongDims(reduce::Sum, {0})); // [nUn]
if(dLdWx0 != nullptr) // dLdWux = dLdh * dhdWux = (dLdh * dhdu) * dudWux = dLdu * dudZu * dZudWux
*dLdWx += *dLdWx0; // dZudWux = xT × ...
// dLdu * dudZu * dZudWux = xT × (dLdu * dudZu)
if(dLdWh0 != nullptr)
*dLdWh += *dLdWh0;
if(dLdb0 != nullptr) // dLdWuh = dLdh * dhdWuh = (dLdh * dhdu) * dudWuh = dLdh * dhdu * dudZu * dZudWuh = dLdu * dudZu * dZudWuh
*dLdb += *dLdb0; // dZudWuh = hLastT × ...
// finally dLdWuh = hLastT × (dLdu * dudZu)
// dLdWcx = dLdh * dhdWcx = dLdh * dhdc * dcdWcx = (dLdh * dhdc) * dcdZc * dZcdWcx = dLdc * dcdZc * dZcdWcx
// dZcdWcx = xT × ...
// finally dLdWcx = xT × (dLdc * dcdZc)
// dLdWch = dLdh * dhdWch = dLdh * dhdc * dcdWch = (dLdh * dhdc) * dcdZc * dZcdWch = dLdc * dcdZc * dZcdWch
// dZcdWch = (r*hLast)^T × ...
// finally dLdWch = (r*hLast)^T × (dLdc * dcdZc)
// dLdbr = dLdh * dhdbr = (dLdh * dhdc) * dcdbr = dLdc * dcdbr = dLdc * dcdZc * dZcdbr = dLdc * dcdZc * dZcdr * drdbr =
// = dLdr * drdZr * dZrdbr
// dZrdbr = 1
// finally dLdbr = dLdr * drdZr
// dLdbu = dLdh * dhdbu = (dLdh * dhdu) * dudbu = dLdu * dudZu * dZudbu
// dZudbu = 1
// finally dLdbu = dLdu * dudZu
// dLdbc = dLdh * dhdbc = (dLdh * dhdc) * dcdbc = dLdc * dcdZc * dZcdbc
// dZcdbc = 1
// finally dLdbc = dLdc * dcdZc
NDArray dhdc = 1.f - u; // [bS, nU]
NDArray dhdu = *hLast - c; // [bS, nU]
NDArray dudZu = u * dhdc; // [bS, nU]
NDArray drdZr = r * (1.f - r); // [bS, nU]
NDArray dcdZc = 1.f - c * c; // [bS, nU]
NDArray dLdZc = *dLdc * dcdZc; // [bS, nU]
NDArray dLdZu = *dLdu * dudZu; // [bS, nU]
NDArray dLdZr = *dLdr * drdZr; // [bS, nU]
// NDArray dLdc = *dLdh * dhdc; // [bS, nU]
// NDArray dLdu = *dLdh * dhdu; // [bS, nU]
// NDArray dLdr = mmul(dLdc * dcdZc * *hLast, WchT); // [bS, nU]
dLdx->assign(mmul(dLdZu, WuxT) + mmul(dLdZc, WcxT) + mmul(dLdZr, WrxT)); // [bS, iS]
dLdhLast->assign(*dLdh * u + mmul(dLdZu, WuhT) + mmul(dLdZc * r, WchT) + mmul(dLdZr, WrhT)); // [bS, nU]
dLdWrx.assign(mmul(xT, dLdZr)); // [iS, bS] × [bS, nU] = [iS, nU]
dLdWrh.assign(mmul(hLastT, dLdZr)); // [nU, bS] × [bS, nU] = [nU, nU]
dLdWux.assign(mmul(xT, dLdZu)); // [iS, bS] × [bS, nU] = [iS, nU]
dLdWuh.assign(mmul(hLastT, dLdZu)); // [nU, bS] × [bS, nU] = [nU, nU]
dLdWcx.assign(mmul(xT, dLdZc)); // [iS, bS] × [bS, nU] = [iS, nU]
dLdWch.assign(mmul((r * *hLast).transpose(), dLdZc)); // [nU, bS] × [bS, nU] = [nU, nU]
dLdbr.assign(dLdZr.reduceAlongDims(reduce::Sum, {0})); // [nU]
dLdbu.assign(dLdZu.reduceAlongDims(reduce::Sum, {0})); // [nU]
dLdbc->assign(dLdZc.reduceAlongDims(reduce::Sum, {0})); // [nU]
} }
// ////////////////////////////////////////////////////////////////////////// // //////////////////////////////////////////////////////////////////////////
@ -255,34 +364,34 @@ void gruCellBP(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h
// void gruTimeLoopBP(const std::vector<NDArray<T>*>& inArrs, const std::vector<NDArray<T>*>& outArrs) { // void gruTimeLoopBP(const std::vector<NDArray<T>*>& inArrs, const std::vector<NDArray<T>*>& outArrs) {
// NDArray<T>* x = inArrs[0]; // input [time, bS, iS] // NDArray<T>* x = inArrs[0]; // input [time, bS, iS]
// NDArray<T>* hi = inArrs[1]; // previous/initial cell output [bS, nUn], that is at previous time step t-1 // NDArray<T>* hi = inArrs[1]; // previous/initial cell output [bS, nU], that is at previous time step t-1
// NDArray<T>* Wx = inArrs[2]; // input-to-hidden weights, [iS, 3*nUn] // NDArray<T>* Wx = inArrs[2]; // input-to-hidden weights, [iS, 3*nU]
// NDArray<T>* Wh = inArrs[3]; // hidden-to-hidden weights, [nUn, 3*nUn] // NDArray<T>* Wh = inArrs[3]; // hidden-to-hidden weights, [nU, 3*nU]
// NDArray<T>* b = inArrs[4]; // biases, [3*nUn] // NDArray<T>* b = inArrs[4]; // biases, [3*nU]
// NDArray<T>* dLdh = inArrs[5]; // gradient wrt output, [time, bS, nUn], that is epsilon_next // NDArray<T>* dLdh = inArrs[5]; // gradient wrt output, [time, bS, nU], that is epsilon_next
// NDArray<T>* dLdx = outArrs[0]; // gradient wrt x, [time, bS, iS], that is epsilon // NDArray<T>* dLdx = outArrs[0]; // gradient wrt x, [time, bS, iS], that is epsilon
// NDArray<T>* dLdhi = outArrs[1]; // gradient wrt hi, [bS, nUn] // NDArray<T>* dLdhi = outArrs[1]; // gradient wrt hi, [bS, nU]
// NDArray<T>* dLdWx = outArrs[2]; // gradient wrt Wx, [iS, 3*nUn] // NDArray<T>* dLdWx = outArrs[2]; // gradient wrt Wx, [iS, 3*nU]
// NDArray<T>* dLdWh = outArrs[3]; // gradient wrt Wh, [nUn, 3*nUn] // NDArray<T>* dLdWh = outArrs[3]; // gradient wrt Wh, [nU, 3*nU]
// NDArray<T>* dLdb = outArrs[4]; // gradient wrt b, [3*nUn] // NDArray<T>* dLdb = outArrs[4]; // gradient wrt b, [3*nU]
// const Nd4jLong time = x->sizeAt(0); // const Nd4jLong time = x->sizeAt(0);
// const Nd4jLong bS = x->sizeAt(1); // const Nd4jLong bS = x->sizeAt(1);
// const Nd4jLong iS = x->sizeAt(2); // const Nd4jLong iS = x->sizeAt(2);
// const Nd4jLong nUn = hi->sizeAt(1); // const Nd4jLong nU = hi->sizeAt(1);
// NDArray<T> h(hi->ordering(), {time, bS, nUn}); // feed forward output // NDArray<T> h(hi->ordering(), {time, bS, nU}); // feed forward output
// // first step, time = 0, feed forward // // first step, time = 0, feed forward
// NDArray<T> x0 = (*x)({{0,1}, {}, {}}); // NDArray<T> x0 = (*x)({{0,1}, {}, {}});
// NDArray<T> h0 = h({{0,1}, {}, {}}); // NDArray<T> hLast = h({{0,1}, {}, {}});
// helpers::gruCell<T>({&x0, hi, Wx, Wh, b}, &h0); // helpers::gruCell<T>({&x0, hi, Wx, Wh, b}, &hLast);
// // first step, time = 0, back prop // // first step, time = 0, back prop
// NDArray<T> dLdx0 = (*dLdx)({{0,1}, {}, {}}); // NDArray<T> dLdx0 = (*dLdx)({{0,1}, {}, {}});
// NDArray<T> dLdh0 = (*dLdh)({{0,1}, {}, {}}); // NDArray<T> dLdhLast = (*dLdh)({{0,1}, {}, {}});
// helpers::gruCellBP<T>({&x0, hi, Wx, Wh, b, &dLdh0, nullptr, nullptr, nullptr}, {&dLdx0, dLdhi, dLdWx, dLdWh, dLdb}); // helpers::gruCellBP<T>({&x0, hi, Wx, Wh, b, &dLdhLast, nullptr, nullptr, nullptr}, {&dLdx0, dLdhi, dLdWx, dLdWh, dLdb});
// // loop through the rest time steps // // loop through the rest time steps
// for (Nd4jLong t = time-1; t > 0; --t) { // for (Nd4jLong t = time-1; t > 0; --t) {
@ -310,4 +419,3 @@ void gruCellBP(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h
} }
} }
} }

View File

@ -20,6 +20,8 @@
#include <ops/declarable/helpers/image_suppression.h> #include <ops/declarable/helpers/image_suppression.h>
//#include <blas/NDArray.h> //#include <blas/NDArray.h>
#include <algorithm>
#include <numeric>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
@ -28,9 +30,8 @@ namespace helpers {
template <typename T> template <typename T>
static void nonMaxSuppressionV2_(NDArray* boxes, NDArray* scales, int maxSize, double threshold, NDArray* output) { static void nonMaxSuppressionV2_(NDArray* boxes, NDArray* scales, int maxSize, double threshold, NDArray* output) {
std::vector<Nd4jLong> indices(scales->lengthOf()); std::vector<Nd4jLong> indices(scales->lengthOf());
std::iota(indices.begin(), indices.end(), 0);
for (size_t i = 0; i < indices.size(); ++i)
indices[i] = i;
std::sort(indices.begin(), indices.end(), [scales](int i, int j) {return scales->e<T>(i) > scales->e<T>(j);}); std::sort(indices.begin(), indices.end(), [scales](int i, int j) {return scales->e<T>(i) > scales->e<T>(j);});
// std::vector<int> selected(output->lengthOf()); // std::vector<int> selected(output->lengthOf());
@ -62,13 +63,15 @@ namespace helpers {
}; };
// int numSelected = 0; // int numSelected = 0;
int numBoxes = boxes->sizeAt(0); int numBoxes = boxes->sizeAt(0);
int numSelected = 0;
for (int i = 0, numSelected = 0; i < numBoxes && numSelected < output->lengthOf(); ++i) { for (int i = 0; i < numBoxes; ++i) {
bool shouldSelect = true; bool shouldSelect = numSelected < output->lengthOf();
PRAGMA_OMP_PARALLEL_FOR //_ARGS(firstprivate(numSelected))
for (int j = numSelected - 1; j >= 0; --j) { for (int j = numSelected - 1; j >= 0; --j) {
if (shouldSelect)
if (needToSuppressWithThreshold(*boxes, indices[i], indices[selectedIndices[j]], T(threshold))) { if (needToSuppressWithThreshold(*boxes, indices[i], indices[selectedIndices[j]], T(threshold))) {
shouldSelect = false; shouldSelect = false;
break;
} }
} }
if (shouldSelect) { if (shouldSelect) {

View File

@ -24,20 +24,20 @@ namespace nd4j {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
template <typename T> template <typename I, typename B>
static void sequenceMask_(NDArray* input, NDArray* output, int maxIndex) { static void sequenceMask_(NDArray* input, NDArray* output, int maxIndex) {
PRAGMA_OMP_PARALLEL_FOR_SIMD_COLLAPSE(2) PRAGMA_OMP_PARALLEL_FOR_SIMD_COLLAPSE(2)
for (Nd4jLong i = 0; i < maxIndex; i++) for (Nd4jLong i = 0; i < maxIndex; i++)
for(Nd4jLong k = 0; k < input->lengthOf(); k++) for(Nd4jLong k = 0; k < input->lengthOf(); k++)
if (i < input->e<int>(k)) if (i < input->t<I>(k))
output->p<T>(k * maxIndex + i, T(1.0f)); output->t<B>(k * maxIndex + i) = B(true); //, T(1.0f));
} }
void sequenceMask(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int maxIndex) { void sequenceMask(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int maxIndex) {
BUILD_SINGLE_SELECTOR(input->dataType(), sequenceMask_, (input, output, maxIndex), LIBND4J_TYPES); BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sequenceMask_, (input, output, maxIndex), INTEGER_TYPES, BOOL_TYPES);
} }
BUILD_SINGLE_TEMPLATE(template void sequenceMask_, (NDArray* input, NDArray* output, int maxIndex), LIBND4J_TYPES); BUILD_DOUBLE_TEMPLATE(template void sequenceMask_, (NDArray* input, NDArray* output, int maxIndex), INTEGER_TYPES, BOOL_TYPES);
} }
} }
} }

View File

@ -27,6 +27,7 @@
#include <helpers/TAD.h> #include <helpers/TAD.h>
#include <helpers/ConstantTadHelper.h> #include <helpers/ConstantTadHelper.h>
#include <Loops.h> #include <Loops.h>
#include <graph/RandomGenerator.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
@ -81,7 +82,7 @@ static void trace_(const NDArray& input, NDArray& output) {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template <typename T> template <typename T>
void randomShuffle_(NDArray& input, NDArray& output, nd4j::random::RandomBuffer& rng, const bool isInplace) { void randomShuffle_(NDArray& input, NDArray& output, nd4j::graph::RandomGenerator& rng, const bool isInplace) {
// check edge cases first // check edge cases first
int temp; int temp;
@ -95,16 +96,16 @@ void randomShuffle_(NDArray& input, NDArray& output, nd4j::random::RandomBuffer&
// apply Fisher-Yates shuffle // apply Fisher-Yates shuffle
if(isInplace) { if(isInplace) {
PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->tadThreshold()) //PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->tadThreshold())
for(int i = firstDim-1; i > 0; --i) { for(int i = firstDim-1; i > 0; --i) {
int r = rng.nextInt(0, i); int r = rng.relativeInt(i) % i;
if(i == r) if(i == r)
continue; continue;
T _e0 = input.e<T>(i); T t0 = input.t<T>(i);
T _e1 = input.e<T>(r); T t1 = input.t<T>(r);
//math::nd4j_swap<T>(input(i), input(r)); //math::nd4j_swap<T>(input(i), input(r));
input.p<T>(i, _e1); input.t<T>(i) = t1;
input.p<T>(r, _e0); input.t<T>(r) = t0;
} }
} }
else { else {
@ -113,12 +114,12 @@ void randomShuffle_(NDArray& input, NDArray& output, nd4j::random::RandomBuffer&
output.p<T>(Nd4jLong(0), input.e<T>(0)); output.p<T>(Nd4jLong(0), input.e<T>(0));
PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->tadThreshold()) PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->tadThreshold())
for(int i = firstDim-1; i > 0; --i) { for(int i = firstDim-1; i > 0; --i) {
int r = rng.nextInt(0, i); int r = rng.relativeInt(i) % i;
output.p(i, input.e<T>(indices[r])); output.t<T>(i) = input.t<T>(indices[r]);
if(i == r) if(i == r)
continue; continue;
output.p(r, input.e<T>(indices[i])); output.t<T>(r) = input.t<T>(indices[i]);
math::nd4j_swap<int>(indices[i], indices[r]); math::nd4j_swap<int>(indices[i], indices[r]);
} }
rng.rewindH(firstDim-1); rng.rewindH(firstDim-1);
@ -132,9 +133,10 @@ void randomShuffle_(NDArray& input, NDArray& output, nd4j::random::RandomBuffer&
// apply Fisher-Yates shuffle // apply Fisher-Yates shuffle
if(isInplace) { if(isInplace) {
PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->elementwiseThreshold()) //PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->elementwiseThreshold())
for(int i = firstDim - 1; i > 0; --i) { for(int i = firstDim - 1; i > 0; --i) {
int r = rng.nextInt(0, i); int r = rng.relativeInt(i) % i;
if(i == r) if(i == r)
continue; continue;
subArrsListIn->at(i)->swapUnsafe(*subArrsListIn->at(r)); subArrsListIn->at(i)->swapUnsafe(*subArrsListIn->at(r));
@ -146,9 +148,9 @@ void randomShuffle_(NDArray& input, NDArray& output, nd4j::random::RandomBuffer&
std::vector<int> indices(firstDim); std::vector<int> indices(firstDim);
std::iota(indices.begin(), indices.end(), 0); std::iota(indices.begin(), indices.end(), 0);
bool isZeroShuffled = false; bool isZeroShuffled = false;
PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->tadThreshold()) //PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->tadThreshold())
for(int i = firstDim - 1; i > 0; --i) { for(int i = firstDim - 1; i > 0; --i) {
int r = rng.nextInt(0, i); int r = rng.relativeInt(i) % i;
subArrsListOut->at(i)->assign(subArrsListIn->at(indices[r])); subArrsListOut->at(i)->assign(subArrsListIn->at(indices[r]));
if(r == 0) if(r == 0)
isZeroShuffled = true; isZeroShuffled = true;
@ -167,11 +169,11 @@ void randomShuffle_(NDArray& input, NDArray& output, nd4j::random::RandomBuffer&
} }
void randomShuffle(nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::random::RandomBuffer& rng, const bool isInplace) { void randomShuffle(nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::graph::RandomGenerator& rng, const bool isInplace) {
BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (input, output, rng, isInplace), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (input, output, rng, isInplace), LIBND4J_TYPES);
} }
BUILD_SINGLE_TEMPLATE(template void randomShuffle_, (NDArray& input, NDArray& output, nd4j::random::RandomBuffer& rng, const bool isInplace), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void randomShuffle_, (NDArray& input, NDArray& output, nd4j::graph::RandomGenerator& rng, const bool isInplace), LIBND4J_TYPES);

View File

@ -16,15 +16,92 @@
// //
// @author raver119@gmail.com // @author raver119@gmail.com
// @author Yurii Shyrma (iuriish@yahoo.com)
// //
#include <ops/declarable/helpers/adjust_hue.h> #include <ops/declarable/helpers/adjust_hue.h>
#include <helpers/ConstantTadHelper.h> #include <helpers/ConstantTadHelper.h>
#include <PointersManager.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
///////////////////////////////////////////////////////////////////
template <typename T>
static void _CUDA_G adjustHueCuda(const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets,
void* vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets,
const Nd4jLong numOfTads, const T delta, const int dimC) {
const T* x = reinterpret_cast<const T*>(vx);
T* z = reinterpret_cast<T*>(vz);
__shared__ int rank;
__shared__ Nd4jLong xDimCstride, zDimCstride;
if (threadIdx.x == 0) {
rank = shape::rank(xShapeInfo);
xDimCstride = shape::stride(xShapeInfo)[dimC];
zDimCstride = shape::stride(zShapeInfo)[dimC];
}
__syncthreads();
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) {
const T* xTad = x + xTadOffsets[i];
T* zTad = z + zTadOffsets[i];
T h, s, v;
rgbToHsv<T>(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], h, s, v);
h += delta * 360;
if(h > 360)
h -= 360;
else if(h < 0)
h += 360;
hsvToRgb<T>(h, s, v, zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]);
}
}
///////////////////////////////////////////////////////////////////
template<typename T>
static _CUDA_H void adjustHueCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream,
const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets,
void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets,
const Nd4jLong numOfTads, const NDArray* deltaScalarArr, const int dimC) {
adjustHueCuda<T><<<blocksPerGrid, threadsPerBlock, 256, *stream>>>(vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, deltaScalarArr->e<T>(0), dimC);
}
BUILD_SINGLE_TEMPLATE(template void adjustHueCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, const NDArray* deltaScalarArr, const int dimC), LIBND4J_TYPES);
////////////////////////////////////////////////////////////////////////
void adjustHue(nd4j::LaunchContext* context, const NDArray *input, const NDArray* deltaScalarArr, NDArray *output, const int dimC) {
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {dimC});
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {dimC});
const Nd4jLong numOfTads = packX.numberOfTads();
const int threadsPerBlock = MAX_NUM_THREADS / 2;
const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock;
PointersManager manager(context, "adjustHue");
NDArray::prepareSpecialUse({output}, {input, deltaScalarArr});
BUILD_SINGLE_SELECTOR(input->dataType(), adjustHueCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformOffsets(), numOfTads, deltaScalarArr, dimC), LIBND4J_TYPES);
NDArray::registerSpecialUse({output}, {input, deltaScalarArr});
manager.synchronize();
}
/*
template <typename T> template <typename T>
static void _CUDA_G adjustHueSingleNHWCKernel(void *xBuffer, Nd4jLong *xShapeInfo, void *zBuffer, Nd4jLong *zShapeInfo, Nd4jLong tuples, float delta) { static void _CUDA_G adjustHueSingleNHWCKernel(void *xBuffer, Nd4jLong *xShapeInfo, void *zBuffer, Nd4jLong *zShapeInfo, Nd4jLong tuples, float delta) {
int numChannels = 3; int numChannels = 3;
@ -134,11 +211,13 @@ namespace helpers {
float d = delta->e<float>(0); float d = delta->e<float>(0);
if (array->rankOf() == 4) { if (array->rankOf() == 4) {
BUILD_SINGLE_SELECTOR(xType, _adjust_hue_batch, (context, array, output, d, isNHWC);, FLOAT_TYPES);
} else { } else {
BUILD_SINGLE_SELECTOR(xType, _adjust_hue_single, (context, array, output, d, isNHWC);, FLOAT_TYPES); BUILD_SINGLE_SELECTOR(xType, _adjust_hue_single, (context, array, output, d, isNHWC);, FLOAT_TYPES);
} }
} }
*/
} }
} }
} }

View File

@ -16,16 +16,93 @@
// //
// @author raver119@gmail.com // @author raver119@gmail.com
// @author Yurii Shyrma (iuriish@yahoo.com)
// //
#include <ops/declarable/helpers/adjust_saturation.h> #include <ops/declarable/helpers/adjust_saturation.h>
#include <ops/declarable/helpers/adjust_hue.h>
#include <helpers/ConstantTadHelper.h> #include <helpers/ConstantTadHelper.h>
#include <PointersManager.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
///////////////////////////////////////////////////////////////////
template <typename T>
static void _CUDA_G adjustSaturationCuda(const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets,
void* vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets,
const Nd4jLong numOfTads, const T factor, const int dimC) {
const T* x = reinterpret_cast<const T*>(vx);
T* z = reinterpret_cast<T*>(vz);
__shared__ int rank;
__shared__ Nd4jLong xDimCstride, zDimCstride;
if (threadIdx.x == 0) {
rank = shape::rank(xShapeInfo);
xDimCstride = shape::stride(xShapeInfo)[dimC];
zDimCstride = shape::stride(zShapeInfo)[dimC];
}
__syncthreads();
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) {
const T* xTad = x + xTadOffsets[i];
T* zTad = z + zTadOffsets[i];
T h, s, v;
rgbToHsv<T>(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], h, s, v);
s *= factor;
if(s > 1.f)
s = 1.f;
else if(s < 0.f)
s = 0.f;
hsvToRgb<T>(h, s, v, zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]);
}
}
///////////////////////////////////////////////////////////////////
template<typename T>
static _CUDA_H void adjustSaturationCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream,
const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets,
void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets,
const Nd4jLong numOfTads, const NDArray* factorScalarArr, const int dimC) {
adjustSaturationCuda<T><<<blocksPerGrid, threadsPerBlock, 256, *stream>>>(vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, factorScalarArr->e<T>(0), dimC);
}
BUILD_SINGLE_TEMPLATE(template void adjustSaturationCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, const NDArray* factorScalarArr, const int dimC), LIBND4J_TYPES);
////////////////////////////////////////////////////////////////////////
void adjustSaturation(nd4j::LaunchContext* context, const NDArray *input, const NDArray* factorScalarArr, NDArray *output, const int dimC) {
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {dimC});
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {dimC});
const Nd4jLong numOfTads = packX.numberOfTads();
const int threadsPerBlock = MAX_NUM_THREADS / 2;
const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock;
PointersManager manager(context, "adjustSaturation");
NDArray::prepareSpecialUse({output}, {input, factorScalarArr});
BUILD_SINGLE_SELECTOR(input->dataType(), adjustSaturationCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformOffsets(), numOfTads, factorScalarArr, dimC), LIBND4J_TYPES);
NDArray::registerSpecialUse({output}, {input, factorScalarArr});
manager.synchronize();
}
/*
template <typename T> template <typename T>
static void _CUDA_G adjustSaturationSingleNHWCKernel(void *xBuffer, Nd4jLong *xShapeInfo, void *zBuffer, Nd4jLong *zShapeInfo, Nd4jLong tuples, float delta) { static void _CUDA_G adjustSaturationSingleNHWCKernel(void *xBuffer, Nd4jLong *xShapeInfo, void *zBuffer, Nd4jLong *zShapeInfo, Nd4jLong tuples, float delta) {
int numChannels = 3; int numChannels = 3;
@ -129,7 +206,7 @@ namespace helpers {
BUILD_SINGLE_SELECTOR(xType, _adjust_saturation_single, (context, array, output, d, isNHWC);, FLOAT_TYPES); BUILD_SINGLE_SELECTOR(xType, _adjust_saturation_single, (context, array, output, d, isNHWC);, FLOAT_TYPES);
} }
} }
*/
} }
} }

View File

@ -22,20 +22,99 @@
#include <NativeOps.h> #include <NativeOps.h>
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <cuda_exception.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
template <typename T> template <typename T>
static void dropoutSimple(NDArray const* input, NDArray* output, double probValue, int seed) { static __global__ void dropoutSimpleKernel(void const* inputBuf, Nd4jLong const* inputShape, void* outputBuf, Nd4jLong* outputShape, double probVal, int inLen, nd4j::graph::RandomGenerator* nodeRng) {
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
auto step = blockDim.x * gridDim.x;
__shared__ T const* input;
__shared__ T* output;
if (threadIdx.x == 0) {
input = reinterpret_cast<T const*>(inputBuf);
output = reinterpret_cast<T*>(outputBuf);
} }
BUILD_SINGLE_TEMPLATE(template void dropoutSimple, (NDArray const* input, NDArray* output, double probValue, int seed), FLOAT_TYPES);
for (Nd4jLong e = 0; e < inLen; ++e) {
T val = nodeRng->relativeT(e, T(0.f), T(1.f));
if (double(val) < probVal)
output[shape::getIndexOffset(e, outputShape, inLen)] = T(input[shape::getIndexOffset(e, inputShape, inLen)] / probVal);
}
}
template <typename T>
static void dropoutSimple(nd4j::LaunchContext* context, NDArray const* input, NDArray* output, double probValue, int seed) {
nd4j::graph::RandomGenerator nodeRng(3019L, seed);
int inLen = input->lengthOf();
nd4j::graph::RandomGenerator* dRandom;
auto stream = context->getCudaStream();
NDArray::prepareSpecialUse({output}, {input});
auto err = cudaMalloc(&dRandom, sizeof(nd4j::graph::RandomGenerator));
if (err) {
throw cuda_exception::build("helpers::dropoutSimple: Cannot allocate device memory for random generator.", err);
}
err = cudaMemcpy(dRandom, &nodeRng, sizeof(nd4j::graph::RandomGenerator), cudaMemcpyHostToDevice);
if (err) {
throw cuda_exception::build("helpers::dropoutSimple: Cannot set up device memory for random generator.", err);
}
dropoutSimpleKernel<T><<<128, 256, 1024, *stream>>>(input->getSpecialBuffer(), input->getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), probValue, inLen, dRandom);
err = cudaFree(dRandom);
if (err) {
throw cuda_exception::build("helpers::dropoutSimple: Cannot deallocate device memory for random generator.", err);
}
NDArray::registerSpecialUse({output}, {input});
}
BUILD_SINGLE_TEMPLATE(template void dropoutSimple, (nd4j::LaunchContext* context, NDArray const* input, NDArray* output, double probValue, int seed), FLOAT_TYPES);
template <typename T> template <typename T>
int _dropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue) { int _dropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue) {
if (reduceShape == nullptr){
dropoutSimple<T>(context.launchContext(), input, output, probValue, seed);
}
else {
REQUIRE_TRUE(reduceShape->lengthOf() <= input->rankOf(), 0, "dropout: Noise shape should be fittable to input");
std::vector<Nd4jLong> dims(reduceShape->lengthOf());
reduceShape->syncToHost(); // to ensure that follows are actual
bool fit = true;
// PRAGMA_OMP_PARALLEL_FOR_ARGS(firstprivate(fit))
for( int i = 0; i < dims.size(); i++ ) {
if (fit) {
dims[i] = reduceShape->e<Nd4jLong>(i);
for (int e = 0; e < input->rankOf(); ++e)
if (fit)
if (input->sizeAt(e) % dims[i]) {
fit = false;
}
}
}
// check dims to fit input
REQUIRE_TRUE(fit, 0, "dropout: Noise shape should fit to input rank.");
std::unique_ptr<NDArray> chunk(new NDArray('c', dims, output->dataType(), context.launchContext()));
chunk->assign(1.f);
//chunk->applyRandom<randomOps::DropOutInverted<T>>(rng, nullptr, chunk.get(), &probValue);
//NativeOpExecutioner::execRandom(random::DropOutInverted, rng, chunk->buffer(), chunk->shapeInfo(), chunk->buffer(), chunk->shapeInfo(), &prob);
dropoutSimple<T>(context.launchContext(), chunk.get(), chunk.get(), probValue, seed);
// broadcast chunk to full matrix
std::unique_ptr<NDArray> dropOutMultiplier(new NDArray(*input));
dropOutMultiplier->assign(1.f);
*dropOutMultiplier += *chunk;
output->assign(*input * *dropOutMultiplier); //input->applyPairwiseTransform(pairwise::Multiply, dropOutMultiplier.get(), output, nullptr);
}
return Status::OK(); return Status::OK();
} }
@ -48,14 +127,121 @@ namespace helpers {
BUILD_SINGLE_TEMPLATE(template int _dropOutFunctor, (graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue);, FLOAT_TYPES); BUILD_SINGLE_TEMPLATE(template int _dropOutFunctor, (graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue);, FLOAT_TYPES);
/////////////////////////////////// backrpopagations /////////////////////////////////////////////// /////////////////////////////////// backrpopagations ///////////////////////////////////////////////
template <typename T>
static __global__ void dropoutBPKernel(void* outputBuf, Nd4jLong* outputShape, void* gradOutBuf, Nd4jLong* gradOutShape, double probValue) {
__shared__ T* output;
__shared__ T* input;
__shared__ int len;
if (threadIdx.x == 0) {
len = shape::length(outputShape);
output = reinterpret_cast<T*>(outputBuf);
input = reinterpret_cast<T*>(gradOutBuf);
}
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
auto step = blockDim.x * gridDim.x;
for (int e = tid; e < len; e += step) {
if (output[shape::getIndexOffset(e, outputShape, len)] != T(0.))
output[shape::getIndexOffset(e, outputShape, len)] = T(input[shape::getIndexOffset(e, gradOutShape, len)] / probValue);
}
}
template <typename T> template <typename T>
static int dropOutFunctorBP_(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue) { static int dropOutFunctorBP_(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue) {
return Status::OK(); int res = dropOutFunctor(context, input, output, reduceShape, seed, probValue);
auto stream = context.launchContext()->getCudaStream();
if (ND4J_STATUS_OK == res)
dropoutBPKernel<T><<<128, 256, 1024, *stream>>>(output->specialBuffer(), output->specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), probValue);
return res;
}
template <typename T>
static __global__ void alphaDropoutSimpleKernel(void const* inputBuf, Nd4jLong const* inputShape, void* outputBuf, Nd4jLong* outputShape, double probValue, double alpha, double alpha1, double beta, int inLen, nd4j::graph::RandomGenerator* nodeRng) {
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
auto step = blockDim.x * gridDim.x;
__shared__ T const* input;
__shared__ T* output;
if (threadIdx.x == 0) {
input = reinterpret_cast<T const*>(inputBuf);
output = reinterpret_cast<T*>(outputBuf);
}
for (auto e = tid; e < inLen; e += step) {
T val = nodeRng->relativeT(e, T(0.f), T(1.f));
T xVal = input[shape::getIndexOffset(e, inputShape, inLen)];
output[shape::getIndexOffset(e, outputShape, inLen)] = (val >= T(probValue) ? T(alpha * beta + alpha1) : T(alpha * (double)xVal + alpha1));
}
}
template <typename T>
static void alphaDropoutSimple(nd4j::LaunchContext* context, NDArray const* input, NDArray* output, int seed, double probValue, double alpha, double alpha1, double beta) {
nd4j::graph::RandomGenerator nodeRng(3019L, seed), *dRandom;
auto stream = context->getCudaStream();
auto err = cudaMalloc(&dRandom, sizeof(nd4j::graph::RandomGenerator));
NDArray::prepareSpecialUse({output}, {input});
if (err) {
throw cuda_exception::build("helpers::alphaDropoutSimple: Cannot allocate device memory for random generator.", err);
}
err = cudaMemcpy(dRandom, &nodeRng, sizeof(nd4j::graph::RandomGenerator), cudaMemcpyHostToDevice);
if (err) {
throw cuda_exception::build("helpers::alphaDropoutSimple: Cannot set up device memory for random generator.", err);
}
alphaDropoutSimpleKernel<T><<<128, 256, 1024, *stream>>>(input->getSpecialBuffer(), input->getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), probValue, alpha, alpha1, beta, output->lengthOf(), dRandom);
err = cudaFree(dRandom);
if (err) {
throw cuda_exception::build("helpers::alphaDropoutSimple: Cannot deallocate device memory for random generator.", err);
}
NDArray::registerSpecialUse({output}, {input});
} }
template <typename T> template <typename T>
static int alphaDropOutFunctor_(graph::Context& context, NDArray* input, NDArray* output, static int alphaDropOutFunctor_(graph::Context& context, NDArray* input, NDArray* output,
NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) { NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) {
if (reduceShape == nullptr){
alphaDropoutSimple<T>(context.launchContext(), input, output, seed, probValue, alpha, alpha1, beta);
}
else {
REQUIRE_TRUE(reduceShape->lengthOf() <= input->rankOf(), 0, "dropout: Noise shape should be fittable to input");
std::vector<Nd4jLong> dims(reduceShape->lengthOf());
reduceShape->syncToHost(); // to ensure that follows are actual
bool fit = true;
// PRAGMA_OMP_PARALLEL_FOR_ARGS(firstprivate(fit))
for( int i = 0; i < dims.size(); i++ ) {
if (fit) {
dims[i] = reduceShape->e<Nd4jLong>(i);
for (int e = 0; e < input->rankOf(); ++e)
if (fit)
if (input->sizeAt(e) % dims[i]) {
fit = false;
}
}
}
// check dims to fit input
REQUIRE_TRUE(fit, 0, "alpha_dropout: Noise shape should fit to input rank.");
std::unique_ptr<NDArray> chunk(new NDArray('c', dims, output->dataType(), context.launchContext()));
chunk->assign(1.f);
//chunk->applyRandom<randomOps::DropOutInverted<T>>(rng, nullptr, chunk.get(), &probValue);
//NativeOpExecutioner::execRandom(random::DropOutInverted, rng, chunk->buffer(), chunk->shapeInfo(), chunk->buffer(), chunk->shapeInfo(), &prob);
alphaDropoutSimple<T>(context.launchContext(), chunk.get(), chunk.get(), seed, probValue, alpha, alpha1, beta);
// broadcast chunk to full matrix
std::unique_ptr<NDArray> dropOutMultiplier(new NDArray(*input));
dropOutMultiplier->assign(1.f);
*dropOutMultiplier += *chunk;
output->assign(*input * *dropOutMultiplier); //input->applyPairwiseTransform(pairwise::Multiply, dropOutMultiplier.get(), output, nullptr);
}
return Status::OK(); return Status::OK();
} }
@ -63,7 +249,12 @@ namespace helpers {
int alphaDropOutFunctorBP_(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, int alphaDropOutFunctorBP_(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output,
NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) { NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) {
return Status::OK(); int res = alphaDropOutFunctor(context, input, output, reduceShape, seed, probValue, alpha, alpha1, beta);
if (res == ND4J_STATUS_OK) {
(*output) *= alpha;
(*output) *= (*gradOut); //->applyPairwiseTransform<transform::Multiply>(gradOut, output, nullptr);
}
return res;
} }
int dropOutFunctorBP(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue) { int dropOutFunctorBP(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue) {

View File

@ -35,58 +35,88 @@ namespace helpers {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void gruCell(nd4j::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* Wru, const NDArray* Wc, void gruCell(nd4j::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* W, const NDArray* Wc,
const NDArray* bru, const NDArray* bc, const NDArray* b, const NDArray* bc,
NDArray* r, NDArray* u, NDArray* c, NDArray* h) { NDArray* r, NDArray* u, NDArray* c, NDArray* h) {
//Inputs: //Inputs:
// x input [bS x inSize] // x input [bS, iS], iS - input size
// hLast previous cell output [bS x numUnits], that is at previous time step t-1 // hLast previous cell output [bS, nU], that is at previous time step t-1, nU - number of units
// Wru RU weights - [bS, 2*numUnits] - reset and update gates // W RU weights - [iS+nU, 2*nU] - reset and update gates
// Wc C weights - [bS, numUnits] - cell gate // Wc C weights - [iS+nU, nU] - cell gate
// bru r and u biases, [2*numUnits] - reset and update gates // b r and u biases, [2*nU] - reset and update gates
// bc c biases, [numUnits] - cell gate // bc c biases, [nU] - cell gate
//Outputs: //Outputs:
// r Reset gate output [bS, numUnits] // r Reset gate output [bS, nU]
// u Update gate output [bS, numUnits] // u Update gate output [bS, nU]
// c Cell gate output [bS, numUnits] // c Cell gate output [bS, nU]
// h current cell output [bS, numUnits] // h current cell output [bS, nU]
const int nIn = x->sizeAt(1); /***************************************************************************************/
const int nU = hLast->sizeAt(1); // number of units /************************ THIS IS NOT OPTIMAZED CODE ***********************************/
/** however it is more math-friendly and convenient for backprop formulas derivation) **/
//Concat inputs: [x, yt-1]: concat([bs,nIn],[bs,nOut]) -> [bs, (nIn+nOut)] const int bS = x->sizeAt(0);
nd4j::ops::concat concatOp; const int iS = x->sizeAt(1);
std::vector<NDArray*> inputs; const int nU = hLast->sizeAt(1);
std::vector<double> targs;
std::vector<Nd4jLong> iargs({1}); //Axis = 1
std::vector<bool> bargs;
inputs.emplace_back(const_cast<NDArray*>(x));
inputs.emplace_back(const_cast<NDArray*>(hLast));
auto result = concatOp.execute(inputs, targs, iargs, bargs); NDArray Wrx = (*W)({0,iS, 0,nU}); // [iS, nU]
auto concatOut = result->at(0); NDArray Wux = (*W)({0,iS, nU,2*nU}); // [iS, nU]
NDArray Wrh = (*W)({iS,iS+nU, 0,nU}); // [nU, nU]
NDArray Wuh = (*W)({iS,iS+nU, nU,2*nU}); // [nU, nU]
//mmul/z for reset and update gates: (x * weight_ux + hLast * weight_xr + b_u) NDArray Wcx = (*Wc)({0,iS, 0,0}); // reset cell weights [iS, nU]
auto m = mmul(*concatOut, *Wru); //mmul: [bs, (nIn+numUnits)]* [(inSize+numUnits), 2*numUnits] = [bs, 4*numUnits] NDArray Wch = (*Wc)({iS,iS+nU, 0,0}); // updates cell weights [nU, nU]
m += (*bru);
sigmoidInplace(m); //sigmoid(rz) and sigmoid(uz) NDArray br = (*b)({0, nU}); // [nU]
auto mr = m({0,0, 0, nU}); NDArray bu = (*b)({nU, 2*nU}); // [nU]
auto mu = m({0,0, nU, 2*nU});
r->assign(&mr); // × means matrix multipication
u->assign(&mu); // * means element-wise product or so called Hadamard product
//Concatenated inputs: [x, yt-1 .* r] // reset gate
auto yr = (*concatOut)({0,0, nIn, nIn+nU}); r->assign(mmul(*x, Wrx) + mmul(*hLast, Wrh) + br); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
yr *= (*r); r->applyTransform(transform::Sigmoid);
//c = tanh(x * weight_cx + (hLast .* r) * weight_cr + b_c) // update gate
MmulHelper::mmul(concatOut, const_cast<NDArray*>(Wc), c, 1.0, 0.0); //c = 1.0 * concatOut * Wc + 0.0 * c u->assign(mmul(*x, Wux) + mmul(*hLast, Wuh) + bu); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
u->applyTransform(transform::Sigmoid);
// cell gate c = activation(x × Wcx + (r * hlast) × Wch + bc)
c->assign(mmul(*x, Wcx) + mmul(*r * *hLast, Wch) + *bc); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
c->applyTransform(transform::Tanh);
NDArray temp = 1.f - *c * *c;
// cell output
h->assign(*u * *hLast + (1.f - *u) * *c);
/***************************************************************************************/
/*************** THIS IS MORE OPTIMAZED CODE (should think about concat) ***************/
/***************************************************************************************/
/*
//Concat inputs: x + hLast : [bs, iS + nU]
NDArray xhConcat(x->ordering(), {bS, iS + nU}, x->dataType(), context); // concat([bs, iS], [bs, nU]) -> [bs, iS + nU]
helpers::concat(context, {const_cast<NDArray*>(x), const_cast<NDArray*>(hLast)}, xhConcat, {1});
//mmul for reset and update gates: (x × weight_ux + hLast × weight_xr + b_u)
auto m = mmul(xhConcat, *W) + *b ; // [bs, iS+nU] * [iS+nU, 2*nU] = [bs, 2*nU]
// m += *bru;
m.applyTransform(transform::Sigmoid); //sigmoid(rz) and sigmoid(uz)
r->assign(m({0,0, 0, nU}));
u->assign(m({0,0, nU, 2*nU}));
// hLast = hLast * r
xhConcat({0,0, iS, iS+nU}) *= *r;
//c = tanh(x × weight_cx + (hLast * r) × weight_cr + b_c)
MmulHelper::mmul(&xhConcat, Wc, c, 1.0, 0.0); //c = 1.0 * xhConcat * Wc + 0.0 * c
*c += *bc; *c += *bc;
tanhInplace(*c); c->applyTransform(transform::Tanh);
//Output: h = (1-u).*c + u .* hPrev //Output: h = (1-u).*c + u .* hPrev
//auto hResult = (*u) * (*hLast) + (1.0f - *u) * (*c); const_cast<NDArray*>(h)->assign(&hResult); //auto hResult = (*u) * (*hLast) + (1.0f - *u) * (*c); const_cast<NDArray*>(h)->assign(&hResult);
@ -94,115 +124,238 @@ void gruCell(nd4j::LaunchContext * context, const NDArray* x, const NDArray* hLa
auto temp = (1.0f - *u); auto temp = (1.0f - *u);
temp *= (*c); temp *= (*c);
(*h) += temp; (*h) += temp;
*/
delete result;
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void gruTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h) { void gruTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h) {
} // x input [time, bS, iS]
// hLast initial cell output (at time step = 0) [bS, nU]
//////////////////////////////////////////////////////////////////////////
void gruCellBP(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh, const NDArray* dLdWx0,
const NDArray* dLdWh0, const NDArray* dLdb0, NDArray* dLdx, NDArray* dLdh0, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb) {
// x input [bS, iS]
// h0 previous cell output [bS, nU], that is at previous time step t-1
// Wx input-to-hidden weights, [iS, 3*nU] // Wx input-to-hidden weights, [iS, 3*nU]
// Wh hidden-to-hidden weights, [nU, 3*nU] // Wh hidden-to-hidden weights, [nU, 3*nU]
// b biases, [3*nU] // b biases, [3*nU]
// dLdh gradient wrt output, [bS,nU], that is epsilon_next
// dLdWx0 gradient wrt Wx at previous time step, [iS, 3*nU]
// dLdWh0 gradient wrt Wh at previous time step, [nU, 3*nU]
// dLdb0 gradient wrt b at previous time step, [3*nU]
// dLdx gradient wrt x, [bS, iS], that is epsilon // h is cell outputs at each time step [time, bS, nU]
// dLdh0 gradient wrt h0, [bS, nU]
// dLdWx gradient wrt Wx, [iS, 3*nU]
// dLdWh gradient wrt Wh, [nU, 3*nU]
// dLdb gradient wrt b at previous time step, [3*nU]
// h is current cell output [bS, nU], that is at current time step t const int time = x->sizeAt(0);
NDArray ht_1(*hLast);
// loop through time steps
for (int t = 0; t < time; ++t) {
auto xt = (*x)({t,t+1, 0,0, 0,0});
auto ht = (*h)({t,t+1, 0,0, 0,0});
// helpers::gruCell(&xt, &ht_1, Wx, Wh, b, &ht);
// ht_1.assign(ht);
}
}
//////////////////////////////////////////////////////////////////////////
void gruCellBP(nd4j::LaunchContext* context,
const NDArray* x, const NDArray* hLast,
const NDArray* W, const NDArray* Wc, const NDArray* b, const NDArray* bc,
const NDArray* dLdr, const NDArray* dLdu, const NDArray* dLdc, const NDArray* dLdh,
NDArray* dLdx, NDArray* dLdhLast,
NDArray* dLdW, NDArray* dLdWc,
NDArray* dLdb, NDArray* dLdbc) {
//Inputs:
// x input [bS, iS]
// hLast previous cell output [bS, nU], that is at previous time step t-1
// W weights - [iS+nU, 2*nU] - reset and update gates
// Wc C weights - [iS+nU, nU] - cell gate
// b r and u biases, [2*nU] - reset and update gates
// bc c biases, [nU] - cell gate
// dLdr gradient wrt reset gate, [bS, nU]
// dLdu gradient wrt update gate, [bS, nU]
// dLdc gradient wrt cell state, [bS, nU]
// dLdh gradient wrt current cell output, [bS, nU]
//Outputs:
// dLdx gradient wrt x, [bS, iS],
// dLdhLast gradient wrt hLast, [bS, nU]
// dLdW gradient wrt W, [iS+nU, 2*nU]
// dLdWc gradient wrt Wc, [iS+nU, nU]
// dLdb gradient wrt bru [2*nU]
// dLdbc gradient wrt bc [nU]
// * means element-wise product or so called Hadamard product
// × means matrix multiplication
/************************************************************************************************/
/******************************* THIS IS NOT OPTIMAZED CODE *************************************/
/*** aim is to have math-readable code in order to keep track of backprop formulas derivation ***/
const int bS = x->sizeAt(0);
const int iS = x->sizeAt(1);
const int nU = hLast->sizeAt(1);
NDArray xT = x->transpose(); // [iS, bS]
NDArray hLastT = hLast->transpose(); // [nU, bS]
NDArray Wrx = (*W)({0,iS, 0,nU}); // [iS, nU]
NDArray Wux = (*W)({0,iS, nU,2*nU}); // [iS, nU]
NDArray Wrh = (*W)({iS,iS+nU, 0,nU}); // [nU, nU]
NDArray Wuh = (*W)({iS,iS+nU, nU,2*nU}); // [nU, nU]
NDArray Wcx = (*Wc)({0,iS, 0,0}); // reset cell weights [iS, nU]
NDArray Wch = (*Wc)({iS,iS+nU, 0,0}); // updates cell weights [nU, nU]
NDArray br = (*b)({0, nU}); // [nU]
NDArray bu = (*b)({nU, 2*nU}); // [nU]
NDArray WrxT = Wrx.transpose(); // [nU, iS]
NDArray WuxT = Wux.transpose(); // [nU, iS]
NDArray WrhT = Wrh.transpose(); // [nU, nU]
NDArray WuhT = Wuh.transpose(); // [nU, nU]
NDArray WcxT = Wcx.transpose(); // [nU, iS]
NDArray WchT = Wch.transpose(); // [nU, nU]
NDArray dLdWrx = (*dLdW)({0,iS, 0,nU}); // [iS, nU]
NDArray dLdWux = (*dLdW)({0,iS, nU,2*nU}); // [iS, nU]
NDArray dLdWrh = (*dLdW)({iS,iS+nU, 0,nU}); // [nU, nU]
NDArray dLdWuh = (*dLdW)({iS,iS+nU, nU,2*nU}); // [nU, nU]
NDArray dLdWcx = (*dLdWc)({0,iS, 0,0}); // [iS, nU]
NDArray dLdWch = (*dLdWc)({iS,iS+nU, 0,0}); // [nU, nU]
NDArray dLdbr = (*dLdb)({0, nU}); // [nU]
NDArray dLdbu = (*dLdb)({nU, 2*nU}); // [nU]
const int nU = h0->sizeAt(1);
// ***** feed forward step ***** // // ***** feed forward step ***** //
// gates = sigmoid(x*Wx + h0*Wh + b)
auto gates = sigmoid(mmul(*x, (*Wx)({0,0, 0,2*nU})) + mmul(*h0, (*Wh)({0,0, 0,2*nU})) + (*b)({0,2*nU})); // [bS, 2*nU] + [bS, 2*nU] + [1, 2*nU] = [bS, 2*nU]
// reset gate // reset gate
auto r = gates({0,0, 0, nU}); // [bS, nU] NDArray r = mmul(*x, Wrx) + mmul(*hLast, Wrh) + br; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
r.applyTransform(transform::Sigmoid);
// update gate // update gate
auto u = gates({0,0, nU, 2*nU}); // [bS, nU] NDArray u = mmul(*x, Wux) + mmul(*hLast, Wuh) + bu; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
// ◦ means element-wise product or so called Hadamard product u.applyTransform(transform::Sigmoid);
// n = tanh(x*Wx + (r◦h0)*Wh + b)
auto n = tanh(mmul(*x, (*Wx)({0,0, 2*nU,3*nU})) + mmul((*h0)*r, (*Wh)({0,0, 2*nU,3*nU})) + (*b)({2*nU,3*nU})); // [bS, nU] // cell gate c = activation(x×Wcx + (r*hlast)×Wcu + bc)
NDArray c = mmul(*x, Wcx) + mmul(r * *hLast, Wch) + *bc; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
c.applyTransform(transform::Tanh);
// h = (1 - u) * c + u * hPrev
// ***** back prop step ***** // // ***** back prop step ***** //
auto Wxr = (*Wx)({0,0, 0, nU});
auto Wxu = (*Wx)({0,0, nU, 2*nU});
auto Wxn = (*Wx)({0,0, 2*nU,3*nU});
auto Whr = (*Wh)({0,0, 0, nU});
auto Whu = (*Wh)({0,0, nU, 2*nU});
auto Whn = (*Wh)({0,0, 2*nU,3*nU});
auto WxrT = Wxr.transpose();
auto WxuT = Wxu.transpose();
auto WxnT = Wxn.transpose();
auto WhrT = Whr.transpose();
auto WhuT = Whu.transpose();
auto WhnT = Whn.transpose();
auto xT = x->transpose();
auto h0T = h0->transpose();
auto dLdWxr = (*dLdWx)({0,0, 0, nU}); // notations:
auto dLdWxu = (*dLdWx)({0,0, nU, 2*nU}); // Zr = x × Wrx + hLast × Wrh + br
auto dLdWxn = (*dLdWx)({0,0, 2*nU,3*nU}); // Zu = x × Wux + hLast × Wuh + bu
// Sr = sigmoid(Zr)
// Su = sigmoid(Zu)
// Zc = x × Wcx + (r * hlast) × Wch + bc
auto dLdWhr = (*dLdWh)({0,0, 0, nU});
auto dLdWhu = (*dLdWh)({0,0, nU, 2*nU});
auto dLdWhn = (*dLdWh)({0,0, 2*nU,3*nU});
auto dLdbr = (*dLdb)({0, nU}); // dLdx = dLdh * dhdx = dLdh * (dhdu * dudx + dhdc * dcdx) = (dLdh * dhdu) * dudx + (dLdh * dhdc) * dcdx = dLdu * dudx + dLdc * dcdx
auto dLdbu = (*dLdb)({nU, 2*nU}); // = dLdx_u + dLdx_c
auto dLdbn = (*dLdb)({2*nU,3*nU}); // dLdx_u = dLdu * dudx = dLdu * dudZu * dZudx = |dZudx = ... × WuxT| = (dLdu * dudZu) × WuxT
// dLdx_c = dLdc * dcdx = dLdc * dcdZc * (dZcdx + dZcdr * drdx) = dLdc * dcdZc * dZcdx + dLdc * dcdZc * dZcdr * drdx = dLdx_c0 + dLdx_c1
// dLdx_c0 = dLdc * dcdZc * dZcdx = |dZcdx = ... × WcxT| = (dLdc * dcdZc) × WcxT
// dZcdr = (... * hLast) × WchT
// dLdc * dcdZc * dZcdr = dLdr = (dLdc * dcdZc * hLast) × WchT
// drdx = drdZr * dZrdx
// dZrdx = ... × WrxT
// dLdx_c1 = dLdc * dcdZc * dZcdr * drdx = dLdr * drdx = (dLdr * drdZr) × WrxT
// finally dLdx = dLdx_u + dLdx_c0 + dLdx_c1 = (dLdu * dudZu) × WuxT + (dLdc * dcdZc) × WcxT + (dLdr * drdZr) × WrxT
auto dhdu = *h0 - n; // [bS, nU]
auto dhdn = 1.f - u; // [bS, nU]
auto dSigdu = u * (1.f - u); // [bS, nU]
auto dSigdr = r * (1.f - r); // [bS, nU]
auto dActdn = 1.f - n * n; // [bS, nU]
auto dndr = mmul(dActdn * (*h0), WhnT);
auto drdh0 = mmul(dSigdr, WhrT);
auto dLdn = (*dLdh) * dhdn; // dLdhLast = dLdh * (dhdhLast + dhdu * dudhLast + dhdc * dcdhLast) = dLdh * dhdhLast + dLdu * dudhLast + dLdc * dcdhLast
auto dLdu = (*dLdh) * dhdu; // = dLdhLast_h + dLdhLast_u + dLdhLast_c
auto dLdr = dLdn * dndr; // dLdhLast_h = dLdh * dhdhLas = dLdh * u
// dLdhLast_u = dLdu * dudhLast = |dudhLast = dudZu * dZudhLast , dZudhLast = ... × WuhT| = (dLdu * dudZu) × WuhT
// dLdhLast_c = dLdc * dcdhLast = dLdc * (dcdZc * dZcdhLast + dcdZc * dZcdr * drdhLast) =
// = dLdc * dcdZc * dZcdhLast + dLdc * dcdZc * dZcdr * drdhLast =
// = dLdc * dcdZc * dZcdhLast + dLdr * drdhLast = dLdhLast_c0 + dLdhLast_c1
// dLdhLast_c0 = dLdc * dcdZc * dZcdhLast = |dZcdhLast = (... * r) × WchT| = (dLdc * dcdZc * r) × WchT
// dLdhLast_c1 = dLdr * drdhLast = |drdhLast = drdZr * dZrdhLast, dZrdhLast = ... × WrhT| = (dLdr * drdZr) × WrhT
// finally dLdhLast = dLdhLast_h + dLdhLast_u + dLdhLast_c0 + dLdhLast_c1 =
// = dLdh * u + (dLdu * dudZu) × WuhT + (dLdc * dcdZc * r) × WchT + (dLdr * drdZr) × WrhT
dLdx->assign( mmul(dLdu * dSigdu, WxuT) + mmul(dLdr * dSigdr, WxrT) + mmul(dLdn * dActdn, WxnT) ); // [bS,iS]
dLdh0->assign( mmul(dLdu * dSigdu, WhuT) + mmul(dLdn * dActdn * (r + drdh0), WhnT) + (*dLdh)*u ); // [bS,nU]
dLdWxr.assign( mmul(xT, dSigdr * dLdr) ); // [iS,nU] // dLdWrx = dLdh * dhdWrx = (dLdh * dhdc) * dcdWrx = dLdc * dcdZc * dZcdWrx = dLdc * dcdZc * dZcdr * drdWrx =
dLdWhr.assign( mmul(h0T, dSigdr * dLdr) ); // [nU,nU] // = dLdc * dcdZc * dZcdr * drdZr * dZrdWrx = dLdr * drdZr * dZrdWrx
// dZrdWrx = xT × ...
// finally dLdWrx = xT × (dLdr * drdZr)
dLdWxu.assign( mmul(xT, dSigdu * dLdu) ); // [iS,nU]
dLdWhu.assign( mmul(h0T, dSigdu * dLdu) ); // [nU,nU]
dLdWxn.assign( mmul(xT, dActdn * dLdn) ); // [iS,nU] // dLdWrh = dLdh * dhdWrh = (dLdh * dhdc) * dcdWrh = dLdc * dcdZc * dZcdWrh = dLdc * dcdZc * dZcdr * drdWrh =
dLdWhn.assign( mmul((r*(*h0)).transpose(), dActdn * dLdn) ); // [nU,nU] // = dLdc * dcdZc * dZcdr * drdZr * dZrdWrh = dLdr * drdZr * dZrdWrh
// dZrdWrh = hLastT × ...
// finally dLdWrh = hLastT × (dLdr * drdZr)
dLdbr.assign( (dSigdr * dLdr).reduceAlongDims(reduce::Sum, {0})); // [nU]
dLdbu.assign( (dSigdu * dLdu).reduceAlongDims(reduce::Sum, {0})); // [nU]
dLdbn.assign( (dActdn * dLdn).reduceAlongDims(reduce::Sum, {0})); // [nU]
if(dLdWx0 != nullptr) // dLdWux = dLdh * dhdWux = (dLdh * dhdu) * dudWux = dLdu * dudZu * dZudWux
*dLdWx += *dLdWx0; // dZudWux = xT × ...
// dLdu * dudZu * dZudWux = xT × (dLdu * dudZu)
if(dLdWh0 != nullptr)
*dLdWh += *dLdWh0;
if(dLdb0 != nullptr) // dLdWuh = dLdh * dhdWuh = (dLdh * dhdu) * dudWuh = dLdh * dhdu * dudZu * dZudWuh = dLdu * dudZu * dZudWuh
*dLdb += *dLdb0; // dZudWuh = hLastT × ...
// finally dLdWuh = hLastT × (dLdu * dudZu)
// dLdWcx = dLdh * dhdWcx = dLdh * dhdc * dcdWcx = (dLdh * dhdc) * dcdZc * dZcdWcx = dLdc * dcdZc * dZcdWcx
// dZcdWcx = xT × ...
// finally dLdWcx = xT × (dLdc * dcdZc)
// dLdWch = dLdh * dhdWch = dLdh * dhdc * dcdWch = (dLdh * dhdc) * dcdZc * dZcdWch = dLdc * dcdZc * dZcdWch
// dZcdWch = (r*hLast)^T × ...
// finally dLdWch = (r*hLast)^T × (dLdc * dcdZc)
// dLdbr = dLdh * dhdbr = (dLdh * dhdc) * dcdbr = dLdc * dcdbr = dLdc * dcdZc * dZcdbr = dLdc * dcdZc * dZcdr * drdbr =
// = dLdr * drdZr * dZrdbr
// dZrdbr = 1
// finally dLdbr = dLdr * drdZr
// dLdbu = dLdh * dhdbu = (dLdh * dhdu) * dudbu = dLdu * dudZu * dZudbu
// dZudbu = 1
// finally dLdbu = dLdu * dudZu
// dLdbc = dLdh * dhdbc = (dLdh * dhdc) * dcdbc = dLdc * dcdZc * dZcdbc
// dZcdbc = 1
// finally dLdbc = dLdc * dcdZc
NDArray dhdc = 1.f - u; // [bS, nU]
NDArray dhdu = *hLast - c; // [bS, nU]
NDArray dudZu = u * dhdc; // [bS, nU]
NDArray drdZr = r * (1.f - r); // [bS, nU]
NDArray dcdZc = 1.f - c * c; // [bS, nU]
NDArray dLdZc = *dLdc * dcdZc; // [bS, nU]
NDArray dLdZu = *dLdu * dudZu; // [bS, nU]
NDArray dLdZr = *dLdr * drdZr; // [bS, nU]
// NDArray dLdc = *dLdh * dhdc; // [bS, nU]
// NDArray dLdu = *dLdh * dhdu; // [bS, nU]
// NDArray dLdr = mmul(dLdc * dcdZc * *hLast, WchT); // [bS, nU]
dLdx->assign(mmul(dLdZu, WuxT) + mmul(dLdZc, WcxT) + mmul(dLdZr, WrxT)); // [bS, iS]
dLdhLast->assign(*dLdh * u + mmul(dLdZu, WuhT) + mmul(dLdZc * r, WchT) + mmul(dLdZr, WrhT)); // [bS, nU]
dLdWrx.assign(mmul(xT, dLdZr)); // [iS, bS] × [bS, nU] = [iS, nU]
dLdWrh.assign(mmul(hLastT, dLdZr)); // [nU, bS] × [bS, nU] = [nU, nU]
dLdWux.assign(mmul(xT, dLdZu)); // [iS, bS] × [bS, nU] = [iS, nU]
dLdWuh.assign(mmul(hLastT, dLdZu)); // [nU, bS] × [bS, nU] = [nU, nU]
dLdWcx.assign(mmul(xT, dLdZc)); // [iS, bS] × [bS, nU] = [iS, nU]
dLdWch.assign(mmul((r * *hLast).transpose(), dLdZc)); // [nU, bS] × [bS, nU] = [nU, nU]
dLdbr.assign(dLdZr.reduceAlongDims(reduce::Sum, {0})); // [nU]
dLdbu.assign(dLdZu.reduceAlongDims(reduce::Sum, {0})); // [nU]
dLdbc->assign(dLdZc.reduceAlongDims(reduce::Sum, {0})); // [nU]
} }

View File

@ -20,12 +20,111 @@
#include <ops/declarable/helpers/hashcode.h> #include <ops/declarable/helpers/hashcode.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
template <typename T>
static __global__ void splitBufferToChuncks(T* buffer, Nd4jLong* tempBuffer, Nd4jLong numBlocks, Nd4jLong blockSize, Nd4jLong length) {
for (int b = blockIdx.x; b < numBlocks; b += gridDim.x) {
auto blockBuffer = buffer + b * numBlocks;
Nd4jLong r = 1;
for (int e = threadIdx.x; e < blockSize && e + (b * numBlocks) < length; e += blockDim.x) {
auto v = longBytes<T>(blockBuffer[e]);
r = 31 * r + v;
}
tempBuffer[b] = r;
}
}
template <typename T>
static __global__ void internalHash(Nd4jLong* tempBuffer, Nd4jLong* tempResult, Nd4jLong numBlocks, Nd4jLong blockSize, Nd4jLong lastLength) {
for (int b = blockIdx.x; b < numBlocks; b += gridDim.x) {
auto blockBuffer = tempBuffer + b * numBlocks;
Nd4jLong r = 1;
for (int e = threadIdx.x; e < blockSize && e + (b * numBlocks) < lastLength; e += blockDim.x) {
auto v = longBytes<T>(blockBuffer[e]);
r = 31 * r + v;
}
tempResult[b] = r;
}
}
static __global__ void lastStep(Nd4jLong* resultBuf, Nd4jLong* tempBufferA, Nd4jLong* tempResult, Nd4jLong length, Nd4jLong blockSize) {
if (threadIdx.x == 0) {
if (length <= blockSize)
*resultBuf = *tempBufferA;
else
*resultBuf = *tempResult;
}
}
template <typename T>
void hashCode_(LaunchContext *context, NDArray &array, NDArray &result) {
auto blockSize = 32;
auto stream = context->getCudaStream();
array.syncToDevice();
NDArray::prepareSpecialUse({&result}, {&array});
auto length = array.lengthOf();
int numBlocks = length / blockSize + ((length % blockSize == 0) ? 0 : 1);
auto tempA = NDArrayFactory::create<Nd4jLong>('c', {numBlocks}, context);
auto tempB = NDArrayFactory::create<Nd4jLong>('c', { numBlocks / blockSize + 1}, context);
auto buffer = reinterpret_cast<T*>(array.specialBuffer()); //bufferAsT<T>();
auto tempBufferA = reinterpret_cast<Nd4jLong*>(tempA.specialBuffer()); //bufferAsT<Nd4jLong>();
auto tempBufferB = reinterpret_cast<Nd4jLong*>(tempB.specialBuffer()); //bufferAsT<Nd4jLong>();
// default buffer is the first one, because it might be the last one in case of small arrays (< blockSize)
auto tempBuffer = tempBufferA;
auto tempResult = tempBufferB;
// we divide array into 32 element chunks, and store intermediate results once
splitBufferToChuncks<T><<<numBlocks, length, 1024, *stream>>>(buffer, tempBuffer, numBlocks, blockSize, length);
// we replace pointer with intermediate one, and repeat only one chunk left
int iterationCount = 0;
while (numBlocks > 1) {
int lastLength = numBlocks;
numBlocks = lastLength / blockSize + ((lastLength % blockSize == 0) ? 0 : 1);
internalHash<Nd4jLong><<<numBlocks, lastLength, 1024, *stream>>>(tempBuffer, tempResult, numBlocks, blockSize, lastLength);
iterationCount++;
// swapping buffers
if (iterationCount % 2 == 0) {
tempBuffer = tempBufferA;
tempResult = tempBufferB;
} else {
tempBuffer = tempBufferB;
tempResult = tempBufferA;
}
}
//lastStep<Nd4jLong><<<1,1,128, *stream>>>(result.specialBuffer(), tempBufferA, tempResult, length, blockSize);
tempA.syncToHost();
tempB.syncToHost();
result.assign((length <= blockSize?tempA.e(0) : tempB.e(0)));
NDArray::registerSpecialUse({&result}, {&array});
}
void hashCode(LaunchContext *context, NDArray &array, NDArray &result) { void hashCode(LaunchContext *context, NDArray &array, NDArray &result) {
BUILD_SINGLE_SELECTOR(array.dataType(), hashCode_, (context, array, result), LIBND4J_TYPES);
} }
BUILD_SINGLE_TEMPLATE(template void hashCode_, (LaunchContext* context, NDArray& array, NDArray& result), LIBND4J_TYPES);
} }
} }
} }

View File

@ -20,6 +20,8 @@
#include <ops/declarable/helpers/image_suppression.h> #include <ops/declarable/helpers/image_suppression.h>
#include <NDArrayFactory.h> #include <NDArrayFactory.h>
#include <NativeOps.h>
#include <cuda_exception.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
@ -35,15 +37,16 @@ namespace helpers {
Nd4jLong next1[] = {nextIndex, 1}; Nd4jLong next1[] = {nextIndex, 1};
Nd4jLong next2[] = {nextIndex, 2}; Nd4jLong next2[] = {nextIndex, 2};
Nd4jLong next3[] = {nextIndex, 3}; Nd4jLong next3[] = {nextIndex, 3};
Nd4jLong* shapeOf = shape::shapeOf(boxesShape);
T minYPrev = nd4j::math::nd4j_min(boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), previous0, 2)], boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), previous2, 2)]); Nd4jLong* strideOf = shape::stride(boxesShape);
T minXPrev = nd4j::math::nd4j_min(boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), previous1, 2)], boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), previous3, 2)]); T minYPrev = nd4j::math::nd4j_min(boxes[shape::getOffset(0, shapeOf, strideOf, previous0, 2)], boxes[shape::getOffset(0, shapeOf, strideOf, previous2, 2)]);
T maxYPrev = nd4j::math::nd4j_max(boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), previous0, 2)], boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), previous2, 2)]); T minXPrev = nd4j::math::nd4j_min(boxes[shape::getOffset(0, shapeOf, strideOf, previous1, 2)], boxes[shape::getOffset(0, shapeOf, strideOf, previous3, 2)]);
T maxXPrev = nd4j::math::nd4j_max(boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), previous1, 2)], boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), previous3, 2)]); T maxYPrev = nd4j::math::nd4j_max(boxes[shape::getOffset(0, shapeOf, strideOf, previous0, 2)], boxes[shape::getOffset(0, shapeOf, strideOf, previous2, 2)]);
T minYNext = nd4j::math::nd4j_min(boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), next0, 2)], boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), next2, 2)]); T maxXPrev = nd4j::math::nd4j_max(boxes[shape::getOffset(0, shapeOf, strideOf, previous1, 2)], boxes[shape::getOffset(0, shapeOf, strideOf, previous3, 2)]);
T minXNext = nd4j::math::nd4j_min(boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), next1, 2)], boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), next3, 2)]); T minYNext = nd4j::math::nd4j_min(boxes[shape::getOffset(0, shapeOf, strideOf, next0, 2)], boxes[shape::getOffset(0, shapeOf, strideOf, next2, 2)]);
T maxYNext = nd4j::math::nd4j_max(boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), next0, 2)], boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), next2, 2)]); T minXNext = nd4j::math::nd4j_min(boxes[shape::getOffset(0, shapeOf, strideOf, next1, 2)], boxes[shape::getOffset(0, shapeOf, strideOf, next3, 2)]);
T maxXNext = nd4j::math::nd4j_max(boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), next1, 2)], boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), next3, 2)]); T maxYNext = nd4j::math::nd4j_max(boxes[shape::getOffset(0, shapeOf, strideOf, next0, 2)], boxes[shape::getOffset(0, shapeOf, strideOf, next2, 2)]);
T maxXNext = nd4j::math::nd4j_max(boxes[shape::getOffset(0, shapeOf, strideOf, next1, 2)], boxes[shape::getOffset(0, shapeOf, strideOf, next3, 2)]);
T areaPrev = (maxYPrev - minYPrev) * (maxXPrev - minXPrev); T areaPrev = (maxYPrev - minYPrev) * (maxXPrev - minXPrev);
T areaNext = (maxYNext - minYNext) * (maxXNext - minXNext); T areaNext = (maxYNext - minYNext) * (maxXNext - minXNext);
@ -62,149 +65,101 @@ namespace helpers {
}; };
template <typename T, typename I> template <typename T, typename I>
static __global__ void nonMaxSuppressionKernel(T* boxes, Nd4jLong* boxesShape, I* indices, int* selectedIndices, Nd4jLong numBoxes, I* output, Nd4jLong* outputShape, T threshold) { static __global__ void shouldSelectKernel(T* boxesBuf, Nd4jLong* boxesShape, I* indexBuf, I* selectedIndicesData, double threshold, int numSelected, int i, bool* shouldSelect) {
__shared__ Nd4jLong outputLen; auto tid = blockIdx.x * blockDim.x + threadIdx.x;
auto step = gridDim.x * blockDim.x;
__shared__ bool shouldSelectShared;
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
outputLen = shape::length(outputShape); shouldSelectShared = shouldSelect[0];
} }
__syncthreads(); __syncthreads();
for (int j = numSelected - 1 - tid; j >= 0; j -= step) {
if (shouldSelectShared) {
if (needToSuppressWithThreshold(boxesBuf, boxesShape, indexBuf[i],
indexBuf[selectedIndicesData[j]], T(threshold)))
shouldSelectShared = false;
}
}
__syncthreads();
if (threadIdx.x == 0) {
*shouldSelect = shouldSelectShared;
}
}
template <typename I>
auto numSelected = blockIdx.x; static __global__ void copyIndices(void* indices, void* indicesLong, Nd4jLong len) {
auto start = blockIdx.x * blockDim.x + threadIdx.x; __shared__ I* indexBuf;
__shared__ Nd4jLong* srcBuf;
if (threadIdx.x == 0) {
indexBuf = reinterpret_cast<I*>(indices);
srcBuf = reinterpret_cast<Nd4jLong*>(indicesLong);
}
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
auto step = blockDim.x * gridDim.x; auto step = blockDim.x * gridDim.x;
// for (int numSelected = blockIdx.x; numSelected < outputLen; numSelected += gridDim.x) {
for (int i = start; i < numBoxes; i += step) {
bool shouldSelect = true;
for (int j = numSelected - 1; shouldSelect && j >= 0; --j) {
if (needToSuppressWithThreshold<T>(boxes, boxesShape, indices[i], indices[selectedIndices[j]], threshold)) {
shouldSelect = false;
}
}
if (shouldSelect) { for (auto i = tid; i < len; i += step)
auto zPos = shape::getIndexOffset(numSelected, outputShape, outputLen); indexBuf[i] = (I)srcBuf[i];
output[zPos] = indices[i];
selectedIndices[numSelected] = i;
}
}
}
template <typename T, typename I>
static __global__ void sortIndices(I* indices, Nd4jLong* indexShape, T* scores, Nd4jLong* scoreShape) {
__shared__ Nd4jLong len;
// __shared__ Nd4jLong* sortedPart;
// __shared__ Nd4jLong part;
// __shared__ Nd4jLong partSize;
if (threadIdx.x == 0) {
// blocksPerArr = (gridDim.x + numOfArrs - 1) / numOfArrs; // ceil
// part = blockIdx.x / blocksPerArr;
len = shape::length(indexShape);
// __shared__ Nd4jLong* shmem = shared[];
// sortedPart = shmem;
}
for (int m = 0; m < len; m++) {
if (m % 2 == 0) {
for (int tid = threadIdx.x; tid < len; tid += blockDim.x) {
auto top = 2 * tid + 1;
if (top < len) {
auto t0 = shape::getIndexOffset(top - 1, indexShape, len);
auto t1 = shape::getIndexOffset(top, indexShape, len);
auto z0 = shape::getIndexOffset(top - 1, scoreShape, len);
auto z1 = shape::getIndexOffset(top, scoreShape, len);
if (scores[t0] < scores[t1]) {
// swap indices first
Nd4jLong di0 = indices[t0];
indices[t0] = indices[t1];
indices[t1] = di0;
//swap scores next
// T dz0 = scores[z0];
// scores[z0] = scores[z1];
// scores[z1] = dz0;
}
}
}
} else {
for (int tid = threadIdx.x; tid < len; tid += blockDim.x) {
auto top = 2 * tid + 2;
if (top < len) {
auto t0 = shape::getIndexOffset(top - 1, indexShape, len);
auto t1 = shape::getIndexOffset(top, indexShape, len);
auto z0 = shape::getIndexOffset(top - 1, scoreShape, len);
auto z1 = shape::getIndexOffset(top, scoreShape, len);
if (scores[t0] < scores[t1]) {
// swap indices first
Nd4jLong di0 = indices[t0];
indices[t0] = indices[t1];
indices[t1] = di0;
//swap scores next
// T dz0 = scores[z0];
// scores[z0] = scores[z1];
// scores[z1] = dz0;
}
}
}
}
__syncthreads();
}
} }
template <typename T, typename I> template <typename T, typename I>
static void nonMaxSuppressionV2_(nd4j::LaunchContext* context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, NDArray* output) { static void nonMaxSuppressionV2_(nd4j::LaunchContext* context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, NDArray* output) {
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
NDArray::prepareSpecialUse({output}, {boxes, scales}); NDArray::prepareSpecialUse({output}, {boxes, scales});
NDArray* indices = NDArrayFactory::create_<I>('c', {scales->lengthOf()}); // - 1, scales->lengthOf()); //, scales->getContext()); std::unique_ptr<NDArray> indices(NDArrayFactory::create_<I>('c', {scales->lengthOf()})); // - 1, scales->lengthOf()); //, scales->getContext());
indices->linspace(0); indices->linspace(0);
indices->syncToDevice(); // linspace only on CPU, so sync to Device as well
NDArray scores(*scales); NDArray scores(*scales);
indices->syncToHost(); //linspace(0); NativeOps nativeOps;
I* indexBuf = reinterpret_cast<I*>(indices->specialBuffer());
T* scoreBuf = reinterpret_cast<T*>(scores.specialBuffer()); Nd4jPointer extras[2] = {nullptr, stream};
sortIndices<T, I><<<1, 32, 128, *stream>>>(indexBuf, indices->specialShapeInfo(), scoreBuf, scores.specialShapeInfo());
nativeOps.sortByValue(extras, indices->buffer(), indices->shapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), scores.buffer(), scores.shapeInfo(), scores.specialBuffer(), scores.specialShapeInfo(), true);
// TO DO: sort indices using scales as value row // TO DO: sort indices using scales as value row
//std::sort(indices.begin(), indices.end(), [scales](int i, int j) {return scales->e<T>(i) > scales->e<T>(j);}); //std::sort(indices.begin(), indices.end(), [scales](int i, int j) {return scales->e<T>(i) > scales->e<T>(j);});
indices->tickWriteDevice(); I* indexBuf = reinterpret_cast<I*>(indices->specialBuffer());
indices->syncToHost();
indices->printIndexedBuffer("AFTERSORT OUTPUT");
NDArray selected = NDArrayFactory::create<int>({output->lengthOf()});
NDArray selectedIndices = NDArrayFactory::create<int>({output->lengthOf()}); NDArray selectedIndices = NDArrayFactory::create<I>('c', {output->lengthOf()});
int numSelected = 0; int numSelected = 0;
int numBoxes = boxes->sizeAt(0); int numBoxes = boxes->sizeAt(0);
T* boxesBuf = reinterpret_cast<T*>(boxes->specialBuffer()); T* boxesBuf = reinterpret_cast<T*>(boxes->specialBuffer());
// Nd4jLong* indicesData = reinterpret_cast<Nd4jLong*>(indices->specialBuffer());
// int* selectedData = reinterpret_cast<int*>(selected.specialBuffer()); I* selectedIndicesData = reinterpret_cast<I*>(selectedIndices.specialBuffer());
int* selectedIndicesData = reinterpret_cast<int*>(selectedIndices.specialBuffer());
I* outputBuf = reinterpret_cast<I*>(output->specialBuffer()); I* outputBuf = reinterpret_cast<I*>(output->specialBuffer());
nonMaxSuppressionKernel<T, I><<<output->lengthOf(), 512, 1024, *stream>>>(boxesBuf, boxes->specialShapeInfo(), indexBuf, selectedIndicesData, numBoxes, outputBuf, output->specialShapeInfo(), T(threshold));
NDArray::registerSpecialUse({output}, {boxes, scales}); bool* shouldSelectD;
// for (int i = 0; i < boxes->sizeAt(0); ++i) { auto err = cudaMalloc(&shouldSelectD, sizeof(bool));
// if (selected.size() >= output->lengthOf()) break; if (err) {
// bool shouldSelect = true; throw cuda_exception::build("helpers::nonMaxSuppressionV2: Cannot allocate memory for bool flag", err);
// // Overlapping boxes are likely to have similar scores, }
// // therefore we iterate through the selected boxes backwards. for (I i = 0; i < boxes->sizeAt(0); ++i) {
// for (int j = numSelected - 1; j >= 0; --j) { bool shouldSelect = numSelected < output->lengthOf();
// if (needToSuppressWithThreshold(*boxes, indices[i], indices[selectedIndices[j]], T(threshold)) { if (shouldSelect) {
// shouldSelect = false; err = cudaMemcpy(shouldSelectD, &shouldSelect, sizeof(bool), cudaMemcpyHostToDevice);
// break; if (err) {
// } throw cuda_exception::build("helpers::nonMaxSuppressionV2: Cannot set up bool flag to device", err);
// } }
// if (shouldSelect) {
// selected.push_back(indices[i]); shouldSelectKernel<T> <<< 128, 256, 1024, *stream >>>
// selectedIndices[numSelected++] = i; (boxesBuf, boxes->specialShapeInfo(), indexBuf, selectedIndicesData, threshold, numSelected, i, shouldSelectD);
// } err = cudaMemcpy(&shouldSelect, shouldSelectD, sizeof(bool), cudaMemcpyDeviceToHost);
// } if (err) {
// for (size_t e = 0; e < selected.size(); ++e) throw cuda_exception::build("helpers::nonMaxSuppressionV2: Cannot set up bool flag to host", err);
// output->p<int>(e, selected[e]); }
// }
delete indices;
if (shouldSelect) {
cudaMemcpy(reinterpret_cast<I*>(output->specialBuffer()) + numSelected, indexBuf + i, sizeof(I), cudaMemcpyDeviceToDevice);
cudaMemcpy(selectedIndicesData + numSelected, &i, sizeof(I), cudaMemcpyHostToDevice);
numSelected++;
}
}
err = cudaFree(shouldSelectD);
if (err) {
throw cuda_exception::build("helpers::nonMaxSuppressionV2: Cannot deallocate memory for bool flag", err);
}
} }
void nonMaxSuppressionV2(nd4j::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, NDArray* output) { void nonMaxSuppressionV2(nd4j::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, NDArray* output) {

View File

@ -32,24 +32,24 @@ namespace nd4j {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
template <typename T> // template <typename T>
static __device__ void swapRows_(T* matrix, Nd4jLong* shape, int theFirst, int theSecond, Nd4jLong N) { // static __device__ void swapRows_(T* matrix, Nd4jLong* shape, int theFirst, int theSecond, Nd4jLong N) {
if (theFirst != theSecond) { // if (theFirst != theSecond) {
auto start = threadIdx.x + blockIdx.x * blockDim.x; // auto start = threadIdx.x + blockIdx.x * blockDim.x;
auto step = blockDim.x * gridDim.x; // auto step = blockDim.x * gridDim.x;
for (auto i = start; i < N; i += step) { // for (auto i = start; i < N; i += step) {
Nd4jLong iCoord1[] = {theFirst, i}; // Nd4jLong iCoord1[] = {theFirst, i};
Nd4jLong iCoord2[] = {theSecond, i}; // Nd4jLong iCoord2[] = {theSecond, i};
auto iIndex1 = shape::getOffset(0, shape::shapeOf(shape), shape::stride(shape), iCoord1, 2); // auto iIndex1 = shape::getOffset(0, shape::shapeOf(shape), shape::stride(shape), iCoord1, 2);
auto iIndex2 = shape::getOffset(0, shape::shapeOf(shape), shape::stride(shape), iCoord2, 2); // auto iIndex2 = shape::getOffset(0, shape::shapeOf(shape), shape::stride(shape), iCoord2, 2);
//atomicExch(&matrix[iIndex1], matrix[iIndex2]); // //atomicExch(&matrix[iIndex1], matrix[iIndex2]);
T e0 = matrix[iIndex1]; // T e0 = matrix[iIndex1];
T e1 = matrix[iIndex2]; // T e1 = matrix[iIndex2];
matrix[iIndex1] = e0; // matrix[iIndex1] = e0;
matrix[iIndex2] = e1; // matrix[iIndex2] = e1;
} // }
} // }
} // }
// BUILD_SINGLE_TEMPLATE(template void swapRows_, (NDArray* matrix, int theFirst, int theSecond), FLOAT_TYPES); // BUILD_SINGLE_TEMPLATE(template void swapRows_, (NDArray* matrix, int theFirst, int theSecond), FLOAT_TYPES);
// //
// void swapRows(NDArray* matrix, int theFirst, int theSecond) { // void swapRows(NDArray* matrix, int theFirst, int theSecond) {
@ -71,9 +71,14 @@ namespace helpers {
for (int i = start + 1; i < n; i += step) { for (int i = start + 1; i < n; i += step) {
Nd4jLong pos[] = {i, i - 1}; Nd4jLong pos[] = {i, i - 1};
Nd4jLong posX[] = {i, i};
Nd4jLong posY[] = {i - 1, i - 1};
auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2); auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2);
auto dxIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posX, 2);
auto dyIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posY, 2);
auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), pos, 2); auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), pos, 2);
inverted[zIndex] = -input[xIndex]; inverted[zIndex] = -input[xIndex] / (input[dxIndex] * input[dyIndex]);
// math::atomics::nd4j_atomicAdd(&inverted[zIndex], - input[xIndex] * inverted[iIndex] / input[dIndex]);
} }
} }
@ -91,10 +96,11 @@ namespace helpers {
auto start = threadIdx.x + blockIdx.x * blockDim.x; auto start = threadIdx.x + blockIdx.x * blockDim.x;
auto step = blockDim.x * gridDim.x; auto step = blockDim.x * gridDim.x;
for (int i = start + 1; i < n; i += step) { for (int i = start; i < n; i += step) {
Nd4jLong pos[] = {i, i}; Nd4jLong pos[] = {i, i};
auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2); auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2);
auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), pos, 2); auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), pos, 2);
// math::atomics::nd4j_atomicDiv(&inverted[zIndex], input[xIndex]);
inverted[zIndex] /= input[xIndex]; inverted[zIndex] /= input[xIndex];
} }
} }
@ -113,16 +119,16 @@ namespace helpers {
auto start = threadIdx.x + blockIdx.x * blockDim.x; auto start = threadIdx.x + blockIdx.x * blockDim.x;
auto step = blockDim.x * gridDim.x; auto step = blockDim.x * gridDim.x;
for (int i = start + 1; i < n - 1; i += step) { for (int i = start; i < n - 1; i += step) {
Nd4jLong pos[] = {i, i + 1}; Nd4jLong pos[] = {i, i + 1};
Nd4jLong posY[] = {i, i}; //Nd4jLong posY[] = {i, i};
Nd4jLong posX[] = {i + 1, i + 1}; Nd4jLong posX[] = {i + 1, i + 1};
auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2); auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2);
auto yIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2); // auto yIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posY, 2);
// auto yIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2); // auto yIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2);
auto iIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posX, 2); auto iIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posX, 2);
auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), pos, 2); auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), pos, 2);
inverted[zIndex] -= input[xIndex] * inverted[iIndex] / input[yIndex]; math::atomics::nd4j_atomicAdd(&inverted[zIndex], - input[xIndex] * inverted[iIndex]); // / input[yIndex]);
//inputMatrix->t<T>(i, i + 1) * invertedMatrix->t<T>(i + 1, i + 1) / inputMatrix->t<T>(i, i) //inputMatrix->t<T>(i, i + 1) * invertedMatrix->t<T>(i + 1, i + 1) / inputMatrix->t<T>(i, i)
} }
} }
@ -142,16 +148,18 @@ namespace helpers {
// auto step = blockDim.x * gridDim.x; // auto step = blockDim.x * gridDim.x;
for (int i = blockIdx.x + 2; i < n; i += gridDim.x) { for (int i = blockIdx.x + 2; i < n; i += gridDim.x) {
for (int j = i - 2; j > -1; --j) for (int j = i - 2; j >= 0; --j)
for (int k = threadIdx.x; k < i; k += blockDim.x) { for (int k = threadIdx.x; k < i; k += blockDim.x) {
Nd4jLong posZ[] = {i, j}; Nd4jLong posZ[] = {i, j};
Nd4jLong posX[] = {k, j}; Nd4jLong posY[] = {k, j};
Nd4jLong posY[] = {i, k}; Nd4jLong posX[] = {i, k};
Nd4jLong posD[] = {i, i};
auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posX, 2); auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posX, 2);
auto yIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posY, 2); auto yIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posY, 2);
auto dIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posD, 2);
auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posZ, 2); auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posZ, 2);
inverted[zIndex] -= inverted[yIndex] * input[xIndex]; math::atomics::nd4j_atomicAdd(&inverted[zIndex], - inverted[yIndex] * input[xIndex] / input[dIndex]);
} }
} }
} }
@ -176,13 +184,13 @@ namespace helpers {
Nd4jLong posZ[] = {i, j}; Nd4jLong posZ[] = {i, j};
Nd4jLong posY[] = {k, j}; Nd4jLong posY[] = {k, j};
Nd4jLong posX[] = {i, k}; Nd4jLong posX[] = {i, k};
Nd4jLong posD[] = {i, i}; // Nd4jLong posD[] = {i, i};
auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posX, 2); auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posX, 2);
auto yIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posY, 2); auto yIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posY, 2);
auto dIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posD, 2); // auto dIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posD, 2);
auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posZ, 2); auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posZ, 2);
inverted[zIndex] -= inverted[yIndex] * input[xIndex] / input[dIndex]; math::atomics::nd4j_atomicAdd(&inverted[zIndex], - inverted[yIndex] * input[xIndex]);// / input[dIndex]);
} }
} }
} }
@ -196,14 +204,18 @@ namespace helpers {
LaunchContext* context = inputMatrix->getContext(); LaunchContext* context = inputMatrix->getContext();
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
// invert main diagonal
upvertKernel<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
// invert the second diagonal
invertKernelLow<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); invertKernelLow<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
// invertKernelLow<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
invertLowKernel<T><<<n, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); invertLowKernel<T><<<n, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
} }
BUILD_SINGLE_TEMPLATE(template void invertLowerMatrix_, (NDArray* inputMatrix, NDArray* invertedMatrix);, FLOAT_TYPES); BUILD_SINGLE_TEMPLATE(template void invertLowerMatrix_, (NDArray* inputMatrix, NDArray* invertedMatrix);, FLOAT_NATIVE);
void invertLowerMatrix(NDArray* inputMatrix, NDArray* invertedMatrix) { void invertLowerMatrix(NDArray* inputMatrix, NDArray* invertedMatrix) {
BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertLowerMatrix_, (inputMatrix, invertedMatrix), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertLowerMatrix_, (inputMatrix, invertedMatrix), FLOAT_NATIVE);
} }
template <typename T> template <typename T>
@ -215,58 +227,58 @@ namespace helpers {
return; return;
} }
upvertKernel<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); //upvertKernel<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
upvertKernelUp<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); upvertKernelUp<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
invertUpKernel<T><<<n, n, 256, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n); invertUpKernel<T><<<n, n, 256, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
} }
BUILD_SINGLE_TEMPLATE(template void invertUpperMatrix_, (NDArray* inputMatrix, NDArray* invertedMatrix);, FLOAT_TYPES); BUILD_SINGLE_TEMPLATE(template void invertUpperMatrix_, (NDArray* inputMatrix, NDArray* invertedMatrix);, FLOAT_NATIVE);
void invertUpperMatrix(NDArray* inputMatrix, NDArray* invertedMatrix) { void invertUpperMatrix(NDArray* inputMatrix, NDArray* invertedMatrix) {
BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertUpperMatrix_, (inputMatrix, invertedMatrix), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertUpperMatrix_, (inputMatrix, invertedMatrix), FLOAT_NATIVE);
} }
template <typename T> // template <typename T>
static __global__ void lupKernel(T* compound, Nd4jLong* compoundShape, T* permutation, Nd4jLong* permutationShape, Nd4jLong rowNum) { // static __global__ void lupKernel(T* compound, Nd4jLong* compoundShape, T* permutation, Nd4jLong* permutationShape, Nd4jLong rowNum) {
int swapCount = 0; // int swapCount = 0;
for(int i = blockIdx.x; i < rowNum; i += gridDim.x ) { // for(int i = blockIdx.x; i < rowNum; i += gridDim.x ) {
auto pivotValue = T(0.0); // auto pivotValue = T(0.0);
auto pivot = -1; // auto pivot = -1;
//
for(int rowCounter = i; rowCounter < rowNum; rowCounter++ ) { // for(int rowCounter = i; rowCounter < rowNum; rowCounter++ ) {
Nd4jLong rowCoord[] = {rowCounter, i}; // Nd4jLong rowCoord[] = {rowCounter, i};
auto rowPos = shape::getOffset(0, shape::shapeOf(compoundShape), shape::stride(compoundShape), rowCoord, 2); // auto rowPos = shape::getOffset(0, shape::shapeOf(compoundShape), shape::stride(compoundShape), rowCoord, 2);
if(nd4j::math::nd4j_abs(compound[rowPos]) > pivotValue ) { // if(nd4j::math::nd4j_abs(compound[rowPos]) > pivotValue ) {
pivotValue = nd4j::math::nd4j_abs(compound[rowPos]); // pivotValue = nd4j::math::nd4j_abs(compound[rowPos]);
pivot = rowCounter; // pivot = rowCounter;
} // }
} // }
//
if( pivotValue != T(0.0) ) { // if( pivotValue != T(0.0) ) {
swapRows_<T>(compound, compoundShape, pivot, i, rowNum); // swapRows_<T>(compound, compoundShape, pivot, i, rowNum);
swapRows_<T>(permutation, permutationShape, pivot, i, rowNum); // swapRows_<T>(permutation, permutationShape, pivot, i, rowNum);
if (pivot != i) // if (pivot != i)
swapCount++; // swapCount++;
//
for( int j = i + 1; j < rowNum; j++ ) { // for( int j = i + 1; j < rowNum; j++ ) {
Nd4jLong posJIbuf[] = {j, i}; // Nd4jLong posJIbuf[] = {j, i};
Nd4jLong posIIbuf[] = {i, i}; // Nd4jLong posIIbuf[] = {i, i};
auto posJI = shape::getOffset(0, shape::shapeOf(compoundShape), shape::stride(compoundShape), posJIbuf, 2); // auto posJI = shape::getOffset(0, shape::shapeOf(compoundShape), shape::stride(compoundShape), posJIbuf, 2);
auto posII = shape::getOffset(0, shape::shapeOf(compoundShape), shape::stride(compoundShape), posIIbuf, 2); // auto posII = shape::getOffset(0, shape::shapeOf(compoundShape), shape::stride(compoundShape), posIIbuf, 2);
//
compound[posJI] /= compound[posII]; // compound[posJI] /= compound[posII];
for( int k = i + 1; k < rowNum; k++ ) { // for( int k = i + 1; k < rowNum; k++ ) {
Nd4jLong posJKbuf[] = {j, k}; // Nd4jLong posJKbuf[] = {j, k};
Nd4jLong posIKbuf[] = {i, k}; // Nd4jLong posIKbuf[] = {i, k};
auto posJK = shape::getOffset(0, shape::shapeOf(compoundShape), shape::stride(compoundShape), posJKbuf, 2); // auto posJK = shape::getOffset(0, shape::shapeOf(compoundShape), shape::stride(compoundShape), posJKbuf, 2);
auto posIK = shape::getOffset(0, shape::shapeOf(compoundShape), shape::stride(compoundShape), posIKbuf, 2); // auto posIK = shape::getOffset(0, shape::shapeOf(compoundShape), shape::stride(compoundShape), posIKbuf, 2);
T arg = compound[posJI] * compound[posIK]; // T arg = compound[posJI] * compound[posIK];
compound[posJK] -= arg; // compound[posJK] -= arg;
} // }
} // }
} // }
} // }
} // }
template <typename T, typename F> template <typename T, typename F>
static __global__ void determinantKernel(T* compound, T* result, Nd4jLong len) { static __global__ void determinantKernel(T* compound, T* result, Nd4jLong len) {
@ -332,6 +344,30 @@ namespace helpers {
matrix[j] = (F)inputBuf[xIndex]; matrix[j] = (F)inputBuf[xIndex];
} }
} }
template <typename T, typename F>
static __global__ void returnMatrix(void* output, Nd4jLong* outputShape, void* input, Nd4jLong* inputShape, Nd4jLong pos, Nd4jLong rowLen) {
__shared__ F* matrix;
__shared__ T* outputBuf;
__shared__ Nd4jLong outputLen;
__shared__ Nd4jLong n2;
if (threadIdx.x == 0) {
matrix = reinterpret_cast<F*>(input);
outputBuf = reinterpret_cast<T*>(output);
outputLen = shape::length(inputShape);
n2 = rowLen * rowLen;
}
__syncthreads();
auto start = blockIdx.x * blockDim.x + threadIdx.x;
auto step = blockDim.x * gridDim.x;
for (int k = pos + start, j = start; j < n2; k += step, j += step) {
auto zIndex = shape::getIndexOffset(k, outputShape, outputLen);
outputBuf[zIndex] = (T)matrix[j];
}
}
template <typename F> template <typename F>
static __global__ void fillUpPermutation(void* output, Nd4jLong* shape, int* source, int rowNum) { static __global__ void fillUpPermutation(void* output, Nd4jLong* shape, int* source, int rowNum) {
__shared__ F* permutation; __shared__ F* permutation;
@ -462,7 +498,7 @@ namespace helpers {
d_work, d_work,
permutationBuf, permutationBuf,
d_info); d_info);
fillUpPermutation<float><<<n, n, 128, *stream>>>(permutation->specialBuffer(), permutation->specialShapeInfo(), permutationBuf, n); fillUpPermutation<T><<<n, n, 128, *stream>>>(permutation->specialBuffer(), permutation->specialShapeInfo(), permutationBuf, n);
permutation->tickWriteDevice(); permutation->tickWriteDevice();
} }
err = cudaFree(d_work); err = cudaFree(d_work);
@ -483,7 +519,7 @@ namespace helpers {
// NDArray::registerSpecialUse({input}, {input}); // NDArray::registerSpecialUse({input}, {input});
input->tickWriteDevice(); input->tickWriteDevice();
} }
BUILD_SINGLE_TEMPLATE(template void lup_, (LaunchContext* context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES); BUILD_SINGLE_TEMPLATE(template void lup_, (LaunchContext* context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_NATIVE);
template <typename T> template <typename T>
static int determinant_(nd4j::LaunchContext* context, NDArray* input, NDArray* output) { static int determinant_(nd4j::LaunchContext* context, NDArray* input, NDArray* output) {
@ -504,32 +540,32 @@ namespace helpers {
output->assign(1.f); output->assign(1.f);
for (int e = 0; e < output->lengthOf(); e++) { for (int e = 0; e < output->lengthOf(); e++) {
Nd4jLong pos = e * n2; Nd4jLong pos = e * n2;
if (matrix.dataType() == input->dataType()) // if (matrix.dataType() == input->dataType())
fillMatrix<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); fillMatrix<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
else // else
fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); // fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
if (matrix.dataType() == input->dataType()) // if (matrix.dataType() == input->dataType())
lup_<T>(context, &matrix, nullptr, nullptr); lup_<T>(context, &matrix, nullptr, nullptr);
else // else
lup_<float>(context, &matrix, nullptr, nullptr); // lup_<float>(context, &matrix, nullptr, nullptr);
auto offset = shape::getIndexOffset(e, output->shapeInfo(), output->lengthOf()); auto offset = shape::getIndexOffset(e, output->shapeInfo(), output->lengthOf());
auto inputBuf = reinterpret_cast<T*>(matrix.specialBuffer()); auto inputBuf = reinterpret_cast<T*>(matrix.specialBuffer());
auto outputBuf = reinterpret_cast<T*>(output->specialBuffer()) + offset; auto outputBuf = reinterpret_cast<T*>(output->specialBuffer()) + offset;
if (matrix.dataType() == input->dataType()) // if (matrix.dataType() == input->dataType())
determinantKernel<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n); determinantKernel<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n);
else // else
determinantKernel<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n); // determinantKernel<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n);
} }
NDArray::registerSpecialUse({output}, {input}); NDArray::registerSpecialUse({output}, {input});
return Status::OK(); return Status::OK();
} }
BUILD_SINGLE_TEMPLATE(template int determinant_, (nd4j::LaunchContext* context, NDArray* input, NDArray* output), FLOAT_TYPES); BUILD_SINGLE_TEMPLATE(template int determinant_, (nd4j::LaunchContext* context, NDArray* input, NDArray* output), FLOAT_NATIVE);
int determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { int determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (context, input, output), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (context, input, output), FLOAT_NATIVE);
} }
template <typename T> template <typename T>
@ -552,22 +588,22 @@ namespace helpers {
output->assign(1.f); output->assign(1.f);
for (int e = 0; e < output->lengthOf(); e++) { for (int e = 0; e < output->lengthOf(); e++) {
Nd4jLong pos = e * n2; Nd4jLong pos = e * n2;
if (matrix.dataType() == input->dataType()) // if (matrix.dataType() == input->dataType())
fillMatrix<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); fillMatrix<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
else // else
fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n); // fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
if (matrix.dataType() == input->dataType()) // if (matrix.dataType() == input->dataType())
lup_<T>(context, &matrix, nullptr, nullptr); lup_<T>(context, &matrix, nullptr, nullptr);
else // else
lup_<float>(context, &matrix, nullptr, nullptr); // lup_<float>(context, &matrix, nullptr, nullptr);
auto offset = shape::getIndexOffset(e, output->shapeInfo(), output->lengthOf()); auto offset = shape::getIndexOffset(e, output->shapeInfo(), output->lengthOf());
auto inputBuf = reinterpret_cast<T*>(matrix.specialBuffer()); auto inputBuf = reinterpret_cast<T*>(matrix.specialBuffer());
auto outputBuf = reinterpret_cast<T*>(output->specialBuffer()) + offset; auto outputBuf = reinterpret_cast<T*>(output->specialBuffer()) + offset;
if (matrix.dataType() == input->dataType()) // if (matrix.dataType() == input->dataType())
determinantLogKernel<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n); determinantLogKernel<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n);
else // else
determinantLogKernel<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n); // determinantLogKernel<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n);
} }
NDArray::registerSpecialUse({output}, {input}); NDArray::registerSpecialUse({output}, {input});
@ -576,10 +612,10 @@ namespace helpers {
return ND4J_STATUS_OK; return ND4J_STATUS_OK;
} }
BUILD_SINGLE_TEMPLATE(template int logAbsDeterminant_, (LaunchContext* context, NDArray* input, NDArray* output), FLOAT_TYPES); BUILD_SINGLE_TEMPLATE(template int logAbsDeterminant_, (LaunchContext* context, NDArray* input, NDArray* output), FLOAT_NATIVE);
int logAbsDeterminant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { int logAbsDeterminant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (context, input, output), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (context, input, output), FLOAT_NATIVE);
} }
template <typename T> template <typename T>
@ -597,10 +633,12 @@ namespace helpers {
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
xShapeOf = shape::shapeOf(lowerShape); xShapeOf = shape::shapeOf(lowerShape);
yShapeOf = shape::shapeOf(upperShape);
zShapeOf = shape::shapeOf(matrixShape);
xStrideOf = shape::stride(lowerShape); xStrideOf = shape::stride(lowerShape);
yShapeOf = shape::shapeOf(upperShape);
yStrideOf = shape::stride(upperShape); yStrideOf = shape::stride(upperShape);
zShapeOf = shape::shapeOf(matrixShape);
zStrideOf = shape::stride(matrixShape); zStrideOf = shape::stride(matrixShape);
lowerMatrix = reinterpret_cast<T*>(lowerBuf); lowerMatrix = reinterpret_cast<T*>(lowerBuf);
upperMatrix = reinterpret_cast<T*>(upperBuf); upperMatrix = reinterpret_cast<T*>(upperBuf);
@ -610,15 +648,16 @@ namespace helpers {
for (int k = blockIdx.x; k < n; k += gridDim.x) { // and then put all values under main diagonal on to it for (int k = blockIdx.x; k < n; k += gridDim.x) { // and then put all values under main diagonal on to it
for (int j = threadIdx.x; j < n; j += blockDim.x) { for (int j = threadIdx.x; j < n; j += blockDim.x) {
Nd4jLong posX[] = {j, k}; Nd4jLong posX[] = {k, j};
Nd4jLong posD[] = {j, j};
auto xPos = shape::getOffset(0, xShapeOf, xStrideOf, posX, 2); auto xPos = shape::getOffset(0, xShapeOf, xStrideOf, posX, 2);
auto yPos = shape::getOffset(0, yShapeOf, yStrideOf, posX, 2); auto yPos = shape::getOffset(0, yShapeOf, yStrideOf, posX, 2);
auto pos = shape::getOffset(0, zShapeOf, zStrideOf, posX, 2); auto iPos = shape::getOffset(0, zShapeOf, zStrideOf, posX, 2);
if (k <= j) auto dPos = shape::getOffset(0, zShapeOf, zStrideOf, posD, 2);
lowerMatrix[xPos] = matrix[pos];//(k, j); if (k >= j)
lowerMatrix[xPos] = matrix[iPos];//(k, j);
else else
upperMatrix[yPos] = matrix[pos]; //k, j); upperMatrix[yPos] = matrix[iPos]; //k, j);
} }
} }
} }
@ -639,38 +678,26 @@ namespace helpers {
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {output->rankOf() - 2, output->rankOf() - 1}); auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {output->rankOf() - 2, output->rankOf() - 1});
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
// PRAGMA_OMP_PARALLEL_FOR
for (auto i = 0LL; i < packX.numberOfTads(); i++) { for (auto i = 0LL; i < packX.numberOfTads(); i++) {
fillMatrix<T, float><<<1, n2, 128, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), i * n2, n); fillMatrix<T, T><<<1, n2, 128, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), i * n2, n);
permutation.assign(0.f);
lup_<float>(context, &matrix, &compound, &permutation);
matrix.tickWriteDevice(); matrix.tickWriteDevice();
permutation.tickWriteDevice(); compound.assign(matrix);
permutation.printIndexedBuffer("PERMUTE"); lup_<T>(context, &compound, nullptr, nullptr);
lower.setIdentity(); // set up U to identity matrix fillLowerUpperKernel<T><<<n, n, 128>>>(lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), compound.specialBuffer(), compound.specialShapeInfo(), n);
upper.setIdentity(); matrix.assign(0);
fillLowerUpperKernel<float><<<1, n2, 128>>>(lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), matrix.specialBuffer(), matrix.specialShapeInfo(), n); invertUpperMatrix(&upper, &matrix); // U^{-1}
lower.tickWriteDevice(); compound.assign(0);
upper.tickWriteDevice(); invertLowerMatrix(&lower, &compound); // L{-1}
invertUpperMatrix(&upper, &matrix);
invertLowerMatrix(&lower, &upper);
lower.tickWriteDevice();
upper.tickWriteDevice();
lower.printIndexedBuffer("LOWER");
upper.printIndexedBuffer("UPPER");
nd4j::MmulHelper::mmul(&matrix, &upper, &compound, 1.0, 0.0); nd4j::MmulHelper::mmul(&matrix, &compound, &upper, 1.0, 0.0);
nd4j::MmulHelper::mmul(&compound, &permutation, &matrix, 1.0, 0.0); returnMatrix<T, T><<<1, n2, 128, *stream>>>(output->specialBuffer(), output->specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), i * n2, n);
// for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) {
// output->t<T>(k) = matrix.template t<T>(row++);
// }
} }
return Status::OK(); return Status::OK();
} }
int inverse(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { int inverse(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (context, input, output), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (context, input, output), FLOAT_NATIVE);
} }
bool checkCholeskyInput(nd4j::LaunchContext * context, NDArray const* input) { bool checkCholeskyInput(nd4j::LaunchContext * context, NDArray const* input) {
@ -803,7 +830,7 @@ namespace helpers {
return cholesky_(context, input, output, inplace); return cholesky_(context, input, output, inplace);
} }
// BUILD_SINGLE_TEMPLATE(template int cholesky_, (LaunchContext* context, NDArray* input, NDArray* output, bool inplace), FLOAT_TYPES); // BUILD_SINGLE_TEMPLATE(template int cholesky_, (LaunchContext* context, NDArray* input, NDArray* output, bool inplace), FLOAT_TYPES);
BUILD_SINGLE_TEMPLATE(template int inverse_, (nd4j::LaunchContext* context, NDArray* input, NDArray* output), FLOAT_TYPES); BUILD_SINGLE_TEMPLATE(template int inverse_, (nd4j::LaunchContext* context, NDArray* input, NDArray* output), FLOAT_NATIVE);
__global__ void logDetKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong batchNum, Nd4jLong* tadShape, Nd4jLong* tadOffsets, void* outputBuf, Nd4jLong* outputShape) { __global__ void logDetKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong batchNum, Nd4jLong* tadShape, Nd4jLong* tadOffsets, void* outputBuf, Nd4jLong* outputShape) {
__shared__ double* output; __shared__ double* output;

View File

@ -143,7 +143,7 @@ namespace helpers {
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
template <typename T> template <typename T>
static void _reverseSequence(nd4j::LaunchContext * context, const NDArray* input, const NDArray* seqLengths, NDArray* output, int seqDim, const int batchDim){ static void reverseSequence_(nd4j::LaunchContext * context, const NDArray* input, const NDArray* seqLengths, NDArray* output, int seqDim, const int batchDim){
int posOfNonUnityDim = -1; int posOfNonUnityDim = -1;
seqLengths->syncToHost(); seqLengths->syncToHost();
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
@ -193,7 +193,7 @@ namespace helpers {
} }
void reverseSequence(nd4j::LaunchContext * context, const NDArray* input, const NDArray* seqLengths, NDArray* output, int seqDim, const int batchDim) { void reverseSequence(nd4j::LaunchContext * context, const NDArray* input, const NDArray* seqLengths, NDArray* output, int seqDim, const int batchDim) {
BUILD_SINGLE_SELECTOR(input->dataType(), _reverseSequence, (context, input, seqLengths, output, seqDim, batchDim), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR(input->dataType(), reverseSequence_, (context, input, seqLengths, output, seqDim, batchDim), LIBND4J_TYPES);
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////

View File

@ -391,7 +391,7 @@ static void scatterCudaLauncher(const int blocksPerGrid, const int threadsPerBlo
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
void scatter(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock) { void scatter(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock) {
PointersManager manager(context, "scatterND"); PointersManager manager(context, "scatter");
NDArray::prepareSpecialUse({&output}, {&updates, &indices}); NDArray::prepareSpecialUse({&output}, {&updates, &indices});

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,427 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author GS <sgazeos@gmail.com>
//
#include <ops/declarable/helpers/segment.h>
#include <ops/declarable/helpers/segment_common.h>
#include <NDArrayFactory.h>
#include <helpers/ShapeUtils.h>
#include <helpers/TAD.h>
#include <exceptions/cuda_exception.h>
#include <PointersManager.h>
#include <ConstantTadHelper.h>
namespace nd4j {
namespace ops {
namespace helpers {
// -------------------------------------------------------------------------------------------------------------- //
// Segment ops linear kernels
// -------------------------------------------------------------------------------------------------------------- //
template<typename T, typename I>
static __global__ void
segmentMaxLinearKernel(void *input, Nd4jLong *inputShape, int *starts, int *lengths, Nd4jLong numOfClasses,
void *output, Nd4jLong *outputShape) {
__shared__
T *val;
__shared__
Nd4jLong xLen, zLen, segment, zIndex;
__shared__
T *x;
__shared__
T *z;
__shared__ int threadsPerSegment, start, finish;
if (threadIdx.x == 0) {
threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses;
segment = blockIdx.x / threadsPerSegment;
x = reinterpret_cast<T *>(input);
z = reinterpret_cast<T *>(output);
extern __shared__ unsigned char shmem[];
val = reinterpret_cast<T *>(shmem);
xLen = shape::length(inputShape);
zLen = shape::length(outputShape);
if (segment < numOfClasses) {
zIndex = shape::getIndexOffset(segment, outputShape, zLen);
start = starts[segment];
finish = start + lengths[segment];
z[zIndex] = x[shape::getIndexOffset(start, inputShape, xLen)];
val[segment] = z[zIndex];
}
}
__syncthreads();
for (auto e = start + threadIdx.x + 1; e < finish; e += blockDim.x) {
auto xIndex = shape::getIndexOffset(e, inputShape, xLen);
nd4j::math::atomics::nd4j_atomicMax(&z[zIndex], x[xIndex]);
}
}
// -------------------------------------------------------------------------------------------------------------- //
template<typename T, typename I>
static __global__ void
unsortedSegmentMaxLinearKernel(void *input, Nd4jLong *inputShape, void *indices, Nd4jLong *indicesShape,
int *starts, int *lengths, Nd4jLong numOfClasses, void *output,
Nd4jLong *outputShape) {
__shared__
T *val;
__shared__
Nd4jLong xLen, zLen, segment, zIndex;
__shared__
T *x;
__shared__
T *z;
__shared__
I *y; //int threadsPerSegment, start, finish;
if (threadIdx.x == 0) {
segment = blockIdx.x;
x = reinterpret_cast<T *>(input);
z = reinterpret_cast<T *>(output);
y = reinterpret_cast<I *>(indices);
xLen = shape::length(inputShape);
zLen = shape::length(outputShape);
zIndex = shape::getIndexOffset(segment, outputShape, zLen);
//start = starts[segment];
//finish = start + lengths[segment];
if (lengths[segment] > 0)
z[zIndex] = x[shape::getIndexOffset(starts[segment], inputShape, xLen)];
else
z[zIndex] = -DataTypeUtils::max<T>();
}
__syncthreads();
if (lengths[segment] > 0)
for (auto e = threadIdx.x + 1; e < xLen; e += blockDim.x) {
auto xIndex = shape::getIndexOffset(e, inputShape, xLen);
auto yIndex = shape::getIndexOffset(e, indicesShape, xLen);
if (y[yIndex] == segment) {
nd4j::math::atomics::nd4j_atomicMax(&z[zIndex], x[xIndex]);
}
}
}
// -------------------------------------------------------------------------------------------------------------- //
template <typename T, typename I>
static __global__ void segmentMaxTadKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong* inputTads,
Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf,
Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets, T filler = 0) {
__shared__ T* val;
__shared__ Nd4jLong len, segment, zIndex, total;
__shared__ T* z;
__shared__ int start, finish;
if (threadIdx.x == 0) {
segment = indices[blockIdx.x]; // / threadsPerSegment;
z = reinterpret_cast<T*>(outputBuf) + outputTadOffsets[segment];
len = shape::length(inputTads);
start = starts[segment];
finish = start + lengths[segment];
total = shape::sizeAt(inputShape, 0);
}
__syncthreads();
auto idx = blockIdx.x;
if (blockIdx.x <= total) {
auto x = reinterpret_cast<T *>(inputBuf) + inputTadOffsets[idx];
if (blockIdx.x == start) {
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
auto xIndex = shape::getIndexOffset(e, inputTads, len);
auto zIndex = shape::getIndexOffset(e, outputTads, len);
z[zIndex] = x[xIndex];
}
}
else {
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
auto xIndex = shape::getIndexOffset(e, inputTads, len);
auto zIndex = shape::getIndexOffset(e, outputTads, len);
nd4j::math::atomics::nd4j_atomicMax(&z[zIndex], x[xIndex]);
}
}
}
}
// -------------------------------------------------------------------------------------------------------------- //
template <typename T, typename I>
static void segmentMaxFunctor_(LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) {
//int numClasses = output->sizeAt(0);
// if input is a vector: (as if in doc sample)
//Nd4jLong idx = indices->e<Nd4jLong>(0);
auto stream = context->getCudaStream();
indices->syncToHost();
Nd4jLong numOfClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses});
classesRangesBegs.assign(indices->lengthOf());
classesRangesLens.assign(0);
dim3 dims(256, 512, 256);
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens);
NDArray::prepareSpecialUse({output}, {input, indices, &classesRangesBegs, &classesRangesLens});
if (input->isVector()) {
segmentMaxLinearKernel<T,I><<<numOfClasses, input->lengthOf(), numOfClasses * 32 + 32, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo());
}
else {
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
Nd4jLong* inputTads = packX.specialShapeInfo();
Nd4jLong* inputTadOffsets = packX.specialOffsets();
Nd4jLong* outputTads = packZ.specialShapeInfo();
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
segmentMaxTadKernel<T,I><<<packX.numberOfTads(), 512, 2048, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
}
NDArray::registerSpecialUse({output}, {input, indices, &classesRangesBegs, &classesRangesLens});
}
// -------------------------------------------------------------------------------------------------------------- //
void segmentMaxFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) {
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentMaxFunctor_, (context, input, indices, output), NUMERIC_TYPES, INTEGER_TYPES);
}
BUILD_DOUBLE_TEMPLATE(template void segmentMaxFunctor_, (LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES);
// -------------------------------------------------------------------------------------------------------------- //
template <typename T, typename I>
static void unsortedSegmentMaxFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
auto stream = context->getCudaStream();
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses});
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
// classes.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), &row, &classes);
classesRangesBegs.assign(indices->lengthOf());
classesRangesLens.assign(0);
dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32);
// int* classesBuf = reinterpret_cast<int*>(classes.specialBuffer());
fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens);
int* begins = reinterpret_cast<int*>(classesRangesBegs.getSpecialBuffer());
int* lengths = reinterpret_cast<int*>(classesRangesLens.getSpecialBuffer());
if (input->isVector()) {
unsortedSegmentMaxLinearKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo());
}
else {
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
Nd4jLong* inputTads = packX.specialShapeInfo();
Nd4jLong* inputTadOffsets = packX.specialOffsets();
Nd4jLong* outputTads = packZ.specialShapeInfo();
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
dims.x = input->sizeAt(0);
output->assign(-DataTypeUtils::max<T>());
segmentMaxTadKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
}
}
// -------------------------------------------------------------------------------------------------------------- //
void unsortedSegmentMaxFunctor(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMaxFunctor_, (context, input, indices, numOfClasses, output), NUMERIC_TYPES, INTEGER_TYPES);
}
// -------------------------------------------------------------------------------------------------------------- //
BUILD_DOUBLE_TEMPLATE(template void unsortedSegmentMaxFunctor_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES);
// -------------------------------------------------------------------------------------------------------------- //
// segment max
// -------------------------------------------------------------------------------------------------------------- //
template <typename T, typename I>
static __global__ void segmentMaxBPLinearKernel(void* inputBuf, Nd4jLong* inputShape, void* forwardOutput,
Nd4jLong* forwardShape, void* eps, Nd4jLong* epsShape, void* indicesBuf, Nd4jLong* indicesShape,
void* outputBuf, Nd4jLong* outputShape) {
__shared__ T* x;
__shared__ T* gradIn;
__shared__ T* gradOut;
__shared__ I* y;
__shared__ T* z;
__shared__ Nd4jLong xLen, gradLen;
if (threadIdx.x == 0) {
xLen = shape::length(inputShape);
x = reinterpret_cast<T*>(inputBuf);
y = reinterpret_cast<I*>(indicesBuf);
z = reinterpret_cast<T*>(outputBuf);
gradIn = reinterpret_cast<T*>(forwardOutput);
gradOut = reinterpret_cast<T*>(eps);
gradLen = shape::length(epsShape);
}
auto start = blockIdx.x * blockDim.x + threadIdx.x;
auto step = gridDim.x * blockDim.x;
for (auto e = start; e < xLen; e += step) {
auto zOffset = shape::getIndexOffset(e, outputShape, xLen);
auto xOffset = shape::getIndexOffset(e, inputShape, xLen);
auto yOffset = shape::getIndexOffset(e, indicesShape, xLen);
auto classIndex = y[yOffset];
auto gradOffsetI = shape::getIndexOffset(classIndex, forwardShape, gradLen);
auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape, gradLen);
if (nd4j::math::nd4j_abs(gradIn[gradOffsetI] - x[xOffset]) <= T(1.e-6)) {
z[zOffset] = gradOut[gradOffsetO];
}
}
}
// -------------------------------------------------------------------------------------------------------------- //
template <typename T, typename I>
static __global__ void segmentMaxBPTadKernel(void* inputBuf, Nd4jLong* inputShape, void* forwardOutput,
Nd4jLong* forwardShape, void* eps, Nd4jLong* epsShape, void* indicesBuf, Nd4jLong* indicesShape,
void* outputBuf, Nd4jLong* outputShape,Nd4jLong* inputTad,
Nd4jLong* inputOffsets, Nd4jLong* gradInTad, Nd4jLong* gradInOffsets,
Nd4jLong* gradOutTad, Nd4jLong* gradOutOffsets, Nd4jLong* outTad,
Nd4jLong* outOffsets) {
__shared__ T* x;
__shared__ T* gradIn;
__shared__ T* gradOut;
__shared__ I* y;
__shared__ T* z;
__shared__ Nd4jLong xLen, yLen, gradLen, currentLen;
if (threadIdx.x == 0) {
xLen = shape::length(inputShape);
x = reinterpret_cast<T*>(inputBuf);
y = reinterpret_cast<I*>(indicesBuf);
z = reinterpret_cast<T*>(outputBuf);
yLen = shape::length(indicesShape);
gradOut = reinterpret_cast<T*>(eps);
gradIn = reinterpret_cast<T*>(forwardOutput);
gradLen = shape::length(epsShape);
currentLen = shape::length(outTad);
}
for (auto i = blockIdx.x; i < yLen; i += gridDim.x) {
auto yIndex = shape::getIndexOffset(i, indicesShape, yLen);
auto segment = y[yIndex];
T* current = x + inputOffsets[i];
T* currentOut = z + outOffsets[i];
T* in = gradIn + gradInOffsets[segment];
T* outGrad = gradOut + gradOutOffsets[segment];
for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) {
if (nd4j::math::nd4j_abs(in[e] - current[e]) <= T(1.e-6))
currentOut[e] = outGrad[e];
}
}
}
// -------------------------------------------------------------------------------------------------------------- //
template <typename T, typename I>
int segmentMaxFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
//int numOfClasses = gradOut->sizeAt(0);
// if input is a vector: (as if in doc sample)
auto stream = context->getCudaStream();
NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), DataTypeUtils::fromT<T>(), context);//->shapeInfo(), context);
segmentMaxFunctor_<T, I>(context, input, indices, &tempRes);
NDArray::prepareSpecialUse({output}, {input, indices, gradOut, &tempRes});
if (input->isVector()) {
Nd4jLong loop_size = input->lengthOf();
auto numOfClasses = gradOut->lengthOf(); //indices->e<Nd4jLong>(loop_size - 1);
segmentMaxBPLinearKernel<T,I><<<1 + gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo());
}
else {
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
auto packGradIn = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.getShapeInfo(), dimensions);
auto packGradOut = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->getShapeInfo(), dimensions);
Nd4jLong* inputTads = packX.specialShapeInfo();
Nd4jLong* inputTadOffsets = packX.specialOffsets();
Nd4jLong* outputTads = packZ.specialShapeInfo();
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
Nd4jLong* gradInTads = packGradIn.specialShapeInfo();
Nd4jLong* gradInTadOffsets = packGradIn.specialOffsets();
Nd4jLong* gradOutTads = packGradOut.specialShapeInfo();
Nd4jLong* gradOutTadOffsets = packGradOut.specialOffsets();
segmentMaxBPTadKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(),
inputTads, inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, gradOutTadOffsets,
outputTads, outputTadOffsets);
}
NDArray::registerSpecialUse({output}, {input, indices, gradOut, &tempRes});
return Status::OK();
}
// -------------------------------------------------------------------------------------------------------------- //
int segmentMaxFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentMaxFunctorBP_, (context, input,
indices, gradOut, output), NUMERIC_TYPES, INTEGER_TYPES);
}
// -------------------------------------------------------------------------------------------------------------- //
BUILD_DOUBLE_TEMPLATE(template int segmentMaxFunctorBP_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES);
// -------------------------------------------------------------------------------------------------------------- //
template <typename T, typename I>
static int unsortedSegmentMaxFunctorBP_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
//int numOfClasses = gradOut->sizeAt(0);
// if input is a vector: (as if in doc sample)
auto stream = context->getCudaStream();
NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), DataTypeUtils::fromT<T>(), context);//->shapeInfo(), context);
unsortedSegmentMaxFunctor_<T, I>(context, input, indices, numOfClasses, &tempRes);
NDArray::prepareSpecialUse({output}, {input, indices, gradOut, &tempRes});
if (input->isVector()) {
Nd4jLong loop_size = input->lengthOf();
auto numOfClasses = gradOut->lengthOf(); //indices->e<Nd4jLong>(loop_size - 1);
segmentMaxBPLinearKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo());
}
else {
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
auto packGradIn = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.getShapeInfo(), dimensions);
auto packGradOut = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->getShapeInfo(), dimensions);
Nd4jLong* inputTads = packX.specialShapeInfo();
Nd4jLong* inputTadOffsets = packX.specialOffsets();
Nd4jLong* outputTads = packZ.specialShapeInfo();
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
Nd4jLong* gradInTads = packGradIn.specialShapeInfo();
Nd4jLong* gradInTadOffsets = packGradIn.specialOffsets();
Nd4jLong* gradOutTads = packGradOut.specialShapeInfo();
Nd4jLong* gradOutTadOffsets = packGradOut.specialOffsets();
segmentMaxBPTadKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(),
inputTads, inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, gradOutTadOffsets,
outputTads, outputTadOffsets);
}
NDArray::registerSpecialUse({output}, {input, indices, gradOut, &tempRes});
return Status::OK();
}
// -------------------------------------------------------------------------------------------------------------- //
int unsortedSegmentMaxFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMaxFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), NUMERIC_TYPES, INTEGER_TYPES);
}
// -------------------------------------------------------------------------------------------------------------- //
BUILD_DOUBLE_TEMPLATE(template int unsortedSegmentMaxFunctorBP_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES);
}
}
}

View File

@ -0,0 +1,414 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author GS <sgazeos@gmail.com>
//
#include <ops/declarable/helpers/segment.h>
#include <ops/declarable/helpers/segment_common.h>
#include <NDArrayFactory.h>
#include <helpers/ShapeUtils.h>
#include <helpers/TAD.h>
#include <exceptions/cuda_exception.h>
#include <PointersManager.h>
#include <ConstantTadHelper.h>
namespace nd4j {
namespace ops {
namespace helpers {
// -------------------------------------------------------------------------------------------------------------- //
// Segment ops linear kernels
// -------------------------------------------------------------------------------------------------------------- //
template <typename T, typename I>
static __global__ void segmentMeanLinearKernel(void* input, Nd4jLong* inputShape, int* starts, int* lengths, Nd4jLong numOfClasses, void* output, Nd4jLong* outputShape) {
__shared__ T* val;
__shared__ Nd4jLong xLen, zLen, segment, zIndex;
__shared__ T* x;
__shared__ T* z;
__shared__ int threadsPerSegment, start, finish;
if (threadIdx.x == 0) {
threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses;
segment = blockIdx.x / threadsPerSegment;
x = reinterpret_cast<T*>(input);
z = reinterpret_cast<T*>(output);
// extern __shared__ unsigned char shmem[];
// val = reinterpret_cast<T*>(shmem);
xLen = shape::length(inputShape);
zLen = shape::length(outputShape);
//[zIndex] =
if (segment < numOfClasses) {
zIndex = shape::getIndexOffset(segment, outputShape, zLen);
start = starts[segment];
finish = start + lengths[segment];
//val[segment] = ;
z[zIndex] = T(x[shape::getIndexOffset(start, inputShape, xLen)] / lengths[segment]);
// val[segment] = z[zIndex];
}
}
__syncthreads();
for (auto e = start + threadIdx.x + 1; e < finish; e += blockDim.x) {
auto xIndex = shape::getIndexOffset(e, inputShape, xLen);
if (lengths[segment])
nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], T(x[xIndex] / lengths[segment]));
}
}
// -------------------------------------------------------------------------------------------------------------- //
template <typename T, typename I>
static __global__ void unsortedSegmentMeanLinearKernel(void* input, Nd4jLong* inputShape, void* indices, Nd4jLong* indicesShape, int* starts, int* lengths, Nd4jLong numOfClasses, void* output, Nd4jLong* outputShape) {
__shared__ T* val;
__shared__ Nd4jLong xLen, zLen, segment, zIndex;
__shared__ T* x;
__shared__ T* z;
__shared__ I* y; //int threadsPerSegment, start, finish;
if (threadIdx.x == 0) {
// threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses;
segment = blockIdx.x;// / threadsPerSegment;
x = reinterpret_cast<T*>(input);
z = reinterpret_cast<T*>(output);
y = reinterpret_cast<I*>(indices);
// extern __shared__ unsigned char shmem[];
// val = reinterpret_cast<T*>(shmem);
xLen = shape::length(inputShape);
zLen = shape::length(outputShape);
// if (segment < numOfClasses) {
zIndex = shape::getIndexOffset(segment, outputShape, zLen);
//start = starts[segment];
//finish = start + lengths[segment];
if (lengths[segment] > 0)
z[zIndex] = T(x[shape::getIndexOffset(starts[segment], inputShape, xLen)] / T(lengths[segment]));
else
z[zIndex] = 0; //DataTypeUtils::max<T>();
// val[segment] = z[zIndex];
// }
}
__syncthreads();
if (lengths[segment] > 0)
for (auto e = threadIdx.x; e < xLen; e += blockDim.x) {
auto xIndex = shape::getIndexOffset(e, inputShape, xLen);
auto yIndex = shape::getIndexOffset(e, indicesShape, xLen);
if (y[yIndex] == segment && e != starts[segment]) {
nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], T(x[xIndex]/T(lengths[segment])));
}
}
}
// -------------------------------------------------------------------------------------------------------------- //
// SegmentMean kernel
template <typename T, typename I>
static __global__ void segmentMeanTadKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong* inputTads, Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets) {
__shared__ T* val;
__shared__ Nd4jLong len, segment, zIndex, total;
__shared__ T* z;
__shared__ int threadsPerSegment, start, finish;
if (threadIdx.x == 0) {
segment = indices[blockIdx.x]; // / threadsPerSegment;
z = reinterpret_cast<T*>(outputBuf) + outputTadOffsets[segment];
len = shape::length(inputTads);
start = starts[segment];
finish = start + lengths[segment];
total = shape::sizeAt(inputShape, 0);
}
__syncthreads();
auto idx = blockIdx.x;
if (blockIdx.x <= total) {
auto x = reinterpret_cast<T *>(inputBuf) + inputTadOffsets[idx];
if (blockIdx.x == start) {
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
auto xIndex = shape::getIndexOffset(e, inputTads, len);
auto zIndex = shape::getIndexOffset(e, outputTads, len);
z[zIndex] = T(x[xIndex]/lengths[segment]);
}
}
else {
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
auto xIndex = shape::getIndexOffset(e, inputTads, len);
auto zIndex = shape::getIndexOffset(e, outputTads, len);
if (lengths[segment])
nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], T(x[xIndex]/lengths[segment]));
}
}
}
}
// -------------------------------------------------------------------------------------------------------------- //
// segmen mean
template <typename T, typename I>
static void segmentMeanFunctor_(LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) {
auto stream = context->getCudaStream();
Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses});
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses});
classesRangesBegs.assign(indices->lengthOf());
classesRangesLens.assign(0);
dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32);
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens);
if (input->isVector()) {
segmentMeanLinearKernel<T,I><<<numClasses, input->lengthOf(), numClasses * 32 + 32, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo());
}
else {
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
Nd4jLong* inputTads = packX.specialShapeInfo();
Nd4jLong* inputTadOffsets = packX.specialOffsets();
Nd4jLong* outputTads = packZ.specialShapeInfo();
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
segmentMeanTadKernel<T,I><<<input->sizeAt(0), 512, 2048, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
}
}
// -------------------------------------------------------------------------------------------------------------- //
void segmentMeanFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) {
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), segmentMeanFunctor_, (context, input, indices, output), FLOAT_TYPES, INTEGER_TYPES);
}
BUILD_DOUBLE_TEMPLATE(template void segmentMeanFunctor_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output), FLOAT_TYPES, INTEGER_TYPES);
// -------------------------------------------------------------------------------------------------------------- //
template <typename T, typename I>
static void unsortedSegmentMeanFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
auto stream = context->getCudaStream();
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses});
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
// classes.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), &row, &classes);
classesRangesBegs.assign(indices->lengthOf());
classesRangesLens.assign(0);
dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32);
// int* classesBuf = reinterpret_cast<int*>(classes.specialBuffer());
fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens);
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
if (input->isVector()) {
unsortedSegmentMeanLinearKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo());
}
else {
output->assign(0);
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
Nd4jLong* inputTads = packX.specialShapeInfo();
Nd4jLong* inputTadOffsets = packX.specialOffsets();
Nd4jLong* outputTads = packZ.specialShapeInfo();
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
dims.x = input->sizeAt(0);
segmentMeanTadKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
}
}
// -------------------------------------------------------------------------------------------------------------- //
void unsortedSegmentMeanFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMeanFunctor_, (context, input, indices, numOfClasses, output),
FLOAT_TYPES, INTEGER_TYPES);
}
// -------------------------------------------------------------------------------------------------------------- //
BUILD_DOUBLE_TEMPLATE(template void unsortedSegmentMeanFunctor_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), FLOAT_TYPES, INTEGER_TYPES);
// -------------------------------------------------------------------------------------------------------------- //
template <typename T, typename I>
static __global__ void segmentMeanBPLinearKernel(void* inputBuf, Nd4jLong* inputShape, void* eps, Nd4jLong* epsShape, void* indicesBuf, Nd4jLong* indicesShape,
int* lengths, void* outputBuf, Nd4jLong* outputShape) {
__shared__ T* x;
__shared__ T* gradIn;
__shared__ T* gradOut;
__shared__ I* y;
__shared__ T* z;
__shared__ Nd4jLong xLen, gradLen;
if (threadIdx.x == 0) {
xLen = shape::length(inputShape);
x = reinterpret_cast<T*>(inputBuf);
y = reinterpret_cast<I*>(indicesBuf);
z = reinterpret_cast<T*>(outputBuf);
gradOut = reinterpret_cast<T*>(eps);
gradLen = shape::length(epsShape);
}
auto start = blockIdx.x * blockDim.x + threadIdx.x;
auto step = gridDim.x * blockDim.x;
for (auto e = start; e < xLen; e += step) {
auto zOffset = shape::getIndexOffset(e, outputShape, xLen);
auto xOffset = shape::getIndexOffset(e, inputShape, xLen);
auto yOffset = shape::getIndexOffset(e, indicesShape, xLen);
auto classIndex = y[yOffset];
auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape, gradLen);
z[zOffset] = T(gradOut[gradOffsetO] / float(lengths[classIndex]));
}
}
// -------------------------------------------------------------------------------------------------------------- //
template <typename T, typename I>
static __global__ void segmentMeanBPTadKernel(void* inputBuf, Nd4jLong* inputShape, void* eps, Nd4jLong* epsShape,
void* indicesBuf, Nd4jLong* indicesShape, int* lengths, void* outputBuf, Nd4jLong* outputShape,Nd4jLong* inputTad,
Nd4jLong* inputOffsets, Nd4jLong* gradOutTad, Nd4jLong* gradOutOffsets, Nd4jLong* outTad, Nd4jLong* outOffsets) {
__shared__ T* x;
__shared__ T* gradOut;
__shared__ I* y;
__shared__ T* z;
__shared__ Nd4jLong xLen, yLen, gradLen, currentLen;
if (threadIdx.x == 0) {
xLen = shape::length(inputShape);
x = reinterpret_cast<T*>(inputBuf);
y = reinterpret_cast<I*>(indicesBuf);
z = reinterpret_cast<T*>(outputBuf);
yLen = shape::length(indicesShape);
gradOut = reinterpret_cast<T*>(eps);
gradLen = shape::length(epsShape);
currentLen = shape::length(outTad);
}
__syncthreads();
for (auto i = blockIdx.x; i < yLen; i += gridDim.x) {
// auto yIndex = shape::getIndexOffset(i, indicesShape, yLen);
auto segment = y[i]; //yIndex];
T* currentOut = z + outOffsets[i];
T* outGrad = gradOut + gradOutOffsets[segment];
for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) {
auto zIndex = shape::getIndexOffset(e, outTad, currentLen);
auto gradIndex = shape::getIndexOffset(e, gradOutTad, gradLen);
if (lengths[segment] > 0)
currentOut[zIndex] = T(outGrad[gradIndex] / float(lengths[segment]));
}
}
}
// -------------------------------------------------------------------------------------------------------------- //
// backrop for mean
template <typename T, typename I>
int segmentMeanFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
auto stream = context->getCudaStream();
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
auto numClasses = indices->e<int>(indices->lengthOf() - 1) + 1;
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses});
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses});
classesRangesBegs.assign(indices->lengthOf());
classesRangesLens.assign(0);
dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32);
fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens);
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
if (input->isVector()) {
Nd4jLong loop_size = input->lengthOf();
auto numOfClasses = gradOut->lengthOf(); //indices->e<Nd4jLong>(loop_size - 1);
segmentMeanBPLinearKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(),
input->specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
indices->specialBuffer(), indices->specialShapeInfo(), lengths, output->specialBuffer(), output->specialShapeInfo());
}
else {
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
// auto packGradIn = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.getShapeInfo(), dimensions);
auto packGradOut = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->getShapeInfo(), dimensions);
Nd4jLong* inputTads = packX.specialShapeInfo();
Nd4jLong* inputTadOffsets = packX.specialOffsets();
Nd4jLong* outputTads = packZ.specialShapeInfo();
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
Nd4jLong* gradOutTads = packGradOut.specialShapeInfo();
Nd4jLong* gradOutTadOffsets = packGradOut.specialOffsets();
segmentMeanBPTadKernel<T,I><<<indices->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
gradOut->specialBuffer(), gradOut->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), lengths,
output->specialBuffer(), output->specialShapeInfo(), inputTads, inputTadOffsets, gradOutTads, gradOutTadOffsets,
outputTads, outputTadOffsets);
}
NDArray::registerSpecialUse({output}, {input, indices, gradOut});
return Status::OK();
}
// -------------------------------------------------------------------------------------------------------------- //
// segmen mean bp main
int segmentMeanFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentMeanFunctorBP_, (context, input,
indices, gradOut, output), NUMERIC_TYPES, INTEGER_TYPES);
}
// -------------------------------------------------------------------------------------------------------------- //
BUILD_DOUBLE_TEMPLATE(template int segmentMeanFunctorBP_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output), FLOAT_TYPES, INTEGER_TYPES);
// -------------------------------------------------------------------------------------------------------------- //
template <typename T, typename I>
static int unsortedSegmentMeanFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
auto stream = context->getCudaStream();
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
auto numClasses = indices->e<int>(indices->lengthOf() - 1) + 1;
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses});
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses});
classesRangesBegs.assign(indices->lengthOf());
classesRangesLens.assign(0);
dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32);
fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens);
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
if (input->isVector()) {
Nd4jLong loop_size = input->lengthOf();
auto numOfClasses = gradOut->lengthOf(); //indices->e<Nd4jLong>(loop_size - 1);
segmentMeanBPLinearKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(),
input->specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
indices->specialBuffer(), indices->specialShapeInfo(), lengths, output->specialBuffer(), output->specialShapeInfo());
}
else {
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
// auto packGradIn = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.getShapeInfo(), dimensions);
auto packGradOut = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->getShapeInfo(), dimensions);
Nd4jLong* inputTads = packX.specialShapeInfo();
Nd4jLong* inputTadOffsets = packX.specialOffsets();
Nd4jLong* outputTads = packZ.specialShapeInfo();
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
Nd4jLong* gradOutTads = packGradOut.specialShapeInfo();
Nd4jLong* gradOutTadOffsets = packGradOut.specialOffsets();
segmentMeanBPTadKernel<T,I><<<indices->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
gradOut->specialBuffer(), gradOut->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), lengths,
output->specialBuffer(), output->specialShapeInfo(), inputTads, inputTadOffsets, gradOutTads, gradOutTadOffsets,
outputTads, outputTadOffsets);
}
NDArray::registerSpecialUse({output}, {input, indices, gradOut});
return Status::OK();
}
// -------------------------------------------------------------------------------------------------------------- //
int unsortedSegmentMeanFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMeanFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INTEGER_TYPES);
}
// -------------------------------------------------------------------------------------------------------------- //
BUILD_DOUBLE_TEMPLATE(template int unsortedSegmentMeanFunctorBP_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), FLOAT_TYPES, INTEGER_TYPES);
// -------------------------------------------------------------------------------------------------------------- //
}
}
}

View File

@ -0,0 +1,423 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author GS <sgazeos@gmail.com>
//
#include <ops/declarable/helpers/segment.h>
#include <ops/declarable/helpers/segment_common.h>
#include <NDArrayFactory.h>
#include <helpers/ShapeUtils.h>
#include <helpers/TAD.h>
#include <exceptions/cuda_exception.h>
#include <PointersManager.h>
#include <ConstantTadHelper.h>
namespace nd4j {
namespace ops {
namespace helpers {
// -------------------------------------------------------------------------------------------------------------- //
// Segment ops linear kernels
// -------------------------------------------------------------------------------------------------------------- //
template<typename T, typename I>
static __global__ void
segmentMinLinearKernel(void *input, Nd4jLong *inputShape, int *starts, int *lengths, Nd4jLong numOfClasses,
void *output, Nd4jLong *outputShape) {
__shared__
T *val;
__shared__
Nd4jLong xLen, zLen, segment, zIndex;
__shared__
T *x;
__shared__
T *z;
__shared__ int threadsPerSegment, start, finish;
if (threadIdx.x == 0) {
threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses;
segment = blockIdx.x / threadsPerSegment;
x = reinterpret_cast<T *>(input);
z = reinterpret_cast<T *>(output);
extern __shared__ unsigned char shmem[];
val = reinterpret_cast<T *>(shmem);
xLen = shape::length(inputShape);
zLen = shape::length(outputShape);
if (segment < numOfClasses) {
zIndex = shape::getIndexOffset(segment, outputShape, zLen);
start = starts[segment];
finish = start + lengths[segment];
z[zIndex] = x[shape::getIndexOffset(start, inputShape, xLen)];
val[segment] = z[zIndex];
}
}
__syncthreads();
for (auto e = start + threadIdx.x + 1; e < finish; e += blockDim.x) {
auto xIndex = shape::getIndexOffset(e, inputShape, xLen);
nd4j::math::atomics::nd4j_atomicMin(&z[zIndex], x[xIndex]);
}
}
// -------------------------------------------------------------------------------------------------------------- //
template<typename T, typename I>
static __global__ void
unsortedSegmentMinLinearKernel(void *input, Nd4jLong *inputShape, void *indices, Nd4jLong *indicesShape,
int *starts, int *lengths, Nd4jLong numOfClasses, void *output,
Nd4jLong *outputShape) {
__shared__
T *val;
__shared__
Nd4jLong xLen, zLen, segment, zIndex;
__shared__
T *x;
__shared__
T *z;
__shared__
I *y; //int threadsPerSegment, start, finish;
if (threadIdx.x == 0) {
segment = blockIdx.x;
x = reinterpret_cast<T *>(input);
z = reinterpret_cast<T *>(output);
y = reinterpret_cast<I *>(indices);
xLen = shape::length(inputShape);
zLen = shape::length(outputShape);
zIndex = shape::getIndexOffset(segment, outputShape, zLen);
if (lengths[segment] > 0)
z[zIndex] = x[shape::getIndexOffset(starts[segment], inputShape, xLen)];
else
z[zIndex] = DataTypeUtils::max<T>();
}
__syncthreads();
if (lengths[segment] > 0)
for (auto e = threadIdx.x + 1; e < xLen; e += blockDim.x) {
auto xIndex = shape::getIndexOffset(e, inputShape, xLen);
auto yIndex = shape::getIndexOffset(e, indicesShape, xLen);
if (y[yIndex] == segment) {
nd4j::math::atomics::nd4j_atomicMin(&z[zIndex], x[xIndex]);
}
}
}
// -------------------------------------------------------------------------------------------------------------- //
// SegmentMin kernel
template <typename T, typename I>
static __global__ void segmentMinTadKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong* inputTads, Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets) {
__shared__ T* val;
__shared__ Nd4jLong len, segment, zIndex, total;
__shared__ T* z;
__shared__ int threadsPerSegment, start, finish;
if (threadIdx.x == 0) {
segment = indices[blockIdx.x]; // / threadsPerSegment;
z = reinterpret_cast<T*>(outputBuf) + outputTadOffsets[segment];
len = shape::length(inputTads);
start = starts[segment];
finish = start + lengths[segment];
total = shape::sizeAt(inputShape, 0);
}
__syncthreads();
auto idx = blockIdx.x;
if (blockIdx.x <= total) {
auto x = reinterpret_cast<T *>(inputBuf) + inputTadOffsets[idx];
if (blockIdx.x == start) {
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
auto xIndex = shape::getIndexOffset(e, inputTads, len);
auto zIndex = shape::getIndexOffset(e, outputTads, len);
z[zIndex] = x[xIndex];
}
}
else {
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
auto xIndex = shape::getIndexOffset(e, inputTads, len);
auto zIndex = shape::getIndexOffset(e, outputTads, len);
nd4j::math::atomics::nd4j_atomicMin(&z[zIndex], x[xIndex]);
}
}
}
}
// -------------------------------------------------------------------------------------------------------------- //
// segmen min
template <typename T, typename I>
static void segmentMinFunctor_(LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) {
auto stream = context->getCudaStream();
Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses});
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses});
classesRangesBegs.assign(indices->lengthOf());
classesRangesLens.assign(0);
fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens);
NDArray::prepareSpecialUse({output}, {input, indices, &classesRangesBegs, &classesRangesLens});
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
if (input->isVector()) {
segmentMinLinearKernel<T,I><<<numClasses, input->lengthOf(), numClasses * 32 + 32, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo());
}
else {
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
Nd4jLong* inputTads = packX.specialShapeInfo();
Nd4jLong* inputTadOffsets = packX.specialOffsets();
Nd4jLong* outputTads = packZ.specialShapeInfo();
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
segmentMinTadKernel<T,I><<<input->sizeAt(0), 512, 2048, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
}
NDArray::registerSpecialUse({output}, {input, indices, &classesRangesBegs, &classesRangesLens});
}
// -------------------------------------------------------------------------------------------------------------- //
void segmentMinFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) {
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentMinFunctor_, (context, input, indices, output), NUMERIC_TYPES, INTEGER_TYPES);
}
BUILD_DOUBLE_TEMPLATE(template void segmentMinFunctor_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES);
// -------------------------------------------------------------------------------------------------------------- //
template <typename T, typename I>
static void unsortedSegmentMinFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
auto stream = context->getCudaStream();
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses});
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
// classes.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), &row, &classes);
classesRangesBegs.assign(indices->lengthOf());
classesRangesLens.assign(0);
dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32);
// int* classesBuf = reinterpret_cast<int*>(classes.specialBuffer());
fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens);
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
NDArray::prepareSpecialUse({output}, {input, indices});
if (input->isVector()) {
unsortedSegmentMinLinearKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo());
}
else {
output->assign(DataTypeUtils::max<T>());
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
Nd4jLong* inputTads = packX.specialShapeInfo();
Nd4jLong* inputTadOffsets = packX.specialOffsets();
Nd4jLong* outputTads = packZ.specialShapeInfo();
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
dims.x = input->sizeAt(0);
segmentMinTadKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
}
NDArray::registerSpecialUse({output}, {input, indices});
}
// -------------------------------------------------------------------------------------------------------------- //
void unsortedSegmentMinFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMinFunctor_, (context, input, indices, numOfClasses, output),
NUMERIC_TYPES, INTEGER_TYPES);
}
// -------------------------------------------------------------------------------------------------------------- //
BUILD_DOUBLE_TEMPLATE(template void unsortedSegmentMinFunctor_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES);
template <typename T, typename I>
static __global__ void segmentMinBPLinearKernel(void* inputBuf, Nd4jLong* inputShape, void* forwardOutput,
Nd4jLong* forwardShape, void* eps, Nd4jLong* epsShape, void* indicesBuf, Nd4jLong* indicesShape,
void* outputBuf, Nd4jLong* outputShape) {
__shared__ T* x;
__shared__ T* gradIn;
__shared__ T* gradOut;
__shared__ I* y;
__shared__ T* z;
__shared__ Nd4jLong xLen, gradLen;
if (threadIdx.x == 0) {
xLen = shape::length(inputShape);
x = reinterpret_cast<T*>(inputBuf);
y = reinterpret_cast<I*>(indicesBuf);
z = reinterpret_cast<T*>(outputBuf);
gradIn = reinterpret_cast<T*>(forwardOutput);
gradOut = reinterpret_cast<T*>(eps);
gradLen = shape::length(epsShape);
}
auto start = blockIdx.x * blockDim.x + threadIdx.x;
auto step = gridDim.x * blockDim.x;
for (auto e = start; e < xLen; e += step) {
auto zOffset = shape::getIndexOffset(e, outputShape, xLen);
auto xOffset = shape::getIndexOffset(e, inputShape, xLen);
auto yOffset = shape::getIndexOffset(e, indicesShape, xLen);
auto classIndex = y[yOffset];
auto gradOffsetI = shape::getIndexOffset(classIndex, forwardShape, gradLen);
auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape, gradLen);
if (nd4j::math::nd4j_abs(gradIn[gradOffsetI] - x[xOffset]) <= T(1.e-6)) {
z[zOffset] = gradOut[gradOffsetO];
}
}
}
// -------------------------------------------------------------------------------------------------------------- //
template <typename T, typename I>
static __global__ void segmentMinBPTadKernel(void* inputBuf, Nd4jLong* inputShape, void* forwardOutput,
Nd4jLong* forwardShape, void* eps, Nd4jLong* epsShape, void* indicesBuf, Nd4jLong* indicesShape,
void* outputBuf, Nd4jLong* outputShape,Nd4jLong* inputTad,
Nd4jLong* inputOffsets, Nd4jLong* gradInTad, Nd4jLong* gradInOffsets,
Nd4jLong* gradOutTad, Nd4jLong* gradOutOffsets, Nd4jLong* outTad,
Nd4jLong* outOffsets) {
__shared__ T* x;
__shared__ T* gradIn;
__shared__ T* gradOut;
__shared__ I* y;
__shared__ T* z;
__shared__ Nd4jLong xLen, yLen, gradLen, currentLen;
if (threadIdx.x == 0) {
xLen = shape::length(inputShape);
x = reinterpret_cast<T*>(inputBuf);
y = reinterpret_cast<I*>(indicesBuf);
z = reinterpret_cast<T*>(outputBuf);
yLen = shape::length(indicesShape);
gradOut = reinterpret_cast<T*>(eps);
gradIn = reinterpret_cast<T*>(forwardOutput);
gradLen = shape::length(epsShape);
currentLen = shape::length(outTad);
}
for (auto i = blockIdx.x; i < yLen; i += gridDim.x) {
auto yIndex = shape::getIndexOffset(i, indicesShape, yLen);
auto segment = y[yIndex];
T* current = x + inputOffsets[i];
T* currentOut = z + outOffsets[i];
T* in = gradIn + gradInOffsets[segment];
T* outGrad = gradOut + gradOutOffsets[segment];
for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) {
if (nd4j::math::nd4j_abs(in[e] - current[e]) <= T(1.e-6))
currentOut[e] = outGrad[e];
}
}
}
// -------------------------------------------------------------------------------------------------------------- //
template <typename T, typename I>
int segmentMinFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
//int numOfClasses = gradOut->sizeAt(0);
// if input is a vector: (as if in doc sample)
auto stream = context->getCudaStream();
NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), DataTypeUtils::fromT<T>(), context);//->shapeInfo(), context);
segmentMinFunctor_<T, I>(context, input, indices, &tempRes);
NDArray::prepareSpecialUse({output}, {input, indices, gradOut, &tempRes});
if (input->isVector()) {
Nd4jLong loop_size = input->lengthOf();
auto numOfClasses = gradOut->lengthOf(); //indices->e<Nd4jLong>(loop_size - 1);
segmentMinBPLinearKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo());
}
else {
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
auto packGradIn = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.getShapeInfo(), dimensions);
auto packGradOut = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->getShapeInfo(), dimensions);
Nd4jLong* inputTads = packX.specialShapeInfo();
Nd4jLong* inputTadOffsets = packX.specialOffsets();
Nd4jLong* outputTads = packZ.specialShapeInfo();
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
Nd4jLong* gradInTads = packGradIn.specialShapeInfo();
Nd4jLong* gradInTadOffsets = packGradIn.specialOffsets();
Nd4jLong* gradOutTads = packGradOut.specialShapeInfo();
Nd4jLong* gradOutTadOffsets = packGradOut.specialOffsets();
segmentMinBPTadKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(),
inputTads, inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, gradOutTadOffsets,
outputTads, outputTadOffsets);
}
NDArray::registerSpecialUse({output}, {input, indices, gradOut, &tempRes});
return Status::OK();
}
// -------------------------------------------------------------------------------------------------------------- //
// segmen min
int segmentMinFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentMinFunctorBP_, (context, input,
indices, gradOut, output), NUMERIC_TYPES, INTEGER_TYPES);
}
BUILD_DOUBLE_TEMPLATE(template int segmentMinFunctorBP_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES);
// -------------------------------------------------------------------------------------------------------------- //
template <typename T, typename I>
static int unsortedSegmentMinFunctorBP_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
//int numOfClasses = gradOut->sizeAt(0);
// if input is a vector: (as if in doc sample)
auto stream = context->getCudaStream();
NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), DataTypeUtils::fromT<T>(), context);//->shapeInfo(), context);
unsortedSegmentMinFunctor_<T, I>(context, input, indices, numOfClasses, &tempRes);
NDArray::prepareSpecialUse({output}, {input, indices, gradOut, &tempRes});
if (input->isVector()) {
Nd4jLong loop_size = input->lengthOf();
auto numOfClasses = gradOut->lengthOf(); //indices->e<Nd4jLong>(loop_size - 1);
segmentMinBPLinearKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo());
}
else {
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
auto packGradIn = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.getShapeInfo(), dimensions);
auto packGradOut = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->getShapeInfo(), dimensions);
Nd4jLong* inputTads = packX.specialShapeInfo();
Nd4jLong* inputTadOffsets = packX.specialOffsets();
Nd4jLong* outputTads = packZ.specialShapeInfo();
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
Nd4jLong* gradInTads = packGradIn.specialShapeInfo();
Nd4jLong* gradInTadOffsets = packGradIn.specialOffsets();
Nd4jLong* gradOutTads = packGradOut.specialShapeInfo();
Nd4jLong* gradOutTadOffsets = packGradOut.specialOffsets();
segmentMinBPTadKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(),
inputTads, inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, gradOutTadOffsets,
outputTads, outputTadOffsets);
}
NDArray::registerSpecialUse({output}, {input, indices, gradOut, &tempRes});
return Status::OK();
}
// -------------------------------------------------------------------------------------------------------------- //
int unsortedSegmentMinFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMinFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), NUMERIC_TYPES, INTEGER_TYPES);
}
// -------------------------------------------------------------------------------------------------------------- //
BUILD_DOUBLE_TEMPLATE(template int unsortedSegmentMinFunctorBP_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES);
// -------------------------------------------------------------------------------------------------------------- //
}
}
}

View File

@ -0,0 +1,419 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author GS <sgazeos@gmail.com>
//
#include <ops/declarable/helpers/segment.h>
#include <ops/declarable/helpers/segment_common.h>
#include <NDArrayFactory.h>
#include <helpers/ShapeUtils.h>
#include <helpers/TAD.h>
#include <exceptions/cuda_exception.h>
#include <PointersManager.h>
#include <ConstantTadHelper.h>
namespace nd4j {
namespace ops {
namespace helpers {
// -------------------------------------------------------------------------------------------------------------- //
// Segment Prod ops linear kernels
// -------------------------------------------------------------------------------------------------------------- //
template <typename T, typename I>
static __global__ void segmentProdLinearKernel(void* input, Nd4jLong* inputShape, int* starts, int* lengths, Nd4jLong numOfClasses, void* output, Nd4jLong* outputShape) {
__shared__ T* val;
__shared__ Nd4jLong xLen, zLen, segment, zIndex;
__shared__ T* x;
__shared__ T* z;
__shared__ int threadsPerSegment, start, finish;
if (threadIdx.x == 0) {
threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses;
segment = blockIdx.x / threadsPerSegment;
x = reinterpret_cast<T*>(input);
z = reinterpret_cast<T*>(output);
extern __shared__ unsigned char shmem[];
val = reinterpret_cast<T*>(shmem);
xLen = shape::length(inputShape);
zLen = shape::length(outputShape);
if (segment < numOfClasses) {
zIndex = shape::getIndexOffset(segment, outputShape, zLen);
start = starts[segment];
finish = start + lengths[segment];
//val[segment] = ;
z[zIndex] = x[shape::getIndexOffset(start, inputShape, xLen)];
val[segment] = z[zIndex];
}
}
__syncthreads();
// auto tid = threadIdx.x + blockIdx.x * blockDim.x;
// auto step = blockDim.x * gridDim.x;
for (auto e = start + threadIdx.x + 1; e < finish; e += blockDim.x) {
auto xIndex = shape::getIndexOffset(e, inputShape, xLen);
nd4j::math::atomics::nd4j_atomicMul(&val[segment], x[xIndex]);
}
__syncthreads();
if (threadIdx.x == 0) {
z[zIndex] = val[segment];
}
}
// -------------------------------------------------------------------------------------------------------------- //
template <typename T, typename I>
static __global__ void unsortedSegmentProdLinearKernel(void* input, Nd4jLong* inputShape, void* indices, Nd4jLong* indicesShape, int* starts, int* lengths, Nd4jLong numOfClasses, void* output, Nd4jLong* outputShape) {
__shared__ T* val;
__shared__ Nd4jLong xLen, zLen, segment, zIndex;
__shared__ T* x;
__shared__ T* z;
__shared__ I* y; //int threadsPerSegment, start, finish;
if (threadIdx.x == 0) {
// threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses;
segment = blockIdx.x;// / threadsPerSegment;
x = reinterpret_cast<T*>(input);
z = reinterpret_cast<T*>(output);
y = reinterpret_cast<I*>(indices);
// extern __shared__ unsigned char shmem[];
// val = reinterpret_cast<T*>(shmem);
xLen = shape::length(inputShape);
zLen = shape::length(outputShape);
// if (segment < numOfClasses) {
zIndex = shape::getIndexOffset(segment, outputShape, zLen);
//start = starts[segment];
//finish = start + lengths[segment];
if (lengths[segment] > 0)
z[zIndex] = x[shape::getIndexOffset(starts[segment], inputShape, xLen)];
else
z[zIndex] = 0; //DataTypeUtils::max<T>();
// val[segment] = z[zIndex];
// }
}
__syncthreads();
if (lengths[segment] > 0)
for (auto e = threadIdx.x; e < xLen; e += blockDim.x) {
auto xIndex = shape::getIndexOffset(e, inputShape, xLen);
auto yIndex = shape::getIndexOffset(e, indicesShape, xLen);
if (y[yIndex] == segment && e != starts[segment]) {
nd4j::math::atomics::nd4j_atomicMul(&z[zIndex], x[xIndex]);
}
}
}
// -------------------------------------------------------------------------------------------------------------- //
// SegmentProd kernel
template <typename T, typename I>
static __global__ void segmentProdTadKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong* inputTads, Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets) {
__shared__ T* val;
__shared__ Nd4jLong len, segment, zIndex, total;
__shared__ T* z;
__shared__ int threadsPerSegment, start, finish;
if (threadIdx.x == 0) {
segment = indices[blockIdx.x]; // / threadsPerSegment;
z = reinterpret_cast<T*>(outputBuf) + outputTadOffsets[segment];
len = shape::length(inputTads);
start = starts[segment];
finish = start + lengths[segment];
total = shape::sizeAt(inputShape, 0);
}
__syncthreads();
auto idx = blockIdx.x;
if (blockIdx.x <= total) {
auto x = reinterpret_cast<T *>(inputBuf) + inputTadOffsets[idx];
if (blockIdx.x == start) {
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
auto xIndex = shape::getIndexOffset(e, inputTads, len);
auto zIndex = shape::getIndexOffset(e, outputTads, len);
z[zIndex] = x[xIndex];
}
}
else {
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
auto xIndex = shape::getIndexOffset(e, inputTads, len);
auto zIndex = shape::getIndexOffset(e, outputTads, len);
nd4j::math::atomics::nd4j_atomicMul(&z[zIndex], x[xIndex]);
}
}
}
}
// -------------------------------------------------------------------------------------------------------------- //
template <typename T, typename I>
static void segmentProdFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) {
auto stream = context->getCudaStream();
Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses});
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses});
classesRangesBegs.assign(indices->lengthOf());
classesRangesLens.assign(0);
dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32);
fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens);
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
if (input->isVector()) {
segmentProdLinearKernel<T,I><<<numClasses, input->lengthOf(), numClasses * 32 + 32, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo());
}
else {
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
Nd4jLong* inputTads = packX.specialShapeInfo();
Nd4jLong* inputTadOffsets = packX.specialOffsets();
Nd4jLong* outputTads = packZ.specialShapeInfo();
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
segmentProdTadKernel<T,I><<<input->sizeAt(0), 512, 2048, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
}
}
// -------------------------------------------------------------------------------------------------------------- //
void segmentProdFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) {
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), segmentProdFunctor_, (context, input, indices, output), NUMERIC_TYPES, INTEGER_TYPES);
}
BUILD_DOUBLE_TEMPLATE(template void segmentProdFunctor_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output), FLOAT_TYPES, INTEGER_TYPES);
// -------------------------------------------------------------------------------------------------------------- //
template <typename T, typename I>
static void unsortedSegmentProdFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
auto stream = context->getCudaStream();
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses});
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
// classes.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), &row, &classes);
classesRangesBegs.assign(indices->lengthOf());
classesRangesLens.assign(0);
dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32);
// int* classesBuf = reinterpret_cast<int*>(classes.specialBuffer());
fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens);
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
if (input->isVector()) {
unsortedSegmentProdLinearKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo());
}
else {
output->assign(1);
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
Nd4jLong* inputTads = packX.specialShapeInfo();
Nd4jLong* inputTadOffsets = packX.specialOffsets();
Nd4jLong* outputTads = packZ.specialShapeInfo();
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
dims.x = input->sizeAt(0);
segmentProdTadKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
}
}
// -------------------------------------------------------------------------------------------------------------- //
void unsortedSegmentProdFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentProdFunctor_, (context, input, indices, numOfClasses, output),
FLOAT_TYPES, INTEGER_TYPES);
}
// -------------------------------------------------------------------------------------------------------------- //
BUILD_DOUBLE_TEMPLATE(template void unsortedSegmentProdFunctor_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), FLOAT_TYPES, INTEGER_TYPES);
// -------------------------------------------------------------------------------------------------------------- //
template <typename T, typename I>
static __global__ void segmentProdBPLinearKernel(void* inputBuf, Nd4jLong* inputShape, void* forwardOutput,
Nd4jLong* forwardShape, void* eps, Nd4jLong* epsShape, void* indicesBuf, Nd4jLong* indicesShape,
void* outputBuf, Nd4jLong* outputShape) {
__shared__ T* x;
__shared__ T* gradIn;
__shared__ T* gradOut;
__shared__ I* y;
__shared__ T* z;
__shared__ Nd4jLong xLen, gradLen;
if (threadIdx.x == 0) {
xLen = shape::length(inputShape);
x = reinterpret_cast<T*>(inputBuf);
y = reinterpret_cast<I*>(indicesBuf);
z = reinterpret_cast<T*>(outputBuf);
gradIn = reinterpret_cast<T*>(forwardOutput);
gradOut = reinterpret_cast<T*>(eps);
gradLen = shape::length(epsShape);
}
auto start = blockIdx.x * blockDim.x + threadIdx.x;
auto step = gridDim.x * blockDim.x;
for (auto e = start; e < xLen; e += step) {
auto zOffset = shape::getIndexOffset(e, outputShape, xLen);
auto xOffset = shape::getIndexOffset(e, inputShape, xLen);
auto yOffset = shape::getIndexOffset(e, indicesShape, xLen);
auto classIndex = y[yOffset];
auto gradOffsetI = shape::getIndexOffset(classIndex, forwardShape, gradLen);
auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape, gradLen);
z[zOffset] = gradOut[gradOffsetO] * gradIn[gradOffsetI] / x[xOffset];
}
}
// -------------------------------------------------------------------------------------------------------------- //
template <typename T, typename I>
static __global__ void segmentProdBPTadKernel(void* inputBuf, Nd4jLong* inputShape, void* forwardOutput,
Nd4jLong* forwardShape, void* eps, Nd4jLong* epsShape, void* indicesBuf, Nd4jLong* indicesShape,
void* outputBuf, Nd4jLong* outputShape,Nd4jLong* inputTad,
Nd4jLong* inputOffsets, Nd4jLong* gradInTad, Nd4jLong* gradInOffsets,
Nd4jLong* gradOutTad, Nd4jLong* gradOutOffsets, Nd4jLong* outTad,
Nd4jLong* outOffsets) {
__shared__ T* x;
__shared__ T* gradIn;
__shared__ T* gradOut;
__shared__ I* y;
__shared__ T* z;
__shared__ Nd4jLong xLen, yLen, gradLen, currentLen;
if (threadIdx.x == 0) {
xLen = shape::length(inputShape);
x = reinterpret_cast<T*>(inputBuf);
y = reinterpret_cast<I*>(indicesBuf);
z = reinterpret_cast<T*>(outputBuf);
yLen = shape::length(indicesShape);
gradOut = reinterpret_cast<T*>(eps);
gradIn = reinterpret_cast<T*>(forwardOutput);
gradLen = shape::length(epsShape);
currentLen = shape::length(outTad);
}
for (auto i = blockIdx.x; i < yLen; i += gridDim.x) {
auto yIndex = shape::getIndexOffset(i, indicesShape, yLen);
auto segment = y[yIndex];
T* current = x + inputOffsets[i];
T* currentOut = z + outOffsets[i];
T* in = gradIn + gradInOffsets[segment];
T* outGrad = gradOut + gradOutOffsets[segment];
for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) {
currentOut[e] = outGrad[e] * in[e] / current[e];
}
}
}
// -------------------------------------------------------------------------------------------------------------- //
template <typename T, typename I>
int segmentProdFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
auto stream = context->getCudaStream();
NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), DataTypeUtils::fromT<T>(), context);//->shapeInfo(), context);
segmentProdFunctor_<T, I>(context, input, indices, &tempRes);
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
if (input->isVector()) {
Nd4jLong loopSize = input->lengthOf();
auto numOfClasses = gradOut->lengthOf(); //indices->e<Nd4jLong>(loop_size - 1);
segmentProdBPLinearKernel<T,I><<<gradOut->lengthOf(), loopSize, 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo());
}
else {
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
auto packGradIn = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.getShapeInfo(), dimensions);
auto packGradOut = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->getShapeInfo(), dimensions);
Nd4jLong* inputTads = packX.specialShapeInfo();
Nd4jLong* inputTadOffsets = packX.specialOffsets();
Nd4jLong* outputTads = packZ.specialShapeInfo();
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
Nd4jLong* gradInTads = packGradIn.specialShapeInfo();
Nd4jLong* gradInTadOffsets = packGradIn.specialOffsets();
Nd4jLong* gradOutTads = packGradOut.specialShapeInfo();
Nd4jLong* gradOutTadOffsets = packGradOut.specialOffsets();
segmentProdBPTadKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(),
inputTads, inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, gradOutTadOffsets,
outputTads, outputTadOffsets);
}
NDArray::registerSpecialUse({output}, {input, indices, gradOut});
return Status::OK();
}
// -------------------------------------------------------------------------------------------------------------- //
int segmentProdFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentProdFunctorBP_, (context, input,
indices, gradOut, output), FLOAT_TYPES, INTEGER_TYPES);
}
// -------------------------------------------------------------------------------------------------------------- //
BUILD_DOUBLE_TEMPLATE(template int segmentProdFunctorBP_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output), FLOAT_TYPES, INTEGER_TYPES);
// -------------------------------------------------------------------------------------------------------------- //
template <typename T, typename I>
static int unsortedSegmentProdFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
auto stream = context->getCudaStream();
NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), DataTypeUtils::fromT<T>(), context);//->shapeInfo(), context);
unsortedSegmentProdFunctor_<T, I>(context, input, indices, numOfClasses, &tempRes);
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
if (input->isVector()) {
Nd4jLong loopSize = input->lengthOf();
auto numOfClasses = gradOut->lengthOf(); //indices->e<Nd4jLong>(loop_size - 1);
segmentProdBPLinearKernel<T,I><<<gradOut->lengthOf(), loopSize, 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo());
}
else {
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
auto packGradIn = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.getShapeInfo(), dimensions);
auto packGradOut = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->getShapeInfo(), dimensions);
Nd4jLong* inputTads = packX.specialShapeInfo();
Nd4jLong* inputTadOffsets = packX.specialOffsets();
Nd4jLong* outputTads = packZ.specialShapeInfo();
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
Nd4jLong* gradInTads = packGradIn.specialShapeInfo();
Nd4jLong* gradInTadOffsets = packGradIn.specialOffsets();
Nd4jLong* gradOutTads = packGradOut.specialShapeInfo();
Nd4jLong* gradOutTadOffsets = packGradOut.specialOffsets();
segmentProdBPTadKernel<T,I><<<indices->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(),
inputTads, inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, gradOutTadOffsets,
outputTads, outputTadOffsets);
}
NDArray::registerSpecialUse({output}, {input, indices, gradOut});
return Status::OK();
}
// -------------------------------------------------------------------------------------------------------------- //
int unsortedSegmentProdFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentProdFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INTEGER_TYPES);
}
// -------------------------------------------------------------------------------------------------------------- //
BUILD_DOUBLE_TEMPLATE(template int unsortedSegmentProdFunctorBP_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), FLOAT_TYPES, INTEGER_TYPES);
// -------------------------------------------------------------------------------------------------------------- //
}
}
}

View File

@ -0,0 +1,280 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author GS <sgazeos@gmail.com>
//
#include <ops/declarable/helpers/segment.h>
#include <ops/declarable/helpers/segment_common.h>
#include <NDArrayFactory.h>
#include <helpers/ShapeUtils.h>
#include <helpers/TAD.h>
#include <exceptions/cuda_exception.h>
#include <PointersManager.h>
#include <ConstantTadHelper.h>
namespace nd4j {
namespace ops {
namespace helpers {
// -------------------------------------------------------------------------------------------------------------- //
template <typename T, typename I>
static __global__ void unsortedSegmentSqrtNLinearKernel(void* input, Nd4jLong* inputShape, void* indices, Nd4jLong* indicesShape, int* starts, int* lengths, Nd4jLong numOfClasses, void* output, Nd4jLong* outputShape) {
__shared__ T* val;
__shared__ Nd4jLong xLen, zLen, segment, zIndex;
__shared__ T* x;
__shared__ T* z;
__shared__ I* y; //int threadsPerSegment, start, finish;
if (threadIdx.x == 0) {
// threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses;
segment = blockIdx.x;// / threadsPerSegment;
x = reinterpret_cast<T*>(input);
z = reinterpret_cast<T*>(output);
y = reinterpret_cast<I*>(indices);
// extern __shared__ unsigned char shmem[];
// val = reinterpret_cast<T*>(shmem);
xLen = shape::length(inputShape);
zLen = shape::length(outputShape);
// if (segment < numOfClasses) {
zIndex = shape::getIndexOffset(segment, outputShape, zLen);
//start = starts[segment];
//finish = start + lengths[segment];
if (lengths[segment] > 0)
z[zIndex] = x[shape::getIndexOffset(starts[segment], inputShape, xLen)] / nd4j::math::nd4j_sqrt<int, T>(lengths[segment]);
else
z[zIndex] = 0; //DataTypeUtils::max<T>();
// val[segment] = z[zIndex];
// }
}
__syncthreads();
if (lengths[segment] > 0)
for (auto e = threadIdx.x + 1; e < xLen; e += blockDim.x) {
auto xIndex = shape::getIndexOffset(e, inputShape, xLen);
auto yIndex = shape::getIndexOffset(e, indicesShape, xLen);
if (y[yIndex] == segment && e != starts[segment]) {
nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex] / nd4j::math::nd4j_sqrt<int, T>(lengths[segment]));
}
}
}
// -------------------------------------------------------------------------------------------------------------- //
// SegmentSqrtN kernel
template <typename T, typename I>
static __global__ void segmentSqrtNTadKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong* inputTads, Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets) {
__shared__ T* val;
__shared__ Nd4jLong len, segment, zIndex, total;
__shared__ T* z;
__shared__ int threadsPerSegment, start, finish;
if (threadIdx.x == 0) {
segment = indices[blockIdx.x]; // / threadsPerSegment;
z = reinterpret_cast<T*>(outputBuf) + outputTadOffsets[segment];
len = shape::length(inputTads);
start = starts[segment];
finish = start + lengths[segment];
total = shape::sizeAt(inputShape, 0);
}
__syncthreads();
auto idx = blockIdx.x;
if (blockIdx.x <= total) {
auto x = reinterpret_cast<T *>(inputBuf) + inputTadOffsets[idx];
if (blockIdx.x == start) {
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
auto xIndex = shape::getIndexOffset(e, inputTads, len);
auto zIndex = shape::getIndexOffset(e, outputTads, len);
z[zIndex] = x[xIndex] / nd4j::math::nd4j_sqrt<int, T>(lengths[segment]);
}
}
else {
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
auto xIndex = shape::getIndexOffset(e, inputTads, len);
auto zIndex = shape::getIndexOffset(e, outputTads, len);
nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex] / nd4j::math::nd4j_sqrt<int, T>(lengths[segment]));
}
}
}
}
// -------------------------------------------------------------------------------------------------------------- //
template <typename T, typename I>
static void unsortedSegmentSqrtNFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
auto stream = context->getCudaStream();
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses});
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
// classes.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), &row, &classes);
classesRangesBegs.assign(indices->lengthOf());
classesRangesLens.assign(0);
dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32);
// int* classesBuf = reinterpret_cast<int*>(classes.specialBuffer());
fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens);
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
if (input->isVector()) {
unsortedSegmentSqrtNLinearKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo());
}
else {
output->assign(0);
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
Nd4jLong* inputTads = packX.specialShapeInfo();
Nd4jLong* inputTadOffsets = packX.specialOffsets();
Nd4jLong* outputTads = packZ.specialShapeInfo();
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
dims.x = input->sizeAt(0);
segmentSqrtNTadKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
}
}
// -------------------------------------------------------------------------------------------------------------- //
void unsortedSegmentSqrtNFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentSqrtNFunctor_, (context, input, indices, numOfClasses, output),
FLOAT_TYPES, INTEGER_TYPES);
}
BUILD_DOUBLE_TEMPLATE(template void unsortedSegmentSqrtNFunctor_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), FLOAT_TYPES, INTEGER_TYPES);
// -------------------------------------------------------------------------------------------------------------- //
template <typename T, typename I>
static __global__ void segmentSqrtNBPLinearKernel(void* inputBuf, Nd4jLong* inputShape, void* eps, Nd4jLong* epsShape, void* indicesBuf, Nd4jLong* indicesShape,
int* lengths, void* outputBuf, Nd4jLong* outputShape) {
__shared__ T* x;
__shared__ T* gradIn;
__shared__ T* gradOut;
__shared__ I* y;
__shared__ T* z;
__shared__ Nd4jLong xLen, gradLen;
if (threadIdx.x == 0) {
xLen = shape::length(inputShape);
x = reinterpret_cast<T*>(inputBuf);
y = reinterpret_cast<I*>(indicesBuf);
z = reinterpret_cast<T*>(outputBuf);
gradOut = reinterpret_cast<T*>(eps);
gradLen = shape::length(epsShape);
}
auto start = blockIdx.x * blockDim.x + threadIdx.x;
auto step = gridDim.x * blockDim.x;
for (auto e = start; e < xLen; e += step) {
auto zOffset = shape::getIndexOffset(e, outputShape, xLen);
auto xOffset = shape::getIndexOffset(e, inputShape, xLen);
auto yOffset = shape::getIndexOffset(e, indicesShape, xLen);
auto classIndex = y[yOffset];
auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape, gradLen);
z[zOffset] = T(gradOut[gradOffsetO] / math::nd4j_sqrt<int, float>(lengths[classIndex]));
}
}
// -------------------------------------------------------------------------------------------------------------- //
template <typename T, typename I>
static __global__ void segmentSqrtNBPTadKernel(void* inputBuf, Nd4jLong* inputShape, void* eps, Nd4jLong* epsShape,
void* indicesBuf, Nd4jLong* indicesShape, int* lengths, void* outputBuf, Nd4jLong* outputShape,Nd4jLong* inputTad,
Nd4jLong* inputOffsets, Nd4jLong* gradOutTad, Nd4jLong* gradOutOffsets, Nd4jLong* outTad, Nd4jLong* outOffsets) {
__shared__ T* x;
__shared__ T* gradOut;
__shared__ I* y;
__shared__ T* z;
__shared__ Nd4jLong xLen, yLen, gradLen, currentLen;
if (threadIdx.x == 0) {
xLen = shape::length(inputShape);
x = reinterpret_cast<T*>(inputBuf);
y = reinterpret_cast<I*>(indicesBuf);
z = reinterpret_cast<T*>(outputBuf);
yLen = shape::length(indicesShape);
gradOut = reinterpret_cast<T*>(eps);
gradLen = shape::length(epsShape);
currentLen = shape::length(outTad);
}
__syncthreads();
for (auto i = blockIdx.x; i < yLen; i += gridDim.x) {
// auto yIndex = shape::getIndexOffset(i, indicesShape, yLen);
auto segment = y[i]; //yIndex];
T* currentOut = z + outOffsets[i];
T* outGrad = gradOut + gradOutOffsets[segment];
for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) {
auto zIndex = shape::getIndexOffset(e, outTad, currentLen);
auto gradIndex = shape::getIndexOffset(e, gradOutTad, gradLen);
if (lengths[segment] > 0)
currentOut[zIndex] = T(outGrad[gradIndex] / math::nd4j_sqrt<int, float>(lengths[segment]));
}
}
}
// -------------------------------------------------------------------------------------------------------------- //
template <typename T, typename I>
static int unsortedSegmentSqrtNFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
auto stream = context->getCudaStream();
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
auto numClasses = indices->e<int>(indices->lengthOf() - 1) + 1;
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses});
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses});
classesRangesBegs.assign(indices->lengthOf());
classesRangesLens.assign(0);
dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32);
fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens);
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
if (input->isVector()) {
Nd4jLong loop_size = input->lengthOf();
auto numOfClasses = gradOut->lengthOf(); //indices->e<Nd4jLong>(loop_size - 1);
segmentSqrtNBPLinearKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(),
input->specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
indices->specialBuffer(), indices->specialShapeInfo(), lengths, output->specialBuffer(), output->specialShapeInfo());
}
else {
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
// auto packGradIn = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.getShapeInfo(), dimensions);
auto packGradOut = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->getShapeInfo(), dimensions);
Nd4jLong* inputTads = packX.specialShapeInfo();
Nd4jLong* inputTadOffsets = packX.specialOffsets();
Nd4jLong* outputTads = packZ.specialShapeInfo();
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
Nd4jLong* gradOutTads = packGradOut.specialShapeInfo();
Nd4jLong* gradOutTadOffsets = packGradOut.specialOffsets();
segmentSqrtNBPTadKernel<T,I><<<indices->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
gradOut->specialBuffer(), gradOut->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), lengths,
output->specialBuffer(), output->specialShapeInfo(), inputTads, inputTadOffsets, gradOutTads, gradOutTadOffsets,
outputTads, outputTadOffsets);
}
NDArray::registerSpecialUse({output}, {input, indices, gradOut});
return Status::OK();
}
// -------------------------------------------------------------------------------------------------------------- //
int unsortedSegmentSqrtNFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentSqrtNFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INTEGER_TYPES);
}
// -------------------------------------------------------------------------------------------------------------- //
BUILD_DOUBLE_TEMPLATE(template int unsortedSegmentSqrtNFunctorBP_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), FLOAT_TYPES, INTEGER_TYPES);
}
}
}

View File

@ -0,0 +1,393 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author GS <sgazeos@gmail.com>
//
#include <ops/declarable/helpers/segment.h>
#include <ops/declarable/helpers/segment_common.h>
#include <NDArrayFactory.h>
#include <helpers/ShapeUtils.h>
#include <helpers/TAD.h>
#include <exceptions/cuda_exception.h>
#include <PointersManager.h>
#include <ConstantTadHelper.h>
namespace nd4j {
namespace ops {
namespace helpers {
// -------------------------------------------------------------------------------------------------------------- //
// Segment ops linear kernels
// -------------------------------------------------------------------------------------------------------------- //
template<typename T, typename I>
static __global__ void
segmentSumLinearKernel(void *input, Nd4jLong *inputShape, int *starts, int *lengths, Nd4jLong numOfClasses,
void *output, Nd4jLong *outputShape) {
__shared__
T *val;
__shared__
Nd4jLong xLen, zLen, segment, zIndex;
__shared__
T *x;
__shared__
T *z;
__shared__ int threadsPerSegment, start, finish;
if (threadIdx.x == 0) {
threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses;
segment = blockIdx.x / threadsPerSegment;
x = reinterpret_cast<T *>(input);
z = reinterpret_cast<T *>(output);
xLen = shape::length(inputShape);
zLen = shape::length(outputShape);
if (segment < numOfClasses) {
zIndex = shape::getIndexOffset(segment, outputShape, zLen);
start = starts[segment];
finish = start + lengths[segment];
//val[segment] = ;
z[zIndex] = x[shape::getIndexOffset(start, inputShape, xLen)];
}
}
__syncthreads();
for (auto e = start + threadIdx.x + 1; e < finish; e += blockDim.x) {
auto xIndex = shape::getIndexOffset(e, inputShape, xLen);
nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex]);
}
}
// -------------------------------------------------------------------------------------------------------------- //
template<typename T, typename I>
static __global__ void
unsortedSegmentSumLinearKernel(void *input, Nd4jLong *inputShape, void *indices, Nd4jLong *indicesShape,
int *starts, int *lengths, Nd4jLong numOfClasses, void *output,
Nd4jLong *outputShape) {
__shared__
T *val;
__shared__
Nd4jLong xLen, zLen, segment, zIndex;
__shared__
T *x;
__shared__
T *z;
__shared__
I *y; //int threadsPerSegment, start, finish;
if (threadIdx.x == 0) {
segment = blockIdx.x;
x = reinterpret_cast<T *>(input);
z = reinterpret_cast<T *>(output);
y = reinterpret_cast<I *>(indices);
xLen = shape::length(inputShape);
zLen = shape::length(outputShape);
zIndex = shape::getIndexOffset(segment, outputShape, zLen);
if (lengths[segment] > 0)
z[zIndex] = x[shape::getIndexOffset(starts[segment], inputShape, xLen)];
else
z[zIndex] = 0; //DataTypeUtils::max<T>();
}
__syncthreads();
if (lengths[segment] > 0)
for (auto e = threadIdx.x; e < xLen; e += blockDim.x) {
auto xIndex = shape::getIndexOffset(e, inputShape, xLen);
auto yIndex = shape::getIndexOffset(e, indicesShape, xLen);
if (y[yIndex] == segment && e != starts[segment]) {
nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex]);
}
}
}
// -------------------------------------------------------------------------------------------------------------- //
// SegmentSum kernel
template <typename T, typename I>
static __global__ void segmentSumTadKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong* inputTads, Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets) {
__shared__ T* val;
__shared__ Nd4jLong len, segment, zIndex, total;
__shared__ T* z;
__shared__ int threadsPerSegment, start, finish;
if (threadIdx.x == 0) {
segment = indices[blockIdx.x]; // / threadsPerSegment;
z = reinterpret_cast<T*>(outputBuf) + outputTadOffsets[segment];
len = shape::length(inputTads);
start = starts[segment];
finish = start + lengths[segment];
total = shape::sizeAt(inputShape, 0);
}
__syncthreads();
auto idx = blockIdx.x;
if (blockIdx.x <= total) {
auto x = reinterpret_cast<T *>(inputBuf) + inputTadOffsets[idx];
if (blockIdx.x == start) {
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
auto xIndex = shape::getIndexOffset(e, inputTads, len);
auto zIndex = shape::getIndexOffset(e, outputTads, len);
z[zIndex] = x[xIndex];
}
}
else {
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
auto xIndex = shape::getIndexOffset(e, inputTads, len);
auto zIndex = shape::getIndexOffset(e, outputTads, len);
if (lengths[segment])
nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex]);
}
}
}
}
// -------------------------------------------------------------------------------------------------------------- //
template <typename T, typename I>
static void segmentSumFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) {
auto stream = context->getCudaStream();
Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses});
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses});
classesRangesBegs.assign(indices->lengthOf());
classesRangesLens.assign(0);
dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32);
fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens);
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
if (input->isVector()) {
segmentSumLinearKernel<T,I><<<numClasses, input->lengthOf(), numClasses * 32 + 32, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo());
}
else {
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
Nd4jLong* inputTads = packX.specialShapeInfo();
Nd4jLong* inputTadOffsets = packX.specialOffsets();
Nd4jLong* outputTads = packZ.specialShapeInfo();
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
segmentSumTadKernel<T,I><<<input->sizeAt(0), 512, 2048, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
}
}
// -------------------------------------------------------------------------------------------------------------- //
void segmentSumFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) {
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentSumFunctor_, (context, input, indices, output), NUMERIC_TYPES, INTEGER_TYPES);
}
BUILD_DOUBLE_TEMPLATE(template void segmentSumFunctor_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES);
// -------------------------------------------------------------------------------------------------------------- //
template <typename T, typename I>
static void unsortedSegmentSumFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
auto stream = context->getCudaStream();
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses});
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
// classes.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), &row, &classes);
classesRangesBegs.assign(indices->lengthOf());
classesRangesLens.assign(0);
dim3 dims(numOfClasses, indices->lengthOf(), (numOfClasses + 1) * 64);
// int* classesBuf = reinterpret_cast<int*>(classes.specialBuffer());
fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens);
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
if (input->isVector()) {
unsortedSegmentSumLinearKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo());
}
else {
output->assign(0);
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
Nd4jLong* inputTads = packX.specialShapeInfo();
Nd4jLong* inputTadOffsets = packX.specialOffsets();
Nd4jLong* outputTads = packZ.specialShapeInfo();
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
dims.x = input->sizeAt(0);
segmentSumTadKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
}
}
// -------------------------------------------------------------------------------------------------------------- //
void unsortedSegmentSumFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentSumFunctor_, (context, input, indices, numOfClasses, output),
NUMERIC_TYPES, INTEGER_TYPES);
}
// -------------------------------------------------------------------------------------------------------------- //
BUILD_DOUBLE_TEMPLATE(template void unsortedSegmentSumFunctor_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES);
// -------------------------------------------------------------------------------------------------------------- //
// Backpropagate ops
// -------------------------------------------------------------------------------------------------------------- //
// Sorted sum backpropagate
template <typename T, typename I>
static __global__ void segmentSumBPLinearKernel(void* inputBuf, Nd4jLong* inputShape, void* eps, Nd4jLong* epsShape,
void* indicesBuf, Nd4jLong* indicesShape, void* outputBuf, Nd4jLong* outputShape) {
__shared__ T* x;
__shared__ T* gradIn;
__shared__ T* gradOut;
__shared__ I* y;
__shared__ T* z;
__shared__ Nd4jLong xLen, gradLen;
if (threadIdx.x == 0) {
xLen = shape::length(inputShape);
x = reinterpret_cast<T*>(inputBuf);
y = reinterpret_cast<I*>(indicesBuf);
z = reinterpret_cast<T*>(outputBuf);
gradOut = reinterpret_cast<T*>(eps);
gradLen = shape::length(epsShape);
}
auto start = blockIdx.x * blockDim.x + threadIdx.x;
auto step = gridDim.x * blockDim.x;
for (auto e = start; e < xLen; e += step) {
auto zOffset = shape::getIndexOffset(e, outputShape, xLen);
auto xOffset = shape::getIndexOffset(e, inputShape, xLen);
auto yOffset = shape::getIndexOffset(e, indicesShape, xLen);
auto classIndex = y[yOffset];
auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape, gradLen);
z[zOffset] = gradOut[gradOffsetO];
}
}
// -------------------------------------------------------------------------------------------------------------- //
template <typename T, typename I>
static __global__ void segmentSumBPTadKernel(void* inputBuf, Nd4jLong* inputShape, void* eps, Nd4jLong* epsShape,
void* indicesBuf, Nd4jLong* indicesShape, void* outputBuf, Nd4jLong* outputShape, Nd4jLong* inputTad,
Nd4jLong* inputOffsets, Nd4jLong* gradOutTad, Nd4jLong* gradOutOffsets, Nd4jLong* outTad, Nd4jLong* outOffsets) {
__shared__ T* x;
__shared__ T* gradOut;
__shared__ I* y;
__shared__ T* z;
__shared__ Nd4jLong xLen, yLen, gradLen, currentLen;
if (threadIdx.x == 0) {
xLen = shape::length(inputShape);
x = reinterpret_cast<T*>(inputBuf);
y = reinterpret_cast<I*>(indicesBuf);
z = reinterpret_cast<T*>(outputBuf);
yLen = shape::length(indicesShape);
gradOut = reinterpret_cast<T*>(eps);
gradLen = shape::length(epsShape);
currentLen = shape::length(outTad);
}
for (auto i = blockIdx.x; i < yLen; i += gridDim.x) {
auto yIndex = shape::getIndexOffset(i, indicesShape, yLen);
auto segment = y[yIndex];
T* currentOut = z + outOffsets[i];
T* outGrad = gradOut + gradOutOffsets[segment];
for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) {
currentOut[e] = outGrad[e];
}
}
}
// -------------------------------------------------------------------------------------------------------------- //
template <typename T, typename I>
int segmentSumFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
auto stream = context->getCudaStream();
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
if (input->isVector()) {
Nd4jLong loop_size = input->lengthOf();
auto numOfClasses = gradOut->lengthOf(); //indices->e<Nd4jLong>(loop_size - 1);
segmentSumBPLinearKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(),
input->specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo());
}
else {
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
auto packGradOut = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->getShapeInfo(), dimensions);
Nd4jLong* inputTads = packX.specialShapeInfo();
Nd4jLong* inputTadOffsets = packX.specialOffsets();
Nd4jLong* outputTads = packZ.specialShapeInfo();
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
Nd4jLong* gradOutTads = packGradOut.specialShapeInfo();
Nd4jLong* gradOutTadOffsets = packGradOut.specialOffsets();
segmentSumBPTadKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
gradOut->specialBuffer(), gradOut->specialShapeInfo(),
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(),
inputTads, inputTadOffsets, gradOutTads, gradOutTadOffsets,
outputTads, outputTadOffsets);
}
NDArray::registerSpecialUse({output}, {input, indices, gradOut});
return Status::OK();
}
// -------------------------------------------------------------------------------------------------------------- //
int segmentSumFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentSumFunctorBP_, (context, input,
indices, gradOut, output), NUMERIC_TYPES, INTEGER_TYPES);
}
BUILD_DOUBLE_TEMPLATE(template int segmentSumFunctorBP_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES);
// -------------------------------------------------------------------------------------------------------------- //
template <typename T, typename I>
static int unsortedSegmentSumFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
auto stream = context->getCudaStream();
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
if (input->isVector()) {
Nd4jLong loop_size = input->lengthOf();
auto numOfClasses = gradOut->lengthOf(); //indices->e<Nd4jLong>(loop_size - 1);
segmentSumBPLinearKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(),
input->specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo());
}
else {
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
auto packGradOut = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->getShapeInfo(), dimensions);
Nd4jLong* inputTads = packX.specialShapeInfo();
Nd4jLong* inputTadOffsets = packX.specialOffsets();
Nd4jLong* outputTads = packZ.specialShapeInfo();
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
Nd4jLong* gradOutTads = packGradOut.specialShapeInfo();
Nd4jLong* gradOutTadOffsets = packGradOut.specialOffsets();
segmentSumBPTadKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
gradOut->specialBuffer(), gradOut->specialShapeInfo(),
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(),
inputTads, inputTadOffsets, gradOutTads, gradOutTadOffsets,
outputTads, outputTadOffsets);
}
NDArray::registerSpecialUse({output}, {input, indices, gradOut});
return Status::OK();
}
// -------------------------------------------------------------------------------------------------------------- //
int unsortedSegmentSumFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentSumFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), NUMERIC_TYPES, INTEGER_TYPES);
}
// -------------------------------------------------------------------------------------------------------------- //
BUILD_DOUBLE_TEMPLATE(template int unsortedSegmentSumFunctorBP_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES);
}
}
}

View File

@ -24,16 +24,40 @@ namespace nd4j {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
template <typename T> template <typename I, typename B>
static void sequenceMask_(NDArray* input, NDArray* output, int maxIndex) { static __global__ void sequenceMaskKernel(void* inputBuf, Nd4jLong* inputShape, void* outputBuf, Nd4jLong* outputShape, int maxIndex) {
//
__shared__ I* input;
__shared__ B* output;
__shared__ Nd4jLong inputLen, outputLen;
if (threadIdx.x == 0) {
input = reinterpret_cast<I*>(inputBuf);
output = reinterpret_cast<B*>(outputBuf);
inputLen = shape::length(inputShape);
outputLen = shape::length(outputShape);
}
for (auto i = blockIdx.x; i < maxIndex; i += gridDim.x)
for(auto k = threadIdx.x; k < inputLen; k += blockDim.x)
if (i < input[shape::getIndexOffset(k, inputShape, inputLen)])
output[shape::getIndexOffset(k * maxIndex + i, outputShape, outputLen)] = B(true);
}
template <typename I, typename B>
static void sequenceMask_(LaunchContext* context, NDArray* input, NDArray* output, int maxIndex) {
dim3 launchDims(maxIndex, input->lengthOf(), 128);
NDArray::prepareSpecialUse({output}, {input});
auto stream = context->getCudaStream();
sequenceMaskKernel<I, B><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), maxIndex);
NDArray::registerSpecialUse({output}, {input});
} }
void sequenceMask(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int maxIndex) { void sequenceMask(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int maxIndex) {
BUILD_SINGLE_SELECTOR(input->dataType(), sequenceMask_, (input, output, maxIndex), LIBND4J_TYPES); BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sequenceMask_, (context, input, output, maxIndex), INTEGER_TYPES, BOOL_TYPES);
} }
BUILD_SINGLE_TEMPLATE(template void sequenceMask_, (NDArray* input, NDArray* output, int maxIndex), LIBND4J_TYPES); BUILD_DOUBLE_TEMPLATE(template void sequenceMask_, (nd4j::LaunchContext* context, NDArray* input, NDArray* output, int maxIndex), INTEGER_TYPES, BOOL_TYPES);
} }
} }
} }

View File

@ -456,27 +456,246 @@ void tileBP(nd4j::LaunchContext * context, const NDArray& gradO /*input*/, NDArr
manager.synchronize(); manager.synchronize();
} }
///////////////////////////////////////////////////////////////////
template<typename T>
__global__ static void scatterUpdateCuda(const int opCode, const int numOfInd,
void* vx, const Nd4jLong *xShapeInfo, const Nd4jLong *xOffsets,
void* vy, const Nd4jLong *yShapeInfo, const Nd4jLong *yOffsets,
const int* indexes) {
__shared__ T *x, *y;
__shared__ Nd4jLong arrLenX, arrLenY;
for (int e = 0; e < numOfInd; e++ ) {
const auto xIndex = indexes[e];
const bool isOwner = xIndex < gridDim.x ? blockIdx.x == xIndex : blockIdx.x == xIndex % gridDim.x;
if (!isOwner)
continue;
if (threadIdx.x == 0) {
x = reinterpret_cast<T*>(vx) + xOffsets[xIndex];
y = reinterpret_cast<T*>(vy) + yOffsets[e];
arrLenX = shape::length(xShapeInfo);
arrLenY = shape::length(yShapeInfo);
}
__syncthreads();
if (arrLenX != arrLenY)
return;
for (Nd4jLong i = threadIdx.x; i < arrLenX; i += blockDim.x) {
const auto xOffset = shape::getIndexOffset(i, xShapeInfo, arrLenX);
const auto yOffset = shape::getIndexOffset(i, yShapeInfo, arrLenY);
switch (opCode) {
case 0:
x[xOffset] += y[yOffset];
break;
case 1:
x[xOffset] -= y[yOffset];
break;
case 2:
x[xOffset] *= y[yOffset];
break;
case 3:
x[xOffset] /= y[yOffset];
break;
case 4:
x[xOffset] = y[yOffset] - x[xOffset];
break;
case 5:
x[xOffset] = y[yOffset] / x[xOffset];
break;
case 6:
x[xOffset] = y[yOffset];
break;
default:
continue;
}
}
__syncthreads();
}
}
template<typename T>
__host__ static void scatterUpdateCudaLauncher(const cudaStream_t* stream, const int opCode, const int numOfInd, void* vx, const Nd4jLong *xShapeInfo, const Nd4jLong *xOffsets, void* vy, const Nd4jLong *yShapeInfo, const Nd4jLong *yOffsets, const int* indexes) {
scatterUpdateCuda<T><<<512, 256, MAX_NUM_THREADS, *stream>>>(opCode, numOfInd, vx, xShapeInfo, xOffsets, vy, yShapeInfo, yOffsets, indexes);
}
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void scatterUpdate(nd4j::LaunchContext* context, NDArray& input, NDArray& updates, const std::vector<int>* intArgs) {
const int opCode = (*intArgs)[0];
const int numOfDims = (*intArgs)[1];
const int numOfInd = (*intArgs)[2 + numOfDims];
std::vector<int> tadDimensions(numOfDims);
for (int e = 2; e < 2 + numOfDims; e++)
tadDimensions[e-2] = (*intArgs)[e];
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), tadDimensions);
auto packY = ConstantTadHelper::getInstance()->tadForDimensions(updates.getShapeInfo(), tadDimensions);
NDArray indices(const_cast<int*>(intArgs->data()) + numOfDims + 3, 'c', {numOfInd}, nd4j::DataType::INT32, context);
PointersManager manager(context, "scatterUpdate");
NDArray::prepareSpecialUse({&input}, {&input, &updates, &indices});
BUILD_SINGLE_SELECTOR(input.dataType(), scatterUpdateCudaLauncher, (context->getCudaStream(), opCode, numOfInd, input.specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), updates.specialBuffer(), packY.platformShapeInfo(), packY.platformOffsets(), reinterpret_cast<int*>(indices.getSpecialBuffer())), LIBND4J_TYPES);
NDArray::registerSpecialUse({&input}, {&input, &updates, &indices});
manager.synchronize();
}
template <typename T> template <typename T>
void randomShuffle_(nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::random::RandomBuffer& rng, const bool isInplace) { static __global__ void swapShuffleKernel(T* input, Nd4jLong* shape, Nd4jLong firstDim, Nd4jLong len, nd4j::graph::RandomGenerator* rng) {
auto tid = blockIdx.x * blockDim.x;
auto step = blockDim.x * gridDim.x;
for (int i = firstDim - 1 - tid - threadIdx.x; i > 0; i -= step) {
int r = rng->relativeInt(i) % i;
if (i != r) {
T e0 = input[shape::getIndexOffset(i, shape, len)];
T e1 = input[shape::getIndexOffset(r, shape, len)];
//math::nd4j_swap<T>(input(i), input(r));
input[shape::getIndexOffset(i, shape, len)] = e1;
input[shape::getIndexOffset(r, shape, len)] = e0;
}
}
}
template <typename T>
static __global__ void fillShuffleKernel(T* input, Nd4jLong* inputShape, T* output, Nd4jLong* outputShape, Nd4jLong firstDim, Nd4jLong len, int* indices, nd4j::graph::RandomGenerator* rng) {
// PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->tadThreshold())
auto tid = blockIdx.x * blockDim.x;
auto step = blockDim.x * gridDim.x;
for(int i = firstDim - 1 - tid - threadIdx.x; i > 0; i -= step) {
int r = rng->relativeInt(i) % i;
output[shape::getIndexOffset(i, outputShape, len)] = input[shape::getIndexOffset(indices[r], inputShape, len)];
if(i != r) {
output[shape::getIndexOffset(r, outputShape, len)] = input[shape::getIndexOffset(indices[i], inputShape, len)];
// output.p(r, input.e<T>(indices[i]));
// math::nd4j_swap<int>(indices[i], indices[r]);
atomicExch(&indices[i], indices[r]);
}
}
}
//////////////////////////////////////////////////////////////////////////
template <typename T>
void randomShuffle_(nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::graph::RandomGenerator& rng, const bool isInplace) {
// check edge cases first
int temp;
const int firstDim = input.sizeAt(0);
auto stream = context->getCudaStream();
NDArray::prepareSpecialUse({&output}, {&input});
if(input.lengthOf() == 1 || firstDim == 1) {
if(!isInplace)
output.assign(input);
}
else if (input.isVector() || shape::isLikeVector(input.getShapeInfo(), temp)) {
// apply Fisher-Yates shuffle
nd4j::graph::RandomGenerator* dRandom = nullptr;
cudaMalloc(&dRandom, sizeof(nd4j::graph::RandomGenerator));
cudaMemcpy(dRandom, &rng, sizeof(nd4j::graph::RandomGenerator), cudaMemcpyHostToDevice);
T* inputBuf = reinterpret_cast<T*>(input.specialBuffer());
if(isInplace) {
swapShuffleKernel<T><<<128, 256, 1024, *stream>>>(inputBuf, input.specialShapeInfo(), firstDim, input.lengthOf(), dRandom);
}
else {
std::vector<int> indices(firstDim);
std::iota(indices.begin(), indices.end(), 0);
cudaMemcpy(output.specialBuffer(), input.specialBuffer(), sizeof(T), cudaMemcpyDeviceToDevice);
//output.p<T>(Nd4jLong(0), input.e<T>(0));
PointersManager pointersManager(context, "helper::randomShuffle_");
int* indicesDev = reinterpret_cast<int*>(pointersManager.replicatePointer(indices.data(), indices.size() * sizeof(int)));
T* outputBuf = reinterpret_cast<T*>(output.specialBuffer());
fillShuffleKernel<T><<<128, 256, 1024, *stream>>>(inputBuf, input.specialShapeInfo(), outputBuf, output.specialShapeInfo(), firstDim, input.lengthOf(), indicesDev, dRandom);
pointersManager.synchronize();
}
// rng.rewindH(firstDim - 1);
cudaFree(dRandom);
}
else {
// evaluate sub-arrays list of input array through all dimensions excluding first one
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input.rankOf(), {0});
auto subArrsListIn = input.allTensorsAlongDimension(dimensions);
// apply Fisher-Yates shuffle
if(isInplace) {
PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->elementwiseThreshold())
for(int i = firstDim - 1; i > 0; --i) {
int r = rng.relativeInt(i) % i;
if(i != r)
subArrsListIn->at(i)->swapUnsafe(*subArrsListIn->at(r));
}
}
else {
// evaluate sub-arrays list of output array through all dimensions excluding first one
auto subArrsListOut = output.allTensorsAlongDimension(dimensions);
std::vector<int> indices(firstDim);
std::iota(indices.begin(), indices.end(), 0);
bool isZeroShuffled = false;
PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->tadThreshold())
for(int i = firstDim - 1; i > 0; --i) {
int r = rng.relativeInt(i) % i;
subArrsListOut->at(i)->assign(subArrsListIn->at(indices[r]));
if(r == 0)
isZeroShuffled = true;
if(i != r) {
subArrsListOut->at(r)->assign(subArrsListIn->at(indices[i]));
math::nd4j_swap<int>(indices[i], indices[r]);
}
}
if(!isZeroShuffled)
subArrsListOut->at(0)->assign(subArrsListIn->at(0));
delete subArrsListOut;
}
rng.rewindH(firstDim-1);
delete subArrsListIn;
}
NDArray::registerSpecialUse({&output}, {&input});
} }
void randomShuffle(nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::random::RandomBuffer& rng, const bool isInplace) { void randomShuffle(nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::graph::RandomGenerator& rng, const bool isInplace) {
BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (context, input, output, rng, isInplace), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (context, input, output, rng, isInplace), LIBND4J_TYPES);
} }
BUILD_SINGLE_TEMPLATE(template void randomShuffle_, (nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::random::RandomBuffer& rng, const bool isInplace), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void randomShuffle_, (nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::graph::RandomGenerator& rng, const bool isInplace), LIBND4J_TYPES);
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
template<typename T> template<typename T>
@ -496,11 +715,6 @@ void tileBP(nd4j::LaunchContext * context, const NDArray& gradO /*input*/, NDArr
void eye(nd4j::LaunchContext * context, NDArray& output) { void eye(nd4j::LaunchContext * context, NDArray& output) {
output.setIdentity(); output.setIdentity();
}
//////////////////////////////////////////////////////////////////////////
void scatterUpdate(nd4j::LaunchContext * context, NDArray& operand, NDArray& updates, const std::vector<int>* intArgs) {
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////

View File

@ -33,27 +33,7 @@ namespace helpers {
void gruTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h); void gruTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h);
void gruCellBP(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh, const NDArray* dLdWx0, void gruCellBP(nd4j::LaunchContext* context, const NDArray* x, const NDArray* hLast, const NDArray* W, const NDArray* Wc, const NDArray* b, const NDArray* bc, const NDArray* dLdr, const NDArray* dLdu, const NDArray* dLdc, const NDArray* dLdh, NDArray* dLdx, NDArray* dLdhLast, NDArray* dLdW, NDArray* dLdWc, NDArray* dLdb, NDArray* dLdbc);
const NDArray* dLdWh0, const NDArray* dLdb0, NDArray* dLdx, NDArray* dLdh0, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb);
//////////////////////////////////////////////////////////////////////////
FORCEINLINE NDArray sigmoid(const NDArray& arr) {
return (const_cast<NDArray&>(arr)).transform(transform::Sigmoid);
}
FORCEINLINE void sigmoidInplace(const NDArray& arr) {
(const_cast<NDArray&>(arr)).applyTransform(transform::Sigmoid);
}
//////////////////////////////////////////////////////////////////////////
FORCEINLINE NDArray tanh(const NDArray& arr) {
return (const_cast<NDArray&>(arr)).transform(transform::Tanh);
}
FORCEINLINE void tanhInplace(const NDArray& arr) {
(const_cast<NDArray&>(arr)).applyTransform(transform::Tanh);
}
} }
} }

View File

@ -27,37 +27,37 @@ namespace nd4j {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
template <typename T> template <typename T>
FORCEINLINE Nd4jLong longBytes(T value); FORCEINLINE _CUDA_HD Nd4jLong longBytes(T value);
template <> template <>
FORCEINLINE Nd4jLong longBytes(float value) { FORCEINLINE _CUDA_HD Nd4jLong longBytes(float value) {
int intie = *(int *)&value; int intie = *(int *)&value;
return static_cast<Nd4jLong>(intie); return static_cast<Nd4jLong>(intie);
} }
template <> template <>
FORCEINLINE Nd4jLong longBytes(double value) { FORCEINLINE _CUDA_HD Nd4jLong longBytes(double value) {
Nd4jLong longie = *(Nd4jLong *)&value; Nd4jLong longie = *(Nd4jLong *)&value;
return longie; return longie;
} }
template <> template <>
FORCEINLINE Nd4jLong longBytes(float16 value) { FORCEINLINE _CUDA_HD Nd4jLong longBytes(float16 value) {
return longBytes<float>((float) value); return longBytes<float>((float) value);
} }
template <> template <>
FORCEINLINE Nd4jLong longBytes(Nd4jLong value) { FORCEINLINE _CUDA_HD Nd4jLong longBytes(Nd4jLong value) {
return value; return value;
} }
template <> template <>
FORCEINLINE Nd4jLong longBytes(bfloat16 value) { FORCEINLINE _CUDA_HD Nd4jLong longBytes(bfloat16 value) {
return longBytes<float>((float) value); return longBytes<float>((float) value);
} }
template <typename T> template <typename T>
FORCEINLINE Nd4jLong longBytes(T value) { FORCEINLINE _CUDA_HD Nd4jLong longBytes(T value) {
return longBytes<Nd4jLong>((Nd4jLong) value); return longBytes<Nd4jLong>((Nd4jLong) value);
} }

View File

@ -30,37 +30,30 @@ namespace helpers {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
static FORCEINLINE NDArray activation(const NDArray& arr) { void rnnCell(nd4j::LaunchContext * context, const NDArray* xt, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* hPrev, NDArray* ht) {
return (const_cast<NDArray&>(arr)).transform(transform::Tanh); // xt input [bS x iS]
// Wx input-to-hidden weights, [iS x nU]
// Wh hidden-to-hidden weights, [nU x nU]
// b biases, [2*nU]: {0, nU} are input-to-hidden biases and {nU, 2*nU} are hidden-to-hidden biases
// hPrev previous cell output [bS x nU], that is at previous time step t-1, in case of projection=false -> nU=nU!!!
const int nU = hPrev->sizeAt(1);
// ht is current cell output [bS x nU], that is at current time step t
ht->assign(mmul(*xt, *Wx) + (*b)({{0, nU}}) + mmul(*hPrev, *Wh) + (*b)({{nU, 2*nU}})); // [bS x nU] + [nU] + [bS x nU] + [nU] = [bS x nU]
ht->applyTransform(transform::Tanh);
} }
//////////////////////////////////////////////////////////////////////////
void rnnCell(nd4j::LaunchContext * context, const NDArray* xt, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* ht_1, NDArray* ht) {
// xt input [bS x inSize]
// Wx input-to-hidden weights, [inSize x numUnits]
// Wh hidden-to-hidden weights, [numUnits x numUnits]
// b biases, [2*numUnits]: {0, numUnits} are input-to-hidden biases and {numUnits, 2*numUnits} are hidden-to-hidden biases
// ht_1 previous cell output [bS x numUnits], that is at previous time step t-1, in case of projection=false -> numUnits=numUnits!!!
const int numUnits = ht_1->sizeAt(1);
// ht is current cell output [bS x numUnits], that is at current time step t
ht->assign(activation(mmul(*xt, *Wx) + (*b)({{0, numUnits}}) + mmul(*ht_1, *Wh) + (*b)({{numUnits, 2*numUnits}}))); // [bS x numUnits] + [numUnits] + [bS x numUnits] + [numUnits] = [bS x numUnits]
}
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void rnnTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* h0, const NDArray* maxTimeStep, NDArray* h, NDArray* hFinal) { void rnnTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* h0, const NDArray* maxTimeStep, NDArray* h, NDArray* hFinal) {
// x input [time x bS x inSize] // x input [time x bS x iS]
// Wx input-to-hidden weights, [inSize x numUnits] // Wx input-to-hidden weights, [iS x nU]
// Wh hidden-to-hidden weights, [numUnits x numUnits] // Wh hidden-to-hidden weights, [nU x nU]
// b biases for, [2*numUnits] // b biases for, [2*nU]
// h0 initial cell output (at time step = 0) [bS x numUnits] // h0 initial cell output (at time step = 0) [bS x nU]
// maxTimeStep vector [bS] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this means there are no calculations for time >= maxTimeStep // maxTimeStep vector [bS] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this means there are no calculations for time >= maxTimeStep
const int time = x->sizeAt(0); const int time = x->sizeAt(0);
@ -82,16 +75,16 @@ void rnnTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray*
auto xt = (*x)({t,t+1, e,e+1, 0,0}, true); auto xt = (*x)({t,t+1, e,e+1, 0,0}, true);
auto ht = (*h)({t,t+1, e,e+1, 0,0}, true); auto ht = (*h)({t,t+1, e,e+1, 0,0}, true);
auto ht_1 = (*hFinal)({e,e+1, 0,0}, true); // previous state auto hPrev = (*hFinal)({e,e+1, 0,0}, true); // previous state
if(t >= maxStep) { if(t >= maxStep) {
ht = 0.; ht = 0.;
if(maxStep != 0) if(maxStep != 0)
ht_1.assign((*h)({maxStep-1,maxStep, e,e+1, 0,0})); hPrev.assign((*h)({maxStep-1,maxStep, e,e+1, 0,0}));
} }
else { else {
helpers::rnnCell(context, &xt, Wx, Wh, b, &ht_1, &ht); helpers::rnnCell(context, &xt, Wx, Wh, b, &hPrev, &ht);
ht_1.assign(ht); hPrev.assign(ht);
} }
} }
} }

View File

@ -0,0 +1,36 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author sgazeos@gmail.com
// @brief helpers common fuctions for segment_* ops (segment_max, segment_min, etc.)
// @brief helpers common fuctions for unsorted_segment_* ops (unsorted_segment_max, etc.)
//
#ifndef __SEGMENT_COMMON_HELPERS__
#define __SEGMENT_COMMON_HELPERS__
#include <op_boilerplate.h>
#include <NDArray.h>
namespace nd4j {
namespace ops {
namespace helpers {
void fillUpSegments(NDArray* indices, Nd4jLong numClasses, NDArray& classesRangesBegs, NDArray& classesRangesLens);
}
}
}
#endif

View File

@ -23,6 +23,7 @@
#include <ops/declarable/helpers/helpers.h> #include <ops/declarable/helpers/helpers.h>
#include <helpers/helper_random.h> #include <helpers/helper_random.h>
#include <graph/RandomGenerator.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
@ -32,7 +33,7 @@ namespace helpers {
void trace(nd4j::LaunchContext * context, const NDArray& input, NDArray& output); void trace(nd4j::LaunchContext * context, const NDArray& input, NDArray& output);
void randomShuffle(nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::random::RandomBuffer& rng, const bool isInplace); void randomShuffle(nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::graph::RandomGenerator& rng, const bool isInplace);
// auxiliary function which serves for recursion purpose and is used in pad operation // auxiliary function which serves for recursion purpose and is used in pad operation
// void recursiveLoopForPad(const int mode, NDArray& input, const NDArray& paddings, NDArray& output, std::vector<int> dimensions, int dim, int inIdx, int outIdx, NDArray& padValue); // void recursiveLoopForPad(const int mode, NDArray& input, const NDArray& paddings, NDArray& output, std::vector<int> dimensions, int dim, int inIdx, int outIdx, NDArray& padValue);

View File

@ -1126,15 +1126,7 @@ inline __device__ bool nd4j_atomicAdd<bool>(bool* address, bool val) {
template <> template <>
inline __device__ double nd4j_atomicSub<double>(double* address, double val) { inline __device__ double nd4j_atomicSub<double>(double* address, double val) {
unsigned long long int* address_as_ull = return nd4j_atomicAdd<double>(address, -val);
(unsigned long long int *) address;
unsigned long long int old = *address_as_ull, assumed;
do {
assumed = old;
old = atomicCAS(address_as_ull, assumed,__double_as_longlong(val -
__longlong_as_double(assumed)));
} while (assumed != old);
return __longlong_as_double(old);
} }
template <> template <>
@ -1152,15 +1144,7 @@ inline __device__ double nd4j_atomicMul<double>(double* address, double val) {
template <> template <>
inline __device__ double nd4j_atomicDiv<double>(double* address, double val) { inline __device__ double nd4j_atomicDiv<double>(double* address, double val) {
unsigned long long int* address_as_ull = return nd4j_atomicMul<double>(address, 1./val);
(unsigned long long int*) address;
unsigned long long int old = *address_as_ull, assumed;
do {
assumed = old;
old = atomicCAS(address_as_ull, assumed,__double_as_longlong(val /
__longlong_as_double(assumed)));
} while (assumed != old);
return __longlong_as_double(old);
} }
template <> template <>
@ -1179,14 +1163,16 @@ inline __device__ int32_t nd4j_atomicAdd<int32_t>(int32_t* address, int32_t val)
template <> template <>
inline __device__ float nd4j_atomicSub<float>(float* address, float val) { inline __device__ float nd4j_atomicSub<float>(float* address, float val) {
int* address_as_ull = (int*) address; return nd4j_atomicAdd<float>(address, -val);
int old = *address_as_ull, assumed; }
do {
assumed = old; template <>
old = atomicCAS(address_as_ull, assumed, __float_as_int(val - inline __device__ float16 nd4j_atomicSub<float16>(float16* address, float16 val) {
__float_as_int(assumed))); return nd4j_atomicAdd<float16>(address, -val);
} while (assumed != old); }
return __int_as_float(old); template <>
inline __device__ bfloat16 nd4j_atomicSub<bfloat16>(bfloat16* address, bfloat16 val) {
return nd4j_atomicAdd<bfloat16>(address, -val);
} }
template <> template <>
@ -1415,6 +1401,30 @@ inline __device__ float16 nd4j_atomicMul<float16>(float16* address, float16 val)
template <> template <>
inline __device__ float nd4j_atomicDiv<float>(float* address, float val) { inline __device__ float nd4j_atomicDiv<float>(float* address, float val) {
int* address_as_ull =
(int*)address;
int old = *address_as_ull, assumed;
do {
assumed = old;
old = atomicCAS(address_as_ull, assumed, __float_as_int(__int_as_float(assumed) / val ));
} while (assumed != old);
return __int_as_float(old);
}
template <>
inline __device__ float16 nd4j_atomicDiv<float16>(float16* address, float16 val) {
int* address_as_ull =
(int*)address;
int old = *address_as_ull, assumed;
do {
assumed = old;
old = atomicCAS(address_as_ull, assumed, __float_as_int(val *
__float_as_int(assumed)));
} while (assumed != old);
return __int_as_float(old);
}
template <>
inline __device__ bfloat16 nd4j_atomicDiv<bfloat16>(bfloat16* address, bfloat16 val) {
int* address_as_ull = int* address_as_ull =
(int*)address; (int*)address;
int old = *address_as_ull, assumed; int old = *address_as_ull, assumed;

View File

@ -76,6 +76,9 @@
(nd4j::DataType::FLOAT32, float), \ (nd4j::DataType::FLOAT32, float), \
(nd4j::DataType::DOUBLE, double) (nd4j::DataType::DOUBLE, double)
#define FLOAT_NATIVE \
(nd4j::DataType::FLOAT32, float), \
(nd4j::DataType::DOUBLE, double)
#define FLOAT_TYPES_0 \ #define FLOAT_TYPES_0 \
(nd4j::DataType::HALF, float16) (nd4j::DataType::HALF, float16)

View File

@ -1870,7 +1870,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_1) {
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0); NDArray* result = results->at(0);
result->printIndexedBuffer("OOOOUUUUTTT"); // result->printIndexedBuffer("OOOOUUUUTTT");
ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.isSameShapeStrict(result));
ASSERT_TRUE(expected.equalsTo(result)); ASSERT_TRUE(expected.equalsTo(result));
@ -1881,9 +1881,9 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_1) {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_2) { TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_2) {
NDArray boxes = NDArrayFactory::create<float>('c', {6,4}, {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1.f, 0.9f, NDArray boxes = NDArrayFactory::create<double>('c', {6,4}, {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1.f, 0.9f,
0, 10, 1, 11, 0, 10.1f, 1.f, 11.1f, 0, 100, 1, 101}); 0, 10, 1, 11, 0, 10.1f, 1.f, 11.1f, 0, 100, 1, 101});
NDArray scales = NDArrayFactory::create<float>('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); //3, 0, 1, 2, 4, 5 NDArray scales = NDArrayFactory::create<double>('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); //3, 0, 1, 2, 4, 5
NDArray expected = NDArrayFactory::create<int>('c', {3}, {3,0,5}); NDArray expected = NDArrayFactory::create<int>('c', {3}, {3,0,5});
nd4j::ops::non_max_suppression op; nd4j::ops::non_max_suppression op;
@ -1892,7 +1892,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_2) {
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0); NDArray* result = results->at(0);
result->printBuffer("NonMaxSuppression OUtput2"); // result->printBuffer("NonMaxSuppression OUtput2");
ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.isSameShapeStrict(result));
ASSERT_TRUE(expected.equalsTo(result)); ASSERT_TRUE(expected.equalsTo(result));
@ -1970,6 +1970,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_3) {
delete results; delete results;
} }
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, Image_CropAndResize_4) { TEST_F(DeclarableOpsTests10, Image_CropAndResize_4) {

View File

@ -421,3 +421,200 @@ ASSERT_TRUE(result->at(0)->e<bool>(0));
//ASSERT_TRUE(exp.equalsTo(result->at(0))); //ASSERT_TRUE(exp.equalsTo(result->at(0)));
delete result; delete result;
} }
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, adjustHue_1) {
NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::FLOAT32);
NDArray exp ('c', {2,2,3}, {100,0,44, 208,5,220, 177,230,97, 2,255,244}, nd4j::DataType::FLOAT32);
nd4j::ops::adjust_hue op;
auto results = op.execute({&input}, {0.5}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto result = results->at(0);
// result->printIndexedBuffer();
ASSERT_TRUE(exp.isSameShape(result));
ASSERT_TRUE(exp.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, adjustHue_2) {
NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::FLOAT32);
NDArray exp ('c', {2,2,3}, {4,100,0, 146,220,5, 97,123.8,230, 255,2,164.8}, nd4j::DataType::FLOAT32);
nd4j::ops::adjust_hue op;
auto results = op.execute({&input}, {0.9}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto result = results->at(0);
ASSERT_TRUE(exp.isSameShape(result));
ASSERT_TRUE(exp.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, adjustHue_3) {
NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::FLOAT32);
NDArray exp ('c', {2,2,3}, {0.,84.,100., 5.,220.,122.0001, 229.8,97.,230., 255.,142.8002,2.}, nd4j::DataType::FLOAT32);
nd4j::ops::adjust_hue op;
auto results = op.execute({&input}, {-0.9}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto result = results->at(0);
ASSERT_TRUE(exp.isSameShape(result));
ASSERT_TRUE(exp.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, adjustHue_4) {
NDArray input('c', {2,3,2}, {0,17, 100,220, 56,5, 150,255, 97,2, 230,13}, nd4j::DataType::FLOAT32);
NDArray exp ('c', {2,3,2}, {100,208, 0,5, 44,220, 177,2, 230,255, 97,244}, nd4j::DataType::FLOAT32);
nd4j::ops::adjust_hue op;
auto results = op.execute({&input}, {0.5}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto result = results->at(0);
ASSERT_TRUE(exp.isSameShape(result));
ASSERT_TRUE(exp.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, adjustHue_5) {
NDArray input('c', {3,2,2}, {0,17, 150,255, 100,220, 97,2, 56,5, 230,13}, nd4j::DataType::FLOAT32);
NDArray exp ('c', {3,2,2}, {100,208, 177,2, 0,5, 230,255, 44,220, 97,244}, nd4j::DataType::FLOAT32);
nd4j::ops::adjust_hue op;
auto results = op.execute({&input}, {0.5}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto result = results->at(0);
ASSERT_TRUE(exp.isSameShape(result));
ASSERT_TRUE(exp.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, adjustSaturation_1) {
NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::FLOAT32);
NDArray exp ('c', {2,2,3}, {50,100,78, 118.5,220,112.5, 190,163.5,230, 255,128.5,134}, nd4j::DataType::FLOAT32);
nd4j::ops::adjust_saturation op;
auto results = op.execute({&input}, {0.5}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto result = results->at(0);
// result->printIndexedBuffer();
ASSERT_TRUE(exp.isSameShape(result));
ASSERT_TRUE(exp.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, adjustSaturation_2) {
NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::FLOAT32);
NDArray exp ('c', {2,2,3}, {0.,100.,56., 12.279087,220.,0., 91.654228,0.,230., 255.,0.,11.087015}, nd4j::DataType::FLOAT32);
nd4j::ops::adjust_saturation op;
auto results = op.execute({&input}, {10}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto result = results->at(0);
// result->printIndexedBuffer();
ASSERT_TRUE(exp.isSameShape(result));
ASSERT_TRUE(exp.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, adjustSaturation_3) {
NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::FLOAT32);
NDArray exp ('c', {2,2,3}, {100.,100.,100., 220.,220.,220., 230.,230.,230., 255., 255., 255.}, nd4j::DataType::FLOAT32);
nd4j::ops::adjust_saturation op;
auto results = op.execute({&input}, {-10}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto result = results->at(0);
ASSERT_TRUE(exp.isSameShape(result));
ASSERT_TRUE(exp.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, adjustSaturation_4) {
NDArray input('c', {2,3,2}, {0,17, 100,220, 56,5, 150,255, 97,2, 230,13}, nd4j::DataType::FLOAT32);
NDArray exp ('c', {2,3,2}, {50,118.5, 100,220, 78,112.5, 190,255, 163.5,128.5, 230,134}, nd4j::DataType::FLOAT32);
nd4j::ops::adjust_saturation op;
auto results = op.execute({&input}, {0.5}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto result = results->at(0);
// result->printIndexedBuffer();
ASSERT_TRUE(exp.isSameShape(result));
ASSERT_TRUE(exp.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, adjustSaturation_5) {
NDArray input('c', {3,2,2}, {0,17, 150,255, 100,220, 97,2, 56,5, 230,13}, nd4j::DataType::FLOAT32);
NDArray exp ('c', {3,2,2}, {50,118.5, 190,255, 100,220, 163.5,128.5, 78,112.5, 230,134}, nd4j::DataType::FLOAT32);
nd4j::ops::adjust_saturation op;
auto results = op.execute({&input}, {0.5}, {0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto result = results->at(0);
ASSERT_TRUE(exp.isSameShape(result));
ASSERT_TRUE(exp.equalsTo(result));
delete results;
}

View File

@ -1479,6 +1479,27 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test3) {
delete results; delete results;
} }
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, random_shuffle_test04) {
auto input = NDArrayFactory::create<double>('c', {4});
input.linspace(1);
nd4j::ops::random_shuffle op;
//NDArray* output;
auto results = op.execute({&input}, {}, {}, {}, true, nd4j::DataType::DOUBLE);
ASSERT_EQ(Status::OK(), results->status());
auto output = &input; //results->at(0);
bool haveZeros = false;
for(int i = 0; i < output->lengthOf(); ++i)
if(output->e<float>(i) == (float)0.)
haveZeros = true;
ASSERT_TRUE(input.isSameShape(output));
//ASSERT_TRUE(!input.equalsTo(output));
ASSERT_TRUE(!haveZeros);
delete results;
}
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, random_shuffle_test4) { TEST_F(DeclarableOpsTests5, random_shuffle_test4) {
@ -1486,17 +1507,17 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test4) {
input.linspace(1); input.linspace(1);
nd4j::ops::random_shuffle op; nd4j::ops::random_shuffle op;
//NDArray* output;
auto results = op.execute({&input}, {}, {}, {}, false, nd4j::DataType::DOUBLE); auto results = op.execute({&input}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(Status::OK(), results->status());
auto output = results->at(0); auto output = results->at(0);
bool haveZeros = false; bool haveZeros = false;
for(int i = 0; i < output->lengthOf(); ++i) for(int i = 0; i < output->lengthOf(); ++i)
if(output->e<float>(i) == (float)0.) if(output->e<float>(i) == (float)0.)
haveZeros = true; haveZeros = true;
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(input.isSameShape(output)); ASSERT_TRUE(input.isSameShape(output));
ASSERT_TRUE(!input.equalsTo(output)); //ASSERT_TRUE(!input.equalsTo(output));
ASSERT_TRUE(!haveZeros); ASSERT_TRUE(!haveZeros);
delete results; delete results;

View File

@ -1601,8 +1601,8 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_1) {
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
// z->printIndexedBuffer("Output "); z->printIndexedBuffer("Output ");
// exp.printIndexedBuffer("Expected "); exp.printIndexedBuffer("Expected ");
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
@ -1610,6 +1610,75 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_1) {
delete result; delete result;
} }
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixInverse_01) {
auto x = NDArrayFactory::create<float>('c', {1, 5, 5}, {
2., 4., 60., 8., 10.,
0., 1., 2., 3., 4.,
0., 0., 2., 4., 6.,
0., 0., 0., 1., 2.,
0., 0., 0., 0., 4.
});
auto exp = NDArrayFactory::create<float>('c', {1, 5, 5}, {
0.5, -2.0, -13.0, 54.0, -6.75,
0.0, 1.0, -1.0, 1.0, 0.0,
0, 0, 0.5, -2.0, 0.25,
0, 0, 0, 1.0, -0.5,
0, 0, 0, 0, 0.25
});
nd4j::ops::matrix_inverse op;
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32);
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
z->printIndexedBuffer("Output ");
exp.printIndexedBuffer("Expected ");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixInverse_02) {
auto x = NDArrayFactory::create<float>('c', {1, 5, 5}, {
1., 0., 0., 0., 0.,
2., 1., 0., 0., 0.,
30., 2., 1., 0., 0.,
4., 3., 2., 1., 0.,
5., 4., 3., 2., 1.
});
auto exp = NDArrayFactory::create<float>('c', {1, 5, 5}, {
1.0, 0.0, 0.0, 0.0, 0.,
-2.0, 1.0, 0., 0., 0.,
-26.0, -2.0, 1, 0, 0.,
54.0, 1.0, -2.0, 1, 0.,
-27.0, 0.0, 1.0, -2.0, 1.
});
nd4j::ops::matrix_inverse op;
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32);
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
z->printIndexedBuffer("Output ");
exp.printIndexedBuffer("Expected ");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
/* /*
@ -1658,6 +1727,39 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_2) {
delete result; delete result;
} }
*/ */
TEST_F(DeclarableOpsTests6, MatrixInverse_03) {
auto x = NDArrayFactory::create<float>('c', {5, 5}, {
4., 0., 0., 0., 0.,
4., 2., 0., 0., 0.,
30., 2., 1., 0., 0.,
8., 6., 4., 2., 0.,
15., 12., 9., 6., 3.,
});
auto exp = NDArrayFactory::create<float>('c', {5, 5}, {
0.25, 0.0, 0.0, 0.0, 0.0,
-0.50, 0.5, 0.0, 0.0, 0.0,
-6.50, -1.0, 1.0, 0.0, 0.0,
13.50, 0.5, -2.0, 0.5, 0.0,
-6.75, 0.0, 1.0, -1.0, 0.33333333
});
nd4j::ops::matrix_inverse op;
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32);
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
z->printIndexedBuffer("Output ");
exp.printIndexedBuffer("Expected ");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixInverse_3) { TEST_F(DeclarableOpsTests6, MatrixInverse_3) {
@ -1695,7 +1797,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_3) {
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixInverse_4) { TEST_F(DeclarableOpsTests6, MatrixInverse_4) {
auto x = NDArrayFactory::create<double>('c', {5, 5}, { auto x = NDArrayFactory::create<float>('c', {5, 5}, {
1., 2., 30., 4., 5., 1., 2., 30., 4., 5.,
0., 1., 2., 3., 4., 0., 1., 2., 3., 4.,
0., 0., 1., 2., 3., 0., 0., 1., 2., 3.,
@ -1703,7 +1805,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_4) {
0., 0., 0., 0., 1. 0., 0., 0., 0., 1.
}); });
auto exp = NDArrayFactory::create<double>('c', {5, 5}, { auto exp = NDArrayFactory::create<float>('c', {5, 5}, {
1.0, -2.0, -26.0, 54.0, -27.0, 1.0, -2.0, -26.0, 54.0, -27.0,
0.0, 1.0, -2.0, 1.0, 0.0, 0.0, 1.0, -2.0, 1.0, 0.0,
0.0, 0.0, 1.0, -2.0, 1.0, 0.0, 0.0, 1.0, -2.0, 1.0,
@ -1712,13 +1814,13 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_4) {
}); });
nd4j::ops::matrix_inverse op; nd4j::ops::matrix_inverse op;
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::DOUBLE); auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32);
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0); auto z = result->at(0);
//z->printIndexedBuffer("Output "); z->printIndexedBuffer("Output ");
//exp.printIndexedBuffer("Expected "); exp.printIndexedBuffer("Expected ");
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));

View File

@ -763,15 +763,15 @@ TEST_F(DeclarableOpsTests7, Test_Dynamic_Partition_119_2) {
TEST_F(DeclarableOpsTests7, Test_SequenceMask_1) { TEST_F(DeclarableOpsTests7, Test_SequenceMask_1) {
auto input = NDArrayFactory::create<double>('c', {4, 4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f}); auto input = NDArrayFactory::create<int>('c', {4, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
auto exp = NDArrayFactory::create<double>('c', {4, 4, 16}, {1.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,1.f, 1.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, auto exp = NDArrayFactory::create<bool>('c', {4, 4, 16}, {1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1.f, 1.f, 1.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,1.f, 1.f, 1.f, 1.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1.f, 1.f, 1.f, 1.f, 1.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0,
1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.f, 0.f, 0.f, 0.f, 0.f,1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0,
1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.f, 0.f, 0.f,1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.f, 0.f, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.f,1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f }); 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 });
nd4j::ops::sequence_mask op; nd4j::ops::sequence_mask op;
auto result = op.execute({&input}, {}, {}); auto result = op.execute({&input}, {}, {});
@ -788,19 +788,19 @@ TEST_F(DeclarableOpsTests7, Test_SequenceMask_1) {
} }
TEST_F(DeclarableOpsTests7, Test_SequenceMask_2) { TEST_F(DeclarableOpsTests7, Test_SequenceMask_2) {
auto input = NDArrayFactory::create<double>('c', {2, 2, 2}, {10., 20., 30., 4., 0., 6., 7., 8.}); auto input = NDArrayFactory::create<int>('c', {2, 2, 2}, {10, 20, 30, 4, 0, 6, 7, 8});
auto exp = NDArrayFactory::create<double>('c', {2, 2, 2, 30}, { 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., auto exp = NDArrayFactory::create<bool>('c', {2, 2, 2, 30}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}); 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
nd4j::ops::sequence_mask op; nd4j::ops::sequence_mask op;
auto result = op.execute({&input}, {}, {}); auto result = op.execute({&input}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
// z->printIndexedBuffer("Output"); // z->printBuffer("Output");
// z->printShapeInfo("Shape"); // z->printShapeInfo("Shape");
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));

View File

@ -2770,9 +2770,8 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test3) {
ASSERT_TRUE(isGradCorrect); ASSERT_TRUE(isGradCorrect);
} }
////////////////////////////////////////////////////////////////////
/* /*
//2019/02/23 AB - GRU backprop tests disabled pending update of GRU backprop op after rewriting forward pass ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, gru_cell_bp_test1) { TEST_F(DeclarableOpsTests9, gru_cell_bp_test1) {
const int bS = 2; const int bS = 2;
@ -2780,160 +2779,58 @@ TEST_F(DeclarableOpsTests9, gru_cell_bp_test1) {
const int nU = 4; const int nU = 4;
NDArray x('c', {bS, iS}, nd4j::DataType::DOUBLE); NDArray x('c', {bS, iS}, nd4j::DataType::DOUBLE);
NDArray h0('c', {bS, nU}, nd4j::DataType::DOUBLE); NDArray hi('c', {bS, nU}, nd4j::DataType::DOUBLE);
NDArray Wx('c', {iS, 3*nU}, nd4j::DataType::DOUBLE); NDArray W('c', {iS+nU, 2*nU}, nd4j::DataType::DOUBLE);
NDArray Wh('c', {nU, 3*nU}, nd4j::DataType::DOUBLE); NDArray Wc('c', {iS+nU, nU}, nd4j::DataType::DOUBLE);
NDArray b('c', {3*nU}, nd4j::DataType::DOUBLE); NDArray b('c', {2*nU}, nd4j::DataType::DOUBLE);
NDArray bc('c', {nU}, nd4j::DataType::DOUBLE);
NDArray dLdr('c', {bS, nU}, nd4j::DataType::DOUBLE);
NDArray dLdu('c', {bS, nU}, nd4j::DataType::DOUBLE);
NDArray dLdc('c', {bS, nU}, nd4j::DataType::DOUBLE);
NDArray dLdh('c', {bS, nU}, nd4j::DataType::DOUBLE); NDArray dLdh('c', {bS, nU}, nd4j::DataType::DOUBLE);
x.linspace(0.5, 0.5); x.linspace(-5, 0.5);
h0 = 1.; hi = 1.;
Wx = 0.003; W = 0.003;
Wh = 0.006; Wc = 0.006;
b = 0.5; b = 0.5;
bc = 0.35;
const OpArgsHolder argsHolderFF({&x, &h0, &Wx, &Wh, &b}, {}, {});
const OpArgsHolder argsHolderBP({&x, &h0, &Wx, &Wh, &b, &dLdh}, {}, {}); const OpArgsHolder argsHolderFF({&x, &hi, &W, &Wc, &b, &bc}, {}, {});
nd4j::ops::gruCell op;
auto results = op.execute(argsHolderFF);
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto u = results->at(1); // [bS, nU]
auto c = results->at(2); // [bS, nU]
auto h = results->at(3); // [bS, nU]
dLdh = 1.; // SUM loss
NDArray Wch = Wc({iS,iS+nU, 0,0}); // [nU, nU]
NDArray dhdc = 1. - *u;
NDArray dhdu = hi - *c;
NDArray dcdZc = 1. - *c * *c;
dLdc.assign(dLdh * dhdc);
dLdu.assign(dLdh * dhdu);
dLdr.assign(mmul(dLdc * dcdZc * hi, Wch.transpose()));
delete results;
const OpArgsHolder argsHolderBP({&x, &hi, &W, &Wc, &b, &bc, &dLdr, &dLdu, &dLdc, &dLdh}, {}, {});
nd4j::ops::gruCell opFF; nd4j::ops::gruCell opFF;
nd4j::ops::gruCell_bp opBP; nd4j::ops::gruCell_bp opBP;
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 1, 1, 1 , 1, 1}, {0., 1.}, nd4j::GradCheck::LossFunc::SUM, true);
ASSERT_TRUE(isGradCorrect);
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, gru_cell_bp_test2) {
const int bS = 2;
const int iS = 3;
const int nU = 4;
NDArray x('c', {bS, iS}, nd4j::DataType::DOUBLE);
NDArray h0('c', {bS, nU}, nd4j::DataType::DOUBLE);
NDArray Wx('c', {iS, 3*nU}, nd4j::DataType::DOUBLE);
NDArray Wh('c', {nU, 3*nU}, nd4j::DataType::DOUBLE);
NDArray b('c', {3*nU}, nd4j::DataType::DOUBLE);
NDArray dLdh('c', {bS, nU}, nd4j::DataType::DOUBLE);
x.linspace(0.5, 0.5);
h0 = 1.;
Wx = 0.003;
Wh = 0.006;
b = 0.;
const OpArgsHolder argsHolderFF({&x, &h0, &Wx, &Wh, &b}, {}, {});
const OpArgsHolder argsHolderBP({&x, &h0, &Wx, &Wh, &b, &dLdh}, {}, {});
nd4j::ops::gruCell opFF;
nd4j::ops::gruCell_bp opBP;
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
ASSERT_TRUE(isGradCorrect);
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, gru_cell_bp_test3) {
const int bS = 2;
const int iS = 3;
const int nU = 4;
NDArray x('c', {bS, iS}, nd4j::DataType::DOUBLE);
NDArray h0('c', {bS, nU}, nd4j::DataType::DOUBLE);
NDArray Wx('c', {iS, 3*nU}, nd4j::DataType::DOUBLE);
NDArray Wh('c', {nU, 3*nU}, nd4j::DataType::DOUBLE);
NDArray b('c', {3*nU}, nd4j::DataType::DOUBLE);
NDArray dLdh('c', {bS, nU}, nd4j::DataType::DOUBLE);
// NDArray<double> dLdWx0('c', {iS, 3*nU});
// NDArray<double> dLdWh0('c', {nU, 3*nU});
// NDArray<double> dLdb0 ('c', {3*nU});
x = 1.;
h0 = 0.0;
Wx = 0.0;
Wh = 0.0;
b = 0.5;
const OpArgsHolder argsHolderFF({&x, &h0, &Wx, &Wh, &b}, {}, {});
const OpArgsHolder argsHolderBP({&x, &h0, &Wx, &Wh, &b, &dLdh}, {}, {});
nd4j::ops::gruCell opFF;
nd4j::ops::gruCell_bp opBP;
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
ASSERT_TRUE(isGradCorrect);
}
////////////////////////////////////////////////////////////////////
// TEST_F(DeclarableOpsTests9, gru_bp_test1) {
// const int time = 5;
// const int bS = 2;
// const int iS = 3;
// const int nU = 4;
// NDArray<double> x ('c', {time, bS, iS});
// NDArray<double> h0 ('c', {bS, nU});
// NDArray<double> Wx ('c', {iS, 3*nU});
// NDArray<double> Wh ('c', {nU, 3*nU});
// NDArray<double> b ('c', {3*nU});
// NDArray<double> dLdh ('c', {time, bS, nU});
// x.linspace(0.5, 0.5);
// h0 = 1.;
// Wx = 0.003;
// Wh = 0.006;
// b = 0.5;
// const OpArgsHolder<double> argsHolderFF({&x, &h0, &Wx, &Wh, &b}, {}, {});
// const OpArgsHolder<double> argsHolderBP({&x, &h0, &Wx, &Wh, &b, &dLdh}, {}, {});
// nd4j::ops::gru<double> opFF;
// nd4j::ops::gru_bp<double> opBP;
// const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
// ASSERT_TRUE(isGradCorrect);
// }
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, gru_cell_bp_test3_1) {
const int bS = 2;
const int iS = 3;
const int nU = 4;
auto x = NDArrayFactory::create<double>('c', {bS, iS});
auto h0 = NDArrayFactory::create<double>('c', {bS, nU});
auto Wx = NDArrayFactory::create<double>('c', {iS, 3*nU});
auto Wh = NDArrayFactory::create<double>('c', {nU, 3*nU});
auto b = NDArrayFactory::create<double>('c', {3*nU});
auto dLdh = NDArrayFactory::create<double>('c', {bS, nU});
// NDArray<double> dLdWx0('c', {iS, 3*nU});
// NDArray<double> dLdWh0('c', {nU, 3*nU});
// NDArray<double> dLdb0 ('c', {3*nU});
x = 1.;
h0 = 0.0;
Wx = 0.0;
Wh = 0.0;
b = 0.5;
const OpArgsHolder argsHolderFF({&x, &h0, &Wx, &Wh, &b}, {}, {});
const OpArgsHolder argsHolderBP({&x, &h0, &Wx, &Wh, &b, &dLdh}, {}, {});
nd4j::ops::gruCell opFF;
nd4j::ops::gruCell_bp opBP;
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
ASSERT_TRUE(isGradCorrect); ASSERT_TRUE(isGradCorrect);
} }
*/ */
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, Cholesky_Test_1) { TEST_F(DeclarableOpsTests9, Cholesky_Test_1) {

View File

@ -719,6 +719,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_1) {
} }
TEST_F(ParityOpsTests, Test_Scatter_Add_2) { TEST_F(ParityOpsTests, Test_Scatter_Add_2) {
auto vec = NDArrayFactory::create<float>('c', {4}, {1, 2, 3, 4}); auto vec = NDArrayFactory::create<float>('c', {4}, {1, 2, 3, 4});
NDArray idc('c', {1, 4}, {0, 1, 2, 3}, nd4j::DataType::INT64); NDArray idc('c', {1, 4}, {0, 1, 2, 3}, nd4j::DataType::INT64);
auto updates = NDArrayFactory::create<float>('c', {1, 4}, {1, 1, 1, 1}); auto updates = NDArrayFactory::create<float>('c', {1, 4}, {1, 1, 1, 1});
@ -1588,36 +1589,79 @@ TEST_F(ParityOpsTests, scatterND_update_test5) {
delete result; delete result;
} }
//////////////////////////////////////////////////////////////////////
TEST_F(ParityOpsTests, scatter_update_1) { TEST_F(ParityOpsTests, scatter_update_1) {
auto matrix = NDArrayFactory::create_<float>('c', {3, 2});
auto updates = NDArrayFactory::create_<float>('c', {2, 2});
updates->assign(1.0);
//updates.printBuffer("Updates"); NDArray x('c', {2,2}, {1,2,3,4}, nd4j::DataType::INT32);
NDArray updates('c', {2,2}, {10,20,30,40}, nd4j::DataType::INT32);
auto variableSpace = new VariableSpace(); NDArray exp('c', {2,2}, {30,40,10,20}, nd4j::DataType::INT32);
variableSpace->putVariable(-1, matrix);
variableSpace->putVariable(-2, updates);
variableSpace->putVariable(1, new Variable(&matrix));
auto block = new Context(1, variableSpace, false);
block->fillInputs({-1, -2});
std::vector<int>* arguments = block->getIArguments();
arguments->push_back(0);
arguments->push_back(1);
arguments->push_back(1);
arguments->push_back(2);
arguments->push_back(1);
arguments->push_back(2);
nd4j::ops::scatter_update op; nd4j::ops::scatter_update op;
auto results = op.execute({&x, &updates}, {}, {6, 1,1, 2,1,0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
// x.printBuffer();
Nd4jStatus result = op.execute(block); ASSERT_TRUE(exp.isSameShape(x));
ASSERT_EQ(ND4J_STATUS_OK, result); ASSERT_TRUE(exp.equalsTo(x));
delete block; delete results;
delete variableSpace;
} }
//////////////////////////////////////////////////////////////////////
TEST_F(ParityOpsTests, scatter_update_2) {
NDArray x('c', {2,2}, {1,2,3,4}, nd4j::DataType::INT32);
NDArray updates('c', {2,2}, {10,20,30,40}, nd4j::DataType::INT32);
NDArray exp('c', {2,2}, {20,10,40,30}, nd4j::DataType::INT32);
nd4j::ops::scatter_update op;
auto results = op.execute({&x, &updates}, {}, {6, 1,0, 2,1,0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
ASSERT_TRUE(exp.isSameShape(x));
ASSERT_TRUE(exp.equalsTo(x));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(ParityOpsTests, scatter_update_3) {
NDArray x('c', {2,2,2}, {1,2,3,4,5,6,7,8}, nd4j::DataType::INT32);
NDArray updates('c', {2,2,2}, {10,20,30,40,50,60,70,80}, nd4j::DataType::INT32);
NDArray exp('c', {2,2,2}, {50,60,70,80,10,20,30,40}, nd4j::DataType::INT32);
nd4j::ops::scatter_update op;
auto results = op.execute({&x, &updates}, {}, {6, 2,1,2, 2,1,0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
ASSERT_TRUE(exp.isSameShape(x));
ASSERT_TRUE(exp.equalsTo(x));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(ParityOpsTests, scatter_update_4) {
NDArray x('c', {2,2,2}, {1,2,3,4,5,6,7,8}, nd4j::DataType::INT32);
NDArray updates('c', {2,2,2}, {10,20,30,40,50,60,70,80}, nd4j::DataType::INT32);
NDArray exp('c', {2,2,2}, {20,2,3,10,60,6,7,50}, nd4j::DataType::INT32);
nd4j::ops::scatter_update op;
auto results = op.execute({&x, &updates}, {}, {6, 1,0, 2,3,0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
ASSERT_TRUE(exp.isSameShape(x));
ASSERT_TRUE(exp.equalsTo(x));
delete results;
}

View File

@ -278,8 +278,8 @@ TEST_F(RNGTests, Test_Gaussian_22) {
auto x0 = NDArrayFactory::create<float>('c', {10000, 1000}); auto x0 = NDArrayFactory::create<float>('c', {10000, 1000});
auto x1 = NDArrayFactory::create<float>('c', {10000, 1000}); auto x1 = NDArrayFactory::create<float>('c', {10000, 1000});
RandomLauncher::fillGaussian(_rngA, &x0, 0.0f, 1.0f); RandomLauncher::fillGaussian(nd4j::LaunchContext::defaultContext(), _rngA, &x0, 0.0f, 1.0f);
RandomLauncher::fillGaussian(_rngB, &x1, 0.0f, 1.0f); RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngB, &x1, 0.0f, 1.0f);
//x0.printIndexedBuffer("x0"); //x0.printIndexedBuffer("x0");
//x1.printIndexedBuffer("x1"); //x1.printIndexedBuffer("x1");
@ -306,7 +306,7 @@ TEST_F(RNGTests, Test_Gaussian_22) {
TEST_F(RNGTests, Test_Gaussian_3) { TEST_F(RNGTests, Test_Gaussian_3) {
auto x0 = NDArrayFactory::create<double>('c', {10000000}); auto x0 = NDArrayFactory::create<double>('c', {10000000});
RandomLauncher::fillGaussian(_rngA, &x0, 0.0, 1.0); RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngA, &x0, 0.0, 1.0);
auto mean = x0.meanNumber().e<double>(0); auto mean = x0.meanNumber().e<double>(0);
auto stdev = x0.varianceNumber(nd4j::variance::SummaryStatsStandardDeviation, false).e<double>(0); auto stdev = x0.varianceNumber(nd4j::variance::SummaryStatsStandardDeviation, false).e<double>(0);
@ -319,8 +319,8 @@ TEST_F(RNGTests, Test_LogNormal_1) {
auto x0 = NDArrayFactory::create<float>('c', {10, 10}); auto x0 = NDArrayFactory::create<float>('c', {10, 10});
auto x1 = NDArrayFactory::create<float>('c', {10, 10}); auto x1 = NDArrayFactory::create<float>('c', {10, 10});
RandomLauncher::fillLogNormal(_rngA, &x0, 1.0f, 2.0f); RandomLauncher::fillLogNormal(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f);
RandomLauncher::fillLogNormal(_rngB, &x1, 1.0f, 2.0f); RandomLauncher::fillLogNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f);
ASSERT_TRUE(x0.equalsTo(&x1)); ASSERT_TRUE(x0.equalsTo(&x1));
@ -333,8 +333,8 @@ TEST_F(RNGTests, Test_Truncated_1) {
auto x0 = NDArrayFactory::create<float>('c', {10, 10}); auto x0 = NDArrayFactory::create<float>('c', {10, 10});
auto x1 = NDArrayFactory::create<float>('c', {10, 10}); auto x1 = NDArrayFactory::create<float>('c', {10, 10});
RandomLauncher::fillTruncatedNormal(_rngA, &x0, 1.0f, 2.0f); RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f);
RandomLauncher::fillTruncatedNormal(_rngB, &x1, 1.0f, 2.0f); RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f);
ASSERT_TRUE(x0.equalsTo(&x1)); ASSERT_TRUE(x0.equalsTo(&x1));
@ -357,8 +357,8 @@ TEST_F(RNGTests, Test_Truncated_2) {
auto x0 = NDArrayFactory::create<float>('c', {1000, 1000}); auto x0 = NDArrayFactory::create<float>('c', {1000, 1000});
auto x1 = NDArrayFactory::create<float>('c', {1000, 1000}); auto x1 = NDArrayFactory::create<float>('c', {1000, 1000});
RandomLauncher::fillTruncatedNormal(_rngA, &x0, 1.0f, 2.0f); RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f);
RandomLauncher::fillTruncatedNormal(_rngB, &x1, 1.0f, 2.0f); RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f);
ASSERT_TRUE(x0.equalsTo(&x1)); ASSERT_TRUE(x0.equalsTo(&x1));
@ -383,8 +383,8 @@ TEST_F(RNGTests, Test_Truncated_21) {
auto x0 = NDArrayFactory::create<float>('c', {1000, 1000}); auto x0 = NDArrayFactory::create<float>('c', {1000, 1000});
auto x1 = NDArrayFactory::create<float>('c', {1000, 1000}); auto x1 = NDArrayFactory::create<float>('c', {1000, 1000});
RandomLauncher::fillTruncatedNormal(_rngA, &x0, 1.0f, 2.0f); RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f);
RandomLauncher::fillTruncatedNormal(_rngB, &x1, 1.0f, 2.0f); RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f);
ASSERT_TRUE(x0.equalsTo(&x1)); ASSERT_TRUE(x0.equalsTo(&x1));
@ -430,8 +430,8 @@ TEST_F(RNGTests, Test_Truncated_22) {
auto x0 = NDArrayFactory::create<float>('c', {1000, 1000}); auto x0 = NDArrayFactory::create<float>('c', {1000, 1000});
auto x1 = NDArrayFactory::create<float>('c', {1000, 1000}); auto x1 = NDArrayFactory::create<float>('c', {1000, 1000});
RandomLauncher::fillTruncatedNormal(_rngA, &x0, 2.0f, 4.0f); RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 2.0f, 4.0f);
RandomLauncher::fillTruncatedNormal(_rngB, &x1, 2.0f, 4.0f); RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 2.0f, 4.0f);
ASSERT_TRUE(x0.equalsTo(&x1)); ASSERT_TRUE(x0.equalsTo(&x1));
@ -477,8 +477,8 @@ TEST_F(RNGTests, Test_Truncated_23) {
auto x0 = NDArrayFactory::create<float>('c', {1000, 1000}); auto x0 = NDArrayFactory::create<float>('c', {1000, 1000});
auto x1 = NDArrayFactory::create<float>('c', {1000, 1000}); auto x1 = NDArrayFactory::create<float>('c', {1000, 1000});
RandomLauncher::fillTruncatedNormal(_rngA, &x0, 0.0f, 1.0f); RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 0.0f, 1.0f);
RandomLauncher::fillTruncatedNormal(_rngB, &x1, 0.0f, 1.0f); RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 0.0f, 1.0f);
ASSERT_TRUE(x0.equalsTo(&x1)); ASSERT_TRUE(x0.equalsTo(&x1));
@ -524,8 +524,8 @@ TEST_F(RNGTests, Test_Truncated_3) {
auto x0 = NDArrayFactory::create<float>('c', {10000, 1000}); auto x0 = NDArrayFactory::create<float>('c', {10000, 1000});
auto x1 = NDArrayFactory::create<float>('c', {10000, 1000}); auto x1 = NDArrayFactory::create<float>('c', {10000, 1000});
RandomLauncher::fillTruncatedNormal(_rngA, &x0, 1.0f, 2.0f); RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f);
RandomLauncher::fillTruncatedNormal(_rngB, &x1, 1.0f, 2.0f); RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f);
ASSERT_TRUE(x0.equalsTo(&x1)); ASSERT_TRUE(x0.equalsTo(&x1));
@ -964,7 +964,7 @@ TEST_F(RNGTests, Test_Reproducibility_2) {
TEST_F(RNGTests, Test_Uniform_4) { TEST_F(RNGTests, Test_Uniform_4) {
auto x1 = NDArrayFactory::create<double>('c', {1000000}); auto x1 = NDArrayFactory::create<double>('c', {1000000});
RandomLauncher::fillUniform(_rngB, &x1, 1.0, 2.0); RandomLauncher::fillUniform(LaunchContext::defaultContext(), _rngB, &x1, 1.0, 2.0);
/* Check up distribution */ /* Check up distribution */
auto mean = x1.reduceNumber(reduce::Mean); auto mean = x1.reduceNumber(reduce::Mean);

View File

@ -69,6 +69,24 @@ TEST_F(SortCudaTests, test_linear_sort_by_val_1) {
ASSERT_EQ(ev, v); ASSERT_EQ(ev, v);
} }
TEST_F(SortCudaTests, test_linear_sort_by_val_2) {
auto k = NDArrayFactory::create<int>('c', {6}, {0, 1, 2, 3, 4, 5});
// auto v = NDArrayFactory::create<double>('c', {6}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5});
NDArray v = NDArrayFactory::create<double>('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f});
auto ek = NDArrayFactory::create<int>('c', {6}, {3, 0, 1, 2, 4, 5});
auto ev = NDArrayFactory::create<double>('c', {6}, {0.95, 0.9, 0.75, 0.6, 0.5, 0.3});
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
NativeOps nativeOps;
nativeOps.sortByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), true);
k.tickWriteDevice();
v.tickWriteDevice();
k.printIndexedBuffer("KEYS");
ASSERT_EQ(ek, k);
ASSERT_EQ(ev, v);
}
TEST_F(SortCudaTests, test_tad_sort_by_key_1) { TEST_F(SortCudaTests, test_tad_sort_by_key_1) {
auto k = NDArrayFactory::create<Nd4jLong>('c', {2, 10}, {1, 3, 5, 9, 0, 2, 4, 6, 7, 8, 1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); auto k = NDArrayFactory::create<Nd4jLong>('c', {2, 10}, {1, 3, 5, 9, 0, 2, 4, 6, 7, 8, 1, 3, 5, 9, 0, 2, 4, 6, 7, 8});
auto v = NDArrayFactory::create<double>('c', {2, 10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5, 1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); auto v = NDArrayFactory::create<double>('c', {2, 10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5, 1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5});