[WIP] More of CUDA operations (#69)
* initial commit Signed-off-by: raver119 <raver119@gmail.com> * - gruCell_bp further Signed-off-by: Yurii <yurii@skymind.io> * - further work on gruCell_bp Signed-off-by: Yurii <yurii@skymind.io> * Inverse matrix cublas implementation. Partial working revision. * Separation of segment ops helpers. Max separation. * Separated segment_min ops. * Separation of segment_mean/sum/prod/sqrtN ops heleprs. * Fixed diagonal processing with LUP decomposition. * Modified inversion approach using current state of LU decomposition. * Implementation of matrix_inverse op with cuda kernels. Working revision. * Implemented sequence_mask cuda helper. Eliminated waste printf with matrix_inverse implementation. Added proper tests. * - further work on gruCell_bp (ff/cuda) Signed-off-by: Yurii <yurii@skymind.io> * comment one test for gruCell_bp Signed-off-by: Yurii <yurii@skymind.io> * - provide cuda static_rnn Signed-off-by: Yurii <yurii@skymind.io> * Refactored random_shuffle op to use new random generator. * Refactored random_shuffle op helper. * Fixed debug tests with random ops tests. * Implement random_shuffle op cuda kernel helper and tests. * - provide cuda scatter_update Signed-off-by: Yurii <yurii@skymind.io> * Implementation of random_shuffle for linear case with cuda kernels and tests. * Implemented random_shuffle with cuda kernels. Final revision. * - finally gruCell_bp is completed Signed-off-by: Yurii <yurii@skymind.io> * Dropout op cuda helper implementation. * Implemented dropout_bp cuda helper. * Implemented alpha_dropout_bp with cuda kernel helpers. * Refactored helper. * Implementation of suppresion helper with cuda kernels. * - provide cpu code fot hsvToRgb, rgbToHsv, adjustHue Signed-off-by: Yurii <yurii@skymind.io> * Using sort by value method. * Implementation of image.non_max_suppression op cuda-based helper. * - correcting and testing adjust_hue, adjust_saturation cpu/cuda code Signed-off-by: Yurii <yurii@skymind.io> * Added cuda device prefixes to declarations. * Implementation of hashcode op with cuda helper. Initital revision. * rnn cu impl removed Signed-off-by: raver119 <raver119@gmail.com>master
parent
06e4f5f96e
commit
763a225c6a
|
@ -208,9 +208,9 @@ namespace nd4j {
|
|||
NDArray(const char order, const std::vector<Nd4jLong> &shape, const std::vector<double>& data, nd4j::DataType dtype = DOUBLE, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||
|
||||
/**
|
||||
* this constructor creates new array using given buffer (without memory allocating) and shape information stored in shape
|
||||
* this constructor creates new array using given buffer (without memory allocation) and shape information stored in shape
|
||||
*/
|
||||
NDArray(void *buffer, const char order, const std::vector<Nd4jLong> &shape, nd4j::DataType dtype, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
||||
NDArray(void *buffer, const char order, const std::vector<Nd4jLong> &shape, nd4j::DataType dtype, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext(), const bool isBuffAlloc = false);
|
||||
|
||||
/**
|
||||
* this constructor creates new NDArray with shape matching "other" array,
|
||||
|
|
|
@ -132,9 +132,8 @@ NDArray::NDArray(const NDArray *other, const bool copyStrides, nd4j::LaunchConte
|
|||
_buffer = std::make_shared<DataBuffer>(lengthOf() * sizeOfT(), dataType(), getContext()->getWorkspace());
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
NDArray::NDArray(void* buffer, const char order, const std::vector<Nd4jLong> &shape, nd4j::DataType dtype, nd4j::LaunchContext * context) {
|
||||
NDArray::NDArray(void* buffer, const char order, const std::vector<Nd4jLong> &shape, nd4j::DataType dtype, nd4j::LaunchContext * context, const bool isBuffAlloc) {
|
||||
|
||||
if (shape.empty())
|
||||
throw std::runtime_error("NDArray constructor: input shape is empty !");
|
||||
|
@ -148,7 +147,7 @@ NDArray::NDArray(void* buffer, const char order, const std::vector<Nd4jLong> &sh
|
|||
|
||||
setShapeInfo(ShapeDescriptor(dtype, order, shape));
|
||||
|
||||
_buffer = std::make_shared<DataBuffer>(buffer, lengthOf() * sizeOfT(), dataType(), true, getContext()->getWorkspace());
|
||||
_buffer = std::make_shared<DataBuffer>(buffer, lengthOf() * sizeOfT(), dataType(), isBuffAlloc, getContext()->getWorkspace());
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -1498,16 +1498,6 @@ void NativeOps::specialConcat(
|
|||
* This method saves
|
||||
*/
|
||||
nd4j::TadPack* NativeOps::tadOnlyShapeInfo(Nd4jLong *dXShapeInfo, int *dimension, int dimensionLength) {
|
||||
/*shape::TAD tad;
|
||||
tad.init(dXShapeInfo, dimension, dimensionLength);
|
||||
//tad->setOutputBuffer(target);
|
||||
tad.createTadOnlyShapeInfo();
|
||||
tad.createOffsets();
|
||||
|
||||
|
||||
std::memcpy(reinterpret_cast<void *>(target), tad.tadOnlyShapeInfo, shape::shapeInfoByteLength(tad.tadOnlyShapeInfo));
|
||||
std::memcpy(reinterpret_cast<void *>(offsets), tad.tadOffsets, tad.numTads * sizeof(Nd4jLong));
|
||||
*/
|
||||
auto pack = new TadPack();
|
||||
*pack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(dXShapeInfo, dimension, dimensionLength);
|
||||
return pack;
|
||||
|
|
|
@ -45,9 +45,9 @@ namespace nd4j {
|
|||
|
||||
static ConstantTadHelper* getInstance();
|
||||
|
||||
TadPack& tadForDimensions(Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape = false);
|
||||
TadPack& tadForDimensions(Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape = false);
|
||||
TadPack& tadForDimensions(Nd4jLong *originalShape, int dimensions, const bool keepUnitiesInShape = false);
|
||||
TadPack& tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape = false);
|
||||
TadPack& tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape = false);
|
||||
TadPack& tadForDimensions(const Nd4jLong *originalShape, int dimensions, const bool keepUnitiesInShape = false);
|
||||
TadPack& tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape = false);
|
||||
TadPack& tadForDimensions(TadDescriptor &descriptor);
|
||||
};
|
||||
|
|
|
@ -38,15 +38,15 @@ namespace nd4j {
|
|||
return _INSTANCE;
|
||||
}
|
||||
|
||||
TadPack& ConstantTadHelper::tadForDimensions(Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) {
|
||||
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) {
|
||||
return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape);
|
||||
}
|
||||
|
||||
TadPack& ConstantTadHelper::tadForDimensions(Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape) {
|
||||
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape) {
|
||||
return tadForDimensions(originalShape, const_cast<int *>(dimensions.data()), dimensions.size(), keepUnitiesInShape);
|
||||
}
|
||||
|
||||
TadPack& ConstantTadHelper::tadForDimensions(Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) {
|
||||
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) {
|
||||
TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, keepUnitiesInShape);
|
||||
return tadForDimensions(tadDescriptor);
|
||||
}
|
||||
|
|
|
@ -42,15 +42,15 @@ namespace nd4j {
|
|||
return _INSTANCE;
|
||||
}
|
||||
|
||||
TadPack& ConstantTadHelper::tadForDimensions(Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) {
|
||||
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) {
|
||||
return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape);
|
||||
}
|
||||
|
||||
TadPack& ConstantTadHelper::tadForDimensions(Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape) {
|
||||
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape) {
|
||||
return tadForDimensions(originalShape, const_cast<int *>(dimensions.data()), dimensions.size(), keepUnitiesInShape);
|
||||
}
|
||||
|
||||
TadPack& ConstantTadHelper::tadForDimensions(Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) {
|
||||
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) {
|
||||
TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, keepUnitiesInShape);
|
||||
return tadForDimensions(tadDescriptor);
|
||||
}
|
||||
|
|
|
@ -58,7 +58,7 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
|
|||
const std::vector<NDArray*>& inArrsFF = argsHolderFF.getInArrs();
|
||||
const std::vector<NDArray*>& inArrsBP = argsHolderBP.getInArrs();
|
||||
|
||||
// fill input gradient arrays in accordance to type of loss function
|
||||
// fill input gradient arrays in accordance to kind of loss function
|
||||
fillGradArrays(loss, std::vector<NDArray*>(&inArrsBP[numInArrsFF], &inArrsBP[numInArrsFF + numInGradArrsBP]));
|
||||
|
||||
// beck prop pass
|
||||
|
|
|
@ -987,9 +987,10 @@ namespace shape {
|
|||
// dimsToExclude - should be sorted in increasing order
|
||||
ND4J_EXPORT _CUDA_HD int outerArrayIndexes(Nd4jLong* maxIdxs, const Nd4jLong minIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, const int* dimsToExclude = nullptr);
|
||||
|
||||
// calculate offsets of max-array, these output offsets correspond to one minIdx index of min-array which is sub-array of max-array
|
||||
// calculate offsets of max-array, these offsets correspond to one minIdx index of min-array which is sub-array of max-array
|
||||
// maxOffsets - will contain calculated offsets of max-array, buffer for maxOffsets should be allocated beforehand
|
||||
// dimsToExclude - should be sorted in increasing order
|
||||
// memBuff - auxiliary memory buffer (size = 2 * max_rank) for coordinates and increments storing, should be passed from outside
|
||||
// memBuff - auxiliary memory buffer (size = 2 * max_rank) for coordinates and increments storing, should be allocated beforehand
|
||||
ND4J_EXPORT _CUDA_HD int outerArrayOffsets(Nd4jLong* maxOffsets, const Nd4jLong minIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, Nd4jLong* memBuff, const int* dimsToExclude = nullptr);
|
||||
|
||||
// calculates offsets for entities (elements or sub-arrays), shape in context of sub-array means dimensions excluded from outer array
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
#include <op_boilerplate.h>
|
||||
|
@ -28,46 +29,35 @@
|
|||
namespace nd4j {
|
||||
namespace ops {
|
||||
|
||||
DECLARE_TYPES(adjust_hue) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||
->setSameMode(true);
|
||||
}
|
||||
|
||||
CONFIGURABLE_OP_IMPL(adjust_hue, 1, 1, true, -2, -2) {
|
||||
CONFIGURABLE_OP_IMPL(adjust_hue, 1, 1, true, 1, -2) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 3 || input->rankOf() == 4, 0, "AdjustHue: op expects either 3D or 4D input, but got %i instead", input->rankOf());
|
||||
const int rank = input->rankOf();
|
||||
const int dimC = block.getIArguments()->size() > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1;
|
||||
const double delta = T_ARG(0);
|
||||
|
||||
REQUIRE_TRUE(rank >= 3, 0, "ADJUST_HUE: op expects rank of input array to be >= 3, but got %i instead", rank);
|
||||
REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "ADJUST_HUE: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC));
|
||||
REQUIRE_TRUE(-1. <= delta && delta <= 1., 0, "ADJUST_HUE: parameter delta must be within [-1, 1] interval, but got %f instead", delta);
|
||||
|
||||
double delta = 0;
|
||||
if (block.numT() > 0)
|
||||
delta = T_ARG(0);
|
||||
else if (block.width() > 1) {
|
||||
auto _d = INPUT_VARIABLE(1);
|
||||
if (!_d->isScalar()) {
|
||||
auto str = ShapeUtils::shapeAsString(_d);
|
||||
REQUIRE_TRUE(_d->isScalar(), 0, "AdjustHue: delta should be scalar NDArray, but got %s instead", str.c_str());
|
||||
}
|
||||
delta = _d->e<double>(0);
|
||||
}
|
||||
NDArray deltaScalarArr = NDArrayFactory::create<double>(delta, block.launchContext());
|
||||
|
||||
|
||||
bool isNHWC = false;
|
||||
if (block.numI() > 0)
|
||||
isNHWC = INT_ARG(0) == 1;
|
||||
|
||||
int numChannels = isNHWC ? input->sizeAt(-1) : input->sizeAt(-3);
|
||||
|
||||
REQUIRE_TRUE(numChannels == 3, 0, "AdjustHue: this operation expects image with 3 channels (R, G, B), but got % instead", numChannels);
|
||||
|
||||
auto ts = NDArrayFactory::create(delta, block.launchContext());
|
||||
// FIXME: delta should be NDArray scalar
|
||||
helpers::_adjust_hue(block.launchContext(), input, output, &ts, isNHWC);
|
||||
helpers::adjustHue(block.launchContext(), input, &deltaScalarArr, output, dimC);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_TYPES(adjust_hue) {
|
||||
getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||
->setSameMode(true);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -27,45 +27,33 @@
|
|||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
DECLARE_TYPES(adjust_saturation) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||
->setSameMode(true);
|
||||
}
|
||||
|
||||
CONFIGURABLE_OP_IMPL(adjust_saturation, 1, 1, true, -2, -2) {
|
||||
CONFIGURABLE_OP_IMPL(adjust_saturation, 1, 1, true, 1, -2) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 3 || input->rankOf() == 4, 0, "AdjustSaturation: op expects either 3D or 4D input, but got %i instead", input->rankOf());
|
||||
const int rank = input->rankOf();
|
||||
const int dimC = block.getIArguments()->size() > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1;
|
||||
const double factor = T_ARG(0);
|
||||
|
||||
double delta = 0;
|
||||
if (block.numT() > 0)
|
||||
delta = T_ARG(0);
|
||||
else if (block.width() > 1) {
|
||||
auto _d = INPUT_VARIABLE(1);
|
||||
if (!_d->isScalar()) {
|
||||
auto str = ShapeUtils::shapeAsString(_d);
|
||||
REQUIRE_TRUE(_d->isScalar(), 0, "AdjustSaturation: delta should be scalar NDArray, but got %s instead", str.c_str());
|
||||
}
|
||||
REQUIRE_TRUE(rank >= 3, 0, "ADJUST_SATURATION: op expects rank of input array to be >= 3, but got %i instead", rank);
|
||||
REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "ADJUST_SATURATION: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC));
|
||||
|
||||
delta = _d->e<double>(0);
|
||||
}
|
||||
NDArray factorScalarArr = NDArrayFactory::create<double>(factor, block.launchContext());
|
||||
|
||||
bool isNHWC = false;
|
||||
if (block.numI() > 0)
|
||||
isNHWC = INT_ARG(0) == 1;
|
||||
|
||||
int numChannels = isNHWC ? input->sizeAt(-1) : input->sizeAt(-3);
|
||||
|
||||
REQUIRE_TRUE(numChannels == 3, 0, "AdjustSaturation: this operation expects image with 3 channels (R, G, B), but got % instead", numChannels);
|
||||
|
||||
auto ts = NDArrayFactory::create(delta, block.launchContext());
|
||||
// FIXME: delta should be NDArray scalar
|
||||
helpers::adjust_saturation(block.launchContext(), input, output, &ts, isNHWC);
|
||||
helpers::adjustSaturation(block.launchContext(), input, &factorScalarArr, output, dimC);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_TYPES(adjust_saturation) {
|
||||
getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||
->setSameMode(true);
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
|
||||
OP_IMPL(scatter_add, 3, 1, true) {
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto indices = INPUT_VARIABLE(1);
|
||||
|
@ -74,8 +75,8 @@ namespace nd4j {
|
|||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_SYN(ScatterAdd, scatter_add);
|
||||
}
|
||||
|
||||
DECLARE_TYPES(scatter_add) {
|
||||
getOpDescriptor()
|
||||
|
@ -84,6 +85,8 @@ namespace nd4j {
|
|||
->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS})
|
||||
->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS});
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
|
@ -57,16 +57,26 @@ namespace nd4j {
|
|||
auto in = inputShape->at(0);
|
||||
int outRank = shape::rank(in) + 1;
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto dtype = DataType::BOOL;
|
||||
Nd4jLong maxInd = input->argMax();
|
||||
float max = input->e<float>(maxInd);
|
||||
Nd4jLong max = input->e<Nd4jLong>(maxInd);
|
||||
|
||||
if (block.getIArguments()->size() > 0) {
|
||||
if (block.width() < 2) {
|
||||
maxInd = INT_ARG(0);
|
||||
if (maxInd < max)
|
||||
maxInd = static_cast<Nd4jLong>(max);
|
||||
if (block.getIArguments()->size() > 1)
|
||||
dtype = (DataType)INT_ARG(1);
|
||||
}
|
||||
else if (block.width() > 1) {
|
||||
else {
|
||||
dtype = (DataType)INT_ARG(0);
|
||||
}
|
||||
}
|
||||
|
||||
if (block.width() > 1) {
|
||||
auto maxlen = INPUT_VARIABLE(1);
|
||||
float tmaxlen = maxlen->e<float>(0);
|
||||
Nd4jLong tmaxlen = maxlen->e<Nd4jLong>(0);
|
||||
if (tmaxlen > max)
|
||||
maxInd = static_cast<Nd4jLong>(tmaxlen);
|
||||
}
|
||||
|
@ -80,14 +90,14 @@ namespace nd4j {
|
|||
outShapeInfo[i + 1] = shape::sizeAt(in, i);
|
||||
outShapeInfo[outRank] = lastDimension;
|
||||
|
||||
ShapeUtils::updateStridesAndType(outShapeInfo, in, shape::order(in));
|
||||
ShapeUtils::updateStridesAndType(outShapeInfo, dtype, shape::order(in));
|
||||
|
||||
return SHAPELIST(CONSTANT(outShapeInfo));
|
||||
}
|
||||
|
||||
DECLARE_TYPES(sequence_mask) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||
->setAllowedInputTypes({ALL_INTS})
|
||||
->setAllowedOutputTypes(nd4j::DataType::ANY);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -33,11 +33,11 @@ OP_IMPL(random_shuffle, 1, 1, true) {
|
|||
const bool isInplace = block.isInplace();
|
||||
auto output = isInplace ? nullptr : OUTPUT_VARIABLE(0);
|
||||
|
||||
nd4j::random::RandomBuffer* rng = block.getRNG();
|
||||
// nd4j::random::RandomBuffer* rng = block.getRNG();
|
||||
nd4j::graph::RandomGenerator rng = block.randomGenerator();
|
||||
// REQUIRE_TRUE(rng != nullptr, 0, "RANDOM_SHUFFLE op: RNG should be defined in Graph !");
|
||||
|
||||
REQUIRE_TRUE(rng != nullptr, 0, "RANDOM_SHUFFLE op: RNG should be defined in Graph !");
|
||||
|
||||
helpers::randomShuffle(block.launchContext(), *input, *output, *rng, isInplace);
|
||||
helpers::randomShuffle(block.launchContext(), *input, *output, rng, isInplace);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -31,6 +31,7 @@ namespace ops {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
CUSTOM_OP_IMPL(gruCell, 6, 4, false, 0, 0) {
|
||||
|
||||
auto x = INPUT_VARIABLE(0); // input [bS, nIn], nIn - input size
|
||||
auto hLast = INPUT_VARIABLE(1); // previous cell output [bS, nU], that is at previous time step t-1, nU - number of units
|
||||
auto Wru = INPUT_VARIABLE(2); // RU weights - [nIn+nU, 2*nU] - reset and update gates (input/recurrent weights)
|
||||
|
@ -118,65 +119,58 @@ DECLARE_SHAPE_FN(gruCell) {
|
|||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
CUSTOM_OP_IMPL(gruCell_bp, 6, 5, false, 0, 0) {
|
||||
CUSTOM_OP_IMPL(gruCell_bp, 10, 6, false, 0, 0) {
|
||||
|
||||
auto x = INPUT_VARIABLE(0); // input [bS x iS]
|
||||
auto hi = INPUT_VARIABLE(1); // previous cell output [bS x nU]
|
||||
auto Wx = INPUT_VARIABLE(2); // input-to-hidden weights, [iS x 3*nU]
|
||||
auto Wh = INPUT_VARIABLE(3); // hidden-to-hidden weights, [nU x 3*nU]
|
||||
auto b = INPUT_VARIABLE(4); // biases, [3*nU]
|
||||
auto dLdh = INPUT_VARIABLE(5); // gradient wrt output, [bS,nU], that is epsilon_next
|
||||
auto dLdWxi = block.width() > 6 ? INPUT_VARIABLE(6) : nullptr; // gradient wrt Wx at previous time step, [iS, 3*nU]
|
||||
auto dLdWhi = block.width() > 7 ? INPUT_VARIABLE(7) : nullptr; // gradient wrt Wh at previous time step, [nU, 3*nU]
|
||||
auto dLdbi = block.width() > 8 ? INPUT_VARIABLE(8) : nullptr; // gradient wrt b at previous time step, [3*nU]
|
||||
auto W = INPUT_VARIABLE(2); // weights, [iS+nU x 2*nU]
|
||||
auto Wc = INPUT_VARIABLE(3); // c weights, [iS+nU x nU]
|
||||
auto b = INPUT_VARIABLE(4); // biases, [2*nU]
|
||||
auto bc = INPUT_VARIABLE(5); // biases, [nU]
|
||||
auto dLdr = INPUT_VARIABLE(6); // gradient wrt reset gate, [bS, nU]
|
||||
auto dLdu = INPUT_VARIABLE(7); // gradient wrt update gate, [bS, nU]
|
||||
auto dLdc = INPUT_VARIABLE(8); // gradient wrt cell state, [bS, nU]
|
||||
auto dLdh = INPUT_VARIABLE(9); // gradient wrt current cell output, [bS, nU]
|
||||
|
||||
auto dLdx = OUTPUT_VARIABLE(0); // gradient wrt x, [bS, iS], that is epsilon
|
||||
auto dLdx = OUTPUT_VARIABLE(0); // gradient wrt x, [bS, iS]
|
||||
auto dLdhi = OUTPUT_VARIABLE(1); // gradient wrt hi, [bS, nU]
|
||||
auto dLdWx = OUTPUT_VARIABLE(2); // gradient wrt Wx, [iS, 3*nU]
|
||||
auto dLdWh = OUTPUT_VARIABLE(3); // gradient wrt Wh, [nU, 3*nU]
|
||||
auto dLdb = OUTPUT_VARIABLE(4); // gradient wrt biases, [3*nU]
|
||||
auto dLdW = OUTPUT_VARIABLE(2); // gradient wrt W, [iS+nU x 2*nU]
|
||||
auto dLdWc = OUTPUT_VARIABLE(3); // gradient wrt Wc, [iS+nU x nU]
|
||||
auto dLdb = OUTPUT_VARIABLE(4); // gradient wrt biases, [2*nU]
|
||||
auto dLdbc = OUTPUT_VARIABLE(5); // gradient wrt c biases, [nU]
|
||||
|
||||
const int rank = x->rankOf(); // = 2
|
||||
const Nd4jLong bS = x->sizeAt(0);
|
||||
const Nd4jLong iS = x->sizeAt(1);
|
||||
const Nd4jLong nU = hi->sizeAt(1);
|
||||
|
||||
REQUIRE_TRUE(x->rankOf() == 2, 0, "GRU_CELL_BP: rank of input array x must be 2, but got %i instead", x->rankOf());
|
||||
|
||||
const std::string hiShape = ShapeUtils::shapeAsString(hi);
|
||||
const std::string hiCorrectShape = ShapeUtils::shapeAsString({bS, nU});
|
||||
const std::string wxShape = ShapeUtils::shapeAsString(Wx);
|
||||
const std::string wxCorrectShape = ShapeUtils::shapeAsString({iS, 3*nU});
|
||||
const std::string whShape = ShapeUtils::shapeAsString(Wh);
|
||||
const std::string whCorrectShape = ShapeUtils::shapeAsString({nU, 3*nU});
|
||||
const std::string wShape = ShapeUtils::shapeAsString(W);
|
||||
const std::string wCorrectShape = ShapeUtils::shapeAsString({iS+nU, 2*nU});
|
||||
const std::string wcShape = ShapeUtils::shapeAsString(Wc);
|
||||
const std::string wcCorrectShape = ShapeUtils::shapeAsString({iS+nU, nU});
|
||||
const std::string bShape = ShapeUtils::shapeAsString(b);
|
||||
const std::string bCorrectShape = ShapeUtils::shapeAsString({3*nU});
|
||||
const std::string bCorrectShape = ShapeUtils::shapeAsString({2*nU});
|
||||
const std::string bcShape = ShapeUtils::shapeAsString(bc);
|
||||
const std::string bcCorrectShape = ShapeUtils::shapeAsString({nU});
|
||||
const std::string dLdrShape = ShapeUtils::shapeAsString(dLdr);
|
||||
const std::string dLduShape = ShapeUtils::shapeAsString(dLdu);
|
||||
const std::string dLdcShape = ShapeUtils::shapeAsString(dLdc);
|
||||
const std::string dLdhShape = ShapeUtils::shapeAsString(dLdh);
|
||||
const std::string dLdhCorrectShape = ShapeUtils::shapeAsString({bS, nU});
|
||||
|
||||
REQUIRE_TRUE(hiShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of previous cell output array, expected is %s, but got %s instead !", hiCorrectShape.c_str(), hiShape.c_str());
|
||||
REQUIRE_TRUE(wxShape == wxCorrectShape, 0, "GRU_CELL_BP op: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", wxCorrectShape.c_str(), wxShape.c_str());
|
||||
REQUIRE_TRUE(whShape == whCorrectShape, 0, "GRU_CELL_BP op: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", whCorrectShape.c_str(), whShape.c_str());
|
||||
REQUIRE_TRUE(wShape == wCorrectShape, 0, "GRU_CELL_BP op: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
|
||||
REQUIRE_TRUE(wcShape == wcCorrectShape, 0, "GRU_CELL_BP op: wrong shape of c weights array, expected is %s, but got %s instead !", wcCorrectShape.c_str(), wcShape.c_str());
|
||||
REQUIRE_TRUE(bShape == bCorrectShape, 0, "GRU_CELL_BP op: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
|
||||
REQUIRE_TRUE(dLdhShape == dLdhCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdh array (epsilon_next), expected is %s, but got %s instead !", dLdhCorrectShape.c_str(), dLdhShape.c_str());
|
||||
REQUIRE_TRUE(bcShape == bcCorrectShape, 0, "GRU_CELL_BP op: wrong shape of c biases array, expected is %s, but got %s instead !", bcCorrectShape.c_str(), bcShape.c_str());
|
||||
REQUIRE_TRUE(dLdrShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdr array (gradient wrt reset gate), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLdrShape.c_str());
|
||||
REQUIRE_TRUE(dLduShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdu array (gradient wrt update gate), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLduShape.c_str());
|
||||
REQUIRE_TRUE(dLdcShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdc array (gradient wrt cell state), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLdcShape.c_str());
|
||||
REQUIRE_TRUE(dLdhShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdh array (gradient wrt current cell output), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLdhShape.c_str());
|
||||
|
||||
if(dLdWxi != nullptr) {
|
||||
const std::string dLdWxiShape = ShapeUtils::shapeAsString(dLdWxi);
|
||||
const std::string dLdWxiCorrectShape = ShapeUtils::shapeAsString({iS, 3*nU});
|
||||
REQUIRE_TRUE(dLdWxiShape == dLdWxiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdWxi array (gradient wrt Wx at previous time step), expected is %s, but got %s instead !", dLdWxiCorrectShape.c_str(), dLdWxiShape.c_str());
|
||||
}
|
||||
|
||||
if(dLdWhi != nullptr) {
|
||||
const std::string dLdWhiShape = ShapeUtils::shapeAsString(dLdWhi);
|
||||
const std::string dLdWhiCorrectShape = ShapeUtils::shapeAsString({nU, 3*nU});
|
||||
REQUIRE_TRUE(dLdWhiShape == dLdWhiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdWhi array (gradient wrt Wh at previous time step), expected is %s, but got %s instead !", dLdWhiCorrectShape.c_str(), dLdWhiShape.c_str());
|
||||
}
|
||||
|
||||
if(dLdbi != nullptr) {
|
||||
const std::string dLdbiShape = ShapeUtils::shapeAsString(dLdbi);
|
||||
const std::string dLdbiCorrectShape = ShapeUtils::shapeAsString({3*nU});
|
||||
REQUIRE_TRUE(dLdbiShape == dLdbiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdbi array (gradient wrt biases at previous time step), expected is %s, but got %s instead !", dLdbiCorrectShape.c_str(), dLdbiShape.c_str());
|
||||
}
|
||||
|
||||
helpers::gruCellBP(block.launchContext(), x, hi, Wx, Wh, b, dLdh, dLdWxi, dLdWhi, dLdbi, dLdx, dLdhi, dLdWx, dLdWh, dLdb);
|
||||
helpers::gruCellBP(block.launchContext(), x, hi, W, Wc, b, bc, dLdr, dLdu, dLdc, dLdh, dLdx, dLdhi, dLdW, dLdWc, dLdb, dLdbc);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -192,6 +186,7 @@ DECLARE_TYPES(gruCell_bp) {
|
|||
->setAllowedInputTypes(6, {ALL_FLOATS})
|
||||
->setAllowedInputTypes(7, {ALL_FLOATS})
|
||||
->setAllowedInputTypes(8, {ALL_FLOATS})
|
||||
->setAllowedInputTypes(9, {ALL_FLOATS})
|
||||
->setAllowedOutputTypes({ALL_FLOATS});
|
||||
}
|
||||
|
||||
|
@ -199,53 +194,46 @@ DECLARE_SHAPE_FN(gruCell_bp) {
|
|||
|
||||
auto xShapeInfo = inputShape->at(0); // [bS x iS]
|
||||
auto hiShapeInfo = inputShape->at(1); // [bS x nU]
|
||||
auto wxShapeInfo = inputShape->at(2); // [iS x 3*nU]
|
||||
auto whShapeInfo = inputShape->at(3); // [nU x 3*nU]
|
||||
auto bShapeInfo = inputShape->at(4); // [3*nU]
|
||||
auto dLdhShapeInfo = inputShape->at(5); // [bS x nU]
|
||||
auto wShapeInfo = inputShape->at(2); // [iS+nU x 2*nU]
|
||||
auto wcShapeInfo = inputShape->at(3); // [iS+nU x nU]
|
||||
auto bShapeInfo = inputShape->at(4); // [2*nU]
|
||||
auto bcShapeInfo = inputShape->at(5); // [nU]
|
||||
auto dLdrShapeInfo = inputShape->at(6); // [bS, nU]
|
||||
auto dLduShapeInfo = inputShape->at(7); // [bS, nU]
|
||||
auto dLdcShapeInfo = inputShape->at(8); // [bS, nU]
|
||||
auto dLdhShapeInfo = inputShape->at(9); // [bS, nU]
|
||||
|
||||
const int rank = xShapeInfo[0]; // = 2
|
||||
const Nd4jLong bS = xShapeInfo[1];
|
||||
const Nd4jLong iS = xShapeInfo[2];
|
||||
const Nd4jLong nU = hiShapeInfo[2];
|
||||
|
||||
REQUIRE_TRUE(xShapeInfo[0] == 2, 0, "GRU_CELL_BP: rank of input array x must be 2, but got %i instead", xShapeInfo[0]);
|
||||
|
||||
const std::string hiShape = ShapeUtils::shapeAsString(hiShapeInfo);
|
||||
const std::string hiCorrectShape = ShapeUtils::shapeAsString({bS, nU});
|
||||
const std::string wxShape = ShapeUtils::shapeAsString(wxShapeInfo);
|
||||
const std::string wxCorrectShape = ShapeUtils::shapeAsString({iS, 3*nU});
|
||||
const std::string whShape = ShapeUtils::shapeAsString(whShapeInfo);
|
||||
const std::string whCorrectShape = ShapeUtils::shapeAsString({nU, 3*nU});
|
||||
const std::string wShape = ShapeUtils::shapeAsString(wShapeInfo);
|
||||
const std::string wCorrectShape = ShapeUtils::shapeAsString({iS+nU, 2*nU});
|
||||
const std::string wcShape = ShapeUtils::shapeAsString(wcShapeInfo);
|
||||
const std::string wcCorrectShape = ShapeUtils::shapeAsString({iS+nU, nU});
|
||||
const std::string bShape = ShapeUtils::shapeAsString(bShapeInfo);
|
||||
const std::string bCorrectShape = ShapeUtils::shapeAsString({3*nU});
|
||||
const std::string bCorrectShape = ShapeUtils::shapeAsString({2*nU});
|
||||
const std::string bcShape = ShapeUtils::shapeAsString(bcShapeInfo);
|
||||
const std::string bcCorrectShape = ShapeUtils::shapeAsString({nU});
|
||||
const std::string dLdrShape = ShapeUtils::shapeAsString(dLdrShapeInfo);
|
||||
const std::string dLduShape = ShapeUtils::shapeAsString(dLduShapeInfo);
|
||||
const std::string dLdcShape = ShapeUtils::shapeAsString(dLdcShapeInfo);
|
||||
const std::string dLdhShape = ShapeUtils::shapeAsString(dLdhShapeInfo);
|
||||
const std::string dLdhCorrectShape = ShapeUtils::shapeAsString({bS, nU});
|
||||
|
||||
REQUIRE_TRUE(hiShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of previous cell output array, expected is %s, but got %s instead !", hiCorrectShape.c_str(), hiShape.c_str());
|
||||
REQUIRE_TRUE(wxShape == wxCorrectShape, 0, "GRU_CELL_BP op: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", wxCorrectShape.c_str(), wxShape.c_str());
|
||||
REQUIRE_TRUE(whShape == whCorrectShape, 0, "GRU_CELL_BP op: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", whCorrectShape.c_str(), whShape.c_str());
|
||||
REQUIRE_TRUE(wShape == wCorrectShape, 0, "GRU_CELL_BP op: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
|
||||
REQUIRE_TRUE(wcShape == wcCorrectShape, 0, "GRU_CELL_BP op: wrong shape of c weights array, expected is %s, but got %s instead !", wcCorrectShape.c_str(), wcShape.c_str());
|
||||
REQUIRE_TRUE(bShape == bCorrectShape, 0, "GRU_CELL_BP op: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
|
||||
REQUIRE_TRUE(dLdhShape == dLdhCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdh array (epsilon_next), expected is %s, but got %s instead !", dLdhCorrectShape.c_str(), dLdhShape.c_str());
|
||||
|
||||
if(block.width() > 6) {
|
||||
Nd4jLong* dLdWxiShapeInfo = inputShape->at(6); // [iS x 3*nU]
|
||||
const std::string dLdWxiShape = ShapeUtils::shapeAsString(dLdWxiShapeInfo);
|
||||
const std::string dLdWxiCorrectShape = ShapeUtils::shapeAsString({iS, 3*nU});
|
||||
REQUIRE_TRUE(dLdWxiShape == dLdWxiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdWxi array (gradient wrt Wx at previous time step), expected is %s, but got %s instead !", dLdWxiCorrectShape.c_str(), dLdWxiShape.c_str());
|
||||
}
|
||||
|
||||
if(block.width() > 7) {
|
||||
Nd4jLong* dLdWhiShapeInfo = inputShape->at(7); // [nU x 3*nU]
|
||||
const std::string dLdWhiShape = ShapeUtils::shapeAsString(dLdWhiShapeInfo);
|
||||
const std::string dLdWhiCorrectShape = ShapeUtils::shapeAsString({nU, 3*nU});
|
||||
REQUIRE_TRUE(dLdWhiShape == dLdWhiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdWhi array (gradient wrt Wh at previous time step), expected is %s, but got %s instead !", dLdWhiCorrectShape.c_str(), dLdWhiShape.c_str());
|
||||
}
|
||||
|
||||
if(block.width() > 8) {
|
||||
Nd4jLong* dLdbiShapeInfo = inputShape->at(8); // [3*nU]
|
||||
const std::string dLdbiShape = ShapeUtils::shapeAsString(dLdbiShapeInfo);
|
||||
const std::string dLdbiCorrectShape = ShapeUtils::shapeAsString({3*nU});
|
||||
REQUIRE_TRUE(dLdbiShape == dLdbiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdbi array (gradient wrt biases at previous time step), expected is %s, but got %s instead !", dLdbiCorrectShape.c_str(), dLdbiShape.c_str());
|
||||
}
|
||||
REQUIRE_TRUE(bcShape == bcCorrectShape, 0, "GRU_CELL_BP op: wrong shape of c biases array, expected is %s, but got %s instead !", bcCorrectShape.c_str(), bcShape.c_str());
|
||||
REQUIRE_TRUE(dLdrShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdr array (gradient wrt reset gate), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLdrShape.c_str());
|
||||
REQUIRE_TRUE(dLduShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdu array (gradient wrt update gate), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLduShape.c_str());
|
||||
REQUIRE_TRUE(dLdcShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdc array (gradient wrt cell state), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLdcShape.c_str());
|
||||
REQUIRE_TRUE(dLdhShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdh array (gradient wrt current cell output), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLdhShape.c_str());
|
||||
|
||||
Nd4jLong *dLdxShapeInfo = nullptr;
|
||||
COPY_SHAPE(xShapeInfo, dLdxShapeInfo);
|
||||
|
@ -253,17 +241,19 @@ DECLARE_SHAPE_FN(gruCell_bp) {
|
|||
Nd4jLong *dLdhiShapeInfo = nullptr;
|
||||
COPY_SHAPE(hiShapeInfo, dLdhiShapeInfo);
|
||||
|
||||
Nd4jLong *dLdWxShapeInfo = nullptr;
|
||||
COPY_SHAPE(wxShapeInfo, dLdWxShapeInfo);
|
||||
Nd4jLong *dLdWShapeInfo = nullptr;
|
||||
COPY_SHAPE(wShapeInfo, dLdWShapeInfo);
|
||||
|
||||
Nd4jLong *dLdWhShapeInfo = nullptr;
|
||||
COPY_SHAPE(whShapeInfo, dLdWhShapeInfo);
|
||||
Nd4jLong *dLdWcShapeInfo = nullptr;
|
||||
COPY_SHAPE(wcShapeInfo, dLdWcShapeInfo);
|
||||
|
||||
Nd4jLong *dLdbShapeInfo = nullptr;
|
||||
COPY_SHAPE(bShapeInfo, dLdbShapeInfo);
|
||||
|
||||
return SHAPELIST(dLdxShapeInfo, dLdhiShapeInfo, dLdWxShapeInfo, dLdWhShapeInfo, dLdbShapeInfo);
|
||||
Nd4jLong *dLdbcShapeInfo = nullptr;
|
||||
COPY_SHAPE(bcShapeInfo, dLdbcShapeInfo);
|
||||
|
||||
return SHAPELIST(dLdxShapeInfo, dLdhiShapeInfo, dLdWShapeInfo, dLdWcShapeInfo, dLdbShapeInfo, dLdbcShapeInfo);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -553,33 +553,31 @@ namespace nd4j {
|
|||
/**
|
||||
* This operation adjusts image hue by delta
|
||||
* Input arrays:
|
||||
* 0 - 1D or 3D input array, must have 3 channels.
|
||||
* 1 - optional scalar, delta value
|
||||
* 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels.
|
||||
*
|
||||
* T arguments:
|
||||
* 0 - optional delta value
|
||||
* 0 - delta value
|
||||
*
|
||||
* Int arguments:
|
||||
* 0 - optional argument, isNHWC. false by default.
|
||||
* 0 - optional argument, corresponds to dimension with 3 channels
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_adjust_hue)
|
||||
DECLARE_CONFIGURABLE_OP(adjust_hue, 1, 1, true, -2, -2);
|
||||
DECLARE_CONFIGURABLE_OP(adjust_hue, 1, 1, true, 1, -2);
|
||||
#endif
|
||||
|
||||
/**
|
||||
* This operation adjusts image saturation by delta
|
||||
* Input arrays:
|
||||
* 0 - 1D or 3D input array, must have 3 channels.
|
||||
* 1 - optional scalar, delta value
|
||||
* 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels.
|
||||
*
|
||||
* T arguments:
|
||||
* 0 - optional delta value
|
||||
* 0 - saturation factor
|
||||
*
|
||||
* Int arguments:
|
||||
* 0 - optional argument, isNHWC. false by default.
|
||||
* 0 - optional argument, corresponds to dimension with 3 channels
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_adjust_saturation)
|
||||
DECLARE_CONFIGURABLE_OP(adjust_saturation, 1, 1, true, -2, -2);
|
||||
DECLARE_CONFIGURABLE_OP(adjust_saturation, 1, 1, true, 1, -2);
|
||||
#endif
|
||||
|
||||
|
||||
|
|
|
@ -259,8 +259,8 @@ namespace ops {
|
|||
* Input arrays:
|
||||
* 0: input with shape [batchSize x inSize], batchSize - batch size, inSize - number of features
|
||||
* 1: previous cell output [batchSize x numUnits], that is at previous time step t-1
|
||||
* 2: RU weights - [(nIn+nOut), 2*numUnits] - reset and update gates (input/recurrent weights)
|
||||
* 3: C weights - [(nIn+nOut), numUnits] - cell gate (input/recurrent weights)
|
||||
* 2: RU weights - [(inSize+numUnits), 2*numUnits] - reset and update gates (input/recurrent weights)
|
||||
* 3: C weights - [(inSize+numUnits), numUnits] - cell gate (input/recurrent weights)
|
||||
* 4: reset and update biases, [2*numUnits] - reset and update gates
|
||||
* 5: cell biases, [numUnits]
|
||||
*
|
||||
|
@ -275,7 +275,7 @@ namespace ops {
|
|||
#endif
|
||||
|
||||
#if NOT_EXCLUDED(OP_gruCell)
|
||||
DECLARE_CUSTOM_OP(gruCell_bp, 6, 5, false, 0, 0);
|
||||
DECLARE_CUSTOM_OP(gruCell_bp, 10, 6, false, 0, 0);
|
||||
#endif
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
#include <op_boilerplate.h>
|
||||
|
@ -24,6 +25,88 @@
|
|||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
|
||||
void adjustHue(nd4j::LaunchContext* context, const NDArray *input, const NDArray* deltaScalarArr, NDArray *output, const int dimC);
|
||||
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
FORCEINLINE _CUDA_HD void rgbToHsv(const T& r, const T& g, const T& b, T& h, T& s, T& v) {
|
||||
|
||||
// h values are in range [0, 360)
|
||||
// s and v values are in range [0, 1]
|
||||
|
||||
const T max = nd4j::math::nd4j_max<T>(r, nd4j::math::nd4j_max<T>(g, b));
|
||||
const T min = nd4j::math::nd4j_min<T>(r, nd4j::math::nd4j_min<T>(g, b));
|
||||
const T c = max - min;
|
||||
|
||||
// calculate h
|
||||
if(c == 0) {
|
||||
h = 0;
|
||||
}
|
||||
else if(max == r) {
|
||||
h = 60.f * ((g - b) / c) + (g >= b ? 0 : 360);
|
||||
}
|
||||
else if(max == g) {
|
||||
h = 60.f * ((b - r) / c) + 120;
|
||||
}
|
||||
else { // max == b
|
||||
h = 60.f * ((r - g) / c) + 240;
|
||||
}
|
||||
|
||||
// calculate s
|
||||
s = max == (T)0 ? (T)0 : c / max;
|
||||
|
||||
// calculate v
|
||||
v = max / 255.f;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
FORCEINLINE _CUDA_HD void hsvToRgb(const T& h, const T& s, const T& v, T& r, T& g, T& b) {
|
||||
|
||||
const float sector = h / 60.f;
|
||||
const T c = v * s;
|
||||
|
||||
if(0.f <= sector && sector < 1.f) {
|
||||
r = v;
|
||||
g = v - c * (1 - sector);
|
||||
b = v - c;
|
||||
}
|
||||
else if(1.f <= sector && sector < 2.f) {
|
||||
r = v - c * (sector - 1);
|
||||
g = v;
|
||||
b = v - c;
|
||||
}
|
||||
else if(2.f <= sector && sector < 3.f) {
|
||||
r = v - c;
|
||||
g = v;
|
||||
b = v - c * (3 - sector);
|
||||
}
|
||||
else if(3.f <= sector && sector < 4.f) {
|
||||
r = v - c;
|
||||
g = v - c * (sector - 3);
|
||||
b = v;
|
||||
}
|
||||
else if(4.f <= sector && sector < 5.f) {
|
||||
r = v - c * (5 - sector);
|
||||
g = v - c;
|
||||
b = v;
|
||||
}
|
||||
else { // 5.f <= sector < 6.f
|
||||
r = v;
|
||||
g = v - c;
|
||||
b = v - c * (sector - 5);
|
||||
}
|
||||
|
||||
r *= 255;
|
||||
g *= 255;
|
||||
b *= 255;
|
||||
}
|
||||
|
||||
/*////////////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
static FORCEINLINE _CUDA_HD void rgb_to_hv(T r, T g, T b, T* h, T* v_min, T* v_max) {
|
||||
T v_mid;
|
||||
|
@ -83,6 +166,7 @@ namespace helpers {
|
|||
*h = h_category + (increase ? ratio : (1 - ratio));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
static FORCEINLINE _CUDA_HD void hv_to_rgb(T h, T v_min, T v_max, T* r, T* g, T* b) {
|
||||
int h_category = static_cast<int>(h);
|
||||
|
@ -128,7 +212,7 @@ namespace helpers {
|
|||
}
|
||||
}
|
||||
|
||||
void _adjust_hue(nd4j::LaunchContext * context, NDArray *input, NDArray *output, NDArray *delta, bool isNHWC);
|
||||
*/
|
||||
}
|
||||
}
|
||||
}
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
#include <op_boilerplate.h>
|
||||
|
@ -25,6 +26,10 @@
|
|||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
void adjustSaturation(nd4j::LaunchContext* context, const NDArray *input, const NDArray* factorScalarArr, NDArray *output, const int dimC);
|
||||
|
||||
/*
|
||||
template <typename T>
|
||||
static FORCEINLINE _CUDA_HD void rgb_to_hsv(T r, T g, T b, T* h, T* s, T* v) {
|
||||
T vv = nd4j::math::nd4j_max<T>(r, nd4j::math::nd4j_max<T>(g, b));
|
||||
|
@ -109,8 +114,8 @@ namespace helpers {
|
|||
*g = gg + m;
|
||||
*b = bb + m;
|
||||
}
|
||||
*/
|
||||
|
||||
void adjust_saturation(nd4j::LaunchContext * context, NDArray *input, NDArray *output, NDArray *delta, bool isNHWC);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -16,16 +16,84 @@
|
|||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
#include <ops/declarable/helpers/adjust_hue.h>
|
||||
#include <helpers/ConstantTadHelper.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
|
||||
template <typename T>
|
||||
static void _adjust_hue_single(nd4j::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC) {
|
||||
static void adjustHue_(const NDArray *input, const NDArray* deltaScalarArr, NDArray *output, const int dimC) {
|
||||
|
||||
const T delta = deltaScalarArr->e<T>(0);
|
||||
const int rank = input->rankOf();
|
||||
|
||||
const T* x = input->bufferAsT<T>();
|
||||
T* z = output->bufferAsT<T>();
|
||||
|
||||
if(dimC == rank - 1 && input->ews() == 1 && output->ews() == 1 && input->ordering() == 'c' && output->ordering() == 'c') {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||
for (Nd4jLong i = 0; i < input->lengthOf(); i += 3) {
|
||||
|
||||
T h, s, v;
|
||||
|
||||
rgbToHsv<T>(x[i], x[i+1], x[i+2], h, s, v);
|
||||
|
||||
h += delta * 360;
|
||||
if(h > 360)
|
||||
h -= 360;
|
||||
else if(h < 0)
|
||||
h += 360;
|
||||
|
||||
hsvToRgb<T>(h, s, v, z[i], z[i+1], z[i+2]);
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {dimC});
|
||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {dimC});
|
||||
|
||||
const Nd4jLong numOfTads = packX.numberOfTads();
|
||||
const Nd4jLong xDimCstride = input->stridesOf()[dimC];
|
||||
const Nd4jLong zDimCstride = output->stridesOf()[dimC];
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||
for(Nd4jLong i = 0; i < numOfTads; ++i) {
|
||||
|
||||
const T* xTad = x + packX.platformOffsets()[i];
|
||||
T* zTad = z + packZ.platformOffsets()[i];
|
||||
|
||||
T h, s, v;
|
||||
|
||||
rgbToHsv<T>(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], h, s, v);
|
||||
|
||||
h += delta * 360;
|
||||
if(h > 360)
|
||||
h -= 360;
|
||||
else if(h < 0)
|
||||
h += 360;
|
||||
|
||||
hsvToRgb<T>(h, s, v, zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]);
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void adjustHue(nd4j::LaunchContext* context, const NDArray *input, const NDArray* deltaScalarArr, NDArray *output, const int dimC) {
|
||||
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), adjustHue_, (input, deltaScalarArr, output, dimC), LIBND4J_TYPES);
|
||||
}
|
||||
|
||||
/*
|
||||
template <typename T>
|
||||
static void adjust_hue_single_(nd4j::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC) {
|
||||
// we're 100% sure it's 3
|
||||
const int numChannels = 3;
|
||||
int tuples = array->lengthOf() / numChannels;
|
||||
|
@ -93,7 +161,7 @@ namespace helpers {
|
|||
}
|
||||
}
|
||||
|
||||
void _adjust_hue(nd4j::LaunchContext * context, NDArray *array, NDArray *output, NDArray* delta, bool isNHWC) {
|
||||
void adjust_hue_(nd4j::LaunchContext * context, NDArray *array, NDArray *output, NDArray* delta, bool isNHWC) {
|
||||
auto xType = array->dataType();
|
||||
|
||||
float d = delta->e<float>(0);
|
||||
|
@ -104,18 +172,20 @@ namespace helpers {
|
|||
// FIXME: template selector should be moved out of loop
|
||||
PRAGMA_OMP_PARALLEL_FOR
|
||||
for (int e = 0; e < tSize; e++) {
|
||||
BUILD_SINGLE_SELECTOR(xType, _adjust_hue_single, (context, tadsIn->at(e), tadsOut->at(e), d, isNHWC);, FLOAT_TYPES);
|
||||
BUILD_SINGLE_SELECTOR(xType, adjust_hue_single_, (context, tadsIn->at(e), tadsOut->at(e), d, isNHWC);, FLOAT_TYPES);
|
||||
}
|
||||
|
||||
|
||||
delete tadsIn;
|
||||
delete tadsOut;
|
||||
} else {
|
||||
BUILD_SINGLE_SELECTOR(xType, _adjust_hue_single, (context, array, output, d, isNHWC);, FLOAT_TYPES);
|
||||
BUILD_SINGLE_SELECTOR(xType, adjust_hue_single_, (context, array, output, d, isNHWC);, FLOAT_TYPES);
|
||||
}
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void _adjust_hue_single, (nd4j::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC);, FLOAT_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template void adjust_hue_single_, (nd4j::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC);, FLOAT_TYPES);
|
||||
*/
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,15 +16,83 @@
|
|||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
#include <ops/declarable/helpers/adjust_saturation.h>
|
||||
#include <ops/declarable/helpers/adjust_hue.h>
|
||||
#include <helpers/ConstantTadHelper.h>
|
||||
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
template <typename T>
|
||||
static void adjustSaturation_(const NDArray *input, const NDArray* factorScalarArr, NDArray *output, const int dimC) {
|
||||
|
||||
const T factor = factorScalarArr->e<T>(0);
|
||||
const int rank = input->rankOf();
|
||||
|
||||
const T* x = input->bufferAsT<T>();
|
||||
T* z = output->bufferAsT<T>();
|
||||
|
||||
if(dimC == rank - 1 && input->ews() == 1 && output->ews() == 1 && input->ordering() == 'c' && output->ordering() == 'c') {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||
for (Nd4jLong i = 0; i < input->lengthOf(); i += 3) {
|
||||
|
||||
T h, s, v;
|
||||
|
||||
rgbToHsv<T>(x[i], x[i+1], x[i+2], h, s, v);
|
||||
|
||||
s *= factor;
|
||||
if(s > 1.f)
|
||||
s = 1.f;
|
||||
else if(s < 0.f)
|
||||
s = 0.f;
|
||||
|
||||
hsvToRgb<T>(h, s, v, z[i], z[i+1], z[i+2]);
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {dimC});
|
||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {dimC});
|
||||
|
||||
const Nd4jLong numOfTads = packX.numberOfTads();
|
||||
const Nd4jLong xDimCstride = input->stridesOf()[dimC];
|
||||
const Nd4jLong zDimCstride = output->stridesOf()[dimC];
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||
for(Nd4jLong i = 0; i < numOfTads; ++i) {
|
||||
|
||||
const T* xTad = x + packX.platformOffsets()[i];
|
||||
T* zTad = z + packZ.platformOffsets()[i];
|
||||
|
||||
T h, s, v;
|
||||
|
||||
rgbToHsv<T>(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], h, s, v);
|
||||
|
||||
s *= factor;
|
||||
if(s > 1.f)
|
||||
s = 1.f;
|
||||
else if(s < 0.f)
|
||||
s = 0.f;
|
||||
|
||||
hsvToRgb<T>(h, s, v, zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]);
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void adjustSaturation(nd4j::LaunchContext* context, const NDArray *input, const NDArray* factorScalarArr, NDArray *output, const int dimC) {
|
||||
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), adjustSaturation_, (input, factorScalarArr, output, dimC), LIBND4J_TYPES);
|
||||
}
|
||||
|
||||
/*
|
||||
template <typename T>
|
||||
static void adjust_saturation_single_(nd4j::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC) {
|
||||
// we're 100% sure it's 3
|
||||
|
@ -108,6 +176,7 @@ namespace helpers {
|
|||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void adjust_saturation_single_, (nd4j::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC), FLOAT_TYPES);
|
||||
*/
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -59,14 +59,17 @@ namespace helpers {
|
|||
std::vector<Nd4jLong> dims(reduceShape->lengthOf());
|
||||
|
||||
bool fit = true;
|
||||
|
||||
for( int i = 0; fit && (i < dims.size()); i++ ) {
|
||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(firstprivate(fit))
|
||||
for( int i = 0; i < dims.size(); i++ ) {
|
||||
if (fit) {
|
||||
dims[i] = reduceShape->e<Nd4jLong>(i);
|
||||
for (int e = 0; fit && (e < input->rankOf()); ++e)
|
||||
for (int e = 0; e < input->rankOf(); ++e)
|
||||
if (fit)
|
||||
if (input->sizeAt(e) % dims[i]) {
|
||||
fit = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check dims to fit input
|
||||
REQUIRE_TRUE(fit, 0, "dropout: Noise shape should fit to input rank.");
|
||||
|
|
|
@ -35,82 +35,88 @@ namespace helpers {
|
|||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void gruCell(nd4j::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* Wru, const NDArray* Wc,
|
||||
const NDArray* bru, const NDArray* bc,
|
||||
void gruCell(nd4j::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* W, const NDArray* Wc,
|
||||
const NDArray* b, const NDArray* bc,
|
||||
NDArray* r, NDArray* u, NDArray* c, NDArray* h) {
|
||||
|
||||
//Inputs:
|
||||
// x input [bS, nIn], nIn - input size
|
||||
// hLast previous cell output [bS, nUn], that is at previous time step t-1, nUn - number of units
|
||||
// Wru RU weights - [nIn+nUn, 2*nUn] - reset and update gates
|
||||
// Wc C weights - [nIn+nUn, nUn] - cell gate
|
||||
// bru r and u biases, [2*nUn] - reset and update gates
|
||||
// bc c biases, [nUn] - cell gate
|
||||
// x input [bS, iS], iS - input size
|
||||
// hLast previous cell output [bS, nU], that is at previous time step t-1, nU - number of units
|
||||
// W RU weights - [iS+nU, 2*nU] - reset and update gates
|
||||
// Wc C weights - [iS+nU, nU] - cell gate
|
||||
// b r and u biases, [2*nU] - reset and update gates
|
||||
// bc c biases, [nU] - cell gate
|
||||
|
||||
//Outputs:
|
||||
// r Reset gate output [bS, nUn]
|
||||
// u Update gate output [bS, nUn]
|
||||
// c Cell gate output [bS, nUn]
|
||||
// h current cell output [bS, nUn]
|
||||
// r Reset gate output [bS, nU]
|
||||
// u Update gate output [bS, nU]
|
||||
// c Cell gate output [bS, nU]
|
||||
// h current cell output [bS, nU]
|
||||
|
||||
/***************************************************************************************/
|
||||
/************************ THIS IS NOT OPTIMAZED CODE ***********************************/
|
||||
/** however it is more math-friendly and convenient for backprop formulas derivation) **/
|
||||
|
||||
const int bS = x->sizeAt(0);
|
||||
const int nIn = x->sizeAt(1);
|
||||
const int nUn = hLast->sizeAt(1);
|
||||
const int iS = x->sizeAt(1);
|
||||
const int nU = hLast->sizeAt(1);
|
||||
|
||||
NDArray Wr = (*Wru)({0,nIn, 0,0}); // reset gates weights [nIn, 2*nUn]
|
||||
NDArray Wu = (*Wru)({nIn,nIn+nUn, 0,0}); // updates gates weights [nUn, 2*nUn]
|
||||
NDArray Wrx = (*W)({0,iS, 0,nU}); // [iS, nU]
|
||||
NDArray Wux = (*W)({0,iS, nU,2*nU}); // [iS, nU]
|
||||
NDArray Wrh = (*W)({iS,iS+nU, 0,nU}); // [nU, nU]
|
||||
NDArray Wuh = (*W)({iS,iS+nU, nU,2*nU}); // [nU, nU]
|
||||
|
||||
NDArray Wcr = (*Wc)({0,nIn, 0,0}); // reset cell weights [nIn, nUn]
|
||||
NDArray Wcu = (*Wc)({nIn,nIn+nUn, 0,0}); // updates cell weights [nUn, nUn]
|
||||
NDArray Wcx = (*Wc)({0,iS, 0,0}); // reset cell weights [iS, nU]
|
||||
NDArray Wch = (*Wc)({iS,iS+nU, 0,0}); // updates cell weights [nU, nU]
|
||||
|
||||
// gates = sigmoid(x*Wr + hLast*Wu + br + bu)
|
||||
NDArray gates = mmul(*x, Wr) + mmul(*hLast, Wu) + *bru; // [bS, nIn] * [nIn, 2*nUn] + [bS, nUn] * [nUn, 2*nUn] + [2*nUn] = [bS, 2*nUn]
|
||||
gates.applyTransform(transform::Sigmoid);
|
||||
NDArray br = (*b)({0, nU}); // [nU]
|
||||
NDArray bu = (*b)({nU, 2*nU}); // [nU]
|
||||
|
||||
// × means matrix multipication
|
||||
// * means element-wise product or so called Hadamard product
|
||||
|
||||
// reset gate
|
||||
r->assign(gates({0,0, 0,nUn})); // [bS, nUn]
|
||||
r->assign(mmul(*x, Wrx) + mmul(*hLast, Wrh) + br); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
|
||||
r->applyTransform(transform::Sigmoid);
|
||||
|
||||
// update gate
|
||||
u->assign(gates({0,0, nUn,2*nUn})); // [bS, nUn]
|
||||
u->assign(mmul(*x, Wux) + mmul(*hLast, Wuh) + bu); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
|
||||
u->applyTransform(transform::Sigmoid);
|
||||
|
||||
// cell gate c = activation(x*Wcr + (r◦hlast)*Wcu + bc)
|
||||
c->assign(mmul(*x, Wcr) + mmul(*r * *hLast, Wcu) + *bc); // [bS, nIn] * [nIn, nUn] + [bS, nUn] * [nUn, nUn] + [nUn] = [bS, nUn]
|
||||
// cell gate c = activation(x × Wcx + (r * hlast) × Wch + bc)
|
||||
c->assign(mmul(*x, Wcx) + mmul(*r * *hLast, Wch) + *bc); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
|
||||
c->applyTransform(transform::Tanh);
|
||||
|
||||
NDArray temp = 1.f - *c * *c;
|
||||
|
||||
// cell output
|
||||
h->assign(*u * *hLast + (1.f - *u) * *c);
|
||||
|
||||
|
||||
|
||||
|
||||
/***************************************************************************************/
|
||||
/********************** THIS MORE OPTIMAZED CODE (except concat ) **********************/
|
||||
/*************** THIS IS MORE OPTIMAZED CODE (should think about concat) ***************/
|
||||
/***************************************************************************************/
|
||||
/*
|
||||
//Concat inputs: x + hLast : [bs, nIn + nUn]
|
||||
NDArray xhConcat(x->ordering(), {bS, nIn + nUn}, x->dataType(), context); // concat([bs, nIn], [bs, nUn]) -> [bs, nIn + nUn]
|
||||
//Concat inputs: x + hLast : [bs, iS + nU]
|
||||
NDArray xhConcat(x->ordering(), {bS, iS + nU}, x->dataType(), context); // concat([bs, iS], [bs, nU]) -> [bs, iS + nU]
|
||||
helpers::concat(context, {const_cast<NDArray*>(x), const_cast<NDArray*>(hLast)}, xhConcat, {1});
|
||||
|
||||
//mmul for reset and update gates: (x * weight_ux + hLast * weight_xr + b_u)
|
||||
auto m = mmul(xhConcat, *Wru) + *bru ; // [bs, nIn+nUn] * [nIn+nUn, 2*nUn] = [bs, 2*nUn]
|
||||
//mmul for reset and update gates: (x × weight_ux + hLast × weight_xr + b_u)
|
||||
auto m = mmul(xhConcat, *W) + *b ; // [bs, iS+nU] * [iS+nU, 2*nU] = [bs, 2*nU]
|
||||
// m += *bru;
|
||||
|
||||
sigmoidInplace(m); //sigmoid(rz) and sigmoid(uz)
|
||||
m.applyTransform(transform::Sigmoid); //sigmoid(rz) and sigmoid(uz)
|
||||
|
||||
r->assign(m({0,0, 0, nUn}));
|
||||
u->assign(m({0,0, nUn, 2*nUn}));
|
||||
r->assign(m({0,0, 0, nU}));
|
||||
u->assign(m({0,0, nU, 2*nU}));
|
||||
|
||||
// hLast = hLast * r
|
||||
xhConcat({0,0, nIn, nIn+nUn}) *= *r;
|
||||
xhConcat({0,0, iS, iS+nU}) *= *r;
|
||||
|
||||
//c = tanh(x * weight_cx + (hLast .* r) * weight_cr + b_c)
|
||||
//c = tanh(x × weight_cx + (hLast * r) × weight_cr + b_c)
|
||||
MmulHelper::mmul(&xhConcat, Wc, c, 1.0, 0.0); //c = 1.0 * xhConcat * Wc + 0.0 * c
|
||||
*c += *bc;
|
||||
tanhInplace(*c);
|
||||
c->applyTransform(transform::Tanh);
|
||||
|
||||
//Output: h = (1-u).*c + u .* hPrev
|
||||
//auto hResult = (*u) * (*hLast) + (1.0f - *u) * (*c); const_cast<NDArray*>(h)->assign(&hResult);
|
||||
|
@ -122,19 +128,19 @@ void gruCell(nd4j::LaunchContext * context, const NDArray* x, const NDArray* hLa
|
|||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void gruTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h) {
|
||||
void gruTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h) {
|
||||
|
||||
// x input [time, bS, iS]
|
||||
// h0 initial cell output (at time step = 0) [bS, nUn]
|
||||
// Wx input-to-hidden weights, [iS, 3*nUn]
|
||||
// Wh hidden-to-hidden weights, [nUn, 3*nUn]
|
||||
// b biases, [3*nUn]
|
||||
// hLast initial cell output (at time step = 0) [bS, nU]
|
||||
// Wx input-to-hidden weights, [iS, 3*nU]
|
||||
// Wh hidden-to-hidden weights, [nU, 3*nU]
|
||||
// b biases, [3*nU]
|
||||
|
||||
// h is cell outputs at each time step [time, bS, nUn]
|
||||
// h is cell outputs at each time step [time, bS, nU]
|
||||
|
||||
const int time = x->sizeAt(0);
|
||||
|
||||
NDArray ht_1(*h0);
|
||||
NDArray ht_1(*hLast);
|
||||
|
||||
// loop through time steps
|
||||
for (int t = 0; t < time; ++t) {
|
||||
|
@ -148,105 +154,208 @@ void gruTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray*
|
|||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void gruCellBP(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh, const NDArray* dLdWx0,
|
||||
const NDArray* dLdWh0, const NDArray* dLdb0, NDArray* dLdx, NDArray* dLdh0, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb) {
|
||||
void gruCellBP(nd4j::LaunchContext* context,
|
||||
const NDArray* x, const NDArray* hLast,
|
||||
const NDArray* W, const NDArray* Wc, const NDArray* b, const NDArray* bc,
|
||||
const NDArray* dLdr, const NDArray* dLdu, const NDArray* dLdc, const NDArray* dLdh,
|
||||
NDArray* dLdx, NDArray* dLdhLast,
|
||||
NDArray* dLdW, NDArray* dLdWc,
|
||||
NDArray* dLdb, NDArray* dLdbc) {
|
||||
|
||||
//Inputs:
|
||||
// x input [bS, iS]
|
||||
// h0 previous cell output [bS, nUn], that is at previous time step t-1
|
||||
// Wx input-to-hidden weights, [iS, 3*nUn]
|
||||
// Wh hidden-to-hidden weights, [nUn, 3*nUn]
|
||||
// b biases, [3*nUn]
|
||||
// dLdh gradient wrt output, [bS,nUn], that is epsilon_next
|
||||
// dLdWx0 gradient wrt Wx at previous time step, [iS, 3*nUn]
|
||||
// dLdWh0 gradient wrt Wh at previous time step, [nUn, 3*nUn]
|
||||
// dLdb0 gradient wrt b at previous time step, [3*nUn]
|
||||
// hLast previous cell output [bS, nU], that is at previous time step t-1
|
||||
// W weights - [iS+nU, 2*nU] - reset and update gates
|
||||
// Wc C weights - [iS+nU, nU] - cell gate
|
||||
// b r and u biases, [2*nU] - reset and update gates
|
||||
// bc c biases, [nU] - cell gate
|
||||
// dLdr gradient wrt reset gate, [bS, nU]
|
||||
// dLdu gradient wrt update gate, [bS, nU]
|
||||
// dLdc gradient wrt cell state, [bS, nU]
|
||||
// dLdh gradient wrt current cell output, [bS, nU]
|
||||
|
||||
// dLdx gradient wrt x, [bS, iS], that is epsilon
|
||||
// dLdh0 gradient wrt h0, [bS, nUn]
|
||||
// dLdWx gradient wrt Wx, [iS, 3*nUn]
|
||||
// dLdWh gradient wrt Wh, [nUn, 3*nUn]
|
||||
// dLdb gradient wrt b at previous time step, [3*nUn]
|
||||
//Outputs:
|
||||
// dLdx gradient wrt x, [bS, iS],
|
||||
// dLdhLast gradient wrt hLast, [bS, nU]
|
||||
// dLdW gradient wrt W, [iS+nU, 2*nU]
|
||||
// dLdWc gradient wrt Wc, [iS+nU, nU]
|
||||
// dLdb gradient wrt bru [2*nU]
|
||||
// dLdbc gradient wrt bc [nU]
|
||||
|
||||
// h is current cell output [bS, nUn], that is at current time step t
|
||||
// * means element-wise product or so called Hadamard product
|
||||
// × means matrix multiplication
|
||||
|
||||
/************************************************************************************************/
|
||||
/******************************* THIS IS NOT OPTIMAZED CODE *************************************/
|
||||
/*** aim is to have math-readable code in order to keep track of backprop formulas derivation ***/
|
||||
|
||||
const int bS = x->sizeAt(0);
|
||||
const int iS = x->sizeAt(1);
|
||||
const int nU = hLast->sizeAt(1);
|
||||
|
||||
NDArray xT = x->transpose(); // [iS, bS]
|
||||
NDArray hLastT = hLast->transpose(); // [nU, bS]
|
||||
|
||||
NDArray Wrx = (*W)({0,iS, 0,nU}); // [iS, nU]
|
||||
NDArray Wux = (*W)({0,iS, nU,2*nU}); // [iS, nU]
|
||||
NDArray Wrh = (*W)({iS,iS+nU, 0,nU}); // [nU, nU]
|
||||
NDArray Wuh = (*W)({iS,iS+nU, nU,2*nU}); // [nU, nU]
|
||||
|
||||
NDArray Wcx = (*Wc)({0,iS, 0,0}); // reset cell weights [iS, nU]
|
||||
NDArray Wch = (*Wc)({iS,iS+nU, 0,0}); // updates cell weights [nU, nU]
|
||||
|
||||
NDArray br = (*b)({0, nU}); // [nU]
|
||||
NDArray bu = (*b)({nU, 2*nU}); // [nU]
|
||||
|
||||
NDArray WrxT = Wrx.transpose(); // [nU, iS]
|
||||
NDArray WuxT = Wux.transpose(); // [nU, iS]
|
||||
NDArray WrhT = Wrh.transpose(); // [nU, nU]
|
||||
NDArray WuhT = Wuh.transpose(); // [nU, nU]
|
||||
|
||||
NDArray WcxT = Wcx.transpose(); // [nU, iS]
|
||||
NDArray WchT = Wch.transpose(); // [nU, nU]
|
||||
|
||||
NDArray dLdWrx = (*dLdW)({0,iS, 0,nU}); // [iS, nU]
|
||||
NDArray dLdWux = (*dLdW)({0,iS, nU,2*nU}); // [iS, nU]
|
||||
NDArray dLdWrh = (*dLdW)({iS,iS+nU, 0,nU}); // [nU, nU]
|
||||
NDArray dLdWuh = (*dLdW)({iS,iS+nU, nU,2*nU}); // [nU, nU]
|
||||
|
||||
NDArray dLdWcx = (*dLdWc)({0,iS, 0,0}); // [iS, nU]
|
||||
NDArray dLdWch = (*dLdWc)({iS,iS+nU, 0,0}); // [nU, nU]
|
||||
|
||||
NDArray dLdbr = (*dLdb)({0, nU}); // [nU]
|
||||
NDArray dLdbu = (*dLdb)({nU, 2*nU}); // [nU]
|
||||
|
||||
const int nUn = h0->sizeAt(1);
|
||||
|
||||
// ***** feed forward step ***** //
|
||||
// gates = sigmoid(x*Wx + h0*Wh + b)
|
||||
auto gates = sigmoid(mmul(*x, (*Wx)({0,0, 0,2*nUn})) + mmul(*h0, (*Wh)({0,0, 0,2*nUn})) + (*b)({0,2*nUn})); // [bS, 2*nUn] + [bS, 2*nUn] + [1, 2*nUn] = [bS, 2*nUn]
|
||||
|
||||
// reset gate
|
||||
auto r = gates({0,0, 0, nUn}); // [bS, nUn]
|
||||
NDArray r = mmul(*x, Wrx) + mmul(*hLast, Wrh) + br; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
|
||||
r.applyTransform(transform::Sigmoid);
|
||||
|
||||
// update gate
|
||||
auto u = gates({0,0, nUn, 2*nUn}); // [bS, nUn]
|
||||
// ◦ means element-wise product or so called Hadamard product
|
||||
// n = tanh(x*Wx + (r◦h0)*Wh + b)
|
||||
auto n = tanh(mmul(*x, (*Wx)({0,0, 2*nUn,3*nUn})) + mmul((*h0)*r, (*Wh)({0,0, 2*nUn,3*nUn})) + (*b)({2*nUn,3*nUn})); // [bS, nUn]
|
||||
NDArray u = mmul(*x, Wux) + mmul(*hLast, Wuh) + bu; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
|
||||
u.applyTransform(transform::Sigmoid);
|
||||
|
||||
// cell gate c = activation(x×Wcx + (r*hlast)×Wcu + bc)
|
||||
NDArray c = mmul(*x, Wcx) + mmul(r * *hLast, Wch) + *bc; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
|
||||
c.applyTransform(transform::Tanh);
|
||||
|
||||
// h = (1 - u) * c + u * hPrev
|
||||
|
||||
|
||||
// ***** back prop step ***** //
|
||||
auto Wxr = (*Wx)({0,0, 0, nUn});
|
||||
auto Wxu = (*Wx)({0,0, nUn, 2*nUn});
|
||||
auto Wxn = (*Wx)({0,0, 2*nUn,3*nUn});
|
||||
auto Whr = (*Wh)({0,0, 0, nUn});
|
||||
auto Whu = (*Wh)({0,0, nUn, 2*nUn});
|
||||
auto Whn = (*Wh)({0,0, 2*nUn,3*nUn});
|
||||
auto WxrT = Wxr.transpose();
|
||||
auto WxuT = Wxu.transpose();
|
||||
auto WxnT = Wxn.transpose();
|
||||
auto WhrT = Whr.transpose();
|
||||
auto WhuT = Whu.transpose();
|
||||
auto WhnT = Whn.transpose();
|
||||
auto xT = x->transpose();
|
||||
auto h0T = h0->transpose();
|
||||
|
||||
auto dLdWxr = (*dLdWx)({0,0, 0, nUn});
|
||||
auto dLdWxu = (*dLdWx)({0,0, nUn, 2*nUn});
|
||||
auto dLdWxn = (*dLdWx)({0,0, 2*nUn,3*nUn});
|
||||
// notations:
|
||||
// Zr = x × Wrx + hLast × Wrh + br
|
||||
// Zu = x × Wux + hLast × Wuh + bu
|
||||
// Sr = sigmoid(Zr)
|
||||
// Su = sigmoid(Zu)
|
||||
// Zc = x × Wcx + (r * hlast) × Wch + bc
|
||||
|
||||
auto dLdWhr = (*dLdWh)({0,0, 0, nUn});
|
||||
auto dLdWhu = (*dLdWh)({0,0, nUn, 2*nUn});
|
||||
auto dLdWhn = (*dLdWh)({0,0, 2*nUn,3*nUn});
|
||||
|
||||
auto dLdbr = (*dLdb)({0, nUn});
|
||||
auto dLdbu = (*dLdb)({nUn, 2*nUn});
|
||||
auto dLdbn = (*dLdb)({2*nUn,3*nUn});
|
||||
// dLdx = dLdh * dhdx = dLdh * (dhdu * dudx + dhdc * dcdx) = (dLdh * dhdu) * dudx + (dLdh * dhdc) * dcdx = dLdu * dudx + dLdc * dcdx
|
||||
// = dLdx_u + dLdx_c
|
||||
// dLdx_u = dLdu * dudx = dLdu * dudZu * dZudx = |dZudx = ... × WuxT| = (dLdu * dudZu) × WuxT
|
||||
// dLdx_c = dLdc * dcdx = dLdc * dcdZc * (dZcdx + dZcdr * drdx) = dLdc * dcdZc * dZcdx + dLdc * dcdZc * dZcdr * drdx = dLdx_c0 + dLdx_c1
|
||||
// dLdx_c0 = dLdc * dcdZc * dZcdx = |dZcdx = ... × WcxT| = (dLdc * dcdZc) × WcxT
|
||||
// dZcdr = (... * hLast) × WchT
|
||||
// dLdc * dcdZc * dZcdr = dLdr = (dLdc * dcdZc * hLast) × WchT
|
||||
// drdx = drdZr * dZrdx
|
||||
// dZrdx = ... × WrxT
|
||||
// dLdx_c1 = dLdc * dcdZc * dZcdr * drdx = dLdr * drdx = (dLdr * drdZr) × WrxT
|
||||
// finally dLdx = dLdx_u + dLdx_c0 + dLdx_c1 = (dLdu * dudZu) × WuxT + (dLdc * dcdZc) × WcxT + (dLdr * drdZr) × WrxT
|
||||
|
||||
auto dhdu = *h0 - n; // [bS, nUn]
|
||||
auto dhdn = 1.f - u; // [bS, nUn]
|
||||
auto dSigdu = u * (1.f - u); // [bS, nUn]
|
||||
auto dSigdr = r * (1.f - r); // [bS, nUn]
|
||||
auto dActdn = 1.f - n * n; // [bS, nUn]
|
||||
auto dndr = mmul(dActdn * (*h0), WhnT);
|
||||
auto drdh0 = mmul(dSigdr, WhrT);
|
||||
|
||||
auto dLdn = (*dLdh) * dhdn;
|
||||
auto dLdu = (*dLdh) * dhdu;
|
||||
auto dLdr = dLdn * dndr;
|
||||
// dLdhLast = dLdh * (dhdhLast + dhdu * dudhLast + dhdc * dcdhLast) = dLdh * dhdhLast + dLdu * dudhLast + dLdc * dcdhLast
|
||||
// = dLdhLast_h + dLdhLast_u + dLdhLast_c
|
||||
// dLdhLast_h = dLdh * dhdhLas = dLdh * u
|
||||
// dLdhLast_u = dLdu * dudhLast = |dudhLast = dudZu * dZudhLast , dZudhLast = ... × WuhT| = (dLdu * dudZu) × WuhT
|
||||
// dLdhLast_c = dLdc * dcdhLast = dLdc * (dcdZc * dZcdhLast + dcdZc * dZcdr * drdhLast) =
|
||||
// = dLdc * dcdZc * dZcdhLast + dLdc * dcdZc * dZcdr * drdhLast =
|
||||
// = dLdc * dcdZc * dZcdhLast + dLdr * drdhLast = dLdhLast_c0 + dLdhLast_c1
|
||||
// dLdhLast_c0 = dLdc * dcdZc * dZcdhLast = |dZcdhLast = (... * r) × WchT| = (dLdc * dcdZc * r) × WchT
|
||||
// dLdhLast_c1 = dLdr * drdhLast = |drdhLast = drdZr * dZrdhLast, dZrdhLast = ... × WrhT| = (dLdr * drdZr) × WrhT
|
||||
// finally dLdhLast = dLdhLast_h + dLdhLast_u + dLdhLast_c0 + dLdhLast_c1 =
|
||||
// = dLdh * u + (dLdu * dudZu) × WuhT + (dLdc * dcdZc * r) × WchT + (dLdr * drdZr) × WrhT
|
||||
|
||||
dLdx->assign( mmul(dLdu * dSigdu, WxuT) + mmul(dLdr * dSigdr, WxrT) + mmul(dLdn * dActdn, WxnT) ); // [bS,iS]
|
||||
dLdh0->assign( mmul(dLdu * dSigdu, WhuT) + mmul(dLdn * dActdn * (r + drdh0), WhnT) + (*dLdh)*u ); // [bS,nUn]
|
||||
|
||||
dLdWxr.assign( mmul(xT, dSigdr * dLdr) ); // [iS,nUn]
|
||||
dLdWhr.assign( mmul(h0T, dSigdr * dLdr) ); // [nUn,nUn]
|
||||
// dLdWrx = dLdh * dhdWrx = (dLdh * dhdc) * dcdWrx = dLdc * dcdZc * dZcdWrx = dLdc * dcdZc * dZcdr * drdWrx =
|
||||
// = dLdc * dcdZc * dZcdr * drdZr * dZrdWrx = dLdr * drdZr * dZrdWrx
|
||||
// dZrdWrx = xT × ...
|
||||
// finally dLdWrx = xT × (dLdr * drdZr)
|
||||
|
||||
dLdWxu.assign( mmul(xT, dSigdu * dLdu) ); // [iS,nUn]
|
||||
dLdWhu.assign( mmul(h0T, dSigdu * dLdu) ); // [nUn,nUn]
|
||||
|
||||
dLdWxn.assign( mmul(xT, dActdn * dLdn) ); // [iS,nUn]
|
||||
dLdWhn.assign( mmul((r*(*h0)).transpose(), dActdn * dLdn) ); // [nUn,nUn]
|
||||
// dLdWrh = dLdh * dhdWrh = (dLdh * dhdc) * dcdWrh = dLdc * dcdZc * dZcdWrh = dLdc * dcdZc * dZcdr * drdWrh =
|
||||
// = dLdc * dcdZc * dZcdr * drdZr * dZrdWrh = dLdr * drdZr * dZrdWrh
|
||||
// dZrdWrh = hLastT × ...
|
||||
// finally dLdWrh = hLastT × (dLdr * drdZr)
|
||||
|
||||
dLdbr.assign( (dSigdr * dLdr).reduceAlongDims(reduce::Sum, {0})); // [nUn]
|
||||
dLdbu.assign( (dSigdu * dLdu).reduceAlongDims(reduce::Sum, {0})); // [nUn]
|
||||
dLdbn.assign( (dActdn * dLdn).reduceAlongDims(reduce::Sum, {0})); // [nUn]
|
||||
|
||||
if(dLdWx0 != nullptr)
|
||||
*dLdWx += *dLdWx0;
|
||||
// dLdWux = dLdh * dhdWux = (dLdh * dhdu) * dudWux = dLdu * dudZu * dZudWux
|
||||
// dZudWux = xT × ...
|
||||
// dLdu * dudZu * dZudWux = xT × (dLdu * dudZu)
|
||||
|
||||
if(dLdWh0 != nullptr)
|
||||
*dLdWh += *dLdWh0;
|
||||
|
||||
if(dLdb0 != nullptr)
|
||||
*dLdb += *dLdb0;
|
||||
// dLdWuh = dLdh * dhdWuh = (dLdh * dhdu) * dudWuh = dLdh * dhdu * dudZu * dZudWuh = dLdu * dudZu * dZudWuh
|
||||
// dZudWuh = hLastT × ...
|
||||
// finally dLdWuh = hLastT × (dLdu * dudZu)
|
||||
|
||||
|
||||
// dLdWcx = dLdh * dhdWcx = dLdh * dhdc * dcdWcx = (dLdh * dhdc) * dcdZc * dZcdWcx = dLdc * dcdZc * dZcdWcx
|
||||
// dZcdWcx = xT × ...
|
||||
// finally dLdWcx = xT × (dLdc * dcdZc)
|
||||
|
||||
|
||||
// dLdWch = dLdh * dhdWch = dLdh * dhdc * dcdWch = (dLdh * dhdc) * dcdZc * dZcdWch = dLdc * dcdZc * dZcdWch
|
||||
// dZcdWch = (r*hLast)^T × ...
|
||||
// finally dLdWch = (r*hLast)^T × (dLdc * dcdZc)
|
||||
|
||||
|
||||
// dLdbr = dLdh * dhdbr = (dLdh * dhdc) * dcdbr = dLdc * dcdbr = dLdc * dcdZc * dZcdbr = dLdc * dcdZc * dZcdr * drdbr =
|
||||
// = dLdr * drdZr * dZrdbr
|
||||
// dZrdbr = 1
|
||||
// finally dLdbr = dLdr * drdZr
|
||||
|
||||
|
||||
// dLdbu = dLdh * dhdbu = (dLdh * dhdu) * dudbu = dLdu * dudZu * dZudbu
|
||||
// dZudbu = 1
|
||||
// finally dLdbu = dLdu * dudZu
|
||||
|
||||
|
||||
// dLdbc = dLdh * dhdbc = (dLdh * dhdc) * dcdbc = dLdc * dcdZc * dZcdbc
|
||||
// dZcdbc = 1
|
||||
// finally dLdbc = dLdc * dcdZc
|
||||
|
||||
NDArray dhdc = 1.f - u; // [bS, nU]
|
||||
NDArray dhdu = *hLast - c; // [bS, nU]
|
||||
NDArray dudZu = u * dhdc; // [bS, nU]
|
||||
NDArray drdZr = r * (1.f - r); // [bS, nU]
|
||||
NDArray dcdZc = 1.f - c * c; // [bS, nU]
|
||||
NDArray dLdZc = *dLdc * dcdZc; // [bS, nU]
|
||||
NDArray dLdZu = *dLdu * dudZu; // [bS, nU]
|
||||
NDArray dLdZr = *dLdr * drdZr; // [bS, nU]
|
||||
|
||||
// NDArray dLdc = *dLdh * dhdc; // [bS, nU]
|
||||
// NDArray dLdu = *dLdh * dhdu; // [bS, nU]
|
||||
// NDArray dLdr = mmul(dLdc * dcdZc * *hLast, WchT); // [bS, nU]
|
||||
|
||||
dLdx->assign(mmul(dLdZu, WuxT) + mmul(dLdZc, WcxT) + mmul(dLdZr, WrxT)); // [bS, iS]
|
||||
|
||||
dLdhLast->assign(*dLdh * u + mmul(dLdZu, WuhT) + mmul(dLdZc * r, WchT) + mmul(dLdZr, WrhT)); // [bS, nU]
|
||||
|
||||
dLdWrx.assign(mmul(xT, dLdZr)); // [iS, bS] × [bS, nU] = [iS, nU]
|
||||
dLdWrh.assign(mmul(hLastT, dLdZr)); // [nU, bS] × [bS, nU] = [nU, nU]
|
||||
dLdWux.assign(mmul(xT, dLdZu)); // [iS, bS] × [bS, nU] = [iS, nU]
|
||||
dLdWuh.assign(mmul(hLastT, dLdZu)); // [nU, bS] × [bS, nU] = [nU, nU]
|
||||
|
||||
dLdWcx.assign(mmul(xT, dLdZc)); // [iS, bS] × [bS, nU] = [iS, nU]
|
||||
dLdWch.assign(mmul((r * *hLast).transpose(), dLdZc)); // [nU, bS] × [bS, nU] = [nU, nU]
|
||||
|
||||
dLdbr.assign(dLdZr.reduceAlongDims(reduce::Sum, {0})); // [nU]
|
||||
dLdbu.assign(dLdZu.reduceAlongDims(reduce::Sum, {0})); // [nU]
|
||||
|
||||
dLdbc->assign(dLdZc.reduceAlongDims(reduce::Sum, {0})); // [nU]
|
||||
}
|
||||
|
||||
// //////////////////////////////////////////////////////////////////////////
|
||||
|
@ -255,34 +364,34 @@ void gruCellBP(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h
|
|||
// void gruTimeLoopBP(const std::vector<NDArray<T>*>& inArrs, const std::vector<NDArray<T>*>& outArrs) {
|
||||
|
||||
// NDArray<T>* x = inArrs[0]; // input [time, bS, iS]
|
||||
// NDArray<T>* hi = inArrs[1]; // previous/initial cell output [bS, nUn], that is at previous time step t-1
|
||||
// NDArray<T>* Wx = inArrs[2]; // input-to-hidden weights, [iS, 3*nUn]
|
||||
// NDArray<T>* Wh = inArrs[3]; // hidden-to-hidden weights, [nUn, 3*nUn]
|
||||
// NDArray<T>* b = inArrs[4]; // biases, [3*nUn]
|
||||
// NDArray<T>* dLdh = inArrs[5]; // gradient wrt output, [time, bS, nUn], that is epsilon_next
|
||||
// NDArray<T>* hi = inArrs[1]; // previous/initial cell output [bS, nU], that is at previous time step t-1
|
||||
// NDArray<T>* Wx = inArrs[2]; // input-to-hidden weights, [iS, 3*nU]
|
||||
// NDArray<T>* Wh = inArrs[3]; // hidden-to-hidden weights, [nU, 3*nU]
|
||||
// NDArray<T>* b = inArrs[4]; // biases, [3*nU]
|
||||
// NDArray<T>* dLdh = inArrs[5]; // gradient wrt output, [time, bS, nU], that is epsilon_next
|
||||
|
||||
// NDArray<T>* dLdx = outArrs[0]; // gradient wrt x, [time, bS, iS], that is epsilon
|
||||
// NDArray<T>* dLdhi = outArrs[1]; // gradient wrt hi, [bS, nUn]
|
||||
// NDArray<T>* dLdWx = outArrs[2]; // gradient wrt Wx, [iS, 3*nUn]
|
||||
// NDArray<T>* dLdWh = outArrs[3]; // gradient wrt Wh, [nUn, 3*nUn]
|
||||
// NDArray<T>* dLdb = outArrs[4]; // gradient wrt b, [3*nUn]
|
||||
// NDArray<T>* dLdhi = outArrs[1]; // gradient wrt hi, [bS, nU]
|
||||
// NDArray<T>* dLdWx = outArrs[2]; // gradient wrt Wx, [iS, 3*nU]
|
||||
// NDArray<T>* dLdWh = outArrs[3]; // gradient wrt Wh, [nU, 3*nU]
|
||||
// NDArray<T>* dLdb = outArrs[4]; // gradient wrt b, [3*nU]
|
||||
|
||||
// const Nd4jLong time = x->sizeAt(0);
|
||||
// const Nd4jLong bS = x->sizeAt(1);
|
||||
// const Nd4jLong iS = x->sizeAt(2);
|
||||
// const Nd4jLong nUn = hi->sizeAt(1);
|
||||
// const Nd4jLong nU = hi->sizeAt(1);
|
||||
|
||||
// NDArray<T> h(hi->ordering(), {time, bS, nUn}); // feed forward output
|
||||
// NDArray<T> h(hi->ordering(), {time, bS, nU}); // feed forward output
|
||||
|
||||
// // first step, time = 0, feed forward
|
||||
// NDArray<T> x0 = (*x)({{0,1}, {}, {}});
|
||||
// NDArray<T> h0 = h({{0,1}, {}, {}});
|
||||
// helpers::gruCell<T>({&x0, hi, Wx, Wh, b}, &h0);
|
||||
// NDArray<T> hLast = h({{0,1}, {}, {}});
|
||||
// helpers::gruCell<T>({&x0, hi, Wx, Wh, b}, &hLast);
|
||||
|
||||
// // first step, time = 0, back prop
|
||||
// NDArray<T> dLdx0 = (*dLdx)({{0,1}, {}, {}});
|
||||
// NDArray<T> dLdh0 = (*dLdh)({{0,1}, {}, {}});
|
||||
// helpers::gruCellBP<T>({&x0, hi, Wx, Wh, b, &dLdh0, nullptr, nullptr, nullptr}, {&dLdx0, dLdhi, dLdWx, dLdWh, dLdb});
|
||||
// NDArray<T> dLdhLast = (*dLdh)({{0,1}, {}, {}});
|
||||
// helpers::gruCellBP<T>({&x0, hi, Wx, Wh, b, &dLdhLast, nullptr, nullptr, nullptr}, {&dLdx0, dLdhi, dLdWx, dLdWh, dLdb});
|
||||
|
||||
// // loop through the rest time steps
|
||||
// for (Nd4jLong t = time-1; t > 0; --t) {
|
||||
|
@ -310,4 +419,3 @@ void gruCellBP(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -20,6 +20,8 @@
|
|||
|
||||
#include <ops/declarable/helpers/image_suppression.h>
|
||||
//#include <blas/NDArray.h>
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
|
@ -28,9 +30,8 @@ namespace helpers {
|
|||
template <typename T>
|
||||
static void nonMaxSuppressionV2_(NDArray* boxes, NDArray* scales, int maxSize, double threshold, NDArray* output) {
|
||||
std::vector<Nd4jLong> indices(scales->lengthOf());
|
||||
std::iota(indices.begin(), indices.end(), 0);
|
||||
|
||||
for (size_t i = 0; i < indices.size(); ++i)
|
||||
indices[i] = i;
|
||||
std::sort(indices.begin(), indices.end(), [scales](int i, int j) {return scales->e<T>(i) > scales->e<T>(j);});
|
||||
|
||||
// std::vector<int> selected(output->lengthOf());
|
||||
|
@ -62,13 +63,15 @@ namespace helpers {
|
|||
};
|
||||
// int numSelected = 0;
|
||||
int numBoxes = boxes->sizeAt(0);
|
||||
int numSelected = 0;
|
||||
|
||||
for (int i = 0, numSelected = 0; i < numBoxes && numSelected < output->lengthOf(); ++i) {
|
||||
bool shouldSelect = true;
|
||||
for (int i = 0; i < numBoxes; ++i) {
|
||||
bool shouldSelect = numSelected < output->lengthOf();
|
||||
PRAGMA_OMP_PARALLEL_FOR //_ARGS(firstprivate(numSelected))
|
||||
for (int j = numSelected - 1; j >= 0; --j) {
|
||||
if (shouldSelect)
|
||||
if (needToSuppressWithThreshold(*boxes, indices[i], indices[selectedIndices[j]], T(threshold))) {
|
||||
shouldSelect = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (shouldSelect) {
|
||||
|
|
|
@ -24,20 +24,20 @@ namespace nd4j {
|
|||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
template <typename T>
|
||||
template <typename I, typename B>
|
||||
static void sequenceMask_(NDArray* input, NDArray* output, int maxIndex) {
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_COLLAPSE(2)
|
||||
for (Nd4jLong i = 0; i < maxIndex; i++)
|
||||
for(Nd4jLong k = 0; k < input->lengthOf(); k++)
|
||||
if (i < input->e<int>(k))
|
||||
output->p<T>(k * maxIndex + i, T(1.0f));
|
||||
if (i < input->t<I>(k))
|
||||
output->t<B>(k * maxIndex + i) = B(true); //, T(1.0f));
|
||||
}
|
||||
|
||||
void sequenceMask(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int maxIndex) {
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), sequenceMask_, (input, output, maxIndex), LIBND4J_TYPES);
|
||||
BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sequenceMask_, (input, output, maxIndex), INTEGER_TYPES, BOOL_TYPES);
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void sequenceMask_, (NDArray* input, NDArray* output, int maxIndex), LIBND4J_TYPES);
|
||||
BUILD_DOUBLE_TEMPLATE(template void sequenceMask_, (NDArray* input, NDArray* output, int maxIndex), INTEGER_TYPES, BOOL_TYPES);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -27,6 +27,7 @@
|
|||
#include <helpers/TAD.h>
|
||||
#include <helpers/ConstantTadHelper.h>
|
||||
#include <Loops.h>
|
||||
#include <graph/RandomGenerator.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
|
@ -81,7 +82,7 @@ static void trace_(const NDArray& input, NDArray& output) {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
void randomShuffle_(NDArray& input, NDArray& output, nd4j::random::RandomBuffer& rng, const bool isInplace) {
|
||||
void randomShuffle_(NDArray& input, NDArray& output, nd4j::graph::RandomGenerator& rng, const bool isInplace) {
|
||||
|
||||
// check edge cases first
|
||||
int temp;
|
||||
|
@ -95,16 +96,16 @@ void randomShuffle_(NDArray& input, NDArray& output, nd4j::random::RandomBuffer&
|
|||
|
||||
// apply Fisher-Yates shuffle
|
||||
if(isInplace) {
|
||||
PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->tadThreshold())
|
||||
//PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->tadThreshold())
|
||||
for(int i = firstDim-1; i > 0; --i) {
|
||||
int r = rng.nextInt(0, i);
|
||||
int r = rng.relativeInt(i) % i;
|
||||
if(i == r)
|
||||
continue;
|
||||
T _e0 = input.e<T>(i);
|
||||
T _e1 = input.e<T>(r);
|
||||
T t0 = input.t<T>(i);
|
||||
T t1 = input.t<T>(r);
|
||||
//math::nd4j_swap<T>(input(i), input(r));
|
||||
input.p<T>(i, _e1);
|
||||
input.p<T>(r, _e0);
|
||||
input.t<T>(i) = t1;
|
||||
input.t<T>(r) = t0;
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
@ -113,12 +114,12 @@ void randomShuffle_(NDArray& input, NDArray& output, nd4j::random::RandomBuffer&
|
|||
output.p<T>(Nd4jLong(0), input.e<T>(0));
|
||||
PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->tadThreshold())
|
||||
for(int i = firstDim-1; i > 0; --i) {
|
||||
int r = rng.nextInt(0, i);
|
||||
output.p(i, input.e<T>(indices[r]));
|
||||
int r = rng.relativeInt(i) % i;
|
||||
output.t<T>(i) = input.t<T>(indices[r]);
|
||||
if(i == r)
|
||||
continue;
|
||||
|
||||
output.p(r, input.e<T>(indices[i]));
|
||||
output.t<T>(r) = input.t<T>(indices[i]);
|
||||
math::nd4j_swap<int>(indices[i], indices[r]);
|
||||
}
|
||||
rng.rewindH(firstDim-1);
|
||||
|
@ -132,9 +133,10 @@ void randomShuffle_(NDArray& input, NDArray& output, nd4j::random::RandomBuffer&
|
|||
|
||||
// apply Fisher-Yates shuffle
|
||||
if(isInplace) {
|
||||
PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->elementwiseThreshold())
|
||||
//PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->elementwiseThreshold())
|
||||
for(int i = firstDim - 1; i > 0; --i) {
|
||||
int r = rng.nextInt(0, i);
|
||||
int r = rng.relativeInt(i) % i;
|
||||
|
||||
if(i == r)
|
||||
continue;
|
||||
subArrsListIn->at(i)->swapUnsafe(*subArrsListIn->at(r));
|
||||
|
@ -146,9 +148,9 @@ void randomShuffle_(NDArray& input, NDArray& output, nd4j::random::RandomBuffer&
|
|||
std::vector<int> indices(firstDim);
|
||||
std::iota(indices.begin(), indices.end(), 0);
|
||||
bool isZeroShuffled = false;
|
||||
PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->tadThreshold())
|
||||
//PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->tadThreshold())
|
||||
for(int i = firstDim - 1; i > 0; --i) {
|
||||
int r = rng.nextInt(0, i);
|
||||
int r = rng.relativeInt(i) % i;
|
||||
subArrsListOut->at(i)->assign(subArrsListIn->at(indices[r]));
|
||||
if(r == 0)
|
||||
isZeroShuffled = true;
|
||||
|
@ -167,11 +169,11 @@ void randomShuffle_(NDArray& input, NDArray& output, nd4j::random::RandomBuffer&
|
|||
|
||||
}
|
||||
|
||||
void randomShuffle(nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::random::RandomBuffer& rng, const bool isInplace) {
|
||||
void randomShuffle(nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::graph::RandomGenerator& rng, const bool isInplace) {
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (input, output, rng, isInplace), LIBND4J_TYPES);
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void randomShuffle_, (NDArray& input, NDArray& output, nd4j::random::RandomBuffer& rng, const bool isInplace), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template void randomShuffle_, (NDArray& input, NDArray& output, nd4j::graph::RandomGenerator& rng, const bool isInplace), LIBND4J_TYPES);
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -16,15 +16,92 @@
|
|||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
#include <ops/declarable/helpers/adjust_hue.h>
|
||||
#include <helpers/ConstantTadHelper.h>
|
||||
#include <PointersManager.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
static void _CUDA_G adjustHueCuda(const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets,
|
||||
void* vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets,
|
||||
const Nd4jLong numOfTads, const T delta, const int dimC) {
|
||||
|
||||
const T* x = reinterpret_cast<const T*>(vx);
|
||||
T* z = reinterpret_cast<T*>(vz);
|
||||
|
||||
__shared__ int rank;
|
||||
__shared__ Nd4jLong xDimCstride, zDimCstride;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
rank = shape::rank(xShapeInfo);
|
||||
xDimCstride = shape::stride(xShapeInfo)[dimC];
|
||||
zDimCstride = shape::stride(zShapeInfo)[dimC];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) {
|
||||
|
||||
const T* xTad = x + xTadOffsets[i];
|
||||
T* zTad = z + zTadOffsets[i];
|
||||
|
||||
T h, s, v;
|
||||
|
||||
rgbToHsv<T>(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], h, s, v);
|
||||
|
||||
h += delta * 360;
|
||||
if(h > 360)
|
||||
h -= 360;
|
||||
else if(h < 0)
|
||||
h += 360;
|
||||
|
||||
hsvToRgb<T>(h, s, v, zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]);
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
static _CUDA_H void adjustHueCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream,
|
||||
const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets,
|
||||
void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets,
|
||||
const Nd4jLong numOfTads, const NDArray* deltaScalarArr, const int dimC) {
|
||||
|
||||
adjustHueCuda<T><<<blocksPerGrid, threadsPerBlock, 256, *stream>>>(vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, deltaScalarArr->e<T>(0), dimC);
|
||||
}
|
||||
BUILD_SINGLE_TEMPLATE(template void adjustHueCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, const NDArray* deltaScalarArr, const int dimC), LIBND4J_TYPES);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
void adjustHue(nd4j::LaunchContext* context, const NDArray *input, const NDArray* deltaScalarArr, NDArray *output, const int dimC) {
|
||||
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {dimC});
|
||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {dimC});
|
||||
|
||||
const Nd4jLong numOfTads = packX.numberOfTads();
|
||||
|
||||
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||
const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock;
|
||||
|
||||
PointersManager manager(context, "adjustHue");
|
||||
|
||||
NDArray::prepareSpecialUse({output}, {input, deltaScalarArr});
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), adjustHueCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformOffsets(), numOfTads, deltaScalarArr, dimC), LIBND4J_TYPES);
|
||||
NDArray::registerSpecialUse({output}, {input, deltaScalarArr});
|
||||
|
||||
manager.synchronize();
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
template <typename T>
|
||||
static void _CUDA_G adjustHueSingleNHWCKernel(void *xBuffer, Nd4jLong *xShapeInfo, void *zBuffer, Nd4jLong *zShapeInfo, Nd4jLong tuples, float delta) {
|
||||
int numChannels = 3;
|
||||
|
@ -134,11 +211,13 @@ namespace helpers {
|
|||
|
||||
float d = delta->e<float>(0);
|
||||
if (array->rankOf() == 4) {
|
||||
BUILD_SINGLE_SELECTOR(xType, _adjust_hue_batch, (context, array, output, d, isNHWC);, FLOAT_TYPES);
|
||||
} else {
|
||||
BUILD_SINGLE_SELECTOR(xType, _adjust_hue_single, (context, array, output, d, isNHWC);, FLOAT_TYPES);
|
||||
}
|
||||
}
|
||||
|
||||
*/
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,16 +16,93 @@
|
|||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
#include <ops/declarable/helpers/adjust_saturation.h>
|
||||
#include <ops/declarable/helpers/adjust_hue.h>
|
||||
#include <helpers/ConstantTadHelper.h>
|
||||
#include <PointersManager.h>
|
||||
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
static void _CUDA_G adjustSaturationCuda(const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets,
|
||||
void* vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets,
|
||||
const Nd4jLong numOfTads, const T factor, const int dimC) {
|
||||
|
||||
const T* x = reinterpret_cast<const T*>(vx);
|
||||
T* z = reinterpret_cast<T*>(vz);
|
||||
|
||||
__shared__ int rank;
|
||||
__shared__ Nd4jLong xDimCstride, zDimCstride;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
rank = shape::rank(xShapeInfo);
|
||||
xDimCstride = shape::stride(xShapeInfo)[dimC];
|
||||
zDimCstride = shape::stride(zShapeInfo)[dimC];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) {
|
||||
|
||||
const T* xTad = x + xTadOffsets[i];
|
||||
T* zTad = z + zTadOffsets[i];
|
||||
|
||||
T h, s, v;
|
||||
|
||||
rgbToHsv<T>(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], h, s, v);
|
||||
|
||||
s *= factor;
|
||||
if(s > 1.f)
|
||||
s = 1.f;
|
||||
else if(s < 0.f)
|
||||
s = 0.f;
|
||||
|
||||
hsvToRgb<T>(h, s, v, zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]);
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
static _CUDA_H void adjustSaturationCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream,
|
||||
const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets,
|
||||
void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets,
|
||||
const Nd4jLong numOfTads, const NDArray* factorScalarArr, const int dimC) {
|
||||
|
||||
adjustSaturationCuda<T><<<blocksPerGrid, threadsPerBlock, 256, *stream>>>(vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, factorScalarArr->e<T>(0), dimC);
|
||||
}
|
||||
BUILD_SINGLE_TEMPLATE(template void adjustSaturationCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, const NDArray* factorScalarArr, const int dimC), LIBND4J_TYPES);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
void adjustSaturation(nd4j::LaunchContext* context, const NDArray *input, const NDArray* factorScalarArr, NDArray *output, const int dimC) {
|
||||
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {dimC});
|
||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {dimC});
|
||||
|
||||
const Nd4jLong numOfTads = packX.numberOfTads();
|
||||
|
||||
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||
const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock;
|
||||
|
||||
PointersManager manager(context, "adjustSaturation");
|
||||
|
||||
NDArray::prepareSpecialUse({output}, {input, factorScalarArr});
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), adjustSaturationCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformOffsets(), numOfTads, factorScalarArr, dimC), LIBND4J_TYPES);
|
||||
NDArray::registerSpecialUse({output}, {input, factorScalarArr});
|
||||
|
||||
manager.synchronize();
|
||||
}
|
||||
|
||||
/*
|
||||
template <typename T>
|
||||
static void _CUDA_G adjustSaturationSingleNHWCKernel(void *xBuffer, Nd4jLong *xShapeInfo, void *zBuffer, Nd4jLong *zShapeInfo, Nd4jLong tuples, float delta) {
|
||||
int numChannels = 3;
|
||||
|
@ -129,7 +206,7 @@ namespace helpers {
|
|||
BUILD_SINGLE_SELECTOR(xType, _adjust_saturation_single, (context, array, output, d, isNHWC);, FLOAT_TYPES);
|
||||
}
|
||||
}
|
||||
|
||||
*/
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,20 +22,99 @@
|
|||
#include <NativeOps.h>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <cuda_exception.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
template <typename T>
|
||||
static void dropoutSimple(NDArray const* input, NDArray* output, double probValue, int seed) {
|
||||
static __global__ void dropoutSimpleKernel(void const* inputBuf, Nd4jLong const* inputShape, void* outputBuf, Nd4jLong* outputShape, double probVal, int inLen, nd4j::graph::RandomGenerator* nodeRng) {
|
||||
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
auto step = blockDim.x * gridDim.x;
|
||||
__shared__ T const* input;
|
||||
__shared__ T* output;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
input = reinterpret_cast<T const*>(inputBuf);
|
||||
output = reinterpret_cast<T*>(outputBuf);
|
||||
}
|
||||
BUILD_SINGLE_TEMPLATE(template void dropoutSimple, (NDArray const* input, NDArray* output, double probValue, int seed), FLOAT_TYPES);
|
||||
|
||||
for (Nd4jLong e = 0; e < inLen; ++e) {
|
||||
T val = nodeRng->relativeT(e, T(0.f), T(1.f));
|
||||
|
||||
if (double(val) < probVal)
|
||||
output[shape::getIndexOffset(e, outputShape, inLen)] = T(input[shape::getIndexOffset(e, inputShape, inLen)] / probVal);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void dropoutSimple(nd4j::LaunchContext* context, NDArray const* input, NDArray* output, double probValue, int seed) {
|
||||
nd4j::graph::RandomGenerator nodeRng(3019L, seed);
|
||||
int inLen = input->lengthOf();
|
||||
nd4j::graph::RandomGenerator* dRandom;
|
||||
auto stream = context->getCudaStream();
|
||||
NDArray::prepareSpecialUse({output}, {input});
|
||||
|
||||
auto err = cudaMalloc(&dRandom, sizeof(nd4j::graph::RandomGenerator));
|
||||
if (err) {
|
||||
throw cuda_exception::build("helpers::dropoutSimple: Cannot allocate device memory for random generator.", err);
|
||||
}
|
||||
err = cudaMemcpy(dRandom, &nodeRng, sizeof(nd4j::graph::RandomGenerator), cudaMemcpyHostToDevice);
|
||||
if (err) {
|
||||
throw cuda_exception::build("helpers::dropoutSimple: Cannot set up device memory for random generator.", err);
|
||||
}
|
||||
|
||||
dropoutSimpleKernel<T><<<128, 256, 1024, *stream>>>(input->getSpecialBuffer(), input->getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), probValue, inLen, dRandom);
|
||||
err = cudaFree(dRandom);
|
||||
if (err) {
|
||||
throw cuda_exception::build("helpers::dropoutSimple: Cannot deallocate device memory for random generator.", err);
|
||||
}
|
||||
NDArray::registerSpecialUse({output}, {input});
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void dropoutSimple, (nd4j::LaunchContext* context, NDArray const* input, NDArray* output, double probValue, int seed), FLOAT_TYPES);
|
||||
|
||||
template <typename T>
|
||||
int _dropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue) {
|
||||
|
||||
if (reduceShape == nullptr){
|
||||
dropoutSimple<T>(context.launchContext(), input, output, probValue, seed);
|
||||
}
|
||||
else {
|
||||
REQUIRE_TRUE(reduceShape->lengthOf() <= input->rankOf(), 0, "dropout: Noise shape should be fittable to input");
|
||||
|
||||
std::vector<Nd4jLong> dims(reduceShape->lengthOf());
|
||||
reduceShape->syncToHost(); // to ensure that follows are actual
|
||||
bool fit = true;
|
||||
// PRAGMA_OMP_PARALLEL_FOR_ARGS(firstprivate(fit))
|
||||
for( int i = 0; i < dims.size(); i++ ) {
|
||||
if (fit) {
|
||||
dims[i] = reduceShape->e<Nd4jLong>(i);
|
||||
for (int e = 0; e < input->rankOf(); ++e)
|
||||
if (fit)
|
||||
if (input->sizeAt(e) % dims[i]) {
|
||||
fit = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check dims to fit input
|
||||
REQUIRE_TRUE(fit, 0, "dropout: Noise shape should fit to input rank.");
|
||||
std::unique_ptr<NDArray> chunk(new NDArray('c', dims, output->dataType(), context.launchContext()));
|
||||
chunk->assign(1.f);
|
||||
//chunk->applyRandom<randomOps::DropOutInverted<T>>(rng, nullptr, chunk.get(), &probValue);
|
||||
//NativeOpExecutioner::execRandom(random::DropOutInverted, rng, chunk->buffer(), chunk->shapeInfo(), chunk->buffer(), chunk->shapeInfo(), &prob);
|
||||
dropoutSimple<T>(context.launchContext(), chunk.get(), chunk.get(), probValue, seed);
|
||||
// broadcast chunk to full matrix
|
||||
std::unique_ptr<NDArray> dropOutMultiplier(new NDArray(*input));
|
||||
dropOutMultiplier->assign(1.f);
|
||||
|
||||
*dropOutMultiplier += *chunk;
|
||||
|
||||
output->assign(*input * *dropOutMultiplier); //input->applyPairwiseTransform(pairwise::Multiply, dropOutMultiplier.get(), output, nullptr);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -48,14 +127,121 @@ namespace helpers {
|
|||
BUILD_SINGLE_TEMPLATE(template int _dropOutFunctor, (graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue);, FLOAT_TYPES);
|
||||
|
||||
/////////////////////////////////// backrpopagations ///////////////////////////////////////////////
|
||||
template <typename T>
|
||||
static __global__ void dropoutBPKernel(void* outputBuf, Nd4jLong* outputShape, void* gradOutBuf, Nd4jLong* gradOutShape, double probValue) {
|
||||
__shared__ T* output;
|
||||
__shared__ T* input;
|
||||
__shared__ int len;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
len = shape::length(outputShape);
|
||||
output = reinterpret_cast<T*>(outputBuf);
|
||||
input = reinterpret_cast<T*>(gradOutBuf);
|
||||
}
|
||||
|
||||
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
auto step = blockDim.x * gridDim.x;
|
||||
|
||||
for (int e = tid; e < len; e += step) {
|
||||
if (output[shape::getIndexOffset(e, outputShape, len)] != T(0.))
|
||||
output[shape::getIndexOffset(e, outputShape, len)] = T(input[shape::getIndexOffset(e, gradOutShape, len)] / probValue);
|
||||
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
static int dropOutFunctorBP_(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue) {
|
||||
return Status::OK();
|
||||
int res = dropOutFunctor(context, input, output, reduceShape, seed, probValue);
|
||||
auto stream = context.launchContext()->getCudaStream();
|
||||
|
||||
if (ND4J_STATUS_OK == res)
|
||||
dropoutBPKernel<T><<<128, 256, 1024, *stream>>>(output->specialBuffer(), output->specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), probValue);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static __global__ void alphaDropoutSimpleKernel(void const* inputBuf, Nd4jLong const* inputShape, void* outputBuf, Nd4jLong* outputShape, double probValue, double alpha, double alpha1, double beta, int inLen, nd4j::graph::RandomGenerator* nodeRng) {
|
||||
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
auto step = blockDim.x * gridDim.x;
|
||||
__shared__ T const* input;
|
||||
__shared__ T* output;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
input = reinterpret_cast<T const*>(inputBuf);
|
||||
output = reinterpret_cast<T*>(outputBuf);
|
||||
}
|
||||
|
||||
for (auto e = tid; e < inLen; e += step) {
|
||||
T val = nodeRng->relativeT(e, T(0.f), T(1.f));
|
||||
T xVal = input[shape::getIndexOffset(e, inputShape, inLen)];
|
||||
output[shape::getIndexOffset(e, outputShape, inLen)] = (val >= T(probValue) ? T(alpha * beta + alpha1) : T(alpha * (double)xVal + alpha1));
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
static void alphaDropoutSimple(nd4j::LaunchContext* context, NDArray const* input, NDArray* output, int seed, double probValue, double alpha, double alpha1, double beta) {
|
||||
nd4j::graph::RandomGenerator nodeRng(3019L, seed), *dRandom;
|
||||
auto stream = context->getCudaStream();
|
||||
auto err = cudaMalloc(&dRandom, sizeof(nd4j::graph::RandomGenerator));
|
||||
NDArray::prepareSpecialUse({output}, {input});
|
||||
if (err) {
|
||||
throw cuda_exception::build("helpers::alphaDropoutSimple: Cannot allocate device memory for random generator.", err);
|
||||
}
|
||||
err = cudaMemcpy(dRandom, &nodeRng, sizeof(nd4j::graph::RandomGenerator), cudaMemcpyHostToDevice);
|
||||
if (err) {
|
||||
throw cuda_exception::build("helpers::alphaDropoutSimple: Cannot set up device memory for random generator.", err);
|
||||
}
|
||||
|
||||
alphaDropoutSimpleKernel<T><<<128, 256, 1024, *stream>>>(input->getSpecialBuffer(), input->getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), probValue, alpha, alpha1, beta, output->lengthOf(), dRandom);
|
||||
|
||||
err = cudaFree(dRandom);
|
||||
if (err) {
|
||||
throw cuda_exception::build("helpers::alphaDropoutSimple: Cannot deallocate device memory for random generator.", err);
|
||||
}
|
||||
NDArray::registerSpecialUse({output}, {input});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static int alphaDropOutFunctor_(graph::Context& context, NDArray* input, NDArray* output,
|
||||
NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) {
|
||||
|
||||
if (reduceShape == nullptr){
|
||||
alphaDropoutSimple<T>(context.launchContext(), input, output, seed, probValue, alpha, alpha1, beta);
|
||||
}
|
||||
else {
|
||||
REQUIRE_TRUE(reduceShape->lengthOf() <= input->rankOf(), 0, "dropout: Noise shape should be fittable to input");
|
||||
|
||||
std::vector<Nd4jLong> dims(reduceShape->lengthOf());
|
||||
reduceShape->syncToHost(); // to ensure that follows are actual
|
||||
bool fit = true;
|
||||
// PRAGMA_OMP_PARALLEL_FOR_ARGS(firstprivate(fit))
|
||||
for( int i = 0; i < dims.size(); i++ ) {
|
||||
if (fit) {
|
||||
dims[i] = reduceShape->e<Nd4jLong>(i);
|
||||
for (int e = 0; e < input->rankOf(); ++e)
|
||||
if (fit)
|
||||
if (input->sizeAt(e) % dims[i]) {
|
||||
fit = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check dims to fit input
|
||||
REQUIRE_TRUE(fit, 0, "alpha_dropout: Noise shape should fit to input rank.");
|
||||
std::unique_ptr<NDArray> chunk(new NDArray('c', dims, output->dataType(), context.launchContext()));
|
||||
chunk->assign(1.f);
|
||||
//chunk->applyRandom<randomOps::DropOutInverted<T>>(rng, nullptr, chunk.get(), &probValue);
|
||||
//NativeOpExecutioner::execRandom(random::DropOutInverted, rng, chunk->buffer(), chunk->shapeInfo(), chunk->buffer(), chunk->shapeInfo(), &prob);
|
||||
alphaDropoutSimple<T>(context.launchContext(), chunk.get(), chunk.get(), seed, probValue, alpha, alpha1, beta);
|
||||
// broadcast chunk to full matrix
|
||||
std::unique_ptr<NDArray> dropOutMultiplier(new NDArray(*input));
|
||||
dropOutMultiplier->assign(1.f);
|
||||
|
||||
*dropOutMultiplier += *chunk;
|
||||
|
||||
output->assign(*input * *dropOutMultiplier); //input->applyPairwiseTransform(pairwise::Multiply, dropOutMultiplier.get(), output, nullptr);
|
||||
}
|
||||
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -63,7 +249,12 @@ namespace helpers {
|
|||
int alphaDropOutFunctorBP_(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output,
|
||||
NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) {
|
||||
|
||||
return Status::OK();
|
||||
int res = alphaDropOutFunctor(context, input, output, reduceShape, seed, probValue, alpha, alpha1, beta);
|
||||
if (res == ND4J_STATUS_OK) {
|
||||
(*output) *= alpha;
|
||||
(*output) *= (*gradOut); //->applyPairwiseTransform<transform::Multiply>(gradOut, output, nullptr);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
int dropOutFunctorBP(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue) {
|
||||
|
|
|
@ -35,58 +35,88 @@ namespace helpers {
|
|||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void gruCell(nd4j::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* Wru, const NDArray* Wc,
|
||||
const NDArray* bru, const NDArray* bc,
|
||||
void gruCell(nd4j::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* W, const NDArray* Wc,
|
||||
const NDArray* b, const NDArray* bc,
|
||||
NDArray* r, NDArray* u, NDArray* c, NDArray* h) {
|
||||
|
||||
//Inputs:
|
||||
// x input [bS x inSize]
|
||||
// hLast previous cell output [bS x numUnits], that is at previous time step t-1
|
||||
// Wru RU weights - [bS, 2*numUnits] - reset and update gates
|
||||
// Wc C weights - [bS, numUnits] - cell gate
|
||||
// bru r and u biases, [2*numUnits] - reset and update gates
|
||||
// bc c biases, [numUnits] - cell gate
|
||||
// x input [bS, iS], iS - input size
|
||||
// hLast previous cell output [bS, nU], that is at previous time step t-1, nU - number of units
|
||||
// W RU weights - [iS+nU, 2*nU] - reset and update gates
|
||||
// Wc C weights - [iS+nU, nU] - cell gate
|
||||
// b r and u biases, [2*nU] - reset and update gates
|
||||
// bc c biases, [nU] - cell gate
|
||||
|
||||
//Outputs:
|
||||
// r Reset gate output [bS, numUnits]
|
||||
// u Update gate output [bS, numUnits]
|
||||
// c Cell gate output [bS, numUnits]
|
||||
// h current cell output [bS, numUnits]
|
||||
// r Reset gate output [bS, nU]
|
||||
// u Update gate output [bS, nU]
|
||||
// c Cell gate output [bS, nU]
|
||||
// h current cell output [bS, nU]
|
||||
|
||||
const int nIn = x->sizeAt(1);
|
||||
const int nU = hLast->sizeAt(1); // number of units
|
||||
/***************************************************************************************/
|
||||
/************************ THIS IS NOT OPTIMAZED CODE ***********************************/
|
||||
/** however it is more math-friendly and convenient for backprop formulas derivation) **/
|
||||
|
||||
//Concat inputs: [x, yt-1]: concat([bs,nIn],[bs,nOut]) -> [bs, (nIn+nOut)]
|
||||
nd4j::ops::concat concatOp;
|
||||
std::vector<NDArray*> inputs;
|
||||
std::vector<double> targs;
|
||||
std::vector<Nd4jLong> iargs({1}); //Axis = 1
|
||||
std::vector<bool> bargs;
|
||||
inputs.emplace_back(const_cast<NDArray*>(x));
|
||||
inputs.emplace_back(const_cast<NDArray*>(hLast));
|
||||
const int bS = x->sizeAt(0);
|
||||
const int iS = x->sizeAt(1);
|
||||
const int nU = hLast->sizeAt(1);
|
||||
|
||||
auto result = concatOp.execute(inputs, targs, iargs, bargs);
|
||||
auto concatOut = result->at(0);
|
||||
NDArray Wrx = (*W)({0,iS, 0,nU}); // [iS, nU]
|
||||
NDArray Wux = (*W)({0,iS, nU,2*nU}); // [iS, nU]
|
||||
NDArray Wrh = (*W)({iS,iS+nU, 0,nU}); // [nU, nU]
|
||||
NDArray Wuh = (*W)({iS,iS+nU, nU,2*nU}); // [nU, nU]
|
||||
|
||||
//mmul/z for reset and update gates: (x * weight_ux + hLast * weight_xr + b_u)
|
||||
auto m = mmul(*concatOut, *Wru); //mmul: [bs, (nIn+numUnits)]* [(inSize+numUnits), 2*numUnits] = [bs, 4*numUnits]
|
||||
m += (*bru);
|
||||
NDArray Wcx = (*Wc)({0,iS, 0,0}); // reset cell weights [iS, nU]
|
||||
NDArray Wch = (*Wc)({iS,iS+nU, 0,0}); // updates cell weights [nU, nU]
|
||||
|
||||
sigmoidInplace(m); //sigmoid(rz) and sigmoid(uz)
|
||||
auto mr = m({0,0, 0, nU});
|
||||
auto mu = m({0,0, nU, 2*nU});
|
||||
NDArray br = (*b)({0, nU}); // [nU]
|
||||
NDArray bu = (*b)({nU, 2*nU}); // [nU]
|
||||
|
||||
r->assign(&mr);
|
||||
u->assign(&mu);
|
||||
// × means matrix multipication
|
||||
// * means element-wise product or so called Hadamard product
|
||||
|
||||
//Concatenated inputs: [x, yt-1 .* r]
|
||||
auto yr = (*concatOut)({0,0, nIn, nIn+nU});
|
||||
yr *= (*r);
|
||||
// reset gate
|
||||
r->assign(mmul(*x, Wrx) + mmul(*hLast, Wrh) + br); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
|
||||
r->applyTransform(transform::Sigmoid);
|
||||
|
||||
//c = tanh(x * weight_cx + (hLast .* r) * weight_cr + b_c)
|
||||
MmulHelper::mmul(concatOut, const_cast<NDArray*>(Wc), c, 1.0, 0.0); //c = 1.0 * concatOut * Wc + 0.0 * c
|
||||
// update gate
|
||||
u->assign(mmul(*x, Wux) + mmul(*hLast, Wuh) + bu); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
|
||||
u->applyTransform(transform::Sigmoid);
|
||||
|
||||
// cell gate c = activation(x × Wcx + (r * hlast) × Wch + bc)
|
||||
c->assign(mmul(*x, Wcx) + mmul(*r * *hLast, Wch) + *bc); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
|
||||
c->applyTransform(transform::Tanh);
|
||||
|
||||
NDArray temp = 1.f - *c * *c;
|
||||
|
||||
// cell output
|
||||
h->assign(*u * *hLast + (1.f - *u) * *c);
|
||||
|
||||
|
||||
/***************************************************************************************/
|
||||
/*************** THIS IS MORE OPTIMAZED CODE (should think about concat) ***************/
|
||||
/***************************************************************************************/
|
||||
/*
|
||||
//Concat inputs: x + hLast : [bs, iS + nU]
|
||||
NDArray xhConcat(x->ordering(), {bS, iS + nU}, x->dataType(), context); // concat([bs, iS], [bs, nU]) -> [bs, iS + nU]
|
||||
helpers::concat(context, {const_cast<NDArray*>(x), const_cast<NDArray*>(hLast)}, xhConcat, {1});
|
||||
|
||||
//mmul for reset and update gates: (x × weight_ux + hLast × weight_xr + b_u)
|
||||
auto m = mmul(xhConcat, *W) + *b ; // [bs, iS+nU] * [iS+nU, 2*nU] = [bs, 2*nU]
|
||||
// m += *bru;
|
||||
|
||||
m.applyTransform(transform::Sigmoid); //sigmoid(rz) and sigmoid(uz)
|
||||
|
||||
r->assign(m({0,0, 0, nU}));
|
||||
u->assign(m({0,0, nU, 2*nU}));
|
||||
|
||||
// hLast = hLast * r
|
||||
xhConcat({0,0, iS, iS+nU}) *= *r;
|
||||
|
||||
//c = tanh(x × weight_cx + (hLast * r) × weight_cr + b_c)
|
||||
MmulHelper::mmul(&xhConcat, Wc, c, 1.0, 0.0); //c = 1.0 * xhConcat * Wc + 0.0 * c
|
||||
*c += *bc;
|
||||
tanhInplace(*c);
|
||||
c->applyTransform(transform::Tanh);
|
||||
|
||||
//Output: h = (1-u).*c + u .* hPrev
|
||||
//auto hResult = (*u) * (*hLast) + (1.0f - *u) * (*c); const_cast<NDArray*>(h)->assign(&hResult);
|
||||
|
@ -94,115 +124,238 @@ void gruCell(nd4j::LaunchContext * context, const NDArray* x, const NDArray* hLa
|
|||
auto temp = (1.0f - *u);
|
||||
temp *= (*c);
|
||||
(*h) += temp;
|
||||
|
||||
delete result;
|
||||
*/
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void gruTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h) {
|
||||
void gruTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h) {
|
||||
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void gruCellBP(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh, const NDArray* dLdWx0,
|
||||
const NDArray* dLdWh0, const NDArray* dLdb0, NDArray* dLdx, NDArray* dLdh0, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb) {
|
||||
|
||||
// x input [bS, iS]
|
||||
// h0 previous cell output [bS, nU], that is at previous time step t-1
|
||||
// x input [time, bS, iS]
|
||||
// hLast initial cell output (at time step = 0) [bS, nU]
|
||||
// Wx input-to-hidden weights, [iS, 3*nU]
|
||||
// Wh hidden-to-hidden weights, [nU, 3*nU]
|
||||
// b biases, [3*nU]
|
||||
// dLdh gradient wrt output, [bS,nU], that is epsilon_next
|
||||
// dLdWx0 gradient wrt Wx at previous time step, [iS, 3*nU]
|
||||
// dLdWh0 gradient wrt Wh at previous time step, [nU, 3*nU]
|
||||
// dLdb0 gradient wrt b at previous time step, [3*nU]
|
||||
|
||||
// dLdx gradient wrt x, [bS, iS], that is epsilon
|
||||
// dLdh0 gradient wrt h0, [bS, nU]
|
||||
// dLdWx gradient wrt Wx, [iS, 3*nU]
|
||||
// dLdWh gradient wrt Wh, [nU, 3*nU]
|
||||
// dLdb gradient wrt b at previous time step, [3*nU]
|
||||
// h is cell outputs at each time step [time, bS, nU]
|
||||
|
||||
// h is current cell output [bS, nU], that is at current time step t
|
||||
const int time = x->sizeAt(0);
|
||||
|
||||
NDArray ht_1(*hLast);
|
||||
|
||||
// loop through time steps
|
||||
for (int t = 0; t < time; ++t) {
|
||||
|
||||
auto xt = (*x)({t,t+1, 0,0, 0,0});
|
||||
auto ht = (*h)({t,t+1, 0,0, 0,0});
|
||||
|
||||
// helpers::gruCell(&xt, &ht_1, Wx, Wh, b, &ht);
|
||||
// ht_1.assign(ht);
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void gruCellBP(nd4j::LaunchContext* context,
|
||||
const NDArray* x, const NDArray* hLast,
|
||||
const NDArray* W, const NDArray* Wc, const NDArray* b, const NDArray* bc,
|
||||
const NDArray* dLdr, const NDArray* dLdu, const NDArray* dLdc, const NDArray* dLdh,
|
||||
NDArray* dLdx, NDArray* dLdhLast,
|
||||
NDArray* dLdW, NDArray* dLdWc,
|
||||
NDArray* dLdb, NDArray* dLdbc) {
|
||||
|
||||
//Inputs:
|
||||
// x input [bS, iS]
|
||||
// hLast previous cell output [bS, nU], that is at previous time step t-1
|
||||
// W weights - [iS+nU, 2*nU] - reset and update gates
|
||||
// Wc C weights - [iS+nU, nU] - cell gate
|
||||
// b r and u biases, [2*nU] - reset and update gates
|
||||
// bc c biases, [nU] - cell gate
|
||||
// dLdr gradient wrt reset gate, [bS, nU]
|
||||
// dLdu gradient wrt update gate, [bS, nU]
|
||||
// dLdc gradient wrt cell state, [bS, nU]
|
||||
// dLdh gradient wrt current cell output, [bS, nU]
|
||||
|
||||
//Outputs:
|
||||
// dLdx gradient wrt x, [bS, iS],
|
||||
// dLdhLast gradient wrt hLast, [bS, nU]
|
||||
// dLdW gradient wrt W, [iS+nU, 2*nU]
|
||||
// dLdWc gradient wrt Wc, [iS+nU, nU]
|
||||
// dLdb gradient wrt bru [2*nU]
|
||||
// dLdbc gradient wrt bc [nU]
|
||||
|
||||
// * means element-wise product or so called Hadamard product
|
||||
// × means matrix multiplication
|
||||
|
||||
/************************************************************************************************/
|
||||
/******************************* THIS IS NOT OPTIMAZED CODE *************************************/
|
||||
/*** aim is to have math-readable code in order to keep track of backprop formulas derivation ***/
|
||||
|
||||
const int bS = x->sizeAt(0);
|
||||
const int iS = x->sizeAt(1);
|
||||
const int nU = hLast->sizeAt(1);
|
||||
|
||||
NDArray xT = x->transpose(); // [iS, bS]
|
||||
NDArray hLastT = hLast->transpose(); // [nU, bS]
|
||||
|
||||
NDArray Wrx = (*W)({0,iS, 0,nU}); // [iS, nU]
|
||||
NDArray Wux = (*W)({0,iS, nU,2*nU}); // [iS, nU]
|
||||
NDArray Wrh = (*W)({iS,iS+nU, 0,nU}); // [nU, nU]
|
||||
NDArray Wuh = (*W)({iS,iS+nU, nU,2*nU}); // [nU, nU]
|
||||
|
||||
NDArray Wcx = (*Wc)({0,iS, 0,0}); // reset cell weights [iS, nU]
|
||||
NDArray Wch = (*Wc)({iS,iS+nU, 0,0}); // updates cell weights [nU, nU]
|
||||
|
||||
NDArray br = (*b)({0, nU}); // [nU]
|
||||
NDArray bu = (*b)({nU, 2*nU}); // [nU]
|
||||
|
||||
NDArray WrxT = Wrx.transpose(); // [nU, iS]
|
||||
NDArray WuxT = Wux.transpose(); // [nU, iS]
|
||||
NDArray WrhT = Wrh.transpose(); // [nU, nU]
|
||||
NDArray WuhT = Wuh.transpose(); // [nU, nU]
|
||||
|
||||
NDArray WcxT = Wcx.transpose(); // [nU, iS]
|
||||
NDArray WchT = Wch.transpose(); // [nU, nU]
|
||||
|
||||
NDArray dLdWrx = (*dLdW)({0,iS, 0,nU}); // [iS, nU]
|
||||
NDArray dLdWux = (*dLdW)({0,iS, nU,2*nU}); // [iS, nU]
|
||||
NDArray dLdWrh = (*dLdW)({iS,iS+nU, 0,nU}); // [nU, nU]
|
||||
NDArray dLdWuh = (*dLdW)({iS,iS+nU, nU,2*nU}); // [nU, nU]
|
||||
|
||||
NDArray dLdWcx = (*dLdWc)({0,iS, 0,0}); // [iS, nU]
|
||||
NDArray dLdWch = (*dLdWc)({iS,iS+nU, 0,0}); // [nU, nU]
|
||||
|
||||
NDArray dLdbr = (*dLdb)({0, nU}); // [nU]
|
||||
NDArray dLdbu = (*dLdb)({nU, 2*nU}); // [nU]
|
||||
|
||||
const int nU = h0->sizeAt(1);
|
||||
|
||||
// ***** feed forward step ***** //
|
||||
// gates = sigmoid(x*Wx + h0*Wh + b)
|
||||
auto gates = sigmoid(mmul(*x, (*Wx)({0,0, 0,2*nU})) + mmul(*h0, (*Wh)({0,0, 0,2*nU})) + (*b)({0,2*nU})); // [bS, 2*nU] + [bS, 2*nU] + [1, 2*nU] = [bS, 2*nU]
|
||||
|
||||
// reset gate
|
||||
auto r = gates({0,0, 0, nU}); // [bS, nU]
|
||||
NDArray r = mmul(*x, Wrx) + mmul(*hLast, Wrh) + br; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
|
||||
r.applyTransform(transform::Sigmoid);
|
||||
|
||||
// update gate
|
||||
auto u = gates({0,0, nU, 2*nU}); // [bS, nU]
|
||||
// ◦ means element-wise product or so called Hadamard product
|
||||
// n = tanh(x*Wx + (r◦h0)*Wh + b)
|
||||
auto n = tanh(mmul(*x, (*Wx)({0,0, 2*nU,3*nU})) + mmul((*h0)*r, (*Wh)({0,0, 2*nU,3*nU})) + (*b)({2*nU,3*nU})); // [bS, nU]
|
||||
NDArray u = mmul(*x, Wux) + mmul(*hLast, Wuh) + bu; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
|
||||
u.applyTransform(transform::Sigmoid);
|
||||
|
||||
// cell gate c = activation(x×Wcx + (r*hlast)×Wcu + bc)
|
||||
NDArray c = mmul(*x, Wcx) + mmul(r * *hLast, Wch) + *bc; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
|
||||
c.applyTransform(transform::Tanh);
|
||||
|
||||
// h = (1 - u) * c + u * hPrev
|
||||
|
||||
|
||||
// ***** back prop step ***** //
|
||||
auto Wxr = (*Wx)({0,0, 0, nU});
|
||||
auto Wxu = (*Wx)({0,0, nU, 2*nU});
|
||||
auto Wxn = (*Wx)({0,0, 2*nU,3*nU});
|
||||
auto Whr = (*Wh)({0,0, 0, nU});
|
||||
auto Whu = (*Wh)({0,0, nU, 2*nU});
|
||||
auto Whn = (*Wh)({0,0, 2*nU,3*nU});
|
||||
auto WxrT = Wxr.transpose();
|
||||
auto WxuT = Wxu.transpose();
|
||||
auto WxnT = Wxn.transpose();
|
||||
auto WhrT = Whr.transpose();
|
||||
auto WhuT = Whu.transpose();
|
||||
auto WhnT = Whn.transpose();
|
||||
auto xT = x->transpose();
|
||||
auto h0T = h0->transpose();
|
||||
|
||||
auto dLdWxr = (*dLdWx)({0,0, 0, nU});
|
||||
auto dLdWxu = (*dLdWx)({0,0, nU, 2*nU});
|
||||
auto dLdWxn = (*dLdWx)({0,0, 2*nU,3*nU});
|
||||
// notations:
|
||||
// Zr = x × Wrx + hLast × Wrh + br
|
||||
// Zu = x × Wux + hLast × Wuh + bu
|
||||
// Sr = sigmoid(Zr)
|
||||
// Su = sigmoid(Zu)
|
||||
// Zc = x × Wcx + (r * hlast) × Wch + bc
|
||||
|
||||
auto dLdWhr = (*dLdWh)({0,0, 0, nU});
|
||||
auto dLdWhu = (*dLdWh)({0,0, nU, 2*nU});
|
||||
auto dLdWhn = (*dLdWh)({0,0, 2*nU,3*nU});
|
||||
|
||||
auto dLdbr = (*dLdb)({0, nU});
|
||||
auto dLdbu = (*dLdb)({nU, 2*nU});
|
||||
auto dLdbn = (*dLdb)({2*nU,3*nU});
|
||||
// dLdx = dLdh * dhdx = dLdh * (dhdu * dudx + dhdc * dcdx) = (dLdh * dhdu) * dudx + (dLdh * dhdc) * dcdx = dLdu * dudx + dLdc * dcdx
|
||||
// = dLdx_u + dLdx_c
|
||||
// dLdx_u = dLdu * dudx = dLdu * dudZu * dZudx = |dZudx = ... × WuxT| = (dLdu * dudZu) × WuxT
|
||||
// dLdx_c = dLdc * dcdx = dLdc * dcdZc * (dZcdx + dZcdr * drdx) = dLdc * dcdZc * dZcdx + dLdc * dcdZc * dZcdr * drdx = dLdx_c0 + dLdx_c1
|
||||
// dLdx_c0 = dLdc * dcdZc * dZcdx = |dZcdx = ... × WcxT| = (dLdc * dcdZc) × WcxT
|
||||
// dZcdr = (... * hLast) × WchT
|
||||
// dLdc * dcdZc * dZcdr = dLdr = (dLdc * dcdZc * hLast) × WchT
|
||||
// drdx = drdZr * dZrdx
|
||||
// dZrdx = ... × WrxT
|
||||
// dLdx_c1 = dLdc * dcdZc * dZcdr * drdx = dLdr * drdx = (dLdr * drdZr) × WrxT
|
||||
// finally dLdx = dLdx_u + dLdx_c0 + dLdx_c1 = (dLdu * dudZu) × WuxT + (dLdc * dcdZc) × WcxT + (dLdr * drdZr) × WrxT
|
||||
|
||||
auto dhdu = *h0 - n; // [bS, nU]
|
||||
auto dhdn = 1.f - u; // [bS, nU]
|
||||
auto dSigdu = u * (1.f - u); // [bS, nU]
|
||||
auto dSigdr = r * (1.f - r); // [bS, nU]
|
||||
auto dActdn = 1.f - n * n; // [bS, nU]
|
||||
auto dndr = mmul(dActdn * (*h0), WhnT);
|
||||
auto drdh0 = mmul(dSigdr, WhrT);
|
||||
|
||||
auto dLdn = (*dLdh) * dhdn;
|
||||
auto dLdu = (*dLdh) * dhdu;
|
||||
auto dLdr = dLdn * dndr;
|
||||
// dLdhLast = dLdh * (dhdhLast + dhdu * dudhLast + dhdc * dcdhLast) = dLdh * dhdhLast + dLdu * dudhLast + dLdc * dcdhLast
|
||||
// = dLdhLast_h + dLdhLast_u + dLdhLast_c
|
||||
// dLdhLast_h = dLdh * dhdhLas = dLdh * u
|
||||
// dLdhLast_u = dLdu * dudhLast = |dudhLast = dudZu * dZudhLast , dZudhLast = ... × WuhT| = (dLdu * dudZu) × WuhT
|
||||
// dLdhLast_c = dLdc * dcdhLast = dLdc * (dcdZc * dZcdhLast + dcdZc * dZcdr * drdhLast) =
|
||||
// = dLdc * dcdZc * dZcdhLast + dLdc * dcdZc * dZcdr * drdhLast =
|
||||
// = dLdc * dcdZc * dZcdhLast + dLdr * drdhLast = dLdhLast_c0 + dLdhLast_c1
|
||||
// dLdhLast_c0 = dLdc * dcdZc * dZcdhLast = |dZcdhLast = (... * r) × WchT| = (dLdc * dcdZc * r) × WchT
|
||||
// dLdhLast_c1 = dLdr * drdhLast = |drdhLast = drdZr * dZrdhLast, dZrdhLast = ... × WrhT| = (dLdr * drdZr) × WrhT
|
||||
// finally dLdhLast = dLdhLast_h + dLdhLast_u + dLdhLast_c0 + dLdhLast_c1 =
|
||||
// = dLdh * u + (dLdu * dudZu) × WuhT + (dLdc * dcdZc * r) × WchT + (dLdr * drdZr) × WrhT
|
||||
|
||||
dLdx->assign( mmul(dLdu * dSigdu, WxuT) + mmul(dLdr * dSigdr, WxrT) + mmul(dLdn * dActdn, WxnT) ); // [bS,iS]
|
||||
dLdh0->assign( mmul(dLdu * dSigdu, WhuT) + mmul(dLdn * dActdn * (r + drdh0), WhnT) + (*dLdh)*u ); // [bS,nU]
|
||||
|
||||
dLdWxr.assign( mmul(xT, dSigdr * dLdr) ); // [iS,nU]
|
||||
dLdWhr.assign( mmul(h0T, dSigdr * dLdr) ); // [nU,nU]
|
||||
// dLdWrx = dLdh * dhdWrx = (dLdh * dhdc) * dcdWrx = dLdc * dcdZc * dZcdWrx = dLdc * dcdZc * dZcdr * drdWrx =
|
||||
// = dLdc * dcdZc * dZcdr * drdZr * dZrdWrx = dLdr * drdZr * dZrdWrx
|
||||
// dZrdWrx = xT × ...
|
||||
// finally dLdWrx = xT × (dLdr * drdZr)
|
||||
|
||||
dLdWxu.assign( mmul(xT, dSigdu * dLdu) ); // [iS,nU]
|
||||
dLdWhu.assign( mmul(h0T, dSigdu * dLdu) ); // [nU,nU]
|
||||
|
||||
dLdWxn.assign( mmul(xT, dActdn * dLdn) ); // [iS,nU]
|
||||
dLdWhn.assign( mmul((r*(*h0)).transpose(), dActdn * dLdn) ); // [nU,nU]
|
||||
// dLdWrh = dLdh * dhdWrh = (dLdh * dhdc) * dcdWrh = dLdc * dcdZc * dZcdWrh = dLdc * dcdZc * dZcdr * drdWrh =
|
||||
// = dLdc * dcdZc * dZcdr * drdZr * dZrdWrh = dLdr * drdZr * dZrdWrh
|
||||
// dZrdWrh = hLastT × ...
|
||||
// finally dLdWrh = hLastT × (dLdr * drdZr)
|
||||
|
||||
dLdbr.assign( (dSigdr * dLdr).reduceAlongDims(reduce::Sum, {0})); // [nU]
|
||||
dLdbu.assign( (dSigdu * dLdu).reduceAlongDims(reduce::Sum, {0})); // [nU]
|
||||
dLdbn.assign( (dActdn * dLdn).reduceAlongDims(reduce::Sum, {0})); // [nU]
|
||||
|
||||
if(dLdWx0 != nullptr)
|
||||
*dLdWx += *dLdWx0;
|
||||
// dLdWux = dLdh * dhdWux = (dLdh * dhdu) * dudWux = dLdu * dudZu * dZudWux
|
||||
// dZudWux = xT × ...
|
||||
// dLdu * dudZu * dZudWux = xT × (dLdu * dudZu)
|
||||
|
||||
if(dLdWh0 != nullptr)
|
||||
*dLdWh += *dLdWh0;
|
||||
|
||||
if(dLdb0 != nullptr)
|
||||
*dLdb += *dLdb0;
|
||||
// dLdWuh = dLdh * dhdWuh = (dLdh * dhdu) * dudWuh = dLdh * dhdu * dudZu * dZudWuh = dLdu * dudZu * dZudWuh
|
||||
// dZudWuh = hLastT × ...
|
||||
// finally dLdWuh = hLastT × (dLdu * dudZu)
|
||||
|
||||
|
||||
// dLdWcx = dLdh * dhdWcx = dLdh * dhdc * dcdWcx = (dLdh * dhdc) * dcdZc * dZcdWcx = dLdc * dcdZc * dZcdWcx
|
||||
// dZcdWcx = xT × ...
|
||||
// finally dLdWcx = xT × (dLdc * dcdZc)
|
||||
|
||||
|
||||
// dLdWch = dLdh * dhdWch = dLdh * dhdc * dcdWch = (dLdh * dhdc) * dcdZc * dZcdWch = dLdc * dcdZc * dZcdWch
|
||||
// dZcdWch = (r*hLast)^T × ...
|
||||
// finally dLdWch = (r*hLast)^T × (dLdc * dcdZc)
|
||||
|
||||
|
||||
// dLdbr = dLdh * dhdbr = (dLdh * dhdc) * dcdbr = dLdc * dcdbr = dLdc * dcdZc * dZcdbr = dLdc * dcdZc * dZcdr * drdbr =
|
||||
// = dLdr * drdZr * dZrdbr
|
||||
// dZrdbr = 1
|
||||
// finally dLdbr = dLdr * drdZr
|
||||
|
||||
|
||||
// dLdbu = dLdh * dhdbu = (dLdh * dhdu) * dudbu = dLdu * dudZu * dZudbu
|
||||
// dZudbu = 1
|
||||
// finally dLdbu = dLdu * dudZu
|
||||
|
||||
|
||||
// dLdbc = dLdh * dhdbc = (dLdh * dhdc) * dcdbc = dLdc * dcdZc * dZcdbc
|
||||
// dZcdbc = 1
|
||||
// finally dLdbc = dLdc * dcdZc
|
||||
|
||||
NDArray dhdc = 1.f - u; // [bS, nU]
|
||||
NDArray dhdu = *hLast - c; // [bS, nU]
|
||||
NDArray dudZu = u * dhdc; // [bS, nU]
|
||||
NDArray drdZr = r * (1.f - r); // [bS, nU]
|
||||
NDArray dcdZc = 1.f - c * c; // [bS, nU]
|
||||
NDArray dLdZc = *dLdc * dcdZc; // [bS, nU]
|
||||
NDArray dLdZu = *dLdu * dudZu; // [bS, nU]
|
||||
NDArray dLdZr = *dLdr * drdZr; // [bS, nU]
|
||||
|
||||
// NDArray dLdc = *dLdh * dhdc; // [bS, nU]
|
||||
// NDArray dLdu = *dLdh * dhdu; // [bS, nU]
|
||||
// NDArray dLdr = mmul(dLdc * dcdZc * *hLast, WchT); // [bS, nU]
|
||||
|
||||
dLdx->assign(mmul(dLdZu, WuxT) + mmul(dLdZc, WcxT) + mmul(dLdZr, WrxT)); // [bS, iS]
|
||||
|
||||
dLdhLast->assign(*dLdh * u + mmul(dLdZu, WuhT) + mmul(dLdZc * r, WchT) + mmul(dLdZr, WrhT)); // [bS, nU]
|
||||
|
||||
dLdWrx.assign(mmul(xT, dLdZr)); // [iS, bS] × [bS, nU] = [iS, nU]
|
||||
dLdWrh.assign(mmul(hLastT, dLdZr)); // [nU, bS] × [bS, nU] = [nU, nU]
|
||||
dLdWux.assign(mmul(xT, dLdZu)); // [iS, bS] × [bS, nU] = [iS, nU]
|
||||
dLdWuh.assign(mmul(hLastT, dLdZu)); // [nU, bS] × [bS, nU] = [nU, nU]
|
||||
|
||||
dLdWcx.assign(mmul(xT, dLdZc)); // [iS, bS] × [bS, nU] = [iS, nU]
|
||||
dLdWch.assign(mmul((r * *hLast).transpose(), dLdZc)); // [nU, bS] × [bS, nU] = [nU, nU]
|
||||
|
||||
dLdbr.assign(dLdZr.reduceAlongDims(reduce::Sum, {0})); // [nU]
|
||||
dLdbu.assign(dLdZu.reduceAlongDims(reduce::Sum, {0})); // [nU]
|
||||
|
||||
dLdbc->assign(dLdZc.reduceAlongDims(reduce::Sum, {0})); // [nU]
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -20,12 +20,111 @@
|
|||
|
||||
#include <ops/declarable/helpers/hashcode.h>
|
||||
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
template <typename T>
|
||||
static __global__ void splitBufferToChuncks(T* buffer, Nd4jLong* tempBuffer, Nd4jLong numBlocks, Nd4jLong blockSize, Nd4jLong length) {
|
||||
|
||||
for (int b = blockIdx.x; b < numBlocks; b += gridDim.x) {
|
||||
auto blockBuffer = buffer + b * numBlocks;
|
||||
|
||||
Nd4jLong r = 1;
|
||||
for (int e = threadIdx.x; e < blockSize && e + (b * numBlocks) < length; e += blockDim.x) {
|
||||
auto v = longBytes<T>(blockBuffer[e]);
|
||||
r = 31 * r + v;
|
||||
}
|
||||
|
||||
tempBuffer[b] = r;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static __global__ void internalHash(Nd4jLong* tempBuffer, Nd4jLong* tempResult, Nd4jLong numBlocks, Nd4jLong blockSize, Nd4jLong lastLength) {
|
||||
|
||||
for (int b = blockIdx.x; b < numBlocks; b += gridDim.x) {
|
||||
auto blockBuffer = tempBuffer + b * numBlocks;
|
||||
|
||||
Nd4jLong r = 1;
|
||||
for (int e = threadIdx.x; e < blockSize && e + (b * numBlocks) < lastLength; e += blockDim.x) {
|
||||
auto v = longBytes<T>(blockBuffer[e]);
|
||||
r = 31 * r + v;
|
||||
}
|
||||
|
||||
tempResult[b] = r;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
static __global__ void lastStep(Nd4jLong* resultBuf, Nd4jLong* tempBufferA, Nd4jLong* tempResult, Nd4jLong length, Nd4jLong blockSize) {
|
||||
if (threadIdx.x == 0) {
|
||||
|
||||
if (length <= blockSize)
|
||||
*resultBuf = *tempBufferA;
|
||||
else
|
||||
*resultBuf = *tempResult;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void hashCode_(LaunchContext *context, NDArray &array, NDArray &result) {
|
||||
auto blockSize = 32;
|
||||
auto stream = context->getCudaStream();
|
||||
array.syncToDevice();
|
||||
|
||||
NDArray::prepareSpecialUse({&result}, {&array});
|
||||
auto length = array.lengthOf();
|
||||
int numBlocks = length / blockSize + ((length % blockSize == 0) ? 0 : 1);
|
||||
auto tempA = NDArrayFactory::create<Nd4jLong>('c', {numBlocks}, context);
|
||||
auto tempB = NDArrayFactory::create<Nd4jLong>('c', { numBlocks / blockSize + 1}, context);
|
||||
|
||||
auto buffer = reinterpret_cast<T*>(array.specialBuffer()); //bufferAsT<T>();
|
||||
auto tempBufferA = reinterpret_cast<Nd4jLong*>(tempA.specialBuffer()); //bufferAsT<Nd4jLong>();
|
||||
auto tempBufferB = reinterpret_cast<Nd4jLong*>(tempB.specialBuffer()); //bufferAsT<Nd4jLong>();
|
||||
|
||||
// default buffer is the first one, because it might be the last one in case of small arrays (< blockSize)
|
||||
auto tempBuffer = tempBufferA;
|
||||
auto tempResult = tempBufferB;
|
||||
|
||||
// we divide array into 32 element chunks, and store intermediate results once
|
||||
splitBufferToChuncks<T><<<numBlocks, length, 1024, *stream>>>(buffer, tempBuffer, numBlocks, blockSize, length);
|
||||
|
||||
// we replace pointer with intermediate one, and repeat only one chunk left
|
||||
int iterationCount = 0;
|
||||
while (numBlocks > 1) {
|
||||
int lastLength = numBlocks;
|
||||
numBlocks = lastLength / blockSize + ((lastLength % blockSize == 0) ? 0 : 1);
|
||||
|
||||
|
||||
internalHash<Nd4jLong><<<numBlocks, lastLength, 1024, *stream>>>(tempBuffer, tempResult, numBlocks, blockSize, lastLength);
|
||||
|
||||
|
||||
iterationCount++;
|
||||
// swapping buffers
|
||||
if (iterationCount % 2 == 0) {
|
||||
tempBuffer = tempBufferA;
|
||||
tempResult = tempBufferB;
|
||||
} else {
|
||||
tempBuffer = tempBufferB;
|
||||
tempResult = tempBufferA;
|
||||
}
|
||||
}
|
||||
|
||||
//lastStep<Nd4jLong><<<1,1,128, *stream>>>(result.specialBuffer(), tempBufferA, tempResult, length, blockSize);
|
||||
tempA.syncToHost();
|
||||
tempB.syncToHost();
|
||||
result.assign((length <= blockSize?tempA.e(0) : tempB.e(0)));
|
||||
|
||||
NDArray::registerSpecialUse({&result}, {&array});
|
||||
}
|
||||
|
||||
void hashCode(LaunchContext *context, NDArray &array, NDArray &result) {
|
||||
|
||||
BUILD_SINGLE_SELECTOR(array.dataType(), hashCode_, (context, array, result), LIBND4J_TYPES);
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void hashCode_, (LaunchContext* context, NDArray& array, NDArray& result), LIBND4J_TYPES);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,6 +20,8 @@
|
|||
|
||||
#include <ops/declarable/helpers/image_suppression.h>
|
||||
#include <NDArrayFactory.h>
|
||||
#include <NativeOps.h>
|
||||
#include <cuda_exception.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
|
@ -35,15 +37,16 @@ namespace helpers {
|
|||
Nd4jLong next1[] = {nextIndex, 1};
|
||||
Nd4jLong next2[] = {nextIndex, 2};
|
||||
Nd4jLong next3[] = {nextIndex, 3};
|
||||
|
||||
T minYPrev = nd4j::math::nd4j_min(boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), previous0, 2)], boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), previous2, 2)]);
|
||||
T minXPrev = nd4j::math::nd4j_min(boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), previous1, 2)], boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), previous3, 2)]);
|
||||
T maxYPrev = nd4j::math::nd4j_max(boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), previous0, 2)], boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), previous2, 2)]);
|
||||
T maxXPrev = nd4j::math::nd4j_max(boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), previous1, 2)], boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), previous3, 2)]);
|
||||
T minYNext = nd4j::math::nd4j_min(boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), next0, 2)], boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), next2, 2)]);
|
||||
T minXNext = nd4j::math::nd4j_min(boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), next1, 2)], boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), next3, 2)]);
|
||||
T maxYNext = nd4j::math::nd4j_max(boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), next0, 2)], boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), next2, 2)]);
|
||||
T maxXNext = nd4j::math::nd4j_max(boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), next1, 2)], boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), next3, 2)]);
|
||||
Nd4jLong* shapeOf = shape::shapeOf(boxesShape);
|
||||
Nd4jLong* strideOf = shape::stride(boxesShape);
|
||||
T minYPrev = nd4j::math::nd4j_min(boxes[shape::getOffset(0, shapeOf, strideOf, previous0, 2)], boxes[shape::getOffset(0, shapeOf, strideOf, previous2, 2)]);
|
||||
T minXPrev = nd4j::math::nd4j_min(boxes[shape::getOffset(0, shapeOf, strideOf, previous1, 2)], boxes[shape::getOffset(0, shapeOf, strideOf, previous3, 2)]);
|
||||
T maxYPrev = nd4j::math::nd4j_max(boxes[shape::getOffset(0, shapeOf, strideOf, previous0, 2)], boxes[shape::getOffset(0, shapeOf, strideOf, previous2, 2)]);
|
||||
T maxXPrev = nd4j::math::nd4j_max(boxes[shape::getOffset(0, shapeOf, strideOf, previous1, 2)], boxes[shape::getOffset(0, shapeOf, strideOf, previous3, 2)]);
|
||||
T minYNext = nd4j::math::nd4j_min(boxes[shape::getOffset(0, shapeOf, strideOf, next0, 2)], boxes[shape::getOffset(0, shapeOf, strideOf, next2, 2)]);
|
||||
T minXNext = nd4j::math::nd4j_min(boxes[shape::getOffset(0, shapeOf, strideOf, next1, 2)], boxes[shape::getOffset(0, shapeOf, strideOf, next3, 2)]);
|
||||
T maxYNext = nd4j::math::nd4j_max(boxes[shape::getOffset(0, shapeOf, strideOf, next0, 2)], boxes[shape::getOffset(0, shapeOf, strideOf, next2, 2)]);
|
||||
T maxXNext = nd4j::math::nd4j_max(boxes[shape::getOffset(0, shapeOf, strideOf, next1, 2)], boxes[shape::getOffset(0, shapeOf, strideOf, next3, 2)]);
|
||||
|
||||
T areaPrev = (maxYPrev - minYPrev) * (maxXPrev - minXPrev);
|
||||
T areaNext = (maxYNext - minYNext) * (maxXNext - minXNext);
|
||||
|
@ -62,149 +65,101 @@ namespace helpers {
|
|||
};
|
||||
|
||||
template <typename T, typename I>
|
||||
static __global__ void nonMaxSuppressionKernel(T* boxes, Nd4jLong* boxesShape, I* indices, int* selectedIndices, Nd4jLong numBoxes, I* output, Nd4jLong* outputShape, T threshold) {
|
||||
__shared__ Nd4jLong outputLen;
|
||||
|
||||
static __global__ void shouldSelectKernel(T* boxesBuf, Nd4jLong* boxesShape, I* indexBuf, I* selectedIndicesData, double threshold, int numSelected, int i, bool* shouldSelect) {
|
||||
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
auto step = gridDim.x * blockDim.x;
|
||||
__shared__ bool shouldSelectShared;
|
||||
if (threadIdx.x == 0) {
|
||||
outputLen = shape::length(outputShape);
|
||||
shouldSelectShared = shouldSelect[0];
|
||||
}
|
||||
__syncthreads();
|
||||
for (int j = numSelected - 1 - tid; j >= 0; j -= step) {
|
||||
if (shouldSelectShared) {
|
||||
if (needToSuppressWithThreshold(boxesBuf, boxesShape, indexBuf[i],
|
||||
indexBuf[selectedIndicesData[j]], T(threshold)))
|
||||
shouldSelectShared = false;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
if (threadIdx.x == 0) {
|
||||
*shouldSelect = shouldSelectShared;
|
||||
}
|
||||
}
|
||||
template <typename I>
|
||||
|
||||
auto numSelected = blockIdx.x;
|
||||
auto start = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
static __global__ void copyIndices(void* indices, void* indicesLong, Nd4jLong len) {
|
||||
__shared__ I* indexBuf;
|
||||
__shared__ Nd4jLong* srcBuf;
|
||||
if (threadIdx.x == 0) {
|
||||
indexBuf = reinterpret_cast<I*>(indices);
|
||||
srcBuf = reinterpret_cast<Nd4jLong*>(indicesLong);
|
||||
}
|
||||
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
auto step = blockDim.x * gridDim.x;
|
||||
// for (int numSelected = blockIdx.x; numSelected < outputLen; numSelected += gridDim.x) {
|
||||
for (int i = start; i < numBoxes; i += step) {
|
||||
bool shouldSelect = true;
|
||||
for (int j = numSelected - 1; shouldSelect && j >= 0; --j) {
|
||||
if (needToSuppressWithThreshold<T>(boxes, boxesShape, indices[i], indices[selectedIndices[j]], threshold)) {
|
||||
shouldSelect = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (shouldSelect) {
|
||||
auto zPos = shape::getIndexOffset(numSelected, outputShape, outputLen);
|
||||
output[zPos] = indices[i];
|
||||
selectedIndices[numSelected] = i;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename I>
|
||||
static __global__ void sortIndices(I* indices, Nd4jLong* indexShape, T* scores, Nd4jLong* scoreShape) {
|
||||
__shared__ Nd4jLong len;
|
||||
// __shared__ Nd4jLong* sortedPart;
|
||||
// __shared__ Nd4jLong part;
|
||||
// __shared__ Nd4jLong partSize;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
// blocksPerArr = (gridDim.x + numOfArrs - 1) / numOfArrs; // ceil
|
||||
// part = blockIdx.x / blocksPerArr;
|
||||
|
||||
len = shape::length(indexShape);
|
||||
// __shared__ Nd4jLong* shmem = shared[];
|
||||
// sortedPart = shmem;
|
||||
}
|
||||
|
||||
for (int m = 0; m < len; m++) {
|
||||
if (m % 2 == 0) {
|
||||
for (int tid = threadIdx.x; tid < len; tid += blockDim.x) {
|
||||
auto top = 2 * tid + 1;
|
||||
if (top < len) {
|
||||
auto t0 = shape::getIndexOffset(top - 1, indexShape, len);
|
||||
auto t1 = shape::getIndexOffset(top, indexShape, len);
|
||||
auto z0 = shape::getIndexOffset(top - 1, scoreShape, len);
|
||||
auto z1 = shape::getIndexOffset(top, scoreShape, len);
|
||||
|
||||
if (scores[t0] < scores[t1]) {
|
||||
// swap indices first
|
||||
Nd4jLong di0 = indices[t0];
|
||||
indices[t0] = indices[t1];
|
||||
indices[t1] = di0;
|
||||
|
||||
//swap scores next
|
||||
// T dz0 = scores[z0];
|
||||
// scores[z0] = scores[z1];
|
||||
// scores[z1] = dz0;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int tid = threadIdx.x; tid < len; tid += blockDim.x) {
|
||||
auto top = 2 * tid + 2;
|
||||
if (top < len) {
|
||||
auto t0 = shape::getIndexOffset(top - 1, indexShape, len);
|
||||
auto t1 = shape::getIndexOffset(top, indexShape, len);
|
||||
auto z0 = shape::getIndexOffset(top - 1, scoreShape, len);
|
||||
auto z1 = shape::getIndexOffset(top, scoreShape, len);
|
||||
|
||||
if (scores[t0] < scores[t1]) {
|
||||
// swap indices first
|
||||
Nd4jLong di0 = indices[t0];
|
||||
indices[t0] = indices[t1];
|
||||
indices[t1] = di0;
|
||||
|
||||
//swap scores next
|
||||
// T dz0 = scores[z0];
|
||||
// scores[z0] = scores[z1];
|
||||
// scores[z1] = dz0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
for (auto i = tid; i < len; i += step)
|
||||
indexBuf[i] = (I)srcBuf[i];
|
||||
}
|
||||
|
||||
template <typename T, typename I>
|
||||
static void nonMaxSuppressionV2_(nd4j::LaunchContext* context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, NDArray* output) {
|
||||
auto stream = context->getCudaStream();
|
||||
NDArray::prepareSpecialUse({output}, {boxes, scales});
|
||||
NDArray* indices = NDArrayFactory::create_<I>('c', {scales->lengthOf()}); // - 1, scales->lengthOf()); //, scales->getContext());
|
||||
std::unique_ptr<NDArray> indices(NDArrayFactory::create_<I>('c', {scales->lengthOf()})); // - 1, scales->lengthOf()); //, scales->getContext());
|
||||
indices->linspace(0);
|
||||
indices->syncToDevice(); // linspace only on CPU, so sync to Device as well
|
||||
|
||||
NDArray scores(*scales);
|
||||
indices->syncToHost(); //linspace(0);
|
||||
I* indexBuf = reinterpret_cast<I*>(indices->specialBuffer());
|
||||
T* scoreBuf = reinterpret_cast<T*>(scores.specialBuffer());
|
||||
sortIndices<T, I><<<1, 32, 128, *stream>>>(indexBuf, indices->specialShapeInfo(), scoreBuf, scores.specialShapeInfo());
|
||||
NativeOps nativeOps;
|
||||
|
||||
Nd4jPointer extras[2] = {nullptr, stream};
|
||||
|
||||
nativeOps.sortByValue(extras, indices->buffer(), indices->shapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), scores.buffer(), scores.shapeInfo(), scores.specialBuffer(), scores.specialShapeInfo(), true);
|
||||
// TO DO: sort indices using scales as value row
|
||||
//std::sort(indices.begin(), indices.end(), [scales](int i, int j) {return scales->e<T>(i) > scales->e<T>(j);});
|
||||
indices->tickWriteDevice();
|
||||
indices->syncToHost();
|
||||
indices->printIndexedBuffer("AFTERSORT OUTPUT");
|
||||
NDArray selected = NDArrayFactory::create<int>({output->lengthOf()});
|
||||
I* indexBuf = reinterpret_cast<I*>(indices->specialBuffer());
|
||||
|
||||
NDArray selectedIndices = NDArrayFactory::create<int>({output->lengthOf()});
|
||||
NDArray selectedIndices = NDArrayFactory::create<I>('c', {output->lengthOf()});
|
||||
int numSelected = 0;
|
||||
int numBoxes = boxes->sizeAt(0);
|
||||
T* boxesBuf = reinterpret_cast<T*>(boxes->specialBuffer());
|
||||
// Nd4jLong* indicesData = reinterpret_cast<Nd4jLong*>(indices->specialBuffer());
|
||||
// int* selectedData = reinterpret_cast<int*>(selected.specialBuffer());
|
||||
int* selectedIndicesData = reinterpret_cast<int*>(selectedIndices.specialBuffer());
|
||||
|
||||
I* selectedIndicesData = reinterpret_cast<I*>(selectedIndices.specialBuffer());
|
||||
I* outputBuf = reinterpret_cast<I*>(output->specialBuffer());
|
||||
nonMaxSuppressionKernel<T, I><<<output->lengthOf(), 512, 1024, *stream>>>(boxesBuf, boxes->specialShapeInfo(), indexBuf, selectedIndicesData, numBoxes, outputBuf, output->specialShapeInfo(), T(threshold));
|
||||
NDArray::registerSpecialUse({output}, {boxes, scales});
|
||||
// for (int i = 0; i < boxes->sizeAt(0); ++i) {
|
||||
// if (selected.size() >= output->lengthOf()) break;
|
||||
// bool shouldSelect = true;
|
||||
// // Overlapping boxes are likely to have similar scores,
|
||||
// // therefore we iterate through the selected boxes backwards.
|
||||
// for (int j = numSelected - 1; j >= 0; --j) {
|
||||
// if (needToSuppressWithThreshold(*boxes, indices[i], indices[selectedIndices[j]], T(threshold)) {
|
||||
// shouldSelect = false;
|
||||
// break;
|
||||
// }
|
||||
// }
|
||||
// if (shouldSelect) {
|
||||
// selected.push_back(indices[i]);
|
||||
// selectedIndices[numSelected++] = i;
|
||||
// }
|
||||
// }
|
||||
// for (size_t e = 0; e < selected.size(); ++e)
|
||||
// output->p<int>(e, selected[e]);
|
||||
//
|
||||
delete indices;
|
||||
|
||||
bool* shouldSelectD;
|
||||
auto err = cudaMalloc(&shouldSelectD, sizeof(bool));
|
||||
if (err) {
|
||||
throw cuda_exception::build("helpers::nonMaxSuppressionV2: Cannot allocate memory for bool flag", err);
|
||||
}
|
||||
for (I i = 0; i < boxes->sizeAt(0); ++i) {
|
||||
bool shouldSelect = numSelected < output->lengthOf();
|
||||
if (shouldSelect) {
|
||||
err = cudaMemcpy(shouldSelectD, &shouldSelect, sizeof(bool), cudaMemcpyHostToDevice);
|
||||
if (err) {
|
||||
throw cuda_exception::build("helpers::nonMaxSuppressionV2: Cannot set up bool flag to device", err);
|
||||
}
|
||||
|
||||
shouldSelectKernel<T> <<< 128, 256, 1024, *stream >>>
|
||||
(boxesBuf, boxes->specialShapeInfo(), indexBuf, selectedIndicesData, threshold, numSelected, i, shouldSelectD);
|
||||
err = cudaMemcpy(&shouldSelect, shouldSelectD, sizeof(bool), cudaMemcpyDeviceToHost);
|
||||
if (err) {
|
||||
throw cuda_exception::build("helpers::nonMaxSuppressionV2: Cannot set up bool flag to host", err);
|
||||
}
|
||||
}
|
||||
|
||||
if (shouldSelect) {
|
||||
cudaMemcpy(reinterpret_cast<I*>(output->specialBuffer()) + numSelected, indexBuf + i, sizeof(I), cudaMemcpyDeviceToDevice);
|
||||
cudaMemcpy(selectedIndicesData + numSelected, &i, sizeof(I), cudaMemcpyHostToDevice);
|
||||
numSelected++;
|
||||
}
|
||||
}
|
||||
|
||||
err = cudaFree(shouldSelectD);
|
||||
if (err) {
|
||||
throw cuda_exception::build("helpers::nonMaxSuppressionV2: Cannot deallocate memory for bool flag", err);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void nonMaxSuppressionV2(nd4j::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, NDArray* output) {
|
||||
|
|
|
@ -32,24 +32,24 @@ namespace nd4j {
|
|||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
template <typename T>
|
||||
static __device__ void swapRows_(T* matrix, Nd4jLong* shape, int theFirst, int theSecond, Nd4jLong N) {
|
||||
if (theFirst != theSecond) {
|
||||
auto start = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
auto step = blockDim.x * gridDim.x;
|
||||
for (auto i = start; i < N; i += step) {
|
||||
Nd4jLong iCoord1[] = {theFirst, i};
|
||||
Nd4jLong iCoord2[] = {theSecond, i};
|
||||
auto iIndex1 = shape::getOffset(0, shape::shapeOf(shape), shape::stride(shape), iCoord1, 2);
|
||||
auto iIndex2 = shape::getOffset(0, shape::shapeOf(shape), shape::stride(shape), iCoord2, 2);
|
||||
//atomicExch(&matrix[iIndex1], matrix[iIndex2]);
|
||||
T e0 = matrix[iIndex1];
|
||||
T e1 = matrix[iIndex2];
|
||||
matrix[iIndex1] = e0;
|
||||
matrix[iIndex2] = e1;
|
||||
}
|
||||
}
|
||||
}
|
||||
// template <typename T>
|
||||
// static __device__ void swapRows_(T* matrix, Nd4jLong* shape, int theFirst, int theSecond, Nd4jLong N) {
|
||||
// if (theFirst != theSecond) {
|
||||
// auto start = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
// auto step = blockDim.x * gridDim.x;
|
||||
// for (auto i = start; i < N; i += step) {
|
||||
// Nd4jLong iCoord1[] = {theFirst, i};
|
||||
// Nd4jLong iCoord2[] = {theSecond, i};
|
||||
// auto iIndex1 = shape::getOffset(0, shape::shapeOf(shape), shape::stride(shape), iCoord1, 2);
|
||||
// auto iIndex2 = shape::getOffset(0, shape::shapeOf(shape), shape::stride(shape), iCoord2, 2);
|
||||
// //atomicExch(&matrix[iIndex1], matrix[iIndex2]);
|
||||
// T e0 = matrix[iIndex1];
|
||||
// T e1 = matrix[iIndex2];
|
||||
// matrix[iIndex1] = e0;
|
||||
// matrix[iIndex2] = e1;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// BUILD_SINGLE_TEMPLATE(template void swapRows_, (NDArray* matrix, int theFirst, int theSecond), FLOAT_TYPES);
|
||||
//
|
||||
// void swapRows(NDArray* matrix, int theFirst, int theSecond) {
|
||||
|
@ -71,9 +71,14 @@ namespace helpers {
|
|||
|
||||
for (int i = start + 1; i < n; i += step) {
|
||||
Nd4jLong pos[] = {i, i - 1};
|
||||
Nd4jLong posX[] = {i, i};
|
||||
Nd4jLong posY[] = {i - 1, i - 1};
|
||||
auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2);
|
||||
auto dxIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posX, 2);
|
||||
auto dyIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posY, 2);
|
||||
auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), pos, 2);
|
||||
inverted[zIndex] = -input[xIndex];
|
||||
inverted[zIndex] = -input[xIndex] / (input[dxIndex] * input[dyIndex]);
|
||||
// math::atomics::nd4j_atomicAdd(&inverted[zIndex], - input[xIndex] * inverted[iIndex] / input[dIndex]);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -91,10 +96,11 @@ namespace helpers {
|
|||
auto start = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
auto step = blockDim.x * gridDim.x;
|
||||
|
||||
for (int i = start + 1; i < n; i += step) {
|
||||
for (int i = start; i < n; i += step) {
|
||||
Nd4jLong pos[] = {i, i};
|
||||
auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2);
|
||||
auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), pos, 2);
|
||||
// math::atomics::nd4j_atomicDiv(&inverted[zIndex], input[xIndex]);
|
||||
inverted[zIndex] /= input[xIndex];
|
||||
}
|
||||
}
|
||||
|
@ -113,16 +119,16 @@ namespace helpers {
|
|||
auto start = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
auto step = blockDim.x * gridDim.x;
|
||||
|
||||
for (int i = start + 1; i < n - 1; i += step) {
|
||||
for (int i = start; i < n - 1; i += step) {
|
||||
Nd4jLong pos[] = {i, i + 1};
|
||||
Nd4jLong posY[] = {i, i};
|
||||
//Nd4jLong posY[] = {i, i};
|
||||
Nd4jLong posX[] = {i + 1, i + 1};
|
||||
auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2);
|
||||
auto yIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2);
|
||||
// auto yIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posY, 2);
|
||||
// auto yIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2);
|
||||
auto iIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posX, 2);
|
||||
auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), pos, 2);
|
||||
inverted[zIndex] -= input[xIndex] * inverted[iIndex] / input[yIndex];
|
||||
math::atomics::nd4j_atomicAdd(&inverted[zIndex], - input[xIndex] * inverted[iIndex]); // / input[yIndex]);
|
||||
//inputMatrix->t<T>(i, i + 1) * invertedMatrix->t<T>(i + 1, i + 1) / inputMatrix->t<T>(i, i)
|
||||
}
|
||||
}
|
||||
|
@ -142,16 +148,18 @@ namespace helpers {
|
|||
// auto step = blockDim.x * gridDim.x;
|
||||
|
||||
for (int i = blockIdx.x + 2; i < n; i += gridDim.x) {
|
||||
for (int j = i - 2; j > -1; --j)
|
||||
for (int j = i - 2; j >= 0; --j)
|
||||
for (int k = threadIdx.x; k < i; k += blockDim.x) {
|
||||
Nd4jLong posZ[] = {i, j};
|
||||
Nd4jLong posX[] = {k, j};
|
||||
Nd4jLong posY[] = {i, k};
|
||||
Nd4jLong posY[] = {k, j};
|
||||
Nd4jLong posX[] = {i, k};
|
||||
Nd4jLong posD[] = {i, i};
|
||||
|
||||
auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posX, 2);
|
||||
auto yIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posY, 2);
|
||||
auto dIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posD, 2);
|
||||
auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posZ, 2);
|
||||
inverted[zIndex] -= inverted[yIndex] * input[xIndex];
|
||||
math::atomics::nd4j_atomicAdd(&inverted[zIndex], - inverted[yIndex] * input[xIndex] / input[dIndex]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -176,13 +184,13 @@ namespace helpers {
|
|||
Nd4jLong posZ[] = {i, j};
|
||||
Nd4jLong posY[] = {k, j};
|
||||
Nd4jLong posX[] = {i, k};
|
||||
Nd4jLong posD[] = {i, i};
|
||||
// Nd4jLong posD[] = {i, i};
|
||||
|
||||
auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posX, 2);
|
||||
auto yIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posY, 2);
|
||||
auto dIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posD, 2);
|
||||
// auto dIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posD, 2);
|
||||
auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posZ, 2);
|
||||
inverted[zIndex] -= inverted[yIndex] * input[xIndex] / input[dIndex];
|
||||
math::atomics::nd4j_atomicAdd(&inverted[zIndex], - inverted[yIndex] * input[xIndex]);// / input[dIndex]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -196,14 +204,18 @@ namespace helpers {
|
|||
LaunchContext* context = inputMatrix->getContext();
|
||||
auto stream = context->getCudaStream();
|
||||
|
||||
// invert main diagonal
|
||||
upvertKernel<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||
// invert the second diagonal
|
||||
invertKernelLow<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||
// invertKernelLow<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||
invertLowKernel<T><<<n, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void invertLowerMatrix_, (NDArray* inputMatrix, NDArray* invertedMatrix);, FLOAT_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template void invertLowerMatrix_, (NDArray* inputMatrix, NDArray* invertedMatrix);, FLOAT_NATIVE);
|
||||
|
||||
void invertLowerMatrix(NDArray* inputMatrix, NDArray* invertedMatrix) {
|
||||
BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertLowerMatrix_, (inputMatrix, invertedMatrix), FLOAT_TYPES);
|
||||
BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertLowerMatrix_, (inputMatrix, invertedMatrix), FLOAT_NATIVE);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -215,58 +227,58 @@ namespace helpers {
|
|||
return;
|
||||
}
|
||||
|
||||
upvertKernel<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||
//upvertKernel<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||
upvertKernelUp<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||
invertUpKernel<T><<<n, n, 256, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void invertUpperMatrix_, (NDArray* inputMatrix, NDArray* invertedMatrix);, FLOAT_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template void invertUpperMatrix_, (NDArray* inputMatrix, NDArray* invertedMatrix);, FLOAT_NATIVE);
|
||||
|
||||
void invertUpperMatrix(NDArray* inputMatrix, NDArray* invertedMatrix) {
|
||||
BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertUpperMatrix_, (inputMatrix, invertedMatrix), FLOAT_TYPES);
|
||||
BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertUpperMatrix_, (inputMatrix, invertedMatrix), FLOAT_NATIVE);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static __global__ void lupKernel(T* compound, Nd4jLong* compoundShape, T* permutation, Nd4jLong* permutationShape, Nd4jLong rowNum) {
|
||||
int swapCount = 0;
|
||||
for(int i = blockIdx.x; i < rowNum; i += gridDim.x ) {
|
||||
auto pivotValue = T(0.0);
|
||||
auto pivot = -1;
|
||||
|
||||
for(int rowCounter = i; rowCounter < rowNum; rowCounter++ ) {
|
||||
Nd4jLong rowCoord[] = {rowCounter, i};
|
||||
auto rowPos = shape::getOffset(0, shape::shapeOf(compoundShape), shape::stride(compoundShape), rowCoord, 2);
|
||||
if(nd4j::math::nd4j_abs(compound[rowPos]) > pivotValue ) {
|
||||
pivotValue = nd4j::math::nd4j_abs(compound[rowPos]);
|
||||
pivot = rowCounter;
|
||||
}
|
||||
}
|
||||
|
||||
if( pivotValue != T(0.0) ) {
|
||||
swapRows_<T>(compound, compoundShape, pivot, i, rowNum);
|
||||
swapRows_<T>(permutation, permutationShape, pivot, i, rowNum);
|
||||
if (pivot != i)
|
||||
swapCount++;
|
||||
|
||||
for( int j = i + 1; j < rowNum; j++ ) {
|
||||
Nd4jLong posJIbuf[] = {j, i};
|
||||
Nd4jLong posIIbuf[] = {i, i};
|
||||
auto posJI = shape::getOffset(0, shape::shapeOf(compoundShape), shape::stride(compoundShape), posJIbuf, 2);
|
||||
auto posII = shape::getOffset(0, shape::shapeOf(compoundShape), shape::stride(compoundShape), posIIbuf, 2);
|
||||
|
||||
compound[posJI] /= compound[posII];
|
||||
for( int k = i + 1; k < rowNum; k++ ) {
|
||||
Nd4jLong posJKbuf[] = {j, k};
|
||||
Nd4jLong posIKbuf[] = {i, k};
|
||||
auto posJK = shape::getOffset(0, shape::shapeOf(compoundShape), shape::stride(compoundShape), posJKbuf, 2);
|
||||
auto posIK = shape::getOffset(0, shape::shapeOf(compoundShape), shape::stride(compoundShape), posIKbuf, 2);
|
||||
T arg = compound[posJI] * compound[posIK];
|
||||
compound[posJK] -= arg;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// template <typename T>
|
||||
// static __global__ void lupKernel(T* compound, Nd4jLong* compoundShape, T* permutation, Nd4jLong* permutationShape, Nd4jLong rowNum) {
|
||||
// int swapCount = 0;
|
||||
// for(int i = blockIdx.x; i < rowNum; i += gridDim.x ) {
|
||||
// auto pivotValue = T(0.0);
|
||||
// auto pivot = -1;
|
||||
//
|
||||
// for(int rowCounter = i; rowCounter < rowNum; rowCounter++ ) {
|
||||
// Nd4jLong rowCoord[] = {rowCounter, i};
|
||||
// auto rowPos = shape::getOffset(0, shape::shapeOf(compoundShape), shape::stride(compoundShape), rowCoord, 2);
|
||||
// if(nd4j::math::nd4j_abs(compound[rowPos]) > pivotValue ) {
|
||||
// pivotValue = nd4j::math::nd4j_abs(compound[rowPos]);
|
||||
// pivot = rowCounter;
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// if( pivotValue != T(0.0) ) {
|
||||
// swapRows_<T>(compound, compoundShape, pivot, i, rowNum);
|
||||
// swapRows_<T>(permutation, permutationShape, pivot, i, rowNum);
|
||||
// if (pivot != i)
|
||||
// swapCount++;
|
||||
//
|
||||
// for( int j = i + 1; j < rowNum; j++ ) {
|
||||
// Nd4jLong posJIbuf[] = {j, i};
|
||||
// Nd4jLong posIIbuf[] = {i, i};
|
||||
// auto posJI = shape::getOffset(0, shape::shapeOf(compoundShape), shape::stride(compoundShape), posJIbuf, 2);
|
||||
// auto posII = shape::getOffset(0, shape::shapeOf(compoundShape), shape::stride(compoundShape), posIIbuf, 2);
|
||||
//
|
||||
// compound[posJI] /= compound[posII];
|
||||
// for( int k = i + 1; k < rowNum; k++ ) {
|
||||
// Nd4jLong posJKbuf[] = {j, k};
|
||||
// Nd4jLong posIKbuf[] = {i, k};
|
||||
// auto posJK = shape::getOffset(0, shape::shapeOf(compoundShape), shape::stride(compoundShape), posJKbuf, 2);
|
||||
// auto posIK = shape::getOffset(0, shape::shapeOf(compoundShape), shape::stride(compoundShape), posIKbuf, 2);
|
||||
// T arg = compound[posJI] * compound[posIK];
|
||||
// compound[posJK] -= arg;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
template <typename T, typename F>
|
||||
static __global__ void determinantKernel(T* compound, T* result, Nd4jLong len) {
|
||||
|
@ -332,6 +344,30 @@ namespace helpers {
|
|||
matrix[j] = (F)inputBuf[xIndex];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename F>
|
||||
static __global__ void returnMatrix(void* output, Nd4jLong* outputShape, void* input, Nd4jLong* inputShape, Nd4jLong pos, Nd4jLong rowLen) {
|
||||
__shared__ F* matrix;
|
||||
__shared__ T* outputBuf;
|
||||
__shared__ Nd4jLong outputLen;
|
||||
__shared__ Nd4jLong n2;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
matrix = reinterpret_cast<F*>(input);
|
||||
outputBuf = reinterpret_cast<T*>(output);
|
||||
outputLen = shape::length(inputShape);
|
||||
n2 = rowLen * rowLen;
|
||||
}
|
||||
__syncthreads();
|
||||
auto start = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
auto step = blockDim.x * gridDim.x;
|
||||
|
||||
for (int k = pos + start, j = start; j < n2; k += step, j += step) {
|
||||
auto zIndex = shape::getIndexOffset(k, outputShape, outputLen);
|
||||
outputBuf[zIndex] = (T)matrix[j];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
static __global__ void fillUpPermutation(void* output, Nd4jLong* shape, int* source, int rowNum) {
|
||||
__shared__ F* permutation;
|
||||
|
@ -462,7 +498,7 @@ namespace helpers {
|
|||
d_work,
|
||||
permutationBuf,
|
||||
d_info);
|
||||
fillUpPermutation<float><<<n, n, 128, *stream>>>(permutation->specialBuffer(), permutation->specialShapeInfo(), permutationBuf, n);
|
||||
fillUpPermutation<T><<<n, n, 128, *stream>>>(permutation->specialBuffer(), permutation->specialShapeInfo(), permutationBuf, n);
|
||||
permutation->tickWriteDevice();
|
||||
}
|
||||
err = cudaFree(d_work);
|
||||
|
@ -483,7 +519,7 @@ namespace helpers {
|
|||
// NDArray::registerSpecialUse({input}, {input});
|
||||
input->tickWriteDevice();
|
||||
}
|
||||
BUILD_SINGLE_TEMPLATE(template void lup_, (LaunchContext* context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template void lup_, (LaunchContext* context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_NATIVE);
|
||||
|
||||
template <typename T>
|
||||
static int determinant_(nd4j::LaunchContext* context, NDArray* input, NDArray* output) {
|
||||
|
@ -504,32 +540,32 @@ namespace helpers {
|
|||
output->assign(1.f);
|
||||
for (int e = 0; e < output->lengthOf(); e++) {
|
||||
Nd4jLong pos = e * n2;
|
||||
if (matrix.dataType() == input->dataType())
|
||||
// if (matrix.dataType() == input->dataType())
|
||||
fillMatrix<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
||||
else
|
||||
fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
||||
// else
|
||||
// fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
||||
|
||||
if (matrix.dataType() == input->dataType())
|
||||
// if (matrix.dataType() == input->dataType())
|
||||
lup_<T>(context, &matrix, nullptr, nullptr);
|
||||
else
|
||||
lup_<float>(context, &matrix, nullptr, nullptr);
|
||||
// else
|
||||
// lup_<float>(context, &matrix, nullptr, nullptr);
|
||||
auto offset = shape::getIndexOffset(e, output->shapeInfo(), output->lengthOf());
|
||||
auto inputBuf = reinterpret_cast<T*>(matrix.specialBuffer());
|
||||
auto outputBuf = reinterpret_cast<T*>(output->specialBuffer()) + offset;
|
||||
if (matrix.dataType() == input->dataType())
|
||||
// if (matrix.dataType() == input->dataType())
|
||||
determinantKernel<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n);
|
||||
else
|
||||
determinantKernel<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n);
|
||||
// else
|
||||
// determinantKernel<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n);
|
||||
}
|
||||
NDArray::registerSpecialUse({output}, {input});
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template int determinant_, (nd4j::LaunchContext* context, NDArray* input, NDArray* output), FLOAT_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template int determinant_, (nd4j::LaunchContext* context, NDArray* input, NDArray* output), FLOAT_NATIVE);
|
||||
|
||||
int determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (context, input, output), FLOAT_TYPES);
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (context, input, output), FLOAT_NATIVE);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -552,22 +588,22 @@ namespace helpers {
|
|||
output->assign(1.f);
|
||||
for (int e = 0; e < output->lengthOf(); e++) {
|
||||
Nd4jLong pos = e * n2;
|
||||
if (matrix.dataType() == input->dataType())
|
||||
// if (matrix.dataType() == input->dataType())
|
||||
fillMatrix<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
||||
else
|
||||
fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
||||
// else
|
||||
// fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
||||
|
||||
if (matrix.dataType() == input->dataType())
|
||||
// if (matrix.dataType() == input->dataType())
|
||||
lup_<T>(context, &matrix, nullptr, nullptr);
|
||||
else
|
||||
lup_<float>(context, &matrix, nullptr, nullptr);
|
||||
// else
|
||||
// lup_<float>(context, &matrix, nullptr, nullptr);
|
||||
auto offset = shape::getIndexOffset(e, output->shapeInfo(), output->lengthOf());
|
||||
auto inputBuf = reinterpret_cast<T*>(matrix.specialBuffer());
|
||||
auto outputBuf = reinterpret_cast<T*>(output->specialBuffer()) + offset;
|
||||
if (matrix.dataType() == input->dataType())
|
||||
// if (matrix.dataType() == input->dataType())
|
||||
determinantLogKernel<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n);
|
||||
else
|
||||
determinantLogKernel<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n);
|
||||
// else
|
||||
// determinantLogKernel<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n);
|
||||
}
|
||||
NDArray::registerSpecialUse({output}, {input});
|
||||
|
||||
|
@ -576,10 +612,10 @@ namespace helpers {
|
|||
return ND4J_STATUS_OK;
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template int logAbsDeterminant_, (LaunchContext* context, NDArray* input, NDArray* output), FLOAT_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template int logAbsDeterminant_, (LaunchContext* context, NDArray* input, NDArray* output), FLOAT_NATIVE);
|
||||
|
||||
int logAbsDeterminant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (context, input, output), FLOAT_TYPES);
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (context, input, output), FLOAT_NATIVE);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -597,10 +633,12 @@ namespace helpers {
|
|||
|
||||
if (threadIdx.x == 0) {
|
||||
xShapeOf = shape::shapeOf(lowerShape);
|
||||
yShapeOf = shape::shapeOf(upperShape);
|
||||
zShapeOf = shape::shapeOf(matrixShape);
|
||||
xStrideOf = shape::stride(lowerShape);
|
||||
|
||||
yShapeOf = shape::shapeOf(upperShape);
|
||||
yStrideOf = shape::stride(upperShape);
|
||||
|
||||
zShapeOf = shape::shapeOf(matrixShape);
|
||||
zStrideOf = shape::stride(matrixShape);
|
||||
lowerMatrix = reinterpret_cast<T*>(lowerBuf);
|
||||
upperMatrix = reinterpret_cast<T*>(upperBuf);
|
||||
|
@ -610,15 +648,16 @@ namespace helpers {
|
|||
|
||||
for (int k = blockIdx.x; k < n; k += gridDim.x) { // and then put all values under main diagonal on to it
|
||||
for (int j = threadIdx.x; j < n; j += blockDim.x) {
|
||||
Nd4jLong posX[] = {j, k};
|
||||
|
||||
Nd4jLong posX[] = {k, j};
|
||||
Nd4jLong posD[] = {j, j};
|
||||
auto xPos = shape::getOffset(0, xShapeOf, xStrideOf, posX, 2);
|
||||
auto yPos = shape::getOffset(0, yShapeOf, yStrideOf, posX, 2);
|
||||
auto pos = shape::getOffset(0, zShapeOf, zStrideOf, posX, 2);
|
||||
if (k <= j)
|
||||
lowerMatrix[xPos] = matrix[pos];//(k, j);
|
||||
auto iPos = shape::getOffset(0, zShapeOf, zStrideOf, posX, 2);
|
||||
auto dPos = shape::getOffset(0, zShapeOf, zStrideOf, posD, 2);
|
||||
if (k >= j)
|
||||
lowerMatrix[xPos] = matrix[iPos];//(k, j);
|
||||
else
|
||||
upperMatrix[yPos] = matrix[pos]; //k, j);
|
||||
upperMatrix[yPos] = matrix[iPos]; //k, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -639,38 +678,26 @@ namespace helpers {
|
|||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {output->rankOf() - 2, output->rankOf() - 1});
|
||||
auto stream = context->getCudaStream();
|
||||
|
||||
// PRAGMA_OMP_PARALLEL_FOR
|
||||
for (auto i = 0LL; i < packX.numberOfTads(); i++) {
|
||||
fillMatrix<T, float><<<1, n2, 128, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), i * n2, n);
|
||||
permutation.assign(0.f);
|
||||
lup_<float>(context, &matrix, &compound, &permutation);
|
||||
fillMatrix<T, T><<<1, n2, 128, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), i * n2, n);
|
||||
matrix.tickWriteDevice();
|
||||
permutation.tickWriteDevice();
|
||||
permutation.printIndexedBuffer("PERMUTE");
|
||||
lower.setIdentity(); // set up U to identity matrix
|
||||
upper.setIdentity();
|
||||
fillLowerUpperKernel<float><<<1, n2, 128>>>(lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), matrix.specialBuffer(), matrix.specialShapeInfo(), n);
|
||||
lower.tickWriteDevice();
|
||||
upper.tickWriteDevice();
|
||||
invertUpperMatrix(&upper, &matrix);
|
||||
invertLowerMatrix(&lower, &upper);
|
||||
lower.tickWriteDevice();
|
||||
upper.tickWriteDevice();
|
||||
lower.printIndexedBuffer("LOWER");
|
||||
upper.printIndexedBuffer("UPPER");
|
||||
compound.assign(matrix);
|
||||
lup_<T>(context, &compound, nullptr, nullptr);
|
||||
fillLowerUpperKernel<T><<<n, n, 128>>>(lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), compound.specialBuffer(), compound.specialShapeInfo(), n);
|
||||
matrix.assign(0);
|
||||
invertUpperMatrix(&upper, &matrix); // U^{-1}
|
||||
compound.assign(0);
|
||||
invertLowerMatrix(&lower, &compound); // L{-1}
|
||||
|
||||
nd4j::MmulHelper::mmul(&matrix, &upper, &compound, 1.0, 0.0);
|
||||
nd4j::MmulHelper::mmul(&compound, &permutation, &matrix, 1.0, 0.0);
|
||||
// for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) {
|
||||
// output->t<T>(k) = matrix.template t<T>(row++);
|
||||
// }
|
||||
nd4j::MmulHelper::mmul(&matrix, &compound, &upper, 1.0, 0.0);
|
||||
returnMatrix<T, T><<<1, n2, 128, *stream>>>(output->specialBuffer(), output->specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), i * n2, n);
|
||||
}
|
||||
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int inverse(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (context, input, output), FLOAT_TYPES);
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (context, input, output), FLOAT_NATIVE);
|
||||
}
|
||||
|
||||
bool checkCholeskyInput(nd4j::LaunchContext * context, NDArray const* input) {
|
||||
|
@ -803,7 +830,7 @@ namespace helpers {
|
|||
return cholesky_(context, input, output, inplace);
|
||||
}
|
||||
// BUILD_SINGLE_TEMPLATE(template int cholesky_, (LaunchContext* context, NDArray* input, NDArray* output, bool inplace), FLOAT_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template int inverse_, (nd4j::LaunchContext* context, NDArray* input, NDArray* output), FLOAT_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template int inverse_, (nd4j::LaunchContext* context, NDArray* input, NDArray* output), FLOAT_NATIVE);
|
||||
|
||||
__global__ void logDetKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong batchNum, Nd4jLong* tadShape, Nd4jLong* tadOffsets, void* outputBuf, Nd4jLong* outputShape) {
|
||||
__shared__ double* output;
|
||||
|
|
|
@ -143,7 +143,7 @@ namespace helpers {
|
|||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
static void _reverseSequence(nd4j::LaunchContext * context, const NDArray* input, const NDArray* seqLengths, NDArray* output, int seqDim, const int batchDim){
|
||||
static void reverseSequence_(nd4j::LaunchContext * context, const NDArray* input, const NDArray* seqLengths, NDArray* output, int seqDim, const int batchDim){
|
||||
int posOfNonUnityDim = -1;
|
||||
seqLengths->syncToHost();
|
||||
auto stream = context->getCudaStream();
|
||||
|
@ -193,7 +193,7 @@ namespace helpers {
|
|||
}
|
||||
|
||||
void reverseSequence(nd4j::LaunchContext * context, const NDArray* input, const NDArray* seqLengths, NDArray* output, int seqDim, const int batchDim) {
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), _reverseSequence, (context, input, seqLengths, output, seqDim, batchDim), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), reverseSequence_, (context, input, seqLengths, output, seqDim, batchDim), LIBND4J_TYPES);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -391,7 +391,7 @@ static void scatterCudaLauncher(const int blocksPerGrid, const int threadsPerBlo
|
|||
///////////////////////////////////////////////////////////////////
|
||||
void scatter(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock) {
|
||||
|
||||
PointersManager manager(context, "scatterND");
|
||||
PointersManager manager(context, "scatter");
|
||||
|
||||
NDArray::prepareSpecialUse({&output}, {&updates, &indices});
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,427 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author GS <sgazeos@gmail.com>
|
||||
//
|
||||
|
||||
#include <ops/declarable/helpers/segment.h>
|
||||
#include <ops/declarable/helpers/segment_common.h>
|
||||
|
||||
#include <NDArrayFactory.h>
|
||||
#include <helpers/ShapeUtils.h>
|
||||
#include <helpers/TAD.h>
|
||||
#include <exceptions/cuda_exception.h>
|
||||
#include <PointersManager.h>
|
||||
#include <ConstantTadHelper.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
// Segment ops linear kernels
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
|
||||
template<typename T, typename I>
|
||||
static __global__ void
|
||||
segmentMaxLinearKernel(void *input, Nd4jLong *inputShape, int *starts, int *lengths, Nd4jLong numOfClasses,
|
||||
void *output, Nd4jLong *outputShape) {
|
||||
__shared__
|
||||
T *val;
|
||||
__shared__
|
||||
Nd4jLong xLen, zLen, segment, zIndex;
|
||||
__shared__
|
||||
T *x;
|
||||
__shared__
|
||||
T *z;
|
||||
__shared__ int threadsPerSegment, start, finish;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses;
|
||||
segment = blockIdx.x / threadsPerSegment;
|
||||
x = reinterpret_cast<T *>(input);
|
||||
z = reinterpret_cast<T *>(output);
|
||||
extern __shared__ unsigned char shmem[];
|
||||
val = reinterpret_cast<T *>(shmem);
|
||||
xLen = shape::length(inputShape);
|
||||
zLen = shape::length(outputShape);
|
||||
|
||||
if (segment < numOfClasses) {
|
||||
zIndex = shape::getIndexOffset(segment, outputShape, zLen);
|
||||
start = starts[segment];
|
||||
finish = start + lengths[segment];
|
||||
z[zIndex] = x[shape::getIndexOffset(start, inputShape, xLen)];
|
||||
val[segment] = z[zIndex];
|
||||
}
|
||||
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (auto e = start + threadIdx.x + 1; e < finish; e += blockDim.x) {
|
||||
auto xIndex = shape::getIndexOffset(e, inputShape, xLen);
|
||||
nd4j::math::atomics::nd4j_atomicMax(&z[zIndex], x[xIndex]);
|
||||
}
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
|
||||
template<typename T, typename I>
|
||||
static __global__ void
|
||||
unsortedSegmentMaxLinearKernel(void *input, Nd4jLong *inputShape, void *indices, Nd4jLong *indicesShape,
|
||||
int *starts, int *lengths, Nd4jLong numOfClasses, void *output,
|
||||
Nd4jLong *outputShape) {
|
||||
__shared__
|
||||
T *val;
|
||||
__shared__
|
||||
Nd4jLong xLen, zLen, segment, zIndex;
|
||||
__shared__
|
||||
T *x;
|
||||
__shared__
|
||||
T *z;
|
||||
__shared__
|
||||
I *y; //int threadsPerSegment, start, finish;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
segment = blockIdx.x;
|
||||
x = reinterpret_cast<T *>(input);
|
||||
z = reinterpret_cast<T *>(output);
|
||||
y = reinterpret_cast<I *>(indices);
|
||||
xLen = shape::length(inputShape);
|
||||
zLen = shape::length(outputShape);
|
||||
|
||||
zIndex = shape::getIndexOffset(segment, outputShape, zLen);
|
||||
//start = starts[segment];
|
||||
//finish = start + lengths[segment];
|
||||
if (lengths[segment] > 0)
|
||||
z[zIndex] = x[shape::getIndexOffset(starts[segment], inputShape, xLen)];
|
||||
else
|
||||
z[zIndex] = -DataTypeUtils::max<T>();
|
||||
}
|
||||
__syncthreads();
|
||||
if (lengths[segment] > 0)
|
||||
for (auto e = threadIdx.x + 1; e < xLen; e += blockDim.x) {
|
||||
auto xIndex = shape::getIndexOffset(e, inputShape, xLen);
|
||||
auto yIndex = shape::getIndexOffset(e, indicesShape, xLen);
|
||||
if (y[yIndex] == segment) {
|
||||
nd4j::math::atomics::nd4j_atomicMax(&z[zIndex], x[xIndex]);
|
||||
}
|
||||
}
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
template <typename T, typename I>
|
||||
static __global__ void segmentMaxTadKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong* inputTads,
|
||||
Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf,
|
||||
Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets, T filler = 0) {
|
||||
|
||||
__shared__ T* val;
|
||||
__shared__ Nd4jLong len, segment, zIndex, total;
|
||||
__shared__ T* z;
|
||||
__shared__ int start, finish;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
segment = indices[blockIdx.x]; // / threadsPerSegment;
|
||||
z = reinterpret_cast<T*>(outputBuf) + outputTadOffsets[segment];
|
||||
len = shape::length(inputTads);
|
||||
|
||||
start = starts[segment];
|
||||
finish = start + lengths[segment];
|
||||
total = shape::sizeAt(inputShape, 0);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
auto idx = blockIdx.x;
|
||||
if (blockIdx.x <= total) {
|
||||
auto x = reinterpret_cast<T *>(inputBuf) + inputTadOffsets[idx];
|
||||
if (blockIdx.x == start) {
|
||||
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
||||
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
||||
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
||||
z[zIndex] = x[xIndex];
|
||||
}
|
||||
}
|
||||
else {
|
||||
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
||||
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
||||
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
||||
nd4j::math::atomics::nd4j_atomicMax(&z[zIndex], x[xIndex]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
|
||||
template <typename T, typename I>
|
||||
static void segmentMaxFunctor_(LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) {
|
||||
//int numClasses = output->sizeAt(0);
|
||||
// if input is a vector: (as if in doc sample)
|
||||
//Nd4jLong idx = indices->e<Nd4jLong>(0);
|
||||
auto stream = context->getCudaStream();
|
||||
indices->syncToHost();
|
||||
Nd4jLong numOfClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
|
||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
|
||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses});
|
||||
|
||||
classesRangesBegs.assign(indices->lengthOf());
|
||||
classesRangesLens.assign(0);
|
||||
dim3 dims(256, 512, 256);
|
||||
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
|
||||
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
|
||||
fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens);
|
||||
|
||||
NDArray::prepareSpecialUse({output}, {input, indices, &classesRangesBegs, &classesRangesLens});
|
||||
|
||||
if (input->isVector()) {
|
||||
|
||||
segmentMaxLinearKernel<T,I><<<numOfClasses, input->lengthOf(), numOfClasses * 32 + 32, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo());
|
||||
}
|
||||
else {
|
||||
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||
Nd4jLong* inputTads = packX.specialShapeInfo();
|
||||
Nd4jLong* inputTadOffsets = packX.specialOffsets();
|
||||
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||
segmentMaxTadKernel<T,I><<<packX.numberOfTads(), 512, 2048, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
|
||||
}
|
||||
NDArray::registerSpecialUse({output}, {input, indices, &classesRangesBegs, &classesRangesLens});
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
void segmentMaxFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) {
|
||||
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentMaxFunctor_, (context, input, indices, output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||
}
|
||||
BUILD_DOUBLE_TEMPLATE(template void segmentMaxFunctor_, (LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
|
||||
template <typename T, typename I>
|
||||
static void unsortedSegmentMaxFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||
auto stream = context->getCudaStream();
|
||||
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
|
||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses});
|
||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
|
||||
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
|
||||
// classes.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), &row, &classes);
|
||||
classesRangesBegs.assign(indices->lengthOf());
|
||||
classesRangesLens.assign(0);
|
||||
dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32);
|
||||
// int* classesBuf = reinterpret_cast<int*>(classes.specialBuffer());
|
||||
fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens);
|
||||
int* begins = reinterpret_cast<int*>(classesRangesBegs.getSpecialBuffer());
|
||||
int* lengths = reinterpret_cast<int*>(classesRangesLens.getSpecialBuffer());
|
||||
|
||||
if (input->isVector()) {
|
||||
unsortedSegmentMaxLinearKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo());
|
||||
}
|
||||
else {
|
||||
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||
Nd4jLong* inputTads = packX.specialShapeInfo();
|
||||
Nd4jLong* inputTadOffsets = packX.specialOffsets();
|
||||
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||
dims.x = input->sizeAt(0);
|
||||
output->assign(-DataTypeUtils::max<T>());
|
||||
segmentMaxTadKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
|
||||
}
|
||||
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
void unsortedSegmentMaxFunctor(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMaxFunctor_, (context, input, indices, numOfClasses, output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
BUILD_DOUBLE_TEMPLATE(template void unsortedSegmentMaxFunctor_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
// segment max
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
template <typename T, typename I>
|
||||
static __global__ void segmentMaxBPLinearKernel(void* inputBuf, Nd4jLong* inputShape, void* forwardOutput,
|
||||
Nd4jLong* forwardShape, void* eps, Nd4jLong* epsShape, void* indicesBuf, Nd4jLong* indicesShape,
|
||||
void* outputBuf, Nd4jLong* outputShape) {
|
||||
__shared__ T* x;
|
||||
__shared__ T* gradIn;
|
||||
__shared__ T* gradOut;
|
||||
__shared__ I* y;
|
||||
__shared__ T* z;
|
||||
__shared__ Nd4jLong xLen, gradLen;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
xLen = shape::length(inputShape);
|
||||
x = reinterpret_cast<T*>(inputBuf);
|
||||
y = reinterpret_cast<I*>(indicesBuf);
|
||||
z = reinterpret_cast<T*>(outputBuf);
|
||||
gradIn = reinterpret_cast<T*>(forwardOutput);
|
||||
gradOut = reinterpret_cast<T*>(eps);
|
||||
gradLen = shape::length(epsShape);
|
||||
}
|
||||
|
||||
auto start = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
auto step = gridDim.x * blockDim.x;
|
||||
|
||||
for (auto e = start; e < xLen; e += step) {
|
||||
|
||||
auto zOffset = shape::getIndexOffset(e, outputShape, xLen);
|
||||
auto xOffset = shape::getIndexOffset(e, inputShape, xLen);
|
||||
auto yOffset = shape::getIndexOffset(e, indicesShape, xLen);
|
||||
auto classIndex = y[yOffset];
|
||||
auto gradOffsetI = shape::getIndexOffset(classIndex, forwardShape, gradLen);
|
||||
auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape, gradLen);
|
||||
|
||||
if (nd4j::math::nd4j_abs(gradIn[gradOffsetI] - x[xOffset]) <= T(1.e-6)) {
|
||||
z[zOffset] = gradOut[gradOffsetO];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
template <typename T, typename I>
|
||||
static __global__ void segmentMaxBPTadKernel(void* inputBuf, Nd4jLong* inputShape, void* forwardOutput,
|
||||
Nd4jLong* forwardShape, void* eps, Nd4jLong* epsShape, void* indicesBuf, Nd4jLong* indicesShape,
|
||||
void* outputBuf, Nd4jLong* outputShape,Nd4jLong* inputTad,
|
||||
Nd4jLong* inputOffsets, Nd4jLong* gradInTad, Nd4jLong* gradInOffsets,
|
||||
Nd4jLong* gradOutTad, Nd4jLong* gradOutOffsets, Nd4jLong* outTad,
|
||||
Nd4jLong* outOffsets) {
|
||||
__shared__ T* x;
|
||||
__shared__ T* gradIn;
|
||||
__shared__ T* gradOut;
|
||||
__shared__ I* y;
|
||||
__shared__ T* z;
|
||||
__shared__ Nd4jLong xLen, yLen, gradLen, currentLen;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
xLen = shape::length(inputShape);
|
||||
x = reinterpret_cast<T*>(inputBuf);
|
||||
y = reinterpret_cast<I*>(indicesBuf);
|
||||
z = reinterpret_cast<T*>(outputBuf);
|
||||
yLen = shape::length(indicesShape);
|
||||
gradOut = reinterpret_cast<T*>(eps);
|
||||
gradIn = reinterpret_cast<T*>(forwardOutput);
|
||||
gradLen = shape::length(epsShape);
|
||||
currentLen = shape::length(outTad);
|
||||
}
|
||||
|
||||
for (auto i = blockIdx.x; i < yLen; i += gridDim.x) {
|
||||
auto yIndex = shape::getIndexOffset(i, indicesShape, yLen);
|
||||
auto segment = y[yIndex];
|
||||
T* current = x + inputOffsets[i];
|
||||
T* currentOut = z + outOffsets[i];
|
||||
T* in = gradIn + gradInOffsets[segment];
|
||||
T* outGrad = gradOut + gradOutOffsets[segment];
|
||||
|
||||
for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) {
|
||||
if (nd4j::math::nd4j_abs(in[e] - current[e]) <= T(1.e-6))
|
||||
currentOut[e] = outGrad[e];
|
||||
}
|
||||
}
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
template <typename T, typename I>
|
||||
int segmentMaxFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
|
||||
//int numOfClasses = gradOut->sizeAt(0);
|
||||
// if input is a vector: (as if in doc sample)
|
||||
auto stream = context->getCudaStream();
|
||||
NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), DataTypeUtils::fromT<T>(), context);//->shapeInfo(), context);
|
||||
segmentMaxFunctor_<T, I>(context, input, indices, &tempRes);
|
||||
NDArray::prepareSpecialUse({output}, {input, indices, gradOut, &tempRes});
|
||||
if (input->isVector()) {
|
||||
Nd4jLong loop_size = input->lengthOf();
|
||||
auto numOfClasses = gradOut->lengthOf(); //indices->e<Nd4jLong>(loop_size - 1);
|
||||
segmentMaxBPLinearKernel<T,I><<<1 + gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
|
||||
tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
|
||||
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo());
|
||||
}
|
||||
else {
|
||||
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||
auto packGradIn = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.getShapeInfo(), dimensions);
|
||||
auto packGradOut = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->getShapeInfo(), dimensions);
|
||||
Nd4jLong* inputTads = packX.specialShapeInfo();
|
||||
Nd4jLong* inputTadOffsets = packX.specialOffsets();
|
||||
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||
Nd4jLong* gradInTads = packGradIn.specialShapeInfo();
|
||||
Nd4jLong* gradInTadOffsets = packGradIn.specialOffsets();
|
||||
Nd4jLong* gradOutTads = packGradOut.specialShapeInfo();
|
||||
Nd4jLong* gradOutTadOffsets = packGradOut.specialOffsets();
|
||||
|
||||
segmentMaxBPTadKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
|
||||
tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
|
||||
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(),
|
||||
inputTads, inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, gradOutTadOffsets,
|
||||
outputTads, outputTadOffsets);
|
||||
}
|
||||
NDArray::registerSpecialUse({output}, {input, indices, gradOut, &tempRes});
|
||||
return Status::OK();
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
int segmentMaxFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
|
||||
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentMaxFunctorBP_, (context, input,
|
||||
indices, gradOut, output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
BUILD_DOUBLE_TEMPLATE(template int segmentMaxFunctorBP_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
template <typename T, typename I>
|
||||
static int unsortedSegmentMaxFunctorBP_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
|
||||
//int numOfClasses = gradOut->sizeAt(0);
|
||||
// if input is a vector: (as if in doc sample)
|
||||
auto stream = context->getCudaStream();
|
||||
NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), DataTypeUtils::fromT<T>(), context);//->shapeInfo(), context);
|
||||
unsortedSegmentMaxFunctor_<T, I>(context, input, indices, numOfClasses, &tempRes);
|
||||
NDArray::prepareSpecialUse({output}, {input, indices, gradOut, &tempRes});
|
||||
if (input->isVector()) {
|
||||
Nd4jLong loop_size = input->lengthOf();
|
||||
auto numOfClasses = gradOut->lengthOf(); //indices->e<Nd4jLong>(loop_size - 1);
|
||||
segmentMaxBPLinearKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
|
||||
tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
|
||||
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo());
|
||||
}
|
||||
else {
|
||||
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||
auto packGradIn = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.getShapeInfo(), dimensions);
|
||||
auto packGradOut = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->getShapeInfo(), dimensions);
|
||||
Nd4jLong* inputTads = packX.specialShapeInfo();
|
||||
Nd4jLong* inputTadOffsets = packX.specialOffsets();
|
||||
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||
Nd4jLong* gradInTads = packGradIn.specialShapeInfo();
|
||||
Nd4jLong* gradInTadOffsets = packGradIn.specialOffsets();
|
||||
Nd4jLong* gradOutTads = packGradOut.specialShapeInfo();
|
||||
Nd4jLong* gradOutTadOffsets = packGradOut.specialOffsets();
|
||||
|
||||
segmentMaxBPTadKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
|
||||
tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
|
||||
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(),
|
||||
inputTads, inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, gradOutTadOffsets,
|
||||
outputTads, outputTadOffsets);
|
||||
}
|
||||
NDArray::registerSpecialUse({output}, {input, indices, gradOut, &tempRes});
|
||||
return Status::OK();
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
int unsortedSegmentMaxFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
|
||||
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMaxFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
BUILD_DOUBLE_TEMPLATE(template int unsortedSegmentMaxFunctorBP_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,414 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author GS <sgazeos@gmail.com>
|
||||
//
|
||||
|
||||
#include <ops/declarable/helpers/segment.h>
|
||||
#include <ops/declarable/helpers/segment_common.h>
|
||||
#include <NDArrayFactory.h>
|
||||
#include <helpers/ShapeUtils.h>
|
||||
#include <helpers/TAD.h>
|
||||
#include <exceptions/cuda_exception.h>
|
||||
#include <PointersManager.h>
|
||||
#include <ConstantTadHelper.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
// Segment ops linear kernels
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
template <typename T, typename I>
|
||||
static __global__ void segmentMeanLinearKernel(void* input, Nd4jLong* inputShape, int* starts, int* lengths, Nd4jLong numOfClasses, void* output, Nd4jLong* outputShape) {
|
||||
__shared__ T* val;
|
||||
__shared__ Nd4jLong xLen, zLen, segment, zIndex;
|
||||
__shared__ T* x;
|
||||
__shared__ T* z;
|
||||
__shared__ int threadsPerSegment, start, finish;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses;
|
||||
segment = blockIdx.x / threadsPerSegment;
|
||||
x = reinterpret_cast<T*>(input);
|
||||
z = reinterpret_cast<T*>(output);
|
||||
// extern __shared__ unsigned char shmem[];
|
||||
// val = reinterpret_cast<T*>(shmem);
|
||||
xLen = shape::length(inputShape);
|
||||
zLen = shape::length(outputShape);
|
||||
|
||||
//[zIndex] =
|
||||
if (segment < numOfClasses) {
|
||||
zIndex = shape::getIndexOffset(segment, outputShape, zLen);
|
||||
start = starts[segment];
|
||||
finish = start + lengths[segment];
|
||||
//val[segment] = ;
|
||||
z[zIndex] = T(x[shape::getIndexOffset(start, inputShape, xLen)] / lengths[segment]);
|
||||
// val[segment] = z[zIndex];
|
||||
}
|
||||
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (auto e = start + threadIdx.x + 1; e < finish; e += blockDim.x) {
|
||||
auto xIndex = shape::getIndexOffset(e, inputShape, xLen);
|
||||
if (lengths[segment])
|
||||
nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], T(x[xIndex] / lengths[segment]));
|
||||
}
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
template <typename T, typename I>
|
||||
static __global__ void unsortedSegmentMeanLinearKernel(void* input, Nd4jLong* inputShape, void* indices, Nd4jLong* indicesShape, int* starts, int* lengths, Nd4jLong numOfClasses, void* output, Nd4jLong* outputShape) {
|
||||
__shared__ T* val;
|
||||
__shared__ Nd4jLong xLen, zLen, segment, zIndex;
|
||||
__shared__ T* x;
|
||||
__shared__ T* z;
|
||||
__shared__ I* y; //int threadsPerSegment, start, finish;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
// threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses;
|
||||
segment = blockIdx.x;// / threadsPerSegment;
|
||||
x = reinterpret_cast<T*>(input);
|
||||
z = reinterpret_cast<T*>(output);
|
||||
y = reinterpret_cast<I*>(indices);
|
||||
// extern __shared__ unsigned char shmem[];
|
||||
// val = reinterpret_cast<T*>(shmem);
|
||||
xLen = shape::length(inputShape);
|
||||
zLen = shape::length(outputShape);
|
||||
|
||||
// if (segment < numOfClasses) {
|
||||
zIndex = shape::getIndexOffset(segment, outputShape, zLen);
|
||||
//start = starts[segment];
|
||||
//finish = start + lengths[segment];
|
||||
if (lengths[segment] > 0)
|
||||
z[zIndex] = T(x[shape::getIndexOffset(starts[segment], inputShape, xLen)] / T(lengths[segment]));
|
||||
else
|
||||
z[zIndex] = 0; //DataTypeUtils::max<T>();
|
||||
// val[segment] = z[zIndex];
|
||||
// }
|
||||
|
||||
}
|
||||
__syncthreads();
|
||||
if (lengths[segment] > 0)
|
||||
for (auto e = threadIdx.x; e < xLen; e += blockDim.x) {
|
||||
auto xIndex = shape::getIndexOffset(e, inputShape, xLen);
|
||||
auto yIndex = shape::getIndexOffset(e, indicesShape, xLen);
|
||||
if (y[yIndex] == segment && e != starts[segment]) {
|
||||
nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], T(x[xIndex]/T(lengths[segment])));
|
||||
}
|
||||
}
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
// SegmentMean kernel
|
||||
template <typename T, typename I>
|
||||
static __global__ void segmentMeanTadKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong* inputTads, Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets) {
|
||||
__shared__ T* val;
|
||||
__shared__ Nd4jLong len, segment, zIndex, total;
|
||||
__shared__ T* z;
|
||||
__shared__ int threadsPerSegment, start, finish;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
segment = indices[blockIdx.x]; // / threadsPerSegment;
|
||||
z = reinterpret_cast<T*>(outputBuf) + outputTadOffsets[segment];
|
||||
len = shape::length(inputTads);
|
||||
start = starts[segment];
|
||||
finish = start + lengths[segment];
|
||||
total = shape::sizeAt(inputShape, 0);
|
||||
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
auto idx = blockIdx.x;
|
||||
if (blockIdx.x <= total) {
|
||||
auto x = reinterpret_cast<T *>(inputBuf) + inputTadOffsets[idx];
|
||||
if (blockIdx.x == start) {
|
||||
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
||||
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
||||
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
||||
z[zIndex] = T(x[xIndex]/lengths[segment]);
|
||||
}
|
||||
}
|
||||
else {
|
||||
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
||||
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
||||
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
||||
if (lengths[segment])
|
||||
nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], T(x[xIndex]/lengths[segment]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
// segmen mean
|
||||
template <typename T, typename I>
|
||||
static void segmentMeanFunctor_(LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) {
|
||||
auto stream = context->getCudaStream();
|
||||
Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
|
||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses});
|
||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses});
|
||||
|
||||
classesRangesBegs.assign(indices->lengthOf());
|
||||
classesRangesLens.assign(0);
|
||||
|
||||
dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32);
|
||||
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
|
||||
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
|
||||
fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens);
|
||||
|
||||
if (input->isVector()) {
|
||||
segmentMeanLinearKernel<T,I><<<numClasses, input->lengthOf(), numClasses * 32 + 32, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo());
|
||||
}
|
||||
else {
|
||||
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||
Nd4jLong* inputTads = packX.specialShapeInfo();
|
||||
Nd4jLong* inputTadOffsets = packX.specialOffsets();
|
||||
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||
segmentMeanTadKernel<T,I><<<input->sizeAt(0), 512, 2048, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
|
||||
}
|
||||
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
void segmentMeanFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) {
|
||||
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), segmentMeanFunctor_, (context, input, indices, output), FLOAT_TYPES, INTEGER_TYPES);
|
||||
}
|
||||
BUILD_DOUBLE_TEMPLATE(template void segmentMeanFunctor_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output), FLOAT_TYPES, INTEGER_TYPES);
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
template <typename T, typename I>
|
||||
static void unsortedSegmentMeanFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||
auto stream = context->getCudaStream();
|
||||
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
|
||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses});
|
||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
|
||||
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
|
||||
// classes.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), &row, &classes);
|
||||
classesRangesBegs.assign(indices->lengthOf());
|
||||
classesRangesLens.assign(0);
|
||||
dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32);
|
||||
// int* classesBuf = reinterpret_cast<int*>(classes.specialBuffer());
|
||||
fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens);
|
||||
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
|
||||
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
|
||||
|
||||
if (input->isVector()) {
|
||||
unsortedSegmentMeanLinearKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo());
|
||||
}
|
||||
else {
|
||||
output->assign(0);
|
||||
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||
Nd4jLong* inputTads = packX.specialShapeInfo();
|
||||
Nd4jLong* inputTadOffsets = packX.specialOffsets();
|
||||
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||
dims.x = input->sizeAt(0);
|
||||
segmentMeanTadKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
|
||||
}
|
||||
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
void unsortedSegmentMeanFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMeanFunctor_, (context, input, indices, numOfClasses, output),
|
||||
FLOAT_TYPES, INTEGER_TYPES);
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
BUILD_DOUBLE_TEMPLATE(template void unsortedSegmentMeanFunctor_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), FLOAT_TYPES, INTEGER_TYPES);
|
||||
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
template <typename T, typename I>
|
||||
static __global__ void segmentMeanBPLinearKernel(void* inputBuf, Nd4jLong* inputShape, void* eps, Nd4jLong* epsShape, void* indicesBuf, Nd4jLong* indicesShape,
|
||||
int* lengths, void* outputBuf, Nd4jLong* outputShape) {
|
||||
__shared__ T* x;
|
||||
__shared__ T* gradIn;
|
||||
__shared__ T* gradOut;
|
||||
__shared__ I* y;
|
||||
__shared__ T* z;
|
||||
__shared__ Nd4jLong xLen, gradLen;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
xLen = shape::length(inputShape);
|
||||
x = reinterpret_cast<T*>(inputBuf);
|
||||
y = reinterpret_cast<I*>(indicesBuf);
|
||||
z = reinterpret_cast<T*>(outputBuf);
|
||||
gradOut = reinterpret_cast<T*>(eps);
|
||||
gradLen = shape::length(epsShape);
|
||||
}
|
||||
|
||||
auto start = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
auto step = gridDim.x * blockDim.x;
|
||||
|
||||
for (auto e = start; e < xLen; e += step) {
|
||||
|
||||
auto zOffset = shape::getIndexOffset(e, outputShape, xLen);
|
||||
auto xOffset = shape::getIndexOffset(e, inputShape, xLen);
|
||||
auto yOffset = shape::getIndexOffset(e, indicesShape, xLen);
|
||||
auto classIndex = y[yOffset];
|
||||
auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape, gradLen);
|
||||
|
||||
z[zOffset] = T(gradOut[gradOffsetO] / float(lengths[classIndex]));
|
||||
}
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
template <typename T, typename I>
|
||||
static __global__ void segmentMeanBPTadKernel(void* inputBuf, Nd4jLong* inputShape, void* eps, Nd4jLong* epsShape,
|
||||
void* indicesBuf, Nd4jLong* indicesShape, int* lengths, void* outputBuf, Nd4jLong* outputShape,Nd4jLong* inputTad,
|
||||
Nd4jLong* inputOffsets, Nd4jLong* gradOutTad, Nd4jLong* gradOutOffsets, Nd4jLong* outTad, Nd4jLong* outOffsets) {
|
||||
__shared__ T* x;
|
||||
__shared__ T* gradOut;
|
||||
__shared__ I* y;
|
||||
__shared__ T* z;
|
||||
__shared__ Nd4jLong xLen, yLen, gradLen, currentLen;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
xLen = shape::length(inputShape);
|
||||
x = reinterpret_cast<T*>(inputBuf);
|
||||
y = reinterpret_cast<I*>(indicesBuf);
|
||||
z = reinterpret_cast<T*>(outputBuf);
|
||||
yLen = shape::length(indicesShape);
|
||||
gradOut = reinterpret_cast<T*>(eps);
|
||||
gradLen = shape::length(epsShape);
|
||||
currentLen = shape::length(outTad);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (auto i = blockIdx.x; i < yLen; i += gridDim.x) {
|
||||
// auto yIndex = shape::getIndexOffset(i, indicesShape, yLen);
|
||||
auto segment = y[i]; //yIndex];
|
||||
T* currentOut = z + outOffsets[i];
|
||||
T* outGrad = gradOut + gradOutOffsets[segment];
|
||||
|
||||
for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) {
|
||||
auto zIndex = shape::getIndexOffset(e, outTad, currentLen);
|
||||
auto gradIndex = shape::getIndexOffset(e, gradOutTad, gradLen);
|
||||
if (lengths[segment] > 0)
|
||||
currentOut[zIndex] = T(outGrad[gradIndex] / float(lengths[segment]));
|
||||
}
|
||||
}
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
// backrop for mean
|
||||
template <typename T, typename I>
|
||||
int segmentMeanFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
|
||||
auto stream = context->getCudaStream();
|
||||
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
|
||||
auto numClasses = indices->e<int>(indices->lengthOf() - 1) + 1;
|
||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses});
|
||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses});
|
||||
|
||||
classesRangesBegs.assign(indices->lengthOf());
|
||||
classesRangesLens.assign(0);
|
||||
dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32);
|
||||
fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens);
|
||||
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
|
||||
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
|
||||
|
||||
if (input->isVector()) {
|
||||
Nd4jLong loop_size = input->lengthOf();
|
||||
auto numOfClasses = gradOut->lengthOf(); //indices->e<Nd4jLong>(loop_size - 1);
|
||||
segmentMeanBPLinearKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(),
|
||||
input->specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
|
||||
indices->specialBuffer(), indices->specialShapeInfo(), lengths, output->specialBuffer(), output->specialShapeInfo());
|
||||
}
|
||||
else {
|
||||
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||
// auto packGradIn = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.getShapeInfo(), dimensions);
|
||||
auto packGradOut = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->getShapeInfo(), dimensions);
|
||||
Nd4jLong* inputTads = packX.specialShapeInfo();
|
||||
Nd4jLong* inputTadOffsets = packX.specialOffsets();
|
||||
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||
Nd4jLong* gradOutTads = packGradOut.specialShapeInfo();
|
||||
Nd4jLong* gradOutTadOffsets = packGradOut.specialOffsets();
|
||||
|
||||
segmentMeanBPTadKernel<T,I><<<indices->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
|
||||
gradOut->specialBuffer(), gradOut->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), lengths,
|
||||
output->specialBuffer(), output->specialShapeInfo(), inputTads, inputTadOffsets, gradOutTads, gradOutTadOffsets,
|
||||
outputTads, outputTadOffsets);
|
||||
}
|
||||
NDArray::registerSpecialUse({output}, {input, indices, gradOut});
|
||||
return Status::OK();
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
// segmen mean bp main
|
||||
int segmentMeanFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
|
||||
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentMeanFunctorBP_, (context, input,
|
||||
indices, gradOut, output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
BUILD_DOUBLE_TEMPLATE(template int segmentMeanFunctorBP_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output), FLOAT_TYPES, INTEGER_TYPES);
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
|
||||
template <typename T, typename I>
|
||||
static int unsortedSegmentMeanFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
|
||||
auto stream = context->getCudaStream();
|
||||
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
|
||||
auto numClasses = indices->e<int>(indices->lengthOf() - 1) + 1;
|
||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses});
|
||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses});
|
||||
|
||||
classesRangesBegs.assign(indices->lengthOf());
|
||||
classesRangesLens.assign(0);
|
||||
dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32);
|
||||
fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens);
|
||||
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
|
||||
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
|
||||
|
||||
if (input->isVector()) {
|
||||
Nd4jLong loop_size = input->lengthOf();
|
||||
auto numOfClasses = gradOut->lengthOf(); //indices->e<Nd4jLong>(loop_size - 1);
|
||||
segmentMeanBPLinearKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(),
|
||||
input->specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
|
||||
indices->specialBuffer(), indices->specialShapeInfo(), lengths, output->specialBuffer(), output->specialShapeInfo());
|
||||
}
|
||||
else {
|
||||
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||
// auto packGradIn = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.getShapeInfo(), dimensions);
|
||||
auto packGradOut = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->getShapeInfo(), dimensions);
|
||||
Nd4jLong* inputTads = packX.specialShapeInfo();
|
||||
Nd4jLong* inputTadOffsets = packX.specialOffsets();
|
||||
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||
Nd4jLong* gradOutTads = packGradOut.specialShapeInfo();
|
||||
Nd4jLong* gradOutTadOffsets = packGradOut.specialOffsets();
|
||||
|
||||
segmentMeanBPTadKernel<T,I><<<indices->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
|
||||
gradOut->specialBuffer(), gradOut->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), lengths,
|
||||
output->specialBuffer(), output->specialShapeInfo(), inputTads, inputTadOffsets, gradOutTads, gradOutTadOffsets,
|
||||
outputTads, outputTadOffsets);
|
||||
}
|
||||
NDArray::registerSpecialUse({output}, {input, indices, gradOut});
|
||||
return Status::OK();
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
int unsortedSegmentMeanFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
|
||||
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMeanFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INTEGER_TYPES);
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
|
||||
BUILD_DOUBLE_TEMPLATE(template int unsortedSegmentMeanFunctorBP_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), FLOAT_TYPES, INTEGER_TYPES);
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,423 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author GS <sgazeos@gmail.com>
|
||||
//
|
||||
|
||||
#include <ops/declarable/helpers/segment.h>
|
||||
#include <ops/declarable/helpers/segment_common.h>
|
||||
#include <NDArrayFactory.h>
|
||||
#include <helpers/ShapeUtils.h>
|
||||
#include <helpers/TAD.h>
|
||||
#include <exceptions/cuda_exception.h>
|
||||
#include <PointersManager.h>
|
||||
#include <ConstantTadHelper.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
// Segment ops linear kernels
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
|
||||
template<typename T, typename I>
|
||||
static __global__ void
|
||||
segmentMinLinearKernel(void *input, Nd4jLong *inputShape, int *starts, int *lengths, Nd4jLong numOfClasses,
|
||||
void *output, Nd4jLong *outputShape) {
|
||||
__shared__
|
||||
T *val;
|
||||
__shared__
|
||||
Nd4jLong xLen, zLen, segment, zIndex;
|
||||
__shared__
|
||||
T *x;
|
||||
__shared__
|
||||
T *z;
|
||||
__shared__ int threadsPerSegment, start, finish;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses;
|
||||
segment = blockIdx.x / threadsPerSegment;
|
||||
x = reinterpret_cast<T *>(input);
|
||||
z = reinterpret_cast<T *>(output);
|
||||
extern __shared__ unsigned char shmem[];
|
||||
val = reinterpret_cast<T *>(shmem);
|
||||
xLen = shape::length(inputShape);
|
||||
zLen = shape::length(outputShape);
|
||||
|
||||
if (segment < numOfClasses) {
|
||||
zIndex = shape::getIndexOffset(segment, outputShape, zLen);
|
||||
start = starts[segment];
|
||||
finish = start + lengths[segment];
|
||||
z[zIndex] = x[shape::getIndexOffset(start, inputShape, xLen)];
|
||||
val[segment] = z[zIndex];
|
||||
}
|
||||
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (auto e = start + threadIdx.x + 1; e < finish; e += blockDim.x) {
|
||||
auto xIndex = shape::getIndexOffset(e, inputShape, xLen);
|
||||
nd4j::math::atomics::nd4j_atomicMin(&z[zIndex], x[xIndex]);
|
||||
}
|
||||
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
|
||||
template<typename T, typename I>
|
||||
static __global__ void
|
||||
unsortedSegmentMinLinearKernel(void *input, Nd4jLong *inputShape, void *indices, Nd4jLong *indicesShape,
|
||||
int *starts, int *lengths, Nd4jLong numOfClasses, void *output,
|
||||
Nd4jLong *outputShape) {
|
||||
__shared__
|
||||
T *val;
|
||||
__shared__
|
||||
Nd4jLong xLen, zLen, segment, zIndex;
|
||||
__shared__
|
||||
T *x;
|
||||
__shared__
|
||||
T *z;
|
||||
__shared__
|
||||
I *y; //int threadsPerSegment, start, finish;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
segment = blockIdx.x;
|
||||
x = reinterpret_cast<T *>(input);
|
||||
z = reinterpret_cast<T *>(output);
|
||||
y = reinterpret_cast<I *>(indices);
|
||||
xLen = shape::length(inputShape);
|
||||
zLen = shape::length(outputShape);
|
||||
|
||||
zIndex = shape::getIndexOffset(segment, outputShape, zLen);
|
||||
if (lengths[segment] > 0)
|
||||
z[zIndex] = x[shape::getIndexOffset(starts[segment], inputShape, xLen)];
|
||||
else
|
||||
z[zIndex] = DataTypeUtils::max<T>();
|
||||
|
||||
}
|
||||
__syncthreads();
|
||||
if (lengths[segment] > 0)
|
||||
for (auto e = threadIdx.x + 1; e < xLen; e += blockDim.x) {
|
||||
auto xIndex = shape::getIndexOffset(e, inputShape, xLen);
|
||||
auto yIndex = shape::getIndexOffset(e, indicesShape, xLen);
|
||||
if (y[yIndex] == segment) {
|
||||
nd4j::math::atomics::nd4j_atomicMin(&z[zIndex], x[xIndex]);
|
||||
}
|
||||
}
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
// SegmentMin kernel
|
||||
template <typename T, typename I>
|
||||
static __global__ void segmentMinTadKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong* inputTads, Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets) {
|
||||
__shared__ T* val;
|
||||
__shared__ Nd4jLong len, segment, zIndex, total;
|
||||
__shared__ T* z;
|
||||
__shared__ int threadsPerSegment, start, finish;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
segment = indices[blockIdx.x]; // / threadsPerSegment;
|
||||
z = reinterpret_cast<T*>(outputBuf) + outputTadOffsets[segment];
|
||||
len = shape::length(inputTads);
|
||||
start = starts[segment];
|
||||
finish = start + lengths[segment];
|
||||
total = shape::sizeAt(inputShape, 0);
|
||||
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
auto idx = blockIdx.x;
|
||||
if (blockIdx.x <= total) {
|
||||
auto x = reinterpret_cast<T *>(inputBuf) + inputTadOffsets[idx];
|
||||
if (blockIdx.x == start) {
|
||||
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
||||
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
||||
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
||||
z[zIndex] = x[xIndex];
|
||||
}
|
||||
}
|
||||
else {
|
||||
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
||||
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
||||
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
||||
nd4j::math::atomics::nd4j_atomicMin(&z[zIndex], x[xIndex]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
// segmen min
|
||||
template <typename T, typename I>
|
||||
static void segmentMinFunctor_(LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) {
|
||||
auto stream = context->getCudaStream();
|
||||
Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
|
||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses});
|
||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses});
|
||||
|
||||
classesRangesBegs.assign(indices->lengthOf());
|
||||
classesRangesLens.assign(0);
|
||||
|
||||
fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens);
|
||||
NDArray::prepareSpecialUse({output}, {input, indices, &classesRangesBegs, &classesRangesLens});
|
||||
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
|
||||
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
|
||||
if (input->isVector()) {
|
||||
segmentMinLinearKernel<T,I><<<numClasses, input->lengthOf(), numClasses * 32 + 32, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo());
|
||||
}
|
||||
else {
|
||||
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||
Nd4jLong* inputTads = packX.specialShapeInfo();
|
||||
Nd4jLong* inputTadOffsets = packX.specialOffsets();
|
||||
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||
segmentMinTadKernel<T,I><<<input->sizeAt(0), 512, 2048, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
|
||||
|
||||
}
|
||||
NDArray::registerSpecialUse({output}, {input, indices, &classesRangesBegs, &classesRangesLens});
|
||||
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
void segmentMinFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) {
|
||||
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentMinFunctor_, (context, input, indices, output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||
}
|
||||
BUILD_DOUBLE_TEMPLATE(template void segmentMinFunctor_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
|
||||
template <typename T, typename I>
|
||||
static void unsortedSegmentMinFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||
auto stream = context->getCudaStream();
|
||||
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
|
||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses});
|
||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
|
||||
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
|
||||
// classes.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), &row, &classes);
|
||||
classesRangesBegs.assign(indices->lengthOf());
|
||||
classesRangesLens.assign(0);
|
||||
dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32);
|
||||
// int* classesBuf = reinterpret_cast<int*>(classes.specialBuffer());
|
||||
fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens);
|
||||
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
|
||||
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
|
||||
NDArray::prepareSpecialUse({output}, {input, indices});
|
||||
if (input->isVector()) {
|
||||
unsortedSegmentMinLinearKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo());
|
||||
}
|
||||
else {
|
||||
output->assign(DataTypeUtils::max<T>());
|
||||
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||
Nd4jLong* inputTads = packX.specialShapeInfo();
|
||||
Nd4jLong* inputTadOffsets = packX.specialOffsets();
|
||||
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||
dims.x = input->sizeAt(0);
|
||||
segmentMinTadKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
|
||||
}
|
||||
NDArray::registerSpecialUse({output}, {input, indices});
|
||||
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
void unsortedSegmentMinFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMinFunctor_, (context, input, indices, numOfClasses, output),
|
||||
NUMERIC_TYPES, INTEGER_TYPES);
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
|
||||
BUILD_DOUBLE_TEMPLATE(template void unsortedSegmentMinFunctor_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||
template <typename T, typename I>
|
||||
static __global__ void segmentMinBPLinearKernel(void* inputBuf, Nd4jLong* inputShape, void* forwardOutput,
|
||||
Nd4jLong* forwardShape, void* eps, Nd4jLong* epsShape, void* indicesBuf, Nd4jLong* indicesShape,
|
||||
void* outputBuf, Nd4jLong* outputShape) {
|
||||
__shared__ T* x;
|
||||
__shared__ T* gradIn;
|
||||
__shared__ T* gradOut;
|
||||
__shared__ I* y;
|
||||
__shared__ T* z;
|
||||
__shared__ Nd4jLong xLen, gradLen;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
xLen = shape::length(inputShape);
|
||||
x = reinterpret_cast<T*>(inputBuf);
|
||||
y = reinterpret_cast<I*>(indicesBuf);
|
||||
z = reinterpret_cast<T*>(outputBuf);
|
||||
gradIn = reinterpret_cast<T*>(forwardOutput);
|
||||
gradOut = reinterpret_cast<T*>(eps);
|
||||
gradLen = shape::length(epsShape);
|
||||
}
|
||||
|
||||
auto start = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
auto step = gridDim.x * blockDim.x;
|
||||
|
||||
for (auto e = start; e < xLen; e += step) {
|
||||
|
||||
auto zOffset = shape::getIndexOffset(e, outputShape, xLen);
|
||||
auto xOffset = shape::getIndexOffset(e, inputShape, xLen);
|
||||
auto yOffset = shape::getIndexOffset(e, indicesShape, xLen);
|
||||
auto classIndex = y[yOffset];
|
||||
auto gradOffsetI = shape::getIndexOffset(classIndex, forwardShape, gradLen);
|
||||
auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape, gradLen);
|
||||
|
||||
if (nd4j::math::nd4j_abs(gradIn[gradOffsetI] - x[xOffset]) <= T(1.e-6)) {
|
||||
z[zOffset] = gradOut[gradOffsetO];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
template <typename T, typename I>
|
||||
static __global__ void segmentMinBPTadKernel(void* inputBuf, Nd4jLong* inputShape, void* forwardOutput,
|
||||
Nd4jLong* forwardShape, void* eps, Nd4jLong* epsShape, void* indicesBuf, Nd4jLong* indicesShape,
|
||||
void* outputBuf, Nd4jLong* outputShape,Nd4jLong* inputTad,
|
||||
Nd4jLong* inputOffsets, Nd4jLong* gradInTad, Nd4jLong* gradInOffsets,
|
||||
Nd4jLong* gradOutTad, Nd4jLong* gradOutOffsets, Nd4jLong* outTad,
|
||||
Nd4jLong* outOffsets) {
|
||||
__shared__ T* x;
|
||||
__shared__ T* gradIn;
|
||||
__shared__ T* gradOut;
|
||||
__shared__ I* y;
|
||||
__shared__ T* z;
|
||||
__shared__ Nd4jLong xLen, yLen, gradLen, currentLen;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
xLen = shape::length(inputShape);
|
||||
x = reinterpret_cast<T*>(inputBuf);
|
||||
y = reinterpret_cast<I*>(indicesBuf);
|
||||
z = reinterpret_cast<T*>(outputBuf);
|
||||
yLen = shape::length(indicesShape);
|
||||
gradOut = reinterpret_cast<T*>(eps);
|
||||
gradIn = reinterpret_cast<T*>(forwardOutput);
|
||||
gradLen = shape::length(epsShape);
|
||||
currentLen = shape::length(outTad);
|
||||
}
|
||||
|
||||
for (auto i = blockIdx.x; i < yLen; i += gridDim.x) {
|
||||
auto yIndex = shape::getIndexOffset(i, indicesShape, yLen);
|
||||
auto segment = y[yIndex];
|
||||
T* current = x + inputOffsets[i];
|
||||
T* currentOut = z + outOffsets[i];
|
||||
T* in = gradIn + gradInOffsets[segment];
|
||||
T* outGrad = gradOut + gradOutOffsets[segment];
|
||||
|
||||
for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) {
|
||||
if (nd4j::math::nd4j_abs(in[e] - current[e]) <= T(1.e-6))
|
||||
currentOut[e] = outGrad[e];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
template <typename T, typename I>
|
||||
int segmentMinFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
|
||||
//int numOfClasses = gradOut->sizeAt(0);
|
||||
// if input is a vector: (as if in doc sample)
|
||||
auto stream = context->getCudaStream();
|
||||
NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), DataTypeUtils::fromT<T>(), context);//->shapeInfo(), context);
|
||||
segmentMinFunctor_<T, I>(context, input, indices, &tempRes);
|
||||
NDArray::prepareSpecialUse({output}, {input, indices, gradOut, &tempRes});
|
||||
if (input->isVector()) {
|
||||
Nd4jLong loop_size = input->lengthOf();
|
||||
auto numOfClasses = gradOut->lengthOf(); //indices->e<Nd4jLong>(loop_size - 1);
|
||||
|
||||
segmentMinBPLinearKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
|
||||
tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
|
||||
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo());
|
||||
}
|
||||
else {
|
||||
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||
auto packGradIn = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.getShapeInfo(), dimensions);
|
||||
auto packGradOut = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->getShapeInfo(), dimensions);
|
||||
Nd4jLong* inputTads = packX.specialShapeInfo();
|
||||
Nd4jLong* inputTadOffsets = packX.specialOffsets();
|
||||
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||
Nd4jLong* gradInTads = packGradIn.specialShapeInfo();
|
||||
Nd4jLong* gradInTadOffsets = packGradIn.specialOffsets();
|
||||
Nd4jLong* gradOutTads = packGradOut.specialShapeInfo();
|
||||
Nd4jLong* gradOutTadOffsets = packGradOut.specialOffsets();
|
||||
|
||||
segmentMinBPTadKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
|
||||
tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
|
||||
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(),
|
||||
inputTads, inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, gradOutTadOffsets,
|
||||
outputTads, outputTadOffsets);
|
||||
}
|
||||
NDArray::registerSpecialUse({output}, {input, indices, gradOut, &tempRes});
|
||||
return Status::OK();
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
// segmen min
|
||||
int segmentMinFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
|
||||
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentMinFunctorBP_, (context, input,
|
||||
indices, gradOut, output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||
}
|
||||
BUILD_DOUBLE_TEMPLATE(template int segmentMinFunctorBP_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
|
||||
template <typename T, typename I>
|
||||
static int unsortedSegmentMinFunctorBP_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
|
||||
//int numOfClasses = gradOut->sizeAt(0);
|
||||
// if input is a vector: (as if in doc sample)
|
||||
auto stream = context->getCudaStream();
|
||||
NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), DataTypeUtils::fromT<T>(), context);//->shapeInfo(), context);
|
||||
unsortedSegmentMinFunctor_<T, I>(context, input, indices, numOfClasses, &tempRes);
|
||||
NDArray::prepareSpecialUse({output}, {input, indices, gradOut, &tempRes});
|
||||
if (input->isVector()) {
|
||||
Nd4jLong loop_size = input->lengthOf();
|
||||
auto numOfClasses = gradOut->lengthOf(); //indices->e<Nd4jLong>(loop_size - 1);
|
||||
segmentMinBPLinearKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
|
||||
tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
|
||||
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo());
|
||||
}
|
||||
else {
|
||||
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||
auto packGradIn = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.getShapeInfo(), dimensions);
|
||||
auto packGradOut = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->getShapeInfo(), dimensions);
|
||||
Nd4jLong* inputTads = packX.specialShapeInfo();
|
||||
Nd4jLong* inputTadOffsets = packX.specialOffsets();
|
||||
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||
Nd4jLong* gradInTads = packGradIn.specialShapeInfo();
|
||||
Nd4jLong* gradInTadOffsets = packGradIn.specialOffsets();
|
||||
Nd4jLong* gradOutTads = packGradOut.specialShapeInfo();
|
||||
Nd4jLong* gradOutTadOffsets = packGradOut.specialOffsets();
|
||||
|
||||
segmentMinBPTadKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
|
||||
tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
|
||||
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(),
|
||||
inputTads, inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, gradOutTadOffsets,
|
||||
outputTads, outputTadOffsets);
|
||||
}
|
||||
NDArray::registerSpecialUse({output}, {input, indices, gradOut, &tempRes});
|
||||
return Status::OK();
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
int unsortedSegmentMinFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
|
||||
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMinFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
BUILD_DOUBLE_TEMPLATE(template int unsortedSegmentMinFunctorBP_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,419 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author GS <sgazeos@gmail.com>
|
||||
//
|
||||
|
||||
#include <ops/declarable/helpers/segment.h>
|
||||
#include <ops/declarable/helpers/segment_common.h>
|
||||
#include <NDArrayFactory.h>
|
||||
#include <helpers/ShapeUtils.h>
|
||||
#include <helpers/TAD.h>
|
||||
#include <exceptions/cuda_exception.h>
|
||||
#include <PointersManager.h>
|
||||
#include <ConstantTadHelper.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
// Segment Prod ops linear kernels
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
|
||||
template <typename T, typename I>
|
||||
static __global__ void segmentProdLinearKernel(void* input, Nd4jLong* inputShape, int* starts, int* lengths, Nd4jLong numOfClasses, void* output, Nd4jLong* outputShape) {
|
||||
__shared__ T* val;
|
||||
__shared__ Nd4jLong xLen, zLen, segment, zIndex;
|
||||
__shared__ T* x;
|
||||
__shared__ T* z;
|
||||
__shared__ int threadsPerSegment, start, finish;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses;
|
||||
segment = blockIdx.x / threadsPerSegment;
|
||||
x = reinterpret_cast<T*>(input);
|
||||
z = reinterpret_cast<T*>(output);
|
||||
extern __shared__ unsigned char shmem[];
|
||||
val = reinterpret_cast<T*>(shmem);
|
||||
xLen = shape::length(inputShape);
|
||||
zLen = shape::length(outputShape);
|
||||
|
||||
if (segment < numOfClasses) {
|
||||
zIndex = shape::getIndexOffset(segment, outputShape, zLen);
|
||||
start = starts[segment];
|
||||
finish = start + lengths[segment];
|
||||
//val[segment] = ;
|
||||
z[zIndex] = x[shape::getIndexOffset(start, inputShape, xLen)];
|
||||
val[segment] = z[zIndex];
|
||||
}
|
||||
|
||||
}
|
||||
__syncthreads();
|
||||
// auto tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
// auto step = blockDim.x * gridDim.x;
|
||||
|
||||
for (auto e = start + threadIdx.x + 1; e < finish; e += blockDim.x) {
|
||||
auto xIndex = shape::getIndexOffset(e, inputShape, xLen);
|
||||
nd4j::math::atomics::nd4j_atomicMul(&val[segment], x[xIndex]);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
z[zIndex] = val[segment];
|
||||
}
|
||||
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
template <typename T, typename I>
|
||||
static __global__ void unsortedSegmentProdLinearKernel(void* input, Nd4jLong* inputShape, void* indices, Nd4jLong* indicesShape, int* starts, int* lengths, Nd4jLong numOfClasses, void* output, Nd4jLong* outputShape) {
|
||||
__shared__ T* val;
|
||||
__shared__ Nd4jLong xLen, zLen, segment, zIndex;
|
||||
__shared__ T* x;
|
||||
__shared__ T* z;
|
||||
__shared__ I* y; //int threadsPerSegment, start, finish;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
// threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses;
|
||||
segment = blockIdx.x;// / threadsPerSegment;
|
||||
x = reinterpret_cast<T*>(input);
|
||||
z = reinterpret_cast<T*>(output);
|
||||
y = reinterpret_cast<I*>(indices);
|
||||
// extern __shared__ unsigned char shmem[];
|
||||
// val = reinterpret_cast<T*>(shmem);
|
||||
xLen = shape::length(inputShape);
|
||||
zLen = shape::length(outputShape);
|
||||
|
||||
// if (segment < numOfClasses) {
|
||||
zIndex = shape::getIndexOffset(segment, outputShape, zLen);
|
||||
//start = starts[segment];
|
||||
//finish = start + lengths[segment];
|
||||
if (lengths[segment] > 0)
|
||||
z[zIndex] = x[shape::getIndexOffset(starts[segment], inputShape, xLen)];
|
||||
else
|
||||
z[zIndex] = 0; //DataTypeUtils::max<T>();
|
||||
// val[segment] = z[zIndex];
|
||||
// }
|
||||
|
||||
}
|
||||
__syncthreads();
|
||||
if (lengths[segment] > 0)
|
||||
for (auto e = threadIdx.x; e < xLen; e += blockDim.x) {
|
||||
auto xIndex = shape::getIndexOffset(e, inputShape, xLen);
|
||||
auto yIndex = shape::getIndexOffset(e, indicesShape, xLen);
|
||||
if (y[yIndex] == segment && e != starts[segment]) {
|
||||
nd4j::math::atomics::nd4j_atomicMul(&z[zIndex], x[xIndex]);
|
||||
}
|
||||
}
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
// SegmentProd kernel
|
||||
template <typename T, typename I>
|
||||
static __global__ void segmentProdTadKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong* inputTads, Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets) {
|
||||
__shared__ T* val;
|
||||
__shared__ Nd4jLong len, segment, zIndex, total;
|
||||
__shared__ T* z;
|
||||
__shared__ int threadsPerSegment, start, finish;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
segment = indices[blockIdx.x]; // / threadsPerSegment;
|
||||
z = reinterpret_cast<T*>(outputBuf) + outputTadOffsets[segment];
|
||||
len = shape::length(inputTads);
|
||||
start = starts[segment];
|
||||
finish = start + lengths[segment];
|
||||
total = shape::sizeAt(inputShape, 0);
|
||||
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
auto idx = blockIdx.x;
|
||||
if (blockIdx.x <= total) {
|
||||
auto x = reinterpret_cast<T *>(inputBuf) + inputTadOffsets[idx];
|
||||
if (blockIdx.x == start) {
|
||||
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
||||
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
||||
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
||||
z[zIndex] = x[xIndex];
|
||||
}
|
||||
}
|
||||
else {
|
||||
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
||||
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
||||
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
||||
nd4j::math::atomics::nd4j_atomicMul(&z[zIndex], x[xIndex]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
|
||||
template <typename T, typename I>
|
||||
static void segmentProdFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) {
|
||||
auto stream = context->getCudaStream();
|
||||
Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
|
||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses});
|
||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses});
|
||||
|
||||
classesRangesBegs.assign(indices->lengthOf());
|
||||
classesRangesLens.assign(0);
|
||||
|
||||
dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32);
|
||||
fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens);
|
||||
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
|
||||
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
|
||||
|
||||
if (input->isVector()) {
|
||||
segmentProdLinearKernel<T,I><<<numClasses, input->lengthOf(), numClasses * 32 + 32, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo());
|
||||
}
|
||||
else {
|
||||
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||
Nd4jLong* inputTads = packX.specialShapeInfo();
|
||||
Nd4jLong* inputTadOffsets = packX.specialOffsets();
|
||||
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||
segmentProdTadKernel<T,I><<<input->sizeAt(0), 512, 2048, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
|
||||
}
|
||||
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
void segmentProdFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) {
|
||||
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), segmentProdFunctor_, (context, input, indices, output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||
}
|
||||
BUILD_DOUBLE_TEMPLATE(template void segmentProdFunctor_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output), FLOAT_TYPES, INTEGER_TYPES);
|
||||
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
template <typename T, typename I>
|
||||
static void unsortedSegmentProdFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||
auto stream = context->getCudaStream();
|
||||
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
|
||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses});
|
||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
|
||||
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
|
||||
// classes.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), &row, &classes);
|
||||
classesRangesBegs.assign(indices->lengthOf());
|
||||
classesRangesLens.assign(0);
|
||||
dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32);
|
||||
// int* classesBuf = reinterpret_cast<int*>(classes.specialBuffer());
|
||||
fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens);
|
||||
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
|
||||
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
|
||||
|
||||
if (input->isVector()) {
|
||||
unsortedSegmentProdLinearKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo());
|
||||
}
|
||||
else {
|
||||
output->assign(1);
|
||||
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||
Nd4jLong* inputTads = packX.specialShapeInfo();
|
||||
Nd4jLong* inputTadOffsets = packX.specialOffsets();
|
||||
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||
dims.x = input->sizeAt(0);
|
||||
segmentProdTadKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
|
||||
}
|
||||
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
void unsortedSegmentProdFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentProdFunctor_, (context, input, indices, numOfClasses, output),
|
||||
FLOAT_TYPES, INTEGER_TYPES);
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
BUILD_DOUBLE_TEMPLATE(template void unsortedSegmentProdFunctor_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), FLOAT_TYPES, INTEGER_TYPES);
|
||||
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
template <typename T, typename I>
|
||||
static __global__ void segmentProdBPLinearKernel(void* inputBuf, Nd4jLong* inputShape, void* forwardOutput,
|
||||
Nd4jLong* forwardShape, void* eps, Nd4jLong* epsShape, void* indicesBuf, Nd4jLong* indicesShape,
|
||||
void* outputBuf, Nd4jLong* outputShape) {
|
||||
__shared__ T* x;
|
||||
__shared__ T* gradIn;
|
||||
__shared__ T* gradOut;
|
||||
__shared__ I* y;
|
||||
__shared__ T* z;
|
||||
__shared__ Nd4jLong xLen, gradLen;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
xLen = shape::length(inputShape);
|
||||
x = reinterpret_cast<T*>(inputBuf);
|
||||
y = reinterpret_cast<I*>(indicesBuf);
|
||||
z = reinterpret_cast<T*>(outputBuf);
|
||||
gradIn = reinterpret_cast<T*>(forwardOutput);
|
||||
gradOut = reinterpret_cast<T*>(eps);
|
||||
gradLen = shape::length(epsShape);
|
||||
}
|
||||
|
||||
auto start = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
auto step = gridDim.x * blockDim.x;
|
||||
|
||||
for (auto e = start; e < xLen; e += step) {
|
||||
|
||||
auto zOffset = shape::getIndexOffset(e, outputShape, xLen);
|
||||
auto xOffset = shape::getIndexOffset(e, inputShape, xLen);
|
||||
auto yOffset = shape::getIndexOffset(e, indicesShape, xLen);
|
||||
auto classIndex = y[yOffset];
|
||||
auto gradOffsetI = shape::getIndexOffset(classIndex, forwardShape, gradLen);
|
||||
auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape, gradLen);
|
||||
|
||||
z[zOffset] = gradOut[gradOffsetO] * gradIn[gradOffsetI] / x[xOffset];
|
||||
}
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
template <typename T, typename I>
|
||||
static __global__ void segmentProdBPTadKernel(void* inputBuf, Nd4jLong* inputShape, void* forwardOutput,
|
||||
Nd4jLong* forwardShape, void* eps, Nd4jLong* epsShape, void* indicesBuf, Nd4jLong* indicesShape,
|
||||
void* outputBuf, Nd4jLong* outputShape,Nd4jLong* inputTad,
|
||||
Nd4jLong* inputOffsets, Nd4jLong* gradInTad, Nd4jLong* gradInOffsets,
|
||||
Nd4jLong* gradOutTad, Nd4jLong* gradOutOffsets, Nd4jLong* outTad,
|
||||
Nd4jLong* outOffsets) {
|
||||
__shared__ T* x;
|
||||
__shared__ T* gradIn;
|
||||
__shared__ T* gradOut;
|
||||
__shared__ I* y;
|
||||
__shared__ T* z;
|
||||
__shared__ Nd4jLong xLen, yLen, gradLen, currentLen;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
xLen = shape::length(inputShape);
|
||||
x = reinterpret_cast<T*>(inputBuf);
|
||||
y = reinterpret_cast<I*>(indicesBuf);
|
||||
z = reinterpret_cast<T*>(outputBuf);
|
||||
yLen = shape::length(indicesShape);
|
||||
gradOut = reinterpret_cast<T*>(eps);
|
||||
gradIn = reinterpret_cast<T*>(forwardOutput);
|
||||
gradLen = shape::length(epsShape);
|
||||
currentLen = shape::length(outTad);
|
||||
}
|
||||
|
||||
for (auto i = blockIdx.x; i < yLen; i += gridDim.x) {
|
||||
auto yIndex = shape::getIndexOffset(i, indicesShape, yLen);
|
||||
auto segment = y[yIndex];
|
||||
T* current = x + inputOffsets[i];
|
||||
T* currentOut = z + outOffsets[i];
|
||||
T* in = gradIn + gradInOffsets[segment];
|
||||
T* outGrad = gradOut + gradOutOffsets[segment];
|
||||
|
||||
for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) {
|
||||
currentOut[e] = outGrad[e] * in[e] / current[e];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
template <typename T, typename I>
|
||||
int segmentProdFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
|
||||
auto stream = context->getCudaStream();
|
||||
NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), DataTypeUtils::fromT<T>(), context);//->shapeInfo(), context);
|
||||
segmentProdFunctor_<T, I>(context, input, indices, &tempRes);
|
||||
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
|
||||
if (input->isVector()) {
|
||||
Nd4jLong loopSize = input->lengthOf();
|
||||
auto numOfClasses = gradOut->lengthOf(); //indices->e<Nd4jLong>(loop_size - 1);
|
||||
segmentProdBPLinearKernel<T,I><<<gradOut->lengthOf(), loopSize, 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
|
||||
tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
|
||||
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo());
|
||||
}
|
||||
else {
|
||||
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||
auto packGradIn = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.getShapeInfo(), dimensions);
|
||||
auto packGradOut = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->getShapeInfo(), dimensions);
|
||||
Nd4jLong* inputTads = packX.specialShapeInfo();
|
||||
Nd4jLong* inputTadOffsets = packX.specialOffsets();
|
||||
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||
Nd4jLong* gradInTads = packGradIn.specialShapeInfo();
|
||||
Nd4jLong* gradInTadOffsets = packGradIn.specialOffsets();
|
||||
Nd4jLong* gradOutTads = packGradOut.specialShapeInfo();
|
||||
Nd4jLong* gradOutTadOffsets = packGradOut.specialOffsets();
|
||||
|
||||
segmentProdBPTadKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
|
||||
tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
|
||||
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(),
|
||||
inputTads, inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, gradOutTadOffsets,
|
||||
outputTads, outputTadOffsets);
|
||||
}
|
||||
NDArray::registerSpecialUse({output}, {input, indices, gradOut});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
|
||||
int segmentProdFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
|
||||
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentProdFunctorBP_, (context, input,
|
||||
indices, gradOut, output), FLOAT_TYPES, INTEGER_TYPES);
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
|
||||
BUILD_DOUBLE_TEMPLATE(template int segmentProdFunctorBP_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output), FLOAT_TYPES, INTEGER_TYPES);
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
|
||||
template <typename T, typename I>
|
||||
static int unsortedSegmentProdFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
|
||||
auto stream = context->getCudaStream();
|
||||
NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), DataTypeUtils::fromT<T>(), context);//->shapeInfo(), context);
|
||||
unsortedSegmentProdFunctor_<T, I>(context, input, indices, numOfClasses, &tempRes);
|
||||
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
|
||||
if (input->isVector()) {
|
||||
Nd4jLong loopSize = input->lengthOf();
|
||||
auto numOfClasses = gradOut->lengthOf(); //indices->e<Nd4jLong>(loop_size - 1);
|
||||
segmentProdBPLinearKernel<T,I><<<gradOut->lengthOf(), loopSize, 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
|
||||
tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
|
||||
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo());
|
||||
}
|
||||
else {
|
||||
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||
auto packGradIn = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.getShapeInfo(), dimensions);
|
||||
auto packGradOut = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->getShapeInfo(), dimensions);
|
||||
Nd4jLong* inputTads = packX.specialShapeInfo();
|
||||
Nd4jLong* inputTadOffsets = packX.specialOffsets();
|
||||
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||
Nd4jLong* gradInTads = packGradIn.specialShapeInfo();
|
||||
Nd4jLong* gradInTadOffsets = packGradIn.specialOffsets();
|
||||
Nd4jLong* gradOutTads = packGradOut.specialShapeInfo();
|
||||
Nd4jLong* gradOutTadOffsets = packGradOut.specialOffsets();
|
||||
|
||||
segmentProdBPTadKernel<T,I><<<indices->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
|
||||
tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
|
||||
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(),
|
||||
inputTads, inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, gradOutTadOffsets,
|
||||
outputTads, outputTadOffsets);
|
||||
}
|
||||
NDArray::registerSpecialUse({output}, {input, indices, gradOut});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
int unsortedSegmentProdFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
|
||||
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentProdFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INTEGER_TYPES);
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
BUILD_DOUBLE_TEMPLATE(template int unsortedSegmentProdFunctorBP_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), FLOAT_TYPES, INTEGER_TYPES);
|
||||
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,280 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author GS <sgazeos@gmail.com>
|
||||
//
|
||||
|
||||
#include <ops/declarable/helpers/segment.h>
|
||||
#include <ops/declarable/helpers/segment_common.h>
|
||||
#include <NDArrayFactory.h>
|
||||
#include <helpers/ShapeUtils.h>
|
||||
#include <helpers/TAD.h>
|
||||
#include <exceptions/cuda_exception.h>
|
||||
#include <PointersManager.h>
|
||||
#include <ConstantTadHelper.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
template <typename T, typename I>
|
||||
static __global__ void unsortedSegmentSqrtNLinearKernel(void* input, Nd4jLong* inputShape, void* indices, Nd4jLong* indicesShape, int* starts, int* lengths, Nd4jLong numOfClasses, void* output, Nd4jLong* outputShape) {
|
||||
__shared__ T* val;
|
||||
__shared__ Nd4jLong xLen, zLen, segment, zIndex;
|
||||
__shared__ T* x;
|
||||
__shared__ T* z;
|
||||
__shared__ I* y; //int threadsPerSegment, start, finish;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
// threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses;
|
||||
segment = blockIdx.x;// / threadsPerSegment;
|
||||
x = reinterpret_cast<T*>(input);
|
||||
z = reinterpret_cast<T*>(output);
|
||||
y = reinterpret_cast<I*>(indices);
|
||||
// extern __shared__ unsigned char shmem[];
|
||||
// val = reinterpret_cast<T*>(shmem);
|
||||
xLen = shape::length(inputShape);
|
||||
zLen = shape::length(outputShape);
|
||||
|
||||
// if (segment < numOfClasses) {
|
||||
zIndex = shape::getIndexOffset(segment, outputShape, zLen);
|
||||
//start = starts[segment];
|
||||
//finish = start + lengths[segment];
|
||||
if (lengths[segment] > 0)
|
||||
z[zIndex] = x[shape::getIndexOffset(starts[segment], inputShape, xLen)] / nd4j::math::nd4j_sqrt<int, T>(lengths[segment]);
|
||||
else
|
||||
z[zIndex] = 0; //DataTypeUtils::max<T>();
|
||||
// val[segment] = z[zIndex];
|
||||
// }
|
||||
|
||||
}
|
||||
__syncthreads();
|
||||
if (lengths[segment] > 0)
|
||||
for (auto e = threadIdx.x + 1; e < xLen; e += blockDim.x) {
|
||||
auto xIndex = shape::getIndexOffset(e, inputShape, xLen);
|
||||
auto yIndex = shape::getIndexOffset(e, indicesShape, xLen);
|
||||
if (y[yIndex] == segment && e != starts[segment]) {
|
||||
nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex] / nd4j::math::nd4j_sqrt<int, T>(lengths[segment]));
|
||||
}
|
||||
}
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
// SegmentSqrtN kernel
|
||||
template <typename T, typename I>
|
||||
static __global__ void segmentSqrtNTadKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong* inputTads, Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets) {
|
||||
__shared__ T* val;
|
||||
__shared__ Nd4jLong len, segment, zIndex, total;
|
||||
__shared__ T* z;
|
||||
__shared__ int threadsPerSegment, start, finish;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
segment = indices[blockIdx.x]; // / threadsPerSegment;
|
||||
z = reinterpret_cast<T*>(outputBuf) + outputTadOffsets[segment];
|
||||
len = shape::length(inputTads);
|
||||
start = starts[segment];
|
||||
finish = start + lengths[segment];
|
||||
total = shape::sizeAt(inputShape, 0);
|
||||
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
auto idx = blockIdx.x;
|
||||
if (blockIdx.x <= total) {
|
||||
auto x = reinterpret_cast<T *>(inputBuf) + inputTadOffsets[idx];
|
||||
if (blockIdx.x == start) {
|
||||
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
||||
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
||||
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
||||
z[zIndex] = x[xIndex] / nd4j::math::nd4j_sqrt<int, T>(lengths[segment]);
|
||||
}
|
||||
}
|
||||
else {
|
||||
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
||||
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
||||
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
||||
nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex] / nd4j::math::nd4j_sqrt<int, T>(lengths[segment]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
template <typename T, typename I>
|
||||
static void unsortedSegmentSqrtNFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||
auto stream = context->getCudaStream();
|
||||
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
|
||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses});
|
||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
|
||||
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
|
||||
// classes.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), &row, &classes);
|
||||
classesRangesBegs.assign(indices->lengthOf());
|
||||
classesRangesLens.assign(0);
|
||||
dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32);
|
||||
// int* classesBuf = reinterpret_cast<int*>(classes.specialBuffer());
|
||||
fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens);
|
||||
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
|
||||
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
|
||||
|
||||
if (input->isVector()) {
|
||||
unsortedSegmentSqrtNLinearKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo());
|
||||
}
|
||||
else {
|
||||
output->assign(0);
|
||||
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||
Nd4jLong* inputTads = packX.specialShapeInfo();
|
||||
Nd4jLong* inputTadOffsets = packX.specialOffsets();
|
||||
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||
dims.x = input->sizeAt(0);
|
||||
segmentSqrtNTadKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
|
||||
}
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
void unsortedSegmentSqrtNFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentSqrtNFunctor_, (context, input, indices, numOfClasses, output),
|
||||
FLOAT_TYPES, INTEGER_TYPES);
|
||||
}
|
||||
BUILD_DOUBLE_TEMPLATE(template void unsortedSegmentSqrtNFunctor_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), FLOAT_TYPES, INTEGER_TYPES);
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
template <typename T, typename I>
|
||||
static __global__ void segmentSqrtNBPLinearKernel(void* inputBuf, Nd4jLong* inputShape, void* eps, Nd4jLong* epsShape, void* indicesBuf, Nd4jLong* indicesShape,
|
||||
int* lengths, void* outputBuf, Nd4jLong* outputShape) {
|
||||
__shared__ T* x;
|
||||
__shared__ T* gradIn;
|
||||
__shared__ T* gradOut;
|
||||
__shared__ I* y;
|
||||
__shared__ T* z;
|
||||
__shared__ Nd4jLong xLen, gradLen;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
xLen = shape::length(inputShape);
|
||||
x = reinterpret_cast<T*>(inputBuf);
|
||||
y = reinterpret_cast<I*>(indicesBuf);
|
||||
z = reinterpret_cast<T*>(outputBuf);
|
||||
gradOut = reinterpret_cast<T*>(eps);
|
||||
gradLen = shape::length(epsShape);
|
||||
}
|
||||
|
||||
auto start = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
auto step = gridDim.x * blockDim.x;
|
||||
|
||||
for (auto e = start; e < xLen; e += step) {
|
||||
|
||||
auto zOffset = shape::getIndexOffset(e, outputShape, xLen);
|
||||
auto xOffset = shape::getIndexOffset(e, inputShape, xLen);
|
||||
auto yOffset = shape::getIndexOffset(e, indicesShape, xLen);
|
||||
auto classIndex = y[yOffset];
|
||||
auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape, gradLen);
|
||||
|
||||
z[zOffset] = T(gradOut[gradOffsetO] / math::nd4j_sqrt<int, float>(lengths[classIndex]));
|
||||
}
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
|
||||
template <typename T, typename I>
|
||||
static __global__ void segmentSqrtNBPTadKernel(void* inputBuf, Nd4jLong* inputShape, void* eps, Nd4jLong* epsShape,
|
||||
void* indicesBuf, Nd4jLong* indicesShape, int* lengths, void* outputBuf, Nd4jLong* outputShape,Nd4jLong* inputTad,
|
||||
Nd4jLong* inputOffsets, Nd4jLong* gradOutTad, Nd4jLong* gradOutOffsets, Nd4jLong* outTad, Nd4jLong* outOffsets) {
|
||||
__shared__ T* x;
|
||||
__shared__ T* gradOut;
|
||||
__shared__ I* y;
|
||||
__shared__ T* z;
|
||||
__shared__ Nd4jLong xLen, yLen, gradLen, currentLen;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
xLen = shape::length(inputShape);
|
||||
x = reinterpret_cast<T*>(inputBuf);
|
||||
y = reinterpret_cast<I*>(indicesBuf);
|
||||
z = reinterpret_cast<T*>(outputBuf);
|
||||
yLen = shape::length(indicesShape);
|
||||
gradOut = reinterpret_cast<T*>(eps);
|
||||
gradLen = shape::length(epsShape);
|
||||
currentLen = shape::length(outTad);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (auto i = blockIdx.x; i < yLen; i += gridDim.x) {
|
||||
// auto yIndex = shape::getIndexOffset(i, indicesShape, yLen);
|
||||
auto segment = y[i]; //yIndex];
|
||||
T* currentOut = z + outOffsets[i];
|
||||
T* outGrad = gradOut + gradOutOffsets[segment];
|
||||
|
||||
for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) {
|
||||
auto zIndex = shape::getIndexOffset(e, outTad, currentLen);
|
||||
auto gradIndex = shape::getIndexOffset(e, gradOutTad, gradLen);
|
||||
if (lengths[segment] > 0)
|
||||
currentOut[zIndex] = T(outGrad[gradIndex] / math::nd4j_sqrt<int, float>(lengths[segment]));
|
||||
}
|
||||
}
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
|
||||
template <typename T, typename I>
|
||||
static int unsortedSegmentSqrtNFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
|
||||
auto stream = context->getCudaStream();
|
||||
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
|
||||
auto numClasses = indices->e<int>(indices->lengthOf() - 1) + 1;
|
||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses});
|
||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses});
|
||||
|
||||
classesRangesBegs.assign(indices->lengthOf());
|
||||
classesRangesLens.assign(0);
|
||||
dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32);
|
||||
fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens);
|
||||
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
|
||||
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
|
||||
|
||||
if (input->isVector()) {
|
||||
Nd4jLong loop_size = input->lengthOf();
|
||||
auto numOfClasses = gradOut->lengthOf(); //indices->e<Nd4jLong>(loop_size - 1);
|
||||
segmentSqrtNBPLinearKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(),
|
||||
input->specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
|
||||
indices->specialBuffer(), indices->specialShapeInfo(), lengths, output->specialBuffer(), output->specialShapeInfo());
|
||||
}
|
||||
else {
|
||||
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||
// auto packGradIn = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.getShapeInfo(), dimensions);
|
||||
auto packGradOut = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->getShapeInfo(), dimensions);
|
||||
Nd4jLong* inputTads = packX.specialShapeInfo();
|
||||
Nd4jLong* inputTadOffsets = packX.specialOffsets();
|
||||
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||
Nd4jLong* gradOutTads = packGradOut.specialShapeInfo();
|
||||
Nd4jLong* gradOutTadOffsets = packGradOut.specialOffsets();
|
||||
|
||||
segmentSqrtNBPTadKernel<T,I><<<indices->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
|
||||
gradOut->specialBuffer(), gradOut->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), lengths,
|
||||
output->specialBuffer(), output->specialShapeInfo(), inputTads, inputTadOffsets, gradOutTads, gradOutTadOffsets,
|
||||
outputTads, outputTadOffsets);
|
||||
}
|
||||
NDArray::registerSpecialUse({output}, {input, indices, gradOut});
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
int unsortedSegmentSqrtNFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
|
||||
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentSqrtNFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INTEGER_TYPES);
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
BUILD_DOUBLE_TEMPLATE(template int unsortedSegmentSqrtNFunctorBP_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), FLOAT_TYPES, INTEGER_TYPES);
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,393 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author GS <sgazeos@gmail.com>
|
||||
//
|
||||
|
||||
#include <ops/declarable/helpers/segment.h>
|
||||
#include <ops/declarable/helpers/segment_common.h>
|
||||
#include <NDArrayFactory.h>
|
||||
#include <helpers/ShapeUtils.h>
|
||||
#include <helpers/TAD.h>
|
||||
#include <exceptions/cuda_exception.h>
|
||||
#include <PointersManager.h>
|
||||
#include <ConstantTadHelper.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
// Segment ops linear kernels
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
template<typename T, typename I>
|
||||
static __global__ void
|
||||
segmentSumLinearKernel(void *input, Nd4jLong *inputShape, int *starts, int *lengths, Nd4jLong numOfClasses,
|
||||
void *output, Nd4jLong *outputShape) {
|
||||
__shared__
|
||||
T *val;
|
||||
__shared__
|
||||
Nd4jLong xLen, zLen, segment, zIndex;
|
||||
__shared__
|
||||
T *x;
|
||||
__shared__
|
||||
T *z;
|
||||
__shared__ int threadsPerSegment, start, finish;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses;
|
||||
segment = blockIdx.x / threadsPerSegment;
|
||||
x = reinterpret_cast<T *>(input);
|
||||
z = reinterpret_cast<T *>(output);
|
||||
|
||||
xLen = shape::length(inputShape);
|
||||
zLen = shape::length(outputShape);
|
||||
|
||||
|
||||
if (segment < numOfClasses) {
|
||||
zIndex = shape::getIndexOffset(segment, outputShape, zLen);
|
||||
start = starts[segment];
|
||||
finish = start + lengths[segment];
|
||||
//val[segment] = ;
|
||||
z[zIndex] = x[shape::getIndexOffset(start, inputShape, xLen)];
|
||||
}
|
||||
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (auto e = start + threadIdx.x + 1; e < finish; e += blockDim.x) {
|
||||
auto xIndex = shape::getIndexOffset(e, inputShape, xLen);
|
||||
nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex]);
|
||||
}
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
|
||||
template<typename T, typename I>
|
||||
static __global__ void
|
||||
unsortedSegmentSumLinearKernel(void *input, Nd4jLong *inputShape, void *indices, Nd4jLong *indicesShape,
|
||||
int *starts, int *lengths, Nd4jLong numOfClasses, void *output,
|
||||
Nd4jLong *outputShape) {
|
||||
__shared__
|
||||
T *val;
|
||||
__shared__
|
||||
Nd4jLong xLen, zLen, segment, zIndex;
|
||||
__shared__
|
||||
T *x;
|
||||
__shared__
|
||||
T *z;
|
||||
__shared__
|
||||
I *y; //int threadsPerSegment, start, finish;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
segment = blockIdx.x;
|
||||
x = reinterpret_cast<T *>(input);
|
||||
z = reinterpret_cast<T *>(output);
|
||||
y = reinterpret_cast<I *>(indices);
|
||||
xLen = shape::length(inputShape);
|
||||
zLen = shape::length(outputShape);
|
||||
|
||||
zIndex = shape::getIndexOffset(segment, outputShape, zLen);
|
||||
if (lengths[segment] > 0)
|
||||
z[zIndex] = x[shape::getIndexOffset(starts[segment], inputShape, xLen)];
|
||||
else
|
||||
z[zIndex] = 0; //DataTypeUtils::max<T>();
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (lengths[segment] > 0)
|
||||
for (auto e = threadIdx.x; e < xLen; e += blockDim.x) {
|
||||
auto xIndex = shape::getIndexOffset(e, inputShape, xLen);
|
||||
auto yIndex = shape::getIndexOffset(e, indicesShape, xLen);
|
||||
if (y[yIndex] == segment && e != starts[segment]) {
|
||||
nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex]);
|
||||
}
|
||||
}
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
// SegmentSum kernel
|
||||
template <typename T, typename I>
|
||||
static __global__ void segmentSumTadKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong* inputTads, Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets) {
|
||||
__shared__ T* val;
|
||||
__shared__ Nd4jLong len, segment, zIndex, total;
|
||||
__shared__ T* z;
|
||||
__shared__ int threadsPerSegment, start, finish;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
segment = indices[blockIdx.x]; // / threadsPerSegment;
|
||||
z = reinterpret_cast<T*>(outputBuf) + outputTadOffsets[segment];
|
||||
len = shape::length(inputTads);
|
||||
start = starts[segment];
|
||||
finish = start + lengths[segment];
|
||||
total = shape::sizeAt(inputShape, 0);
|
||||
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
auto idx = blockIdx.x;
|
||||
if (blockIdx.x <= total) {
|
||||
auto x = reinterpret_cast<T *>(inputBuf) + inputTadOffsets[idx];
|
||||
if (blockIdx.x == start) {
|
||||
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
||||
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
||||
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
||||
z[zIndex] = x[xIndex];
|
||||
}
|
||||
}
|
||||
else {
|
||||
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
||||
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
||||
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
||||
if (lengths[segment])
|
||||
nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
|
||||
template <typename T, typename I>
|
||||
static void segmentSumFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) {
|
||||
auto stream = context->getCudaStream();
|
||||
Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
|
||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses});
|
||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses});
|
||||
|
||||
classesRangesBegs.assign(indices->lengthOf());
|
||||
classesRangesLens.assign(0);
|
||||
|
||||
dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32);
|
||||
fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens);
|
||||
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
|
||||
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
|
||||
|
||||
if (input->isVector()) {
|
||||
segmentSumLinearKernel<T,I><<<numClasses, input->lengthOf(), numClasses * 32 + 32, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo());
|
||||
}
|
||||
else {
|
||||
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||
Nd4jLong* inputTads = packX.specialShapeInfo();
|
||||
Nd4jLong* inputTadOffsets = packX.specialOffsets();
|
||||
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||
segmentSumTadKernel<T,I><<<input->sizeAt(0), 512, 2048, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
|
||||
}
|
||||
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
void segmentSumFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) {
|
||||
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentSumFunctor_, (context, input, indices, output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||
}
|
||||
BUILD_DOUBLE_TEMPLATE(template void segmentSumFunctor_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
template <typename T, typename I>
|
||||
static void unsortedSegmentSumFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||
auto stream = context->getCudaStream();
|
||||
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
|
||||
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses});
|
||||
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
|
||||
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
|
||||
// classes.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), &row, &classes);
|
||||
classesRangesBegs.assign(indices->lengthOf());
|
||||
classesRangesLens.assign(0);
|
||||
dim3 dims(numOfClasses, indices->lengthOf(), (numOfClasses + 1) * 64);
|
||||
// int* classesBuf = reinterpret_cast<int*>(classes.specialBuffer());
|
||||
fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens);
|
||||
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
|
||||
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
|
||||
|
||||
if (input->isVector()) {
|
||||
unsortedSegmentSumLinearKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo());
|
||||
}
|
||||
else {
|
||||
output->assign(0);
|
||||
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||
Nd4jLong* inputTads = packX.specialShapeInfo();
|
||||
Nd4jLong* inputTadOffsets = packX.specialOffsets();
|
||||
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||
dims.x = input->sizeAt(0);
|
||||
segmentSumTadKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
|
||||
}
|
||||
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
void unsortedSegmentSumFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentSumFunctor_, (context, input, indices, numOfClasses, output),
|
||||
NUMERIC_TYPES, INTEGER_TYPES);
|
||||
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
BUILD_DOUBLE_TEMPLATE(template void unsortedSegmentSumFunctor_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
// Backpropagate ops
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
// Sorted sum backpropagate
|
||||
template <typename T, typename I>
|
||||
static __global__ void segmentSumBPLinearKernel(void* inputBuf, Nd4jLong* inputShape, void* eps, Nd4jLong* epsShape,
|
||||
void* indicesBuf, Nd4jLong* indicesShape, void* outputBuf, Nd4jLong* outputShape) {
|
||||
__shared__ T* x;
|
||||
__shared__ T* gradIn;
|
||||
__shared__ T* gradOut;
|
||||
__shared__ I* y;
|
||||
__shared__ T* z;
|
||||
__shared__ Nd4jLong xLen, gradLen;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
xLen = shape::length(inputShape);
|
||||
x = reinterpret_cast<T*>(inputBuf);
|
||||
y = reinterpret_cast<I*>(indicesBuf);
|
||||
z = reinterpret_cast<T*>(outputBuf);
|
||||
gradOut = reinterpret_cast<T*>(eps);
|
||||
gradLen = shape::length(epsShape);
|
||||
}
|
||||
|
||||
auto start = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
auto step = gridDim.x * blockDim.x;
|
||||
|
||||
for (auto e = start; e < xLen; e += step) {
|
||||
|
||||
auto zOffset = shape::getIndexOffset(e, outputShape, xLen);
|
||||
auto xOffset = shape::getIndexOffset(e, inputShape, xLen);
|
||||
auto yOffset = shape::getIndexOffset(e, indicesShape, xLen);
|
||||
auto classIndex = y[yOffset];
|
||||
auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape, gradLen);
|
||||
|
||||
z[zOffset] = gradOut[gradOffsetO];
|
||||
}
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
template <typename T, typename I>
|
||||
static __global__ void segmentSumBPTadKernel(void* inputBuf, Nd4jLong* inputShape, void* eps, Nd4jLong* epsShape,
|
||||
void* indicesBuf, Nd4jLong* indicesShape, void* outputBuf, Nd4jLong* outputShape, Nd4jLong* inputTad,
|
||||
Nd4jLong* inputOffsets, Nd4jLong* gradOutTad, Nd4jLong* gradOutOffsets, Nd4jLong* outTad, Nd4jLong* outOffsets) {
|
||||
__shared__ T* x;
|
||||
__shared__ T* gradOut;
|
||||
__shared__ I* y;
|
||||
__shared__ T* z;
|
||||
__shared__ Nd4jLong xLen, yLen, gradLen, currentLen;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
xLen = shape::length(inputShape);
|
||||
x = reinterpret_cast<T*>(inputBuf);
|
||||
y = reinterpret_cast<I*>(indicesBuf);
|
||||
z = reinterpret_cast<T*>(outputBuf);
|
||||
yLen = shape::length(indicesShape);
|
||||
gradOut = reinterpret_cast<T*>(eps);
|
||||
gradLen = shape::length(epsShape);
|
||||
currentLen = shape::length(outTad);
|
||||
}
|
||||
|
||||
for (auto i = blockIdx.x; i < yLen; i += gridDim.x) {
|
||||
auto yIndex = shape::getIndexOffset(i, indicesShape, yLen);
|
||||
auto segment = y[yIndex];
|
||||
T* currentOut = z + outOffsets[i];
|
||||
T* outGrad = gradOut + gradOutOffsets[segment];
|
||||
|
||||
for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) {
|
||||
currentOut[e] = outGrad[e];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
template <typename T, typename I>
|
||||
int segmentSumFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
|
||||
auto stream = context->getCudaStream();
|
||||
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
|
||||
if (input->isVector()) {
|
||||
Nd4jLong loop_size = input->lengthOf();
|
||||
auto numOfClasses = gradOut->lengthOf(); //indices->e<Nd4jLong>(loop_size - 1);
|
||||
segmentSumBPLinearKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(),
|
||||
input->specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
|
||||
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo());
|
||||
}
|
||||
else {
|
||||
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||
auto packGradOut = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->getShapeInfo(), dimensions);
|
||||
Nd4jLong* inputTads = packX.specialShapeInfo();
|
||||
Nd4jLong* inputTadOffsets = packX.specialOffsets();
|
||||
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||
Nd4jLong* gradOutTads = packGradOut.specialShapeInfo();
|
||||
Nd4jLong* gradOutTadOffsets = packGradOut.specialOffsets();
|
||||
|
||||
segmentSumBPTadKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
|
||||
gradOut->specialBuffer(), gradOut->specialShapeInfo(),
|
||||
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(),
|
||||
inputTads, inputTadOffsets, gradOutTads, gradOutTadOffsets,
|
||||
outputTads, outputTadOffsets);
|
||||
}
|
||||
NDArray::registerSpecialUse({output}, {input, indices, gradOut});
|
||||
return Status::OK();
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
|
||||
int segmentSumFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
|
||||
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentSumFunctorBP_, (context, input,
|
||||
indices, gradOut, output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||
}
|
||||
BUILD_DOUBLE_TEMPLATE(template int segmentSumFunctorBP_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
|
||||
template <typename T, typename I>
|
||||
static int unsortedSegmentSumFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
|
||||
auto stream = context->getCudaStream();
|
||||
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
|
||||
if (input->isVector()) {
|
||||
Nd4jLong loop_size = input->lengthOf();
|
||||
auto numOfClasses = gradOut->lengthOf(); //indices->e<Nd4jLong>(loop_size - 1);
|
||||
segmentSumBPLinearKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(),
|
||||
input->specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
|
||||
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo());
|
||||
}
|
||||
else {
|
||||
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||
auto packGradOut = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->getShapeInfo(), dimensions);
|
||||
Nd4jLong* inputTads = packX.specialShapeInfo();
|
||||
Nd4jLong* inputTadOffsets = packX.specialOffsets();
|
||||
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||
Nd4jLong* gradOutTads = packGradOut.specialShapeInfo();
|
||||
Nd4jLong* gradOutTadOffsets = packGradOut.specialOffsets();
|
||||
|
||||
segmentSumBPTadKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
|
||||
gradOut->specialBuffer(), gradOut->specialShapeInfo(),
|
||||
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(),
|
||||
inputTads, inputTadOffsets, gradOutTads, gradOutTadOffsets,
|
||||
outputTads, outputTadOffsets);
|
||||
}
|
||||
NDArray::registerSpecialUse({output}, {input, indices, gradOut});
|
||||
return Status::OK();
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
int unsortedSegmentSumFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
|
||||
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentSumFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||
}
|
||||
// -------------------------------------------------------------------------------------------------------------- //
|
||||
BUILD_DOUBLE_TEMPLATE(template int unsortedSegmentSumFunctorBP_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -24,16 +24,40 @@ namespace nd4j {
|
|||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
template <typename T>
|
||||
static void sequenceMask_(NDArray* input, NDArray* output, int maxIndex) {
|
||||
//
|
||||
template <typename I, typename B>
|
||||
static __global__ void sequenceMaskKernel(void* inputBuf, Nd4jLong* inputShape, void* outputBuf, Nd4jLong* outputShape, int maxIndex) {
|
||||
|
||||
__shared__ I* input;
|
||||
__shared__ B* output;
|
||||
__shared__ Nd4jLong inputLen, outputLen;
|
||||
if (threadIdx.x == 0) {
|
||||
input = reinterpret_cast<I*>(inputBuf);
|
||||
output = reinterpret_cast<B*>(outputBuf);
|
||||
inputLen = shape::length(inputShape);
|
||||
outputLen = shape::length(outputShape);
|
||||
}
|
||||
|
||||
for (auto i = blockIdx.x; i < maxIndex; i += gridDim.x)
|
||||
for(auto k = threadIdx.x; k < inputLen; k += blockDim.x)
|
||||
if (i < input[shape::getIndexOffset(k, inputShape, inputLen)])
|
||||
output[shape::getIndexOffset(k * maxIndex + i, outputShape, outputLen)] = B(true);
|
||||
|
||||
}
|
||||
|
||||
template <typename I, typename B>
|
||||
static void sequenceMask_(LaunchContext* context, NDArray* input, NDArray* output, int maxIndex) {
|
||||
dim3 launchDims(maxIndex, input->lengthOf(), 128);
|
||||
NDArray::prepareSpecialUse({output}, {input});
|
||||
auto stream = context->getCudaStream();
|
||||
sequenceMaskKernel<I, B><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), maxIndex);
|
||||
NDArray::registerSpecialUse({output}, {input});
|
||||
}
|
||||
|
||||
void sequenceMask(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int maxIndex) {
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), sequenceMask_, (input, output, maxIndex), LIBND4J_TYPES);
|
||||
BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sequenceMask_, (context, input, output, maxIndex), INTEGER_TYPES, BOOL_TYPES);
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void sequenceMask_, (NDArray* input, NDArray* output, int maxIndex), LIBND4J_TYPES);
|
||||
BUILD_DOUBLE_TEMPLATE(template void sequenceMask_, (nd4j::LaunchContext* context, NDArray* input, NDArray* output, int maxIndex), INTEGER_TYPES, BOOL_TYPES);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -456,27 +456,246 @@ void tileBP(nd4j::LaunchContext * context, const NDArray& gradO /*input*/, NDArr
|
|||
manager.synchronize();
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
__global__ static void scatterUpdateCuda(const int opCode, const int numOfInd,
|
||||
void* vx, const Nd4jLong *xShapeInfo, const Nd4jLong *xOffsets,
|
||||
void* vy, const Nd4jLong *yShapeInfo, const Nd4jLong *yOffsets,
|
||||
const int* indexes) {
|
||||
|
||||
__shared__ T *x, *y;
|
||||
__shared__ Nd4jLong arrLenX, arrLenY;
|
||||
|
||||
for (int e = 0; e < numOfInd; e++ ) {
|
||||
|
||||
const auto xIndex = indexes[e];
|
||||
const bool isOwner = xIndex < gridDim.x ? blockIdx.x == xIndex : blockIdx.x == xIndex % gridDim.x;
|
||||
|
||||
if (!isOwner)
|
||||
continue;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
x = reinterpret_cast<T*>(vx) + xOffsets[xIndex];
|
||||
y = reinterpret_cast<T*>(vy) + yOffsets[e];
|
||||
arrLenX = shape::length(xShapeInfo);
|
||||
arrLenY = shape::length(yShapeInfo);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (arrLenX != arrLenY)
|
||||
return;
|
||||
|
||||
for (Nd4jLong i = threadIdx.x; i < arrLenX; i += blockDim.x) {
|
||||
|
||||
const auto xOffset = shape::getIndexOffset(i, xShapeInfo, arrLenX);
|
||||
const auto yOffset = shape::getIndexOffset(i, yShapeInfo, arrLenY);
|
||||
|
||||
switch (opCode) {
|
||||
case 0:
|
||||
x[xOffset] += y[yOffset];
|
||||
break;
|
||||
case 1:
|
||||
x[xOffset] -= y[yOffset];
|
||||
break;
|
||||
case 2:
|
||||
x[xOffset] *= y[yOffset];
|
||||
break;
|
||||
case 3:
|
||||
x[xOffset] /= y[yOffset];
|
||||
break;
|
||||
case 4:
|
||||
x[xOffset] = y[yOffset] - x[xOffset];
|
||||
break;
|
||||
case 5:
|
||||
x[xOffset] = y[yOffset] / x[xOffset];
|
||||
break;
|
||||
case 6:
|
||||
x[xOffset] = y[yOffset];
|
||||
break;
|
||||
default:
|
||||
continue;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__host__ static void scatterUpdateCudaLauncher(const cudaStream_t* stream, const int opCode, const int numOfInd, void* vx, const Nd4jLong *xShapeInfo, const Nd4jLong *xOffsets, void* vy, const Nd4jLong *yShapeInfo, const Nd4jLong *yOffsets, const int* indexes) {
|
||||
|
||||
scatterUpdateCuda<T><<<512, 256, MAX_NUM_THREADS, *stream>>>(opCode, numOfInd, vx, xShapeInfo, xOffsets, vy, yShapeInfo, yOffsets, indexes);
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void scatterUpdate(nd4j::LaunchContext* context, NDArray& input, NDArray& updates, const std::vector<int>* intArgs) {
|
||||
|
||||
const int opCode = (*intArgs)[0];
|
||||
const int numOfDims = (*intArgs)[1];
|
||||
const int numOfInd = (*intArgs)[2 + numOfDims];
|
||||
|
||||
std::vector<int> tadDimensions(numOfDims);
|
||||
for (int e = 2; e < 2 + numOfDims; e++)
|
||||
tadDimensions[e-2] = (*intArgs)[e];
|
||||
|
||||
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), tadDimensions);
|
||||
auto packY = ConstantTadHelper::getInstance()->tadForDimensions(updates.getShapeInfo(), tadDimensions);
|
||||
|
||||
NDArray indices(const_cast<int*>(intArgs->data()) + numOfDims + 3, 'c', {numOfInd}, nd4j::DataType::INT32, context);
|
||||
|
||||
PointersManager manager(context, "scatterUpdate");
|
||||
|
||||
NDArray::prepareSpecialUse({&input}, {&input, &updates, &indices});
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), scatterUpdateCudaLauncher, (context->getCudaStream(), opCode, numOfInd, input.specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), updates.specialBuffer(), packY.platformShapeInfo(), packY.platformOffsets(), reinterpret_cast<int*>(indices.getSpecialBuffer())), LIBND4J_TYPES);
|
||||
NDArray::registerSpecialUse({&input}, {&input, &updates, &indices});
|
||||
|
||||
manager.synchronize();
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void randomShuffle_(nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::random::RandomBuffer& rng, const bool isInplace) {
|
||||
static __global__ void swapShuffleKernel(T* input, Nd4jLong* shape, Nd4jLong firstDim, Nd4jLong len, nd4j::graph::RandomGenerator* rng) {
|
||||
auto tid = blockIdx.x * blockDim.x;
|
||||
auto step = blockDim.x * gridDim.x;
|
||||
|
||||
for (int i = firstDim - 1 - tid - threadIdx.x; i > 0; i -= step) {
|
||||
int r = rng->relativeInt(i) % i;
|
||||
if (i != r) {
|
||||
T e0 = input[shape::getIndexOffset(i, shape, len)];
|
||||
T e1 = input[shape::getIndexOffset(r, shape, len)];
|
||||
//math::nd4j_swap<T>(input(i), input(r));
|
||||
input[shape::getIndexOffset(i, shape, len)] = e1;
|
||||
input[shape::getIndexOffset(r, shape, len)] = e0;
|
||||
}
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
static __global__ void fillShuffleKernel(T* input, Nd4jLong* inputShape, T* output, Nd4jLong* outputShape, Nd4jLong firstDim, Nd4jLong len, int* indices, nd4j::graph::RandomGenerator* rng) {
|
||||
|
||||
// PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->tadThreshold())
|
||||
auto tid = blockIdx.x * blockDim.x;
|
||||
auto step = blockDim.x * gridDim.x;
|
||||
|
||||
for(int i = firstDim - 1 - tid - threadIdx.x; i > 0; i -= step) {
|
||||
int r = rng->relativeInt(i) % i;
|
||||
output[shape::getIndexOffset(i, outputShape, len)] = input[shape::getIndexOffset(indices[r], inputShape, len)];
|
||||
if(i != r) {
|
||||
output[shape::getIndexOffset(r, outputShape, len)] = input[shape::getIndexOffset(indices[i], inputShape, len)];
|
||||
// output.p(r, input.e<T>(indices[i]));
|
||||
// math::nd4j_swap<int>(indices[i], indices[r]);
|
||||
atomicExch(&indices[i], indices[r]);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
void randomShuffle_(nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::graph::RandomGenerator& rng, const bool isInplace) {
|
||||
|
||||
// check edge cases first
|
||||
int temp;
|
||||
const int firstDim = input.sizeAt(0);
|
||||
auto stream = context->getCudaStream();
|
||||
NDArray::prepareSpecialUse({&output}, {&input});
|
||||
if(input.lengthOf() == 1 || firstDim == 1) {
|
||||
if(!isInplace)
|
||||
output.assign(input);
|
||||
}
|
||||
else if (input.isVector() || shape::isLikeVector(input.getShapeInfo(), temp)) {
|
||||
|
||||
// apply Fisher-Yates shuffle
|
||||
nd4j::graph::RandomGenerator* dRandom = nullptr;
|
||||
cudaMalloc(&dRandom, sizeof(nd4j::graph::RandomGenerator));
|
||||
cudaMemcpy(dRandom, &rng, sizeof(nd4j::graph::RandomGenerator), cudaMemcpyHostToDevice);
|
||||
T* inputBuf = reinterpret_cast<T*>(input.specialBuffer());
|
||||
if(isInplace) {
|
||||
swapShuffleKernel<T><<<128, 256, 1024, *stream>>>(inputBuf, input.specialShapeInfo(), firstDim, input.lengthOf(), dRandom);
|
||||
}
|
||||
else {
|
||||
std::vector<int> indices(firstDim);
|
||||
std::iota(indices.begin(), indices.end(), 0);
|
||||
cudaMemcpy(output.specialBuffer(), input.specialBuffer(), sizeof(T), cudaMemcpyDeviceToDevice);
|
||||
//output.p<T>(Nd4jLong(0), input.e<T>(0));
|
||||
PointersManager pointersManager(context, "helper::randomShuffle_");
|
||||
int* indicesDev = reinterpret_cast<int*>(pointersManager.replicatePointer(indices.data(), indices.size() * sizeof(int)));
|
||||
T* outputBuf = reinterpret_cast<T*>(output.specialBuffer());
|
||||
fillShuffleKernel<T><<<128, 256, 1024, *stream>>>(inputBuf, input.specialShapeInfo(), outputBuf, output.specialShapeInfo(), firstDim, input.lengthOf(), indicesDev, dRandom);
|
||||
pointersManager.synchronize();
|
||||
}
|
||||
// rng.rewindH(firstDim - 1);
|
||||
cudaFree(dRandom);
|
||||
}
|
||||
else {
|
||||
|
||||
// evaluate sub-arrays list of input array through all dimensions excluding first one
|
||||
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input.rankOf(), {0});
|
||||
auto subArrsListIn = input.allTensorsAlongDimension(dimensions);
|
||||
|
||||
// apply Fisher-Yates shuffle
|
||||
if(isInplace) {
|
||||
PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->elementwiseThreshold())
|
||||
for(int i = firstDim - 1; i > 0; --i) {
|
||||
int r = rng.relativeInt(i) % i;
|
||||
|
||||
if(i != r)
|
||||
subArrsListIn->at(i)->swapUnsafe(*subArrsListIn->at(r));
|
||||
}
|
||||
}
|
||||
else {
|
||||
// evaluate sub-arrays list of output array through all dimensions excluding first one
|
||||
auto subArrsListOut = output.allTensorsAlongDimension(dimensions);
|
||||
std::vector<int> indices(firstDim);
|
||||
std::iota(indices.begin(), indices.end(), 0);
|
||||
bool isZeroShuffled = false;
|
||||
PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->tadThreshold())
|
||||
for(int i = firstDim - 1; i > 0; --i) {
|
||||
int r = rng.relativeInt(i) % i;
|
||||
subArrsListOut->at(i)->assign(subArrsListIn->at(indices[r]));
|
||||
if(r == 0)
|
||||
isZeroShuffled = true;
|
||||
|
||||
if(i != r) {
|
||||
subArrsListOut->at(r)->assign(subArrsListIn->at(indices[i]));
|
||||
math::nd4j_swap<int>(indices[i], indices[r]);
|
||||
}
|
||||
}
|
||||
if(!isZeroShuffled)
|
||||
subArrsListOut->at(0)->assign(subArrsListIn->at(0));
|
||||
delete subArrsListOut;
|
||||
}
|
||||
rng.rewindH(firstDim-1);
|
||||
delete subArrsListIn;
|
||||
}
|
||||
NDArray::registerSpecialUse({&output}, {&input});
|
||||
|
||||
}
|
||||
|
||||
void randomShuffle(nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::random::RandomBuffer& rng, const bool isInplace) {
|
||||
void randomShuffle(nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::graph::RandomGenerator& rng, const bool isInplace) {
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (context, input, output, rng, isInplace), LIBND4J_TYPES);
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void randomShuffle_, (nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::random::RandomBuffer& rng, const bool isInplace), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template void randomShuffle_, (nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::graph::RandomGenerator& rng, const bool isInplace), LIBND4J_TYPES);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
|
@ -496,11 +715,6 @@ void tileBP(nd4j::LaunchContext * context, const NDArray& gradO /*input*/, NDArr
|
|||
void eye(nd4j::LaunchContext * context, NDArray& output) {
|
||||
|
||||
output.setIdentity();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void scatterUpdate(nd4j::LaunchContext * context, NDArray& operand, NDArray& updates, const std::vector<int>* intArgs) {
|
||||
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -33,27 +33,7 @@ namespace helpers {
|
|||
|
||||
void gruTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h);
|
||||
|
||||
void gruCellBP(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh, const NDArray* dLdWx0,
|
||||
const NDArray* dLdWh0, const NDArray* dLdb0, NDArray* dLdx, NDArray* dLdh0, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb);
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
FORCEINLINE NDArray sigmoid(const NDArray& arr) {
|
||||
return (const_cast<NDArray&>(arr)).transform(transform::Sigmoid);
|
||||
}
|
||||
|
||||
FORCEINLINE void sigmoidInplace(const NDArray& arr) {
|
||||
(const_cast<NDArray&>(arr)).applyTransform(transform::Sigmoid);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
FORCEINLINE NDArray tanh(const NDArray& arr) {
|
||||
return (const_cast<NDArray&>(arr)).transform(transform::Tanh);
|
||||
}
|
||||
|
||||
FORCEINLINE void tanhInplace(const NDArray& arr) {
|
||||
(const_cast<NDArray&>(arr)).applyTransform(transform::Tanh);
|
||||
}
|
||||
void gruCellBP(nd4j::LaunchContext* context, const NDArray* x, const NDArray* hLast, const NDArray* W, const NDArray* Wc, const NDArray* b, const NDArray* bc, const NDArray* dLdr, const NDArray* dLdu, const NDArray* dLdc, const NDArray* dLdh, NDArray* dLdx, NDArray* dLdhLast, NDArray* dLdW, NDArray* dLdWc, NDArray* dLdb, NDArray* dLdbc);
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -27,37 +27,37 @@ namespace nd4j {
|
|||
namespace ops {
|
||||
namespace helpers {
|
||||
template <typename T>
|
||||
FORCEINLINE Nd4jLong longBytes(T value);
|
||||
FORCEINLINE _CUDA_HD Nd4jLong longBytes(T value);
|
||||
|
||||
template <>
|
||||
FORCEINLINE Nd4jLong longBytes(float value) {
|
||||
FORCEINLINE _CUDA_HD Nd4jLong longBytes(float value) {
|
||||
int intie = *(int *)&value;
|
||||
return static_cast<Nd4jLong>(intie);
|
||||
}
|
||||
|
||||
template <>
|
||||
FORCEINLINE Nd4jLong longBytes(double value) {
|
||||
FORCEINLINE _CUDA_HD Nd4jLong longBytes(double value) {
|
||||
Nd4jLong longie = *(Nd4jLong *)&value;
|
||||
return longie;
|
||||
}
|
||||
|
||||
template <>
|
||||
FORCEINLINE Nd4jLong longBytes(float16 value) {
|
||||
FORCEINLINE _CUDA_HD Nd4jLong longBytes(float16 value) {
|
||||
return longBytes<float>((float) value);
|
||||
}
|
||||
|
||||
template <>
|
||||
FORCEINLINE Nd4jLong longBytes(Nd4jLong value) {
|
||||
FORCEINLINE _CUDA_HD Nd4jLong longBytes(Nd4jLong value) {
|
||||
return value;
|
||||
}
|
||||
|
||||
template <>
|
||||
FORCEINLINE Nd4jLong longBytes(bfloat16 value) {
|
||||
FORCEINLINE _CUDA_HD Nd4jLong longBytes(bfloat16 value) {
|
||||
return longBytes<float>((float) value);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
FORCEINLINE Nd4jLong longBytes(T value) {
|
||||
FORCEINLINE _CUDA_HD Nd4jLong longBytes(T value) {
|
||||
return longBytes<Nd4jLong>((Nd4jLong) value);
|
||||
}
|
||||
|
||||
|
|
|
@ -30,37 +30,30 @@ namespace helpers {
|
|||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
static FORCEINLINE NDArray activation(const NDArray& arr) {
|
||||
void rnnCell(nd4j::LaunchContext * context, const NDArray* xt, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* hPrev, NDArray* ht) {
|
||||
|
||||
return (const_cast<NDArray&>(arr)).transform(transform::Tanh);
|
||||
// xt input [bS x iS]
|
||||
// Wx input-to-hidden weights, [iS x nU]
|
||||
// Wh hidden-to-hidden weights, [nU x nU]
|
||||
// b biases, [2*nU]: {0, nU} are input-to-hidden biases and {nU, 2*nU} are hidden-to-hidden biases
|
||||
// hPrev previous cell output [bS x nU], that is at previous time step t-1, in case of projection=false -> nU=nU!!!
|
||||
|
||||
const int nU = hPrev->sizeAt(1);
|
||||
|
||||
// ht is current cell output [bS x nU], that is at current time step t
|
||||
ht->assign(mmul(*xt, *Wx) + (*b)({{0, nU}}) + mmul(*hPrev, *Wh) + (*b)({{nU, 2*nU}})); // [bS x nU] + [nU] + [bS x nU] + [nU] = [bS x nU]
|
||||
ht->applyTransform(transform::Tanh);
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void rnnCell(nd4j::LaunchContext * context, const NDArray* xt, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* ht_1, NDArray* ht) {
|
||||
|
||||
// xt input [bS x inSize]
|
||||
// Wx input-to-hidden weights, [inSize x numUnits]
|
||||
// Wh hidden-to-hidden weights, [numUnits x numUnits]
|
||||
// b biases, [2*numUnits]: {0, numUnits} are input-to-hidden biases and {numUnits, 2*numUnits} are hidden-to-hidden biases
|
||||
// ht_1 previous cell output [bS x numUnits], that is at previous time step t-1, in case of projection=false -> numUnits=numUnits!!!
|
||||
|
||||
const int numUnits = ht_1->sizeAt(1);
|
||||
|
||||
// ht is current cell output [bS x numUnits], that is at current time step t
|
||||
ht->assign(activation(mmul(*xt, *Wx) + (*b)({{0, numUnits}}) + mmul(*ht_1, *Wh) + (*b)({{numUnits, 2*numUnits}}))); // [bS x numUnits] + [numUnits] + [bS x numUnits] + [numUnits] = [bS x numUnits]
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void rnnTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* h0, const NDArray* maxTimeStep, NDArray* h, NDArray* hFinal) {
|
||||
|
||||
// x input [time x bS x inSize]
|
||||
// Wx input-to-hidden weights, [inSize x numUnits]
|
||||
// Wh hidden-to-hidden weights, [numUnits x numUnits]
|
||||
// b biases for, [2*numUnits]
|
||||
// x input [time x bS x iS]
|
||||
// Wx input-to-hidden weights, [iS x nU]
|
||||
// Wh hidden-to-hidden weights, [nU x nU]
|
||||
// b biases for, [2*nU]
|
||||
|
||||
// h0 initial cell output (at time step = 0) [bS x numUnits]
|
||||
// h0 initial cell output (at time step = 0) [bS x nU]
|
||||
// maxTimeStep vector [bS] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this means there are no calculations for time >= maxTimeStep
|
||||
|
||||
const int time = x->sizeAt(0);
|
||||
|
@ -82,16 +75,16 @@ void rnnTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray*
|
|||
|
||||
auto xt = (*x)({t,t+1, e,e+1, 0,0}, true);
|
||||
auto ht = (*h)({t,t+1, e,e+1, 0,0}, true);
|
||||
auto ht_1 = (*hFinal)({e,e+1, 0,0}, true); // previous state
|
||||
auto hPrev = (*hFinal)({e,e+1, 0,0}, true); // previous state
|
||||
|
||||
if(t >= maxStep) {
|
||||
ht = 0.;
|
||||
if(maxStep != 0)
|
||||
ht_1.assign((*h)({maxStep-1,maxStep, e,e+1, 0,0}));
|
||||
hPrev.assign((*h)({maxStep-1,maxStep, e,e+1, 0,0}));
|
||||
}
|
||||
else {
|
||||
helpers::rnnCell(context, &xt, Wx, Wh, b, &ht_1, &ht);
|
||||
ht_1.assign(ht);
|
||||
helpers::rnnCell(context, &xt, Wx, Wh, b, &hPrev, &ht);
|
||||
hPrev.assign(ht);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author sgazeos@gmail.com
|
||||
// @brief helpers common fuctions for segment_* ops (segment_max, segment_min, etc.)
|
||||
// @brief helpers common fuctions for unsorted_segment_* ops (unsorted_segment_max, etc.)
|
||||
//
|
||||
#ifndef __SEGMENT_COMMON_HELPERS__
|
||||
#define __SEGMENT_COMMON_HELPERS__
|
||||
#include <op_boilerplate.h>
|
||||
#include <NDArray.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
void fillUpSegments(NDArray* indices, Nd4jLong numClasses, NDArray& classesRangesBegs, NDArray& classesRangesLens);
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
|
@ -23,6 +23,7 @@
|
|||
|
||||
#include <ops/declarable/helpers/helpers.h>
|
||||
#include <helpers/helper_random.h>
|
||||
#include <graph/RandomGenerator.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
|
@ -32,7 +33,7 @@ namespace helpers {
|
|||
|
||||
void trace(nd4j::LaunchContext * context, const NDArray& input, NDArray& output);
|
||||
|
||||
void randomShuffle(nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::random::RandomBuffer& rng, const bool isInplace);
|
||||
void randomShuffle(nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::graph::RandomGenerator& rng, const bool isInplace);
|
||||
|
||||
// auxiliary function which serves for recursion purpose and is used in pad operation
|
||||
// void recursiveLoopForPad(const int mode, NDArray& input, const NDArray& paddings, NDArray& output, std::vector<int> dimensions, int dim, int inIdx, int outIdx, NDArray& padValue);
|
||||
|
|
|
@ -1126,15 +1126,7 @@ inline __device__ bool nd4j_atomicAdd<bool>(bool* address, bool val) {
|
|||
|
||||
template <>
|
||||
inline __device__ double nd4j_atomicSub<double>(double* address, double val) {
|
||||
unsigned long long int* address_as_ull =
|
||||
(unsigned long long int *) address;
|
||||
unsigned long long int old = *address_as_ull, assumed;
|
||||
do {
|
||||
assumed = old;
|
||||
old = atomicCAS(address_as_ull, assumed,__double_as_longlong(val -
|
||||
__longlong_as_double(assumed)));
|
||||
} while (assumed != old);
|
||||
return __longlong_as_double(old);
|
||||
return nd4j_atomicAdd<double>(address, -val);
|
||||
}
|
||||
|
||||
template <>
|
||||
|
@ -1152,15 +1144,7 @@ inline __device__ double nd4j_atomicMul<double>(double* address, double val) {
|
|||
|
||||
template <>
|
||||
inline __device__ double nd4j_atomicDiv<double>(double* address, double val) {
|
||||
unsigned long long int* address_as_ull =
|
||||
(unsigned long long int*) address;
|
||||
unsigned long long int old = *address_as_ull, assumed;
|
||||
do {
|
||||
assumed = old;
|
||||
old = atomicCAS(address_as_ull, assumed,__double_as_longlong(val /
|
||||
__longlong_as_double(assumed)));
|
||||
} while (assumed != old);
|
||||
return __longlong_as_double(old);
|
||||
return nd4j_atomicMul<double>(address, 1./val);
|
||||
}
|
||||
|
||||
template <>
|
||||
|
@ -1179,14 +1163,16 @@ inline __device__ int32_t nd4j_atomicAdd<int32_t>(int32_t* address, int32_t val)
|
|||
|
||||
template <>
|
||||
inline __device__ float nd4j_atomicSub<float>(float* address, float val) {
|
||||
int* address_as_ull = (int*) address;
|
||||
int old = *address_as_ull, assumed;
|
||||
do {
|
||||
assumed = old;
|
||||
old = atomicCAS(address_as_ull, assumed, __float_as_int(val -
|
||||
__float_as_int(assumed)));
|
||||
} while (assumed != old);
|
||||
return __int_as_float(old);
|
||||
return nd4j_atomicAdd<float>(address, -val);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ float16 nd4j_atomicSub<float16>(float16* address, float16 val) {
|
||||
return nd4j_atomicAdd<float16>(address, -val);
|
||||
}
|
||||
template <>
|
||||
inline __device__ bfloat16 nd4j_atomicSub<bfloat16>(bfloat16* address, bfloat16 val) {
|
||||
return nd4j_atomicAdd<bfloat16>(address, -val);
|
||||
}
|
||||
|
||||
template <>
|
||||
|
@ -1415,6 +1401,30 @@ inline __device__ float16 nd4j_atomicMul<float16>(float16* address, float16 val)
|
|||
|
||||
template <>
|
||||
inline __device__ float nd4j_atomicDiv<float>(float* address, float val) {
|
||||
int* address_as_ull =
|
||||
(int*)address;
|
||||
int old = *address_as_ull, assumed;
|
||||
do {
|
||||
assumed = old;
|
||||
old = atomicCAS(address_as_ull, assumed, __float_as_int(__int_as_float(assumed) / val ));
|
||||
} while (assumed != old);
|
||||
return __int_as_float(old);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ float16 nd4j_atomicDiv<float16>(float16* address, float16 val) {
|
||||
int* address_as_ull =
|
||||
(int*)address;
|
||||
int old = *address_as_ull, assumed;
|
||||
do {
|
||||
assumed = old;
|
||||
old = atomicCAS(address_as_ull, assumed, __float_as_int(val *
|
||||
__float_as_int(assumed)));
|
||||
} while (assumed != old);
|
||||
return __int_as_float(old);
|
||||
}
|
||||
template <>
|
||||
inline __device__ bfloat16 nd4j_atomicDiv<bfloat16>(bfloat16* address, bfloat16 val) {
|
||||
int* address_as_ull =
|
||||
(int*)address;
|
||||
int old = *address_as_ull, assumed;
|
||||
|
|
|
@ -76,6 +76,9 @@
|
|||
(nd4j::DataType::FLOAT32, float), \
|
||||
(nd4j::DataType::DOUBLE, double)
|
||||
|
||||
#define FLOAT_NATIVE \
|
||||
(nd4j::DataType::FLOAT32, float), \
|
||||
(nd4j::DataType::DOUBLE, double)
|
||||
|
||||
#define FLOAT_TYPES_0 \
|
||||
(nd4j::DataType::HALF, float16)
|
||||
|
|
|
@ -1870,7 +1870,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_1) {
|
|||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
NDArray* result = results->at(0);
|
||||
result->printIndexedBuffer("OOOOUUUUTTT");
|
||||
// result->printIndexedBuffer("OOOOUUUUTTT");
|
||||
|
||||
ASSERT_TRUE(expected.isSameShapeStrict(result));
|
||||
ASSERT_TRUE(expected.equalsTo(result));
|
||||
|
@ -1881,9 +1881,9 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_1) {
|
|||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_2) {
|
||||
|
||||
NDArray boxes = NDArrayFactory::create<float>('c', {6,4}, {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1.f, 0.9f,
|
||||
NDArray boxes = NDArrayFactory::create<double>('c', {6,4}, {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1.f, 0.9f,
|
||||
0, 10, 1, 11, 0, 10.1f, 1.f, 11.1f, 0, 100, 1, 101});
|
||||
NDArray scales = NDArrayFactory::create<float>('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); //3, 0, 1, 2, 4, 5
|
||||
NDArray scales = NDArrayFactory::create<double>('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); //3, 0, 1, 2, 4, 5
|
||||
NDArray expected = NDArrayFactory::create<int>('c', {3}, {3,0,5});
|
||||
|
||||
nd4j::ops::non_max_suppression op;
|
||||
|
@ -1892,7 +1892,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_2) {
|
|||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
NDArray* result = results->at(0);
|
||||
result->printBuffer("NonMaxSuppression OUtput2");
|
||||
// result->printBuffer("NonMaxSuppression OUtput2");
|
||||
ASSERT_TRUE(expected.isSameShapeStrict(result));
|
||||
ASSERT_TRUE(expected.equalsTo(result));
|
||||
|
||||
|
@ -1970,6 +1970,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_3) {
|
|||
|
||||
delete results;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests10, Image_CropAndResize_4) {
|
||||
|
||||
|
|
|
@ -421,3 +421,200 @@ ASSERT_TRUE(result->at(0)->e<bool>(0));
|
|||
//ASSERT_TRUE(exp.equalsTo(result->at(0)));
|
||||
delete result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, adjustHue_1) {
|
||||
|
||||
NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp ('c', {2,2,3}, {100,0,44, 208,5,220, 177,230,97, 2,255,244}, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::adjust_hue op;
|
||||
auto results = op.execute({&input}, {0.5}, {2});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto result = results->at(0);
|
||||
// result->printIndexedBuffer();
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(result));
|
||||
ASSERT_TRUE(exp.equalsTo(result));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, adjustHue_2) {
|
||||
|
||||
NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp ('c', {2,2,3}, {4,100,0, 146,220,5, 97,123.8,230, 255,2,164.8}, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::adjust_hue op;
|
||||
auto results = op.execute({&input}, {0.9}, {2});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto result = results->at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(result));
|
||||
ASSERT_TRUE(exp.equalsTo(result));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, adjustHue_3) {
|
||||
|
||||
NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp ('c', {2,2,3}, {0.,84.,100., 5.,220.,122.0001, 229.8,97.,230., 255.,142.8002,2.}, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::adjust_hue op;
|
||||
auto results = op.execute({&input}, {-0.9}, {2});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto result = results->at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(result));
|
||||
ASSERT_TRUE(exp.equalsTo(result));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, adjustHue_4) {
|
||||
|
||||
NDArray input('c', {2,3,2}, {0,17, 100,220, 56,5, 150,255, 97,2, 230,13}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp ('c', {2,3,2}, {100,208, 0,5, 44,220, 177,2, 230,255, 97,244}, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::adjust_hue op;
|
||||
auto results = op.execute({&input}, {0.5}, {1});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto result = results->at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(result));
|
||||
ASSERT_TRUE(exp.equalsTo(result));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, adjustHue_5) {
|
||||
|
||||
NDArray input('c', {3,2,2}, {0,17, 150,255, 100,220, 97,2, 56,5, 230,13}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp ('c', {3,2,2}, {100,208, 177,2, 0,5, 230,255, 44,220, 97,244}, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::adjust_hue op;
|
||||
auto results = op.execute({&input}, {0.5}, {0});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto result = results->at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(result));
|
||||
ASSERT_TRUE(exp.equalsTo(result));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, adjustSaturation_1) {
|
||||
|
||||
NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp ('c', {2,2,3}, {50,100,78, 118.5,220,112.5, 190,163.5,230, 255,128.5,134}, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::adjust_saturation op;
|
||||
auto results = op.execute({&input}, {0.5}, {2});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto result = results->at(0);
|
||||
// result->printIndexedBuffer();
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(result));
|
||||
ASSERT_TRUE(exp.equalsTo(result));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, adjustSaturation_2) {
|
||||
|
||||
NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp ('c', {2,2,3}, {0.,100.,56., 12.279087,220.,0., 91.654228,0.,230., 255.,0.,11.087015}, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::adjust_saturation op;
|
||||
auto results = op.execute({&input}, {10}, {2});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto result = results->at(0);
|
||||
// result->printIndexedBuffer();
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(result));
|
||||
ASSERT_TRUE(exp.equalsTo(result));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, adjustSaturation_3) {
|
||||
|
||||
NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp ('c', {2,2,3}, {100.,100.,100., 220.,220.,220., 230.,230.,230., 255., 255., 255.}, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::adjust_saturation op;
|
||||
auto results = op.execute({&input}, {-10}, {2});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto result = results->at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(result));
|
||||
ASSERT_TRUE(exp.equalsTo(result));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, adjustSaturation_4) {
|
||||
|
||||
NDArray input('c', {2,3,2}, {0,17, 100,220, 56,5, 150,255, 97,2, 230,13}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp ('c', {2,3,2}, {50,118.5, 100,220, 78,112.5, 190,255, 163.5,128.5, 230,134}, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::adjust_saturation op;
|
||||
auto results = op.execute({&input}, {0.5}, {1});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto result = results->at(0);
|
||||
// result->printIndexedBuffer();
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(result));
|
||||
ASSERT_TRUE(exp.equalsTo(result));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, adjustSaturation_5) {
|
||||
|
||||
NDArray input('c', {3,2,2}, {0,17, 150,255, 100,220, 97,2, 56,5, 230,13}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp ('c', {3,2,2}, {50,118.5, 190,255, 100,220, 163.5,128.5, 78,112.5, 230,134}, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::adjust_saturation op;
|
||||
auto results = op.execute({&input}, {0.5}, {0});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto result = results->at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(result));
|
||||
ASSERT_TRUE(exp.equalsTo(result));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
|
|
|
@ -1479,6 +1479,27 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test3) {
|
|||
|
||||
delete results;
|
||||
}
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests5, random_shuffle_test04) {
|
||||
auto input = NDArrayFactory::create<double>('c', {4});
|
||||
input.linspace(1);
|
||||
|
||||
nd4j::ops::random_shuffle op;
|
||||
//NDArray* output;
|
||||
auto results = op.execute({&input}, {}, {}, {}, true, nd4j::DataType::DOUBLE);
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
auto output = &input; //results->at(0);
|
||||
bool haveZeros = false;
|
||||
for(int i = 0; i < output->lengthOf(); ++i)
|
||||
if(output->e<float>(i) == (float)0.)
|
||||
haveZeros = true;
|
||||
|
||||
ASSERT_TRUE(input.isSameShape(output));
|
||||
//ASSERT_TRUE(!input.equalsTo(output));
|
||||
ASSERT_TRUE(!haveZeros);
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests5, random_shuffle_test4) {
|
||||
|
@ -1486,17 +1507,17 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test4) {
|
|||
input.linspace(1);
|
||||
|
||||
nd4j::ops::random_shuffle op;
|
||||
//NDArray* output;
|
||||
auto results = op.execute({&input}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
auto output = results->at(0);
|
||||
|
||||
bool haveZeros = false;
|
||||
for(int i = 0; i < output->lengthOf(); ++i)
|
||||
if(output->e<float>(i) == (float)0.)
|
||||
haveZeros = true;
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(input.isSameShape(output));
|
||||
ASSERT_TRUE(!input.equalsTo(output));
|
||||
//ASSERT_TRUE(!input.equalsTo(output));
|
||||
ASSERT_TRUE(!haveZeros);
|
||||
|
||||
delete results;
|
||||
|
|
|
@ -1601,8 +1601,8 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_1) {
|
|||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
// z->printIndexedBuffer("Output ");
|
||||
// exp.printIndexedBuffer("Expected ");
|
||||
z->printIndexedBuffer("Output ");
|
||||
exp.printIndexedBuffer("Expected ");
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
@ -1610,6 +1610,75 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_1) {
|
|||
delete result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests6, MatrixInverse_01) {
|
||||
|
||||
auto x = NDArrayFactory::create<float>('c', {1, 5, 5}, {
|
||||
2., 4., 60., 8., 10.,
|
||||
0., 1., 2., 3., 4.,
|
||||
0., 0., 2., 4., 6.,
|
||||
0., 0., 0., 1., 2.,
|
||||
0., 0., 0., 0., 4.
|
||||
|
||||
});
|
||||
|
||||
auto exp = NDArrayFactory::create<float>('c', {1, 5, 5}, {
|
||||
0.5, -2.0, -13.0, 54.0, -6.75,
|
||||
0.0, 1.0, -1.0, 1.0, 0.0,
|
||||
0, 0, 0.5, -2.0, 0.25,
|
||||
0, 0, 0, 1.0, -0.5,
|
||||
0, 0, 0, 0, 0.25
|
||||
|
||||
});
|
||||
|
||||
nd4j::ops::matrix_inverse op;
|
||||
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32);
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
z->printIndexedBuffer("Output ");
|
||||
exp.printIndexedBuffer("Expected ");
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests6, MatrixInverse_02) {
|
||||
|
||||
auto x = NDArrayFactory::create<float>('c', {1, 5, 5}, {
|
||||
1., 0., 0., 0., 0.,
|
||||
2., 1., 0., 0., 0.,
|
||||
30., 2., 1., 0., 0.,
|
||||
4., 3., 2., 1., 0.,
|
||||
5., 4., 3., 2., 1.
|
||||
});
|
||||
|
||||
auto exp = NDArrayFactory::create<float>('c', {1, 5, 5}, {
|
||||
1.0, 0.0, 0.0, 0.0, 0.,
|
||||
-2.0, 1.0, 0., 0., 0.,
|
||||
-26.0, -2.0, 1, 0, 0.,
|
||||
54.0, 1.0, -2.0, 1, 0.,
|
||||
-27.0, 0.0, 1.0, -2.0, 1.
|
||||
});
|
||||
|
||||
nd4j::ops::matrix_inverse op;
|
||||
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32);
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
z->printIndexedBuffer("Output ");
|
||||
exp.printIndexedBuffer("Expected ");
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/*
|
||||
|
@ -1658,6 +1727,39 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_2) {
|
|||
delete result;
|
||||
}
|
||||
*/
|
||||
TEST_F(DeclarableOpsTests6, MatrixInverse_03) {
|
||||
|
||||
auto x = NDArrayFactory::create<float>('c', {5, 5}, {
|
||||
4., 0., 0., 0., 0.,
|
||||
4., 2., 0., 0., 0.,
|
||||
30., 2., 1., 0., 0.,
|
||||
8., 6., 4., 2., 0.,
|
||||
15., 12., 9., 6., 3.,
|
||||
});
|
||||
|
||||
auto exp = NDArrayFactory::create<float>('c', {5, 5}, {
|
||||
0.25, 0.0, 0.0, 0.0, 0.0,
|
||||
-0.50, 0.5, 0.0, 0.0, 0.0,
|
||||
-6.50, -1.0, 1.0, 0.0, 0.0,
|
||||
13.50, 0.5, -2.0, 0.5, 0.0,
|
||||
-6.75, 0.0, 1.0, -1.0, 0.33333333
|
||||
});
|
||||
|
||||
nd4j::ops::matrix_inverse op;
|
||||
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32);
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
z->printIndexedBuffer("Output ");
|
||||
exp.printIndexedBuffer("Expected ");
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests6, MatrixInverse_3) {
|
||||
|
||||
|
@ -1695,7 +1797,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_3) {
|
|||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests6, MatrixInverse_4) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {5, 5}, {
|
||||
auto x = NDArrayFactory::create<float>('c', {5, 5}, {
|
||||
1., 2., 30., 4., 5.,
|
||||
0., 1., 2., 3., 4.,
|
||||
0., 0., 1., 2., 3.,
|
||||
|
@ -1703,7 +1805,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_4) {
|
|||
0., 0., 0., 0., 1.
|
||||
});
|
||||
|
||||
auto exp = NDArrayFactory::create<double>('c', {5, 5}, {
|
||||
auto exp = NDArrayFactory::create<float>('c', {5, 5}, {
|
||||
1.0, -2.0, -26.0, 54.0, -27.0,
|
||||
0.0, 1.0, -2.0, 1.0, 0.0,
|
||||
0.0, 0.0, 1.0, -2.0, 1.0,
|
||||
|
@ -1712,13 +1814,13 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_4) {
|
|||
});
|
||||
|
||||
nd4j::ops::matrix_inverse op;
|
||||
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
|
||||
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32);
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
//z->printIndexedBuffer("Output ");
|
||||
//exp.printIndexedBuffer("Expected ");
|
||||
z->printIndexedBuffer("Output ");
|
||||
exp.printIndexedBuffer("Expected ");
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
|
|
@ -763,15 +763,15 @@ TEST_F(DeclarableOpsTests7, Test_Dynamic_Partition_119_2) {
|
|||
|
||||
|
||||
TEST_F(DeclarableOpsTests7, Test_SequenceMask_1) {
|
||||
auto input = NDArrayFactory::create<double>('c', {4, 4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f});
|
||||
auto exp = NDArrayFactory::create<double>('c', {4, 4, 16}, {1.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,1.f, 1.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
|
||||
1.f, 1.f, 1.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,1.f, 1.f, 1.f, 1.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
|
||||
1.f, 1.f, 1.f, 1.f, 1.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
|
||||
1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
|
||||
1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
|
||||
1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.f, 0.f, 0.f, 0.f, 0.f,1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.f, 0.f, 0.f, 0.f,
|
||||
1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.f, 0.f, 0.f,1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.f, 0.f,
|
||||
1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.f,1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f });
|
||||
auto input = NDArrayFactory::create<int>('c', {4, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
auto exp = NDArrayFactory::create<bool>('c', {4, 4, 16}, {1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 });
|
||||
|
||||
nd4j::ops::sequence_mask op;
|
||||
auto result = op.execute({&input}, {}, {});
|
||||
|
@ -788,19 +788,19 @@ TEST_F(DeclarableOpsTests7, Test_SequenceMask_1) {
|
|||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests7, Test_SequenceMask_2) {
|
||||
auto input = NDArrayFactory::create<double>('c', {2, 2, 2}, {10., 20., 30., 4., 0., 6., 7., 8.});
|
||||
auto exp = NDArrayFactory::create<double>('c', {2, 2, 2, 30}, { 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
|
||||
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
|
||||
1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
|
||||
1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.});
|
||||
auto input = NDArrayFactory::create<int>('c', {2, 2, 2}, {10, 20, 30, 4, 0, 6, 7, 8});
|
||||
auto exp = NDArrayFactory::create<bool>('c', {2, 2, 2, 30}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
|
||||
|
||||
nd4j::ops::sequence_mask op;
|
||||
auto result = op.execute({&input}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
// z->printIndexedBuffer("Output");
|
||||
// z->printBuffer("Output");
|
||||
// z->printShapeInfo("Shape");
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
|
|
@ -2770,9 +2770,8 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test3) {
|
|||
ASSERT_TRUE(isGradCorrect);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
/*
|
||||
//2019/02/23 AB - GRU backprop tests disabled pending update of GRU backprop op after rewriting forward pass
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, gru_cell_bp_test1) {
|
||||
|
||||
const int bS = 2;
|
||||
|
@ -2780,160 +2779,58 @@ TEST_F(DeclarableOpsTests9, gru_cell_bp_test1) {
|
|||
const int nU = 4;
|
||||
|
||||
NDArray x('c', {bS, iS}, nd4j::DataType::DOUBLE);
|
||||
NDArray h0('c', {bS, nU}, nd4j::DataType::DOUBLE);
|
||||
NDArray Wx('c', {iS, 3*nU}, nd4j::DataType::DOUBLE);
|
||||
NDArray Wh('c', {nU, 3*nU}, nd4j::DataType::DOUBLE);
|
||||
NDArray b('c', {3*nU}, nd4j::DataType::DOUBLE);
|
||||
NDArray hi('c', {bS, nU}, nd4j::DataType::DOUBLE);
|
||||
NDArray W('c', {iS+nU, 2*nU}, nd4j::DataType::DOUBLE);
|
||||
NDArray Wc('c', {iS+nU, nU}, nd4j::DataType::DOUBLE);
|
||||
NDArray b('c', {2*nU}, nd4j::DataType::DOUBLE);
|
||||
NDArray bc('c', {nU}, nd4j::DataType::DOUBLE);
|
||||
NDArray dLdr('c', {bS, nU}, nd4j::DataType::DOUBLE);
|
||||
NDArray dLdu('c', {bS, nU}, nd4j::DataType::DOUBLE);
|
||||
NDArray dLdc('c', {bS, nU}, nd4j::DataType::DOUBLE);
|
||||
NDArray dLdh('c', {bS, nU}, nd4j::DataType::DOUBLE);
|
||||
|
||||
x.linspace(0.5, 0.5);
|
||||
h0 = 1.;
|
||||
Wx = 0.003;
|
||||
Wh = 0.006;
|
||||
x.linspace(-5, 0.5);
|
||||
hi = 1.;
|
||||
W = 0.003;
|
||||
Wc = 0.006;
|
||||
b = 0.5;
|
||||
bc = 0.35;
|
||||
|
||||
const OpArgsHolder argsHolderFF({&x, &h0, &Wx, &Wh, &b}, {}, {});
|
||||
const OpArgsHolder argsHolderBP({&x, &h0, &Wx, &Wh, &b, &dLdh}, {}, {});
|
||||
|
||||
const OpArgsHolder argsHolderFF({&x, &hi, &W, &Wc, &b, &bc}, {}, {});
|
||||
nd4j::ops::gruCell op;
|
||||
auto results = op.execute(argsHolderFF);
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto u = results->at(1); // [bS, nU]
|
||||
auto c = results->at(2); // [bS, nU]
|
||||
auto h = results->at(3); // [bS, nU]
|
||||
|
||||
dLdh = 1.; // SUM loss
|
||||
|
||||
NDArray Wch = Wc({iS,iS+nU, 0,0}); // [nU, nU]
|
||||
NDArray dhdc = 1. - *u;
|
||||
NDArray dhdu = hi - *c;
|
||||
NDArray dcdZc = 1. - *c * *c;
|
||||
dLdc.assign(dLdh * dhdc);
|
||||
dLdu.assign(dLdh * dhdu);
|
||||
dLdr.assign(mmul(dLdc * dcdZc * hi, Wch.transpose()));
|
||||
|
||||
delete results;
|
||||
|
||||
|
||||
const OpArgsHolder argsHolderBP({&x, &hi, &W, &Wc, &b, &bc, &dLdr, &dLdu, &dLdc, &dLdh}, {}, {});
|
||||
|
||||
nd4j::ops::gruCell opFF;
|
||||
nd4j::ops::gruCell_bp opBP;
|
||||
|
||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
||||
|
||||
ASSERT_TRUE(isGradCorrect);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, gru_cell_bp_test2) {
|
||||
|
||||
const int bS = 2;
|
||||
const int iS = 3;
|
||||
const int nU = 4;
|
||||
|
||||
NDArray x('c', {bS, iS}, nd4j::DataType::DOUBLE);
|
||||
NDArray h0('c', {bS, nU}, nd4j::DataType::DOUBLE);
|
||||
NDArray Wx('c', {iS, 3*nU}, nd4j::DataType::DOUBLE);
|
||||
NDArray Wh('c', {nU, 3*nU}, nd4j::DataType::DOUBLE);
|
||||
NDArray b('c', {3*nU}, nd4j::DataType::DOUBLE);
|
||||
NDArray dLdh('c', {bS, nU}, nd4j::DataType::DOUBLE);
|
||||
|
||||
x.linspace(0.5, 0.5);
|
||||
h0 = 1.;
|
||||
Wx = 0.003;
|
||||
Wh = 0.006;
|
||||
b = 0.;
|
||||
|
||||
const OpArgsHolder argsHolderFF({&x, &h0, &Wx, &Wh, &b}, {}, {});
|
||||
const OpArgsHolder argsHolderBP({&x, &h0, &Wx, &Wh, &b, &dLdh}, {}, {});
|
||||
|
||||
nd4j::ops::gruCell opFF;
|
||||
nd4j::ops::gruCell_bp opBP;
|
||||
|
||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
||||
|
||||
ASSERT_TRUE(isGradCorrect);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, gru_cell_bp_test3) {
|
||||
|
||||
const int bS = 2;
|
||||
const int iS = 3;
|
||||
const int nU = 4;
|
||||
|
||||
NDArray x('c', {bS, iS}, nd4j::DataType::DOUBLE);
|
||||
NDArray h0('c', {bS, nU}, nd4j::DataType::DOUBLE);
|
||||
NDArray Wx('c', {iS, 3*nU}, nd4j::DataType::DOUBLE);
|
||||
NDArray Wh('c', {nU, 3*nU}, nd4j::DataType::DOUBLE);
|
||||
NDArray b('c', {3*nU}, nd4j::DataType::DOUBLE);
|
||||
NDArray dLdh('c', {bS, nU}, nd4j::DataType::DOUBLE);
|
||||
// NDArray<double> dLdWx0('c', {iS, 3*nU});
|
||||
// NDArray<double> dLdWh0('c', {nU, 3*nU});
|
||||
// NDArray<double> dLdb0 ('c', {3*nU});
|
||||
|
||||
x = 1.;
|
||||
h0 = 0.0;
|
||||
Wx = 0.0;
|
||||
Wh = 0.0;
|
||||
b = 0.5;
|
||||
|
||||
const OpArgsHolder argsHolderFF({&x, &h0, &Wx, &Wh, &b}, {}, {});
|
||||
const OpArgsHolder argsHolderBP({&x, &h0, &Wx, &Wh, &b, &dLdh}, {}, {});
|
||||
|
||||
nd4j::ops::gruCell opFF;
|
||||
nd4j::ops::gruCell_bp opBP;
|
||||
|
||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
||||
|
||||
ASSERT_TRUE(isGradCorrect);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
// TEST_F(DeclarableOpsTests9, gru_bp_test1) {
|
||||
|
||||
// const int time = 5;
|
||||
// const int bS = 2;
|
||||
// const int iS = 3;
|
||||
// const int nU = 4;
|
||||
|
||||
// NDArray<double> x ('c', {time, bS, iS});
|
||||
// NDArray<double> h0 ('c', {bS, nU});
|
||||
// NDArray<double> Wx ('c', {iS, 3*nU});
|
||||
// NDArray<double> Wh ('c', {nU, 3*nU});
|
||||
// NDArray<double> b ('c', {3*nU});
|
||||
// NDArray<double> dLdh ('c', {time, bS, nU});
|
||||
|
||||
// x.linspace(0.5, 0.5);
|
||||
// h0 = 1.;
|
||||
// Wx = 0.003;
|
||||
// Wh = 0.006;
|
||||
// b = 0.5;
|
||||
|
||||
// const OpArgsHolder<double> argsHolderFF({&x, &h0, &Wx, &Wh, &b}, {}, {});
|
||||
// const OpArgsHolder<double> argsHolderBP({&x, &h0, &Wx, &Wh, &b, &dLdh}, {}, {});
|
||||
|
||||
// nd4j::ops::gru<double> opFF;
|
||||
// nd4j::ops::gru_bp<double> opBP;
|
||||
|
||||
// const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
||||
|
||||
// ASSERT_TRUE(isGradCorrect);
|
||||
// }
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, gru_cell_bp_test3_1) {
|
||||
|
||||
const int bS = 2;
|
||||
const int iS = 3;
|
||||
const int nU = 4;
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {bS, iS});
|
||||
auto h0 = NDArrayFactory::create<double>('c', {bS, nU});
|
||||
auto Wx = NDArrayFactory::create<double>('c', {iS, 3*nU});
|
||||
auto Wh = NDArrayFactory::create<double>('c', {nU, 3*nU});
|
||||
auto b = NDArrayFactory::create<double>('c', {3*nU});
|
||||
auto dLdh = NDArrayFactory::create<double>('c', {bS, nU});
|
||||
// NDArray<double> dLdWx0('c', {iS, 3*nU});
|
||||
// NDArray<double> dLdWh0('c', {nU, 3*nU});
|
||||
// NDArray<double> dLdb0 ('c', {3*nU});
|
||||
|
||||
x = 1.;
|
||||
h0 = 0.0;
|
||||
Wx = 0.0;
|
||||
Wh = 0.0;
|
||||
b = 0.5;
|
||||
|
||||
const OpArgsHolder argsHolderFF({&x, &h0, &Wx, &Wh, &b}, {}, {});
|
||||
const OpArgsHolder argsHolderBP({&x, &h0, &Wx, &Wh, &b, &dLdh}, {}, {});
|
||||
|
||||
nd4j::ops::gruCell opFF;
|
||||
nd4j::ops::gruCell_bp opBP;
|
||||
|
||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 1, 1, 1 , 1, 1}, {0., 1.}, nd4j::GradCheck::LossFunc::SUM, true);
|
||||
|
||||
ASSERT_TRUE(isGradCorrect);
|
||||
}
|
||||
*/
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, Cholesky_Test_1) {
|
||||
|
||||
|
|
|
@ -719,6 +719,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_1) {
|
|||
}
|
||||
|
||||
TEST_F(ParityOpsTests, Test_Scatter_Add_2) {
|
||||
|
||||
auto vec = NDArrayFactory::create<float>('c', {4}, {1, 2, 3, 4});
|
||||
NDArray idc('c', {1, 4}, {0, 1, 2, 3}, nd4j::DataType::INT64);
|
||||
auto updates = NDArrayFactory::create<float>('c', {1, 4}, {1, 1, 1, 1});
|
||||
|
@ -1588,36 +1589,79 @@ TEST_F(ParityOpsTests, scatterND_update_test5) {
|
|||
delete result;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(ParityOpsTests, scatter_update_1) {
|
||||
auto matrix = NDArrayFactory::create_<float>('c', {3, 2});
|
||||
auto updates = NDArrayFactory::create_<float>('c', {2, 2});
|
||||
updates->assign(1.0);
|
||||
|
||||
//updates.printBuffer("Updates");
|
||||
NDArray x('c', {2,2}, {1,2,3,4}, nd4j::DataType::INT32);
|
||||
NDArray updates('c', {2,2}, {10,20,30,40}, nd4j::DataType::INT32);
|
||||
|
||||
auto variableSpace = new VariableSpace();
|
||||
variableSpace->putVariable(-1, matrix);
|
||||
variableSpace->putVariable(-2, updates);
|
||||
variableSpace->putVariable(1, new Variable(&matrix));
|
||||
|
||||
auto block = new Context(1, variableSpace, false);
|
||||
block->fillInputs({-1, -2});
|
||||
|
||||
std::vector<int>* arguments = block->getIArguments();
|
||||
arguments->push_back(0);
|
||||
arguments->push_back(1);
|
||||
arguments->push_back(1);
|
||||
arguments->push_back(2);
|
||||
arguments->push_back(1);
|
||||
arguments->push_back(2);
|
||||
NDArray exp('c', {2,2}, {30,40,10,20}, nd4j::DataType::INT32);
|
||||
|
||||
nd4j::ops::scatter_update op;
|
||||
auto results = op.execute({&x, &updates}, {}, {6, 1,1, 2,1,0});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
// x.printBuffer();
|
||||
|
||||
Nd4jStatus result = op.execute(block);
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result);
|
||||
ASSERT_TRUE(exp.isSameShape(x));
|
||||
ASSERT_TRUE(exp.equalsTo(x));
|
||||
|
||||
delete block;
|
||||
delete variableSpace;
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(ParityOpsTests, scatter_update_2) {
|
||||
|
||||
NDArray x('c', {2,2}, {1,2,3,4}, nd4j::DataType::INT32);
|
||||
NDArray updates('c', {2,2}, {10,20,30,40}, nd4j::DataType::INT32);
|
||||
|
||||
NDArray exp('c', {2,2}, {20,10,40,30}, nd4j::DataType::INT32);
|
||||
|
||||
nd4j::ops::scatter_update op;
|
||||
auto results = op.execute({&x, &updates}, {}, {6, 1,0, 2,1,0});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(x));
|
||||
ASSERT_TRUE(exp.equalsTo(x));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(ParityOpsTests, scatter_update_3) {
|
||||
|
||||
NDArray x('c', {2,2,2}, {1,2,3,4,5,6,7,8}, nd4j::DataType::INT32);
|
||||
NDArray updates('c', {2,2,2}, {10,20,30,40,50,60,70,80}, nd4j::DataType::INT32);
|
||||
|
||||
NDArray exp('c', {2,2,2}, {50,60,70,80,10,20,30,40}, nd4j::DataType::INT32);
|
||||
|
||||
nd4j::ops::scatter_update op;
|
||||
auto results = op.execute({&x, &updates}, {}, {6, 2,1,2, 2,1,0});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(x));
|
||||
ASSERT_TRUE(exp.equalsTo(x));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(ParityOpsTests, scatter_update_4) {
|
||||
|
||||
NDArray x('c', {2,2,2}, {1,2,3,4,5,6,7,8}, nd4j::DataType::INT32);
|
||||
NDArray updates('c', {2,2,2}, {10,20,30,40,50,60,70,80}, nd4j::DataType::INT32);
|
||||
|
||||
NDArray exp('c', {2,2,2}, {20,2,3,10,60,6,7,50}, nd4j::DataType::INT32);
|
||||
|
||||
nd4j::ops::scatter_update op;
|
||||
auto results = op.execute({&x, &updates}, {}, {6, 1,0, 2,3,0});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(x));
|
||||
ASSERT_TRUE(exp.equalsTo(x));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
|
|
@ -278,8 +278,8 @@ TEST_F(RNGTests, Test_Gaussian_22) {
|
|||
auto x0 = NDArrayFactory::create<float>('c', {10000, 1000});
|
||||
auto x1 = NDArrayFactory::create<float>('c', {10000, 1000});
|
||||
|
||||
RandomLauncher::fillGaussian(_rngA, &x0, 0.0f, 1.0f);
|
||||
RandomLauncher::fillGaussian(_rngB, &x1, 0.0f, 1.0f);
|
||||
RandomLauncher::fillGaussian(nd4j::LaunchContext::defaultContext(), _rngA, &x0, 0.0f, 1.0f);
|
||||
RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngB, &x1, 0.0f, 1.0f);
|
||||
|
||||
//x0.printIndexedBuffer("x0");
|
||||
//x1.printIndexedBuffer("x1");
|
||||
|
@ -306,7 +306,7 @@ TEST_F(RNGTests, Test_Gaussian_22) {
|
|||
TEST_F(RNGTests, Test_Gaussian_3) {
|
||||
auto x0 = NDArrayFactory::create<double>('c', {10000000});
|
||||
|
||||
RandomLauncher::fillGaussian(_rngA, &x0, 0.0, 1.0);
|
||||
RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngA, &x0, 0.0, 1.0);
|
||||
|
||||
auto mean = x0.meanNumber().e<double>(0);
|
||||
auto stdev = x0.varianceNumber(nd4j::variance::SummaryStatsStandardDeviation, false).e<double>(0);
|
||||
|
@ -319,8 +319,8 @@ TEST_F(RNGTests, Test_LogNormal_1) {
|
|||
auto x0 = NDArrayFactory::create<float>('c', {10, 10});
|
||||
auto x1 = NDArrayFactory::create<float>('c', {10, 10});
|
||||
|
||||
RandomLauncher::fillLogNormal(_rngA, &x0, 1.0f, 2.0f);
|
||||
RandomLauncher::fillLogNormal(_rngB, &x1, 1.0f, 2.0f);
|
||||
RandomLauncher::fillLogNormal(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f);
|
||||
RandomLauncher::fillLogNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f);
|
||||
|
||||
ASSERT_TRUE(x0.equalsTo(&x1));
|
||||
|
||||
|
@ -333,8 +333,8 @@ TEST_F(RNGTests, Test_Truncated_1) {
|
|||
auto x0 = NDArrayFactory::create<float>('c', {10, 10});
|
||||
auto x1 = NDArrayFactory::create<float>('c', {10, 10});
|
||||
|
||||
RandomLauncher::fillTruncatedNormal(_rngA, &x0, 1.0f, 2.0f);
|
||||
RandomLauncher::fillTruncatedNormal(_rngB, &x1, 1.0f, 2.0f);
|
||||
RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f);
|
||||
RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f);
|
||||
|
||||
ASSERT_TRUE(x0.equalsTo(&x1));
|
||||
|
||||
|
@ -357,8 +357,8 @@ TEST_F(RNGTests, Test_Truncated_2) {
|
|||
auto x0 = NDArrayFactory::create<float>('c', {1000, 1000});
|
||||
auto x1 = NDArrayFactory::create<float>('c', {1000, 1000});
|
||||
|
||||
RandomLauncher::fillTruncatedNormal(_rngA, &x0, 1.0f, 2.0f);
|
||||
RandomLauncher::fillTruncatedNormal(_rngB, &x1, 1.0f, 2.0f);
|
||||
RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f);
|
||||
RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f);
|
||||
|
||||
ASSERT_TRUE(x0.equalsTo(&x1));
|
||||
|
||||
|
@ -383,8 +383,8 @@ TEST_F(RNGTests, Test_Truncated_21) {
|
|||
auto x0 = NDArrayFactory::create<float>('c', {1000, 1000});
|
||||
auto x1 = NDArrayFactory::create<float>('c', {1000, 1000});
|
||||
|
||||
RandomLauncher::fillTruncatedNormal(_rngA, &x0, 1.0f, 2.0f);
|
||||
RandomLauncher::fillTruncatedNormal(_rngB, &x1, 1.0f, 2.0f);
|
||||
RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f);
|
||||
RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f);
|
||||
|
||||
ASSERT_TRUE(x0.equalsTo(&x1));
|
||||
|
||||
|
@ -430,8 +430,8 @@ TEST_F(RNGTests, Test_Truncated_22) {
|
|||
auto x0 = NDArrayFactory::create<float>('c', {1000, 1000});
|
||||
auto x1 = NDArrayFactory::create<float>('c', {1000, 1000});
|
||||
|
||||
RandomLauncher::fillTruncatedNormal(_rngA, &x0, 2.0f, 4.0f);
|
||||
RandomLauncher::fillTruncatedNormal(_rngB, &x1, 2.0f, 4.0f);
|
||||
RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 2.0f, 4.0f);
|
||||
RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 2.0f, 4.0f);
|
||||
|
||||
ASSERT_TRUE(x0.equalsTo(&x1));
|
||||
|
||||
|
@ -477,8 +477,8 @@ TEST_F(RNGTests, Test_Truncated_23) {
|
|||
auto x0 = NDArrayFactory::create<float>('c', {1000, 1000});
|
||||
auto x1 = NDArrayFactory::create<float>('c', {1000, 1000});
|
||||
|
||||
RandomLauncher::fillTruncatedNormal(_rngA, &x0, 0.0f, 1.0f);
|
||||
RandomLauncher::fillTruncatedNormal(_rngB, &x1, 0.0f, 1.0f);
|
||||
RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 0.0f, 1.0f);
|
||||
RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 0.0f, 1.0f);
|
||||
|
||||
ASSERT_TRUE(x0.equalsTo(&x1));
|
||||
|
||||
|
@ -524,8 +524,8 @@ TEST_F(RNGTests, Test_Truncated_3) {
|
|||
auto x0 = NDArrayFactory::create<float>('c', {10000, 1000});
|
||||
auto x1 = NDArrayFactory::create<float>('c', {10000, 1000});
|
||||
|
||||
RandomLauncher::fillTruncatedNormal(_rngA, &x0, 1.0f, 2.0f);
|
||||
RandomLauncher::fillTruncatedNormal(_rngB, &x1, 1.0f, 2.0f);
|
||||
RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f);
|
||||
RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f);
|
||||
|
||||
ASSERT_TRUE(x0.equalsTo(&x1));
|
||||
|
||||
|
@ -964,7 +964,7 @@ TEST_F(RNGTests, Test_Reproducibility_2) {
|
|||
TEST_F(RNGTests, Test_Uniform_4) {
|
||||
auto x1 = NDArrayFactory::create<double>('c', {1000000});
|
||||
|
||||
RandomLauncher::fillUniform(_rngB, &x1, 1.0, 2.0);
|
||||
RandomLauncher::fillUniform(LaunchContext::defaultContext(), _rngB, &x1, 1.0, 2.0);
|
||||
|
||||
/* Check up distribution */
|
||||
auto mean = x1.reduceNumber(reduce::Mean);
|
||||
|
|
|
@ -69,6 +69,24 @@ TEST_F(SortCudaTests, test_linear_sort_by_val_1) {
|
|||
ASSERT_EQ(ev, v);
|
||||
}
|
||||
|
||||
TEST_F(SortCudaTests, test_linear_sort_by_val_2) {
|
||||
auto k = NDArrayFactory::create<int>('c', {6}, {0, 1, 2, 3, 4, 5});
|
||||
// auto v = NDArrayFactory::create<double>('c', {6}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5});
|
||||
NDArray v = NDArrayFactory::create<double>('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f});
|
||||
auto ek = NDArrayFactory::create<int>('c', {6}, {3, 0, 1, 2, 4, 5});
|
||||
auto ev = NDArrayFactory::create<double>('c', {6}, {0.95, 0.9, 0.75, 0.6, 0.5, 0.3});
|
||||
|
||||
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
|
||||
|
||||
NativeOps nativeOps;
|
||||
nativeOps.sortByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), true);
|
||||
k.tickWriteDevice();
|
||||
v.tickWriteDevice();
|
||||
k.printIndexedBuffer("KEYS");
|
||||
ASSERT_EQ(ek, k);
|
||||
ASSERT_EQ(ev, v);
|
||||
}
|
||||
|
||||
TEST_F(SortCudaTests, test_tad_sort_by_key_1) {
|
||||
auto k = NDArrayFactory::create<Nd4jLong>('c', {2, 10}, {1, 3, 5, 9, 0, 2, 4, 6, 7, 8, 1, 3, 5, 9, 0, 2, 4, 6, 7, 8});
|
||||
auto v = NDArrayFactory::create<double>('c', {2, 10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5, 1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5});
|
||||
|
|
Loading…
Reference in New Issue