[WIP] More of CUDA operations (#69)
* initial commit Signed-off-by: raver119 <raver119@gmail.com> * - gruCell_bp further Signed-off-by: Yurii <yurii@skymind.io> * - further work on gruCell_bp Signed-off-by: Yurii <yurii@skymind.io> * Inverse matrix cublas implementation. Partial working revision. * Separation of segment ops helpers. Max separation. * Separated segment_min ops. * Separation of segment_mean/sum/prod/sqrtN ops heleprs. * Fixed diagonal processing with LUP decomposition. * Modified inversion approach using current state of LU decomposition. * Implementation of matrix_inverse op with cuda kernels. Working revision. * Implemented sequence_mask cuda helper. Eliminated waste printf with matrix_inverse implementation. Added proper tests. * - further work on gruCell_bp (ff/cuda) Signed-off-by: Yurii <yurii@skymind.io> * comment one test for gruCell_bp Signed-off-by: Yurii <yurii@skymind.io> * - provide cuda static_rnn Signed-off-by: Yurii <yurii@skymind.io> * Refactored random_shuffle op to use new random generator. * Refactored random_shuffle op helper. * Fixed debug tests with random ops tests. * Implement random_shuffle op cuda kernel helper and tests. * - provide cuda scatter_update Signed-off-by: Yurii <yurii@skymind.io> * Implementation of random_shuffle for linear case with cuda kernels and tests. * Implemented random_shuffle with cuda kernels. Final revision. * - finally gruCell_bp is completed Signed-off-by: Yurii <yurii@skymind.io> * Dropout op cuda helper implementation. * Implemented dropout_bp cuda helper. * Implemented alpha_dropout_bp with cuda kernel helpers. * Refactored helper. * Implementation of suppresion helper with cuda kernels. * - provide cpu code fot hsvToRgb, rgbToHsv, adjustHue Signed-off-by: Yurii <yurii@skymind.io> * Using sort by value method. * Implementation of image.non_max_suppression op cuda-based helper. * - correcting and testing adjust_hue, adjust_saturation cpu/cuda code Signed-off-by: Yurii <yurii@skymind.io> * Added cuda device prefixes to declarations. * Implementation of hashcode op with cuda helper. Initital revision. * rnn cu impl removed Signed-off-by: raver119 <raver119@gmail.com>master
parent
06e4f5f96e
commit
763a225c6a
|
@ -155,20 +155,20 @@ namespace nd4j {
|
||||||
#ifndef __JAVACPP_HACK__
|
#ifndef __JAVACPP_HACK__
|
||||||
NDArray(std::shared_ptr<DataBuffer> buffer, const ShapeDescriptor& descriptor, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext(), const Nd4jLong offset = 0);
|
NDArray(std::shared_ptr<DataBuffer> buffer, const ShapeDescriptor& descriptor, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext(), const Nd4jLong offset = 0);
|
||||||
|
|
||||||
NDArray(std::shared_ptr<DataBuffer> buffer, const char order, const std::vector<Nd4jLong> &shape, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
NDArray(std::shared_ptr<DataBuffer> buffer, const char order, const std::vector<Nd4jLong> &shape, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* do not allocate memory, memory for array is passed from outside
|
* do not allocate memory, memory for array is passed from outside
|
||||||
*/
|
*/
|
||||||
NDArray(void *buffer, Nd4jLong* shapeInfo, nd4j::LaunchContext * context = nd4j::LaunchContext::defaultContext(), const bool isBuffAlloc = false);
|
NDArray(void *buffer, Nd4jLong* shapeInfo, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext(), const bool isBuffAlloc = false);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* do not allocate memory, memory for array is passed from outside
|
* do not allocate memory, memory for array is passed from outside
|
||||||
* we suppose the content of both (device and host) buffers is identical
|
* we suppose the content of both (device and host) buffers is identical
|
||||||
*/
|
*/
|
||||||
NDArray(void *buffer, void *bufferD, Nd4jLong* shapeInfo, nd4j::LaunchContext * context = nd4j::LaunchContext::defaultContext(), const bool isBuffAlloc = false, const bool isBuffDAlloc = false);
|
NDArray(void *buffer, void *bufferD, Nd4jLong* shapeInfo, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext(), const bool isBuffAlloc = false, const bool isBuffDAlloc = false);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* copy constructor
|
* copy constructor
|
||||||
|
@ -189,28 +189,28 @@ namespace nd4j {
|
||||||
/**
|
/**
|
||||||
* constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently
|
* constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently
|
||||||
*/
|
*/
|
||||||
NDArray(Nd4jLong* shapeInfo, const bool copyStrides = false, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
NDArray(Nd4jLong* shapeInfo, const bool copyStrides = false, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to be zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently
|
* constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to be zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently
|
||||||
* set dtype as array type
|
* set dtype as array type
|
||||||
*/
|
*/
|
||||||
NDArray(Nd4jLong* shapeInfo, const nd4j::DataType dtype, const bool copyStrides = false, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
NDArray(Nd4jLong* shapeInfo, const nd4j::DataType dtype, const bool copyStrides = false, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* this constructor creates new array using shape information contained in vector argument
|
* this constructor creates new array using shape information contained in vector argument
|
||||||
*/
|
*/
|
||||||
NDArray(const char order, const std::vector<Nd4jLong> &shape, nd4j::DataType dtype = DOUBLE, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
NDArray(const char order, const std::vector<Nd4jLong> &shape, nd4j::DataType dtype = DOUBLE, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This constructor creates new array with elements copied from data and using shape information stored in shape, elements from data will be casted to dtype
|
* This constructor creates new array with elements copied from data and using shape information stored in shape, elements from data will be casted to dtype
|
||||||
*/
|
*/
|
||||||
NDArray(const char order, const std::vector<Nd4jLong> &shape, const std::vector<double>& data, nd4j::DataType dtype = DOUBLE, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
NDArray(const char order, const std::vector<Nd4jLong> &shape, const std::vector<double>& data, nd4j::DataType dtype = DOUBLE, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext());
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* this constructor creates new array using given buffer (without memory allocating) and shape information stored in shape
|
* this constructor creates new array using given buffer (without memory allocation) and shape information stored in shape
|
||||||
*/
|
*/
|
||||||
NDArray(void *buffer, const char order, const std::vector<Nd4jLong> &shape, nd4j::DataType dtype, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext());
|
NDArray(void *buffer, const char order, const std::vector<Nd4jLong> &shape, nd4j::DataType dtype, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext(), const bool isBuffAlloc = false);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* this constructor creates new NDArray with shape matching "other" array,
|
* this constructor creates new NDArray with shape matching "other" array,
|
||||||
|
@ -221,7 +221,7 @@ namespace nd4j {
|
||||||
/**
|
/**
|
||||||
* this constructor creates scalar(and set its value = 0) or empty array depending on bool argument isScalar
|
* this constructor creates scalar(and set its value = 0) or empty array depending on bool argument isScalar
|
||||||
*/
|
*/
|
||||||
NDArray(nd4j::DataType dtype, nd4j::LaunchContext * context = nd4j::LaunchContext ::defaultContext(), const bool isScalar = true);
|
NDArray(nd4j::DataType dtype, nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext(), const bool isScalar = true);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method blocks until asynchronous operation finishes
|
* This method blocks until asynchronous operation finishes
|
||||||
|
|
|
@ -132,9 +132,8 @@ NDArray::NDArray(const NDArray *other, const bool copyStrides, nd4j::LaunchConte
|
||||||
_buffer = std::make_shared<DataBuffer>(lengthOf() * sizeOfT(), dataType(), getContext()->getWorkspace());
|
_buffer = std::make_shared<DataBuffer>(lengthOf() * sizeOfT(), dataType(), getContext()->getWorkspace());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
NDArray::NDArray(void* buffer, const char order, const std::vector<Nd4jLong> &shape, nd4j::DataType dtype, nd4j::LaunchContext * context) {
|
NDArray::NDArray(void* buffer, const char order, const std::vector<Nd4jLong> &shape, nd4j::DataType dtype, nd4j::LaunchContext * context, const bool isBuffAlloc) {
|
||||||
|
|
||||||
if (shape.empty())
|
if (shape.empty())
|
||||||
throw std::runtime_error("NDArray constructor: input shape is empty !");
|
throw std::runtime_error("NDArray constructor: input shape is empty !");
|
||||||
|
@ -148,7 +147,7 @@ NDArray::NDArray(void* buffer, const char order, const std::vector<Nd4jLong> &sh
|
||||||
|
|
||||||
setShapeInfo(ShapeDescriptor(dtype, order, shape));
|
setShapeInfo(ShapeDescriptor(dtype, order, shape));
|
||||||
|
|
||||||
_buffer = std::make_shared<DataBuffer>(buffer, lengthOf() * sizeOfT(), dataType(), true, getContext()->getWorkspace());
|
_buffer = std::make_shared<DataBuffer>(buffer, lengthOf() * sizeOfT(), dataType(), isBuffAlloc, getContext()->getWorkspace());
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
|
|
@ -1498,16 +1498,6 @@ void NativeOps::specialConcat(
|
||||||
* This method saves
|
* This method saves
|
||||||
*/
|
*/
|
||||||
nd4j::TadPack* NativeOps::tadOnlyShapeInfo(Nd4jLong *dXShapeInfo, int *dimension, int dimensionLength) {
|
nd4j::TadPack* NativeOps::tadOnlyShapeInfo(Nd4jLong *dXShapeInfo, int *dimension, int dimensionLength) {
|
||||||
/*shape::TAD tad;
|
|
||||||
tad.init(dXShapeInfo, dimension, dimensionLength);
|
|
||||||
//tad->setOutputBuffer(target);
|
|
||||||
tad.createTadOnlyShapeInfo();
|
|
||||||
tad.createOffsets();
|
|
||||||
|
|
||||||
|
|
||||||
std::memcpy(reinterpret_cast<void *>(target), tad.tadOnlyShapeInfo, shape::shapeInfoByteLength(tad.tadOnlyShapeInfo));
|
|
||||||
std::memcpy(reinterpret_cast<void *>(offsets), tad.tadOffsets, tad.numTads * sizeof(Nd4jLong));
|
|
||||||
*/
|
|
||||||
auto pack = new TadPack();
|
auto pack = new TadPack();
|
||||||
*pack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(dXShapeInfo, dimension, dimensionLength);
|
*pack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(dXShapeInfo, dimension, dimensionLength);
|
||||||
return pack;
|
return pack;
|
||||||
|
|
|
@ -45,9 +45,9 @@ namespace nd4j {
|
||||||
|
|
||||||
static ConstantTadHelper* getInstance();
|
static ConstantTadHelper* getInstance();
|
||||||
|
|
||||||
TadPack& tadForDimensions(Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape = false);
|
TadPack& tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape = false);
|
||||||
TadPack& tadForDimensions(Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape = false);
|
TadPack& tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape = false);
|
||||||
TadPack& tadForDimensions(Nd4jLong *originalShape, int dimensions, const bool keepUnitiesInShape = false);
|
TadPack& tadForDimensions(const Nd4jLong *originalShape, int dimensions, const bool keepUnitiesInShape = false);
|
||||||
TadPack& tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape = false);
|
TadPack& tadForDimensions(ShapeDescriptor &descriptor, std::vector<int> &dimensions, const bool keepUnitiesInShape = false);
|
||||||
TadPack& tadForDimensions(TadDescriptor &descriptor);
|
TadPack& tadForDimensions(TadDescriptor &descriptor);
|
||||||
};
|
};
|
||||||
|
|
|
@ -38,15 +38,15 @@ namespace nd4j {
|
||||||
return _INSTANCE;
|
return _INSTANCE;
|
||||||
}
|
}
|
||||||
|
|
||||||
TadPack& ConstantTadHelper::tadForDimensions(Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) {
|
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) {
|
||||||
return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape);
|
return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape);
|
||||||
}
|
}
|
||||||
|
|
||||||
TadPack& ConstantTadHelper::tadForDimensions(Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape) {
|
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape) {
|
||||||
return tadForDimensions(originalShape, const_cast<int *>(dimensions.data()), dimensions.size(), keepUnitiesInShape);
|
return tadForDimensions(originalShape, const_cast<int *>(dimensions.data()), dimensions.size(), keepUnitiesInShape);
|
||||||
}
|
}
|
||||||
|
|
||||||
TadPack& ConstantTadHelper::tadForDimensions(Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) {
|
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) {
|
||||||
TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, keepUnitiesInShape);
|
TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, keepUnitiesInShape);
|
||||||
return tadForDimensions(tadDescriptor);
|
return tadForDimensions(tadDescriptor);
|
||||||
}
|
}
|
||||||
|
|
|
@ -42,15 +42,15 @@ namespace nd4j {
|
||||||
return _INSTANCE;
|
return _INSTANCE;
|
||||||
}
|
}
|
||||||
|
|
||||||
TadPack& ConstantTadHelper::tadForDimensions(Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) {
|
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int dimension, const bool keepUnitiesInShape) {
|
||||||
return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape);
|
return tadForDimensions(originalShape, &dimension, 1, keepUnitiesInShape);
|
||||||
}
|
}
|
||||||
|
|
||||||
TadPack& ConstantTadHelper::tadForDimensions(Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape) {
|
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, const std::vector<int> &dimensions, const bool keepUnitiesInShape) {
|
||||||
return tadForDimensions(originalShape, const_cast<int *>(dimensions.data()), dimensions.size(), keepUnitiesInShape);
|
return tadForDimensions(originalShape, const_cast<int *>(dimensions.data()), dimensions.size(), keepUnitiesInShape);
|
||||||
}
|
}
|
||||||
|
|
||||||
TadPack& ConstantTadHelper::tadForDimensions(Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) {
|
TadPack& ConstantTadHelper::tadForDimensions(const Nd4jLong *originalShape, int* dimensions, int dimLength, const bool keepUnitiesInShape) {
|
||||||
TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, keepUnitiesInShape);
|
TadDescriptor tadDescriptor(originalShape, dimensions, dimLength, keepUnitiesInShape);
|
||||||
return tadForDimensions(tadDescriptor);
|
return tadForDimensions(tadDescriptor);
|
||||||
}
|
}
|
||||||
|
|
|
@ -58,7 +58,7 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
|
||||||
const std::vector<NDArray*>& inArrsFF = argsHolderFF.getInArrs();
|
const std::vector<NDArray*>& inArrsFF = argsHolderFF.getInArrs();
|
||||||
const std::vector<NDArray*>& inArrsBP = argsHolderBP.getInArrs();
|
const std::vector<NDArray*>& inArrsBP = argsHolderBP.getInArrs();
|
||||||
|
|
||||||
// fill input gradient arrays in accordance to type of loss function
|
// fill input gradient arrays in accordance to kind of loss function
|
||||||
fillGradArrays(loss, std::vector<NDArray*>(&inArrsBP[numInArrsFF], &inArrsBP[numInArrsFF + numInGradArrsBP]));
|
fillGradArrays(loss, std::vector<NDArray*>(&inArrsBP[numInArrsFF], &inArrsBP[numInArrsFF + numInGradArrsBP]));
|
||||||
|
|
||||||
// beck prop pass
|
// beck prop pass
|
||||||
|
|
|
@ -987,9 +987,10 @@ namespace shape {
|
||||||
// dimsToExclude - should be sorted in increasing order
|
// dimsToExclude - should be sorted in increasing order
|
||||||
ND4J_EXPORT _CUDA_HD int outerArrayIndexes(Nd4jLong* maxIdxs, const Nd4jLong minIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, const int* dimsToExclude = nullptr);
|
ND4J_EXPORT _CUDA_HD int outerArrayIndexes(Nd4jLong* maxIdxs, const Nd4jLong minIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, const int* dimsToExclude = nullptr);
|
||||||
|
|
||||||
// calculate offsets of max-array, these output offsets correspond to one minIdx index of min-array which is sub-array of max-array
|
// calculate offsets of max-array, these offsets correspond to one minIdx index of min-array which is sub-array of max-array
|
||||||
|
// maxOffsets - will contain calculated offsets of max-array, buffer for maxOffsets should be allocated beforehand
|
||||||
// dimsToExclude - should be sorted in increasing order
|
// dimsToExclude - should be sorted in increasing order
|
||||||
// memBuff - auxiliary memory buffer (size = 2 * max_rank) for coordinates and increments storing, should be passed from outside
|
// memBuff - auxiliary memory buffer (size = 2 * max_rank) for coordinates and increments storing, should be allocated beforehand
|
||||||
ND4J_EXPORT _CUDA_HD int outerArrayOffsets(Nd4jLong* maxOffsets, const Nd4jLong minIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, Nd4jLong* memBuff, const int* dimsToExclude = nullptr);
|
ND4J_EXPORT _CUDA_HD int outerArrayOffsets(Nd4jLong* maxOffsets, const Nd4jLong minIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, Nd4jLong* memBuff, const int* dimsToExclude = nullptr);
|
||||||
|
|
||||||
// calculates offsets for entities (elements or sub-arrays), shape in context of sub-array means dimensions excluded from outer array
|
// calculates offsets for entities (elements or sub-arrays), shape in context of sub-array means dimensions excluded from outer array
|
||||||
|
|
|
@ -15,7 +15,8 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// @author raver119@gmail.com
|
// @author raver119@gmail.com
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <op_boilerplate.h>
|
#include <op_boilerplate.h>
|
||||||
|
@ -28,46 +29,35 @@
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
|
||||||
DECLARE_TYPES(adjust_hue) {
|
|
||||||
getOpDescriptor()
|
|
||||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
|
||||||
->setSameMode(true);
|
|
||||||
}
|
|
||||||
|
|
||||||
CONFIGURABLE_OP_IMPL(adjust_hue, 1, 1, true, -2, -2) {
|
CONFIGURABLE_OP_IMPL(adjust_hue, 1, 1, true, 1, -2) {
|
||||||
auto input = INPUT_VARIABLE(0);
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
|
||||||
|
|
||||||
REQUIRE_TRUE(input->rankOf() == 3 || input->rankOf() == 4, 0, "AdjustHue: op expects either 3D or 4D input, but got %i instead", input->rankOf());
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
const int rank = input->rankOf();
|
||||||
|
const int dimC = block.getIArguments()->size() > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1;
|
||||||
|
const double delta = T_ARG(0);
|
||||||
|
|
||||||
|
REQUIRE_TRUE(rank >= 3, 0, "ADJUST_HUE: op expects rank of input array to be >= 3, but got %i instead", rank);
|
||||||
|
REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "ADJUST_HUE: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC));
|
||||||
|
REQUIRE_TRUE(-1. <= delta && delta <= 1., 0, "ADJUST_HUE: parameter delta must be within [-1, 1] interval, but got %f instead", delta);
|
||||||
|
|
||||||
|
NDArray deltaScalarArr = NDArrayFactory::create<double>(delta, block.launchContext());
|
||||||
|
|
||||||
|
helpers::adjustHue(block.launchContext(), input, &deltaScalarArr, output, dimC);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_TYPES(adjust_hue) {
|
||||||
|
getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||||
|
->setSameMode(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
double delta = 0;
|
|
||||||
if (block.numT() > 0)
|
|
||||||
delta = T_ARG(0);
|
|
||||||
else if (block.width() > 1) {
|
|
||||||
auto _d = INPUT_VARIABLE(1);
|
|
||||||
if (!_d->isScalar()) {
|
|
||||||
auto str = ShapeUtils::shapeAsString(_d);
|
|
||||||
REQUIRE_TRUE(_d->isScalar(), 0, "AdjustHue: delta should be scalar NDArray, but got %s instead", str.c_str());
|
|
||||||
}
|
|
||||||
delta = _d->e<double>(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
bool isNHWC = false;
|
|
||||||
if (block.numI() > 0)
|
|
||||||
isNHWC = INT_ARG(0) == 1;
|
|
||||||
|
|
||||||
int numChannels = isNHWC ? input->sizeAt(-1) : input->sizeAt(-3);
|
|
||||||
|
|
||||||
REQUIRE_TRUE(numChannels == 3, 0, "AdjustHue: this operation expects image with 3 channels (R, G, B), but got % instead", numChannels);
|
|
||||||
|
|
||||||
auto ts = NDArrayFactory::create(delta, block.launchContext());
|
|
||||||
// FIXME: delta should be NDArray scalar
|
|
||||||
helpers::_adjust_hue(block.launchContext(), input, output, &ts, isNHWC);
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -27,45 +27,33 @@
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
DECLARE_TYPES(adjust_saturation) {
|
|
||||||
getOpDescriptor()
|
|
||||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
|
||||||
->setSameMode(true);
|
|
||||||
}
|
|
||||||
|
|
||||||
CONFIGURABLE_OP_IMPL(adjust_saturation, 1, 1, true, -2, -2) {
|
CONFIGURABLE_OP_IMPL(adjust_saturation, 1, 1, true, 1, -2) {
|
||||||
auto input = INPUT_VARIABLE(0);
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
|
||||||
|
|
||||||
REQUIRE_TRUE(input->rankOf() == 3 || input->rankOf() == 4, 0, "AdjustSaturation: op expects either 3D or 4D input, but got %i instead", input->rankOf());
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
double delta = 0;
|
const int rank = input->rankOf();
|
||||||
if (block.numT() > 0)
|
const int dimC = block.getIArguments()->size() > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1;
|
||||||
delta = T_ARG(0);
|
const double factor = T_ARG(0);
|
||||||
else if (block.width() > 1) {
|
|
||||||
auto _d = INPUT_VARIABLE(1);
|
|
||||||
if (!_d->isScalar()) {
|
|
||||||
auto str = ShapeUtils::shapeAsString(_d);
|
|
||||||
REQUIRE_TRUE(_d->isScalar(), 0, "AdjustSaturation: delta should be scalar NDArray, but got %s instead", str.c_str());
|
|
||||||
}
|
|
||||||
|
|
||||||
delta = _d->e<double>(0);
|
REQUIRE_TRUE(rank >= 3, 0, "ADJUST_SATURATION: op expects rank of input array to be >= 3, but got %i instead", rank);
|
||||||
}
|
REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "ADJUST_SATURATION: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC));
|
||||||
|
|
||||||
bool isNHWC = false;
|
NDArray factorScalarArr = NDArrayFactory::create<double>(factor, block.launchContext());
|
||||||
if (block.numI() > 0)
|
|
||||||
isNHWC = INT_ARG(0) == 1;
|
|
||||||
|
|
||||||
int numChannels = isNHWC ? input->sizeAt(-1) : input->sizeAt(-3);
|
helpers::adjustSaturation(block.launchContext(), input, &factorScalarArr, output, dimC);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_TYPES(adjust_saturation) {
|
||||||
|
getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||||
|
->setSameMode(true);
|
||||||
|
}
|
||||||
|
|
||||||
REQUIRE_TRUE(numChannels == 3, 0, "AdjustSaturation: this operation expects image with 3 channels (R, G, B), but got % instead", numChannels);
|
|
||||||
|
|
||||||
auto ts = NDArrayFactory::create(delta, block.launchContext());
|
|
||||||
// FIXME: delta should be NDArray scalar
|
|
||||||
helpers::adjust_saturation(block.launchContext(), input, output, &ts, isNHWC);
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -26,64 +26,67 @@
|
||||||
#include <ops/declarable/generic/helpers/ScatterHelper.h>
|
#include <ops/declarable/generic/helpers/ScatterHelper.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
OP_IMPL(scatter_add, 3, 1, true) {
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
|
||||||
auto indices = INPUT_VARIABLE(1);
|
|
||||||
auto updates = INPUT_VARIABLE(2);
|
|
||||||
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
OP_IMPL(scatter_add, 3, 1, true) {
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto indices = INPUT_VARIABLE(1);
|
||||||
|
auto updates = INPUT_VARIABLE(2);
|
||||||
|
|
||||||
const bool lock = block.getBArguments()->empty() ? false : B_ARG(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
const int inRank = input->rankOf();
|
const bool lock = block.getBArguments()->empty() ? false : B_ARG(0);
|
||||||
const int indRank = indices->rankOf();
|
|
||||||
const int updRank = updates->rankOf();
|
|
||||||
const Nd4jLong indLen = indices->lengthOf();
|
|
||||||
|
|
||||||
REQUIRE_TRUE(inRank > 0, 0, "SCATTER_ADD OP: input should not be scalar !");
|
const int inRank = input->rankOf();
|
||||||
|
const int indRank = indices->rankOf();
|
||||||
|
const int updRank = updates->rankOf();
|
||||||
|
const Nd4jLong indLen = indices->lengthOf();
|
||||||
|
|
||||||
if(inRank == 1) {
|
REQUIRE_TRUE(inRank > 0, 0, "SCATTER_ADD OP: input should not be scalar !");
|
||||||
REQUIRE_TRUE(indices->isSameShape(updates), 0, "SCATTER_ADD OP: when input array has rank = 1 then indices and updates must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(indices).c_str(), ShapeUtils::shapeAsString(updates).c_str());
|
|
||||||
}
|
|
||||||
else if (inRank == updRank && indices->isVector()) {
|
|
||||||
|
|
||||||
std::vector<Nd4jLong> updShape = updates->getShapeAsVector();
|
if(inRank == 1) {
|
||||||
std::vector<Nd4jLong> inShape = input->getShapeAsVector();
|
REQUIRE_TRUE(indices->isSameShape(updates), 0, "SCATTER_ADD OP: when input array has rank = 1 then indices and updates must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(indices).c_str(), ShapeUtils::shapeAsString(updates).c_str());
|
||||||
std::vector<Nd4jLong> expectedUpdShape = {indices->lengthOf()};
|
}
|
||||||
expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end());
|
else if (inRank == updRank && indices->isVector()) {
|
||||||
|
|
||||||
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ADD OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
std::vector<Nd4jLong> updShape = updates->getShapeAsVector();
|
||||||
}
|
std::vector<Nd4jLong> inShape = input->getShapeAsVector();
|
||||||
else {
|
std::vector<Nd4jLong> expectedUpdShape = {indices->lengthOf()};
|
||||||
|
expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end());
|
||||||
|
|
||||||
REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, "SCATTER_ADD OP: wrong rank of updates array, expected is %i, but got %i instead !", indRank + inRank - 1 , updRank);
|
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ADD OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
|
||||||
std::vector<Nd4jLong> updShape = updates->getShapeAsVector();
|
REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, "SCATTER_ADD OP: wrong rank of updates array, expected is %i, but got %i instead !", indRank + inRank - 1 , updRank);
|
||||||
std::vector<Nd4jLong> inShape = input->getShapeAsVector();
|
|
||||||
std::vector<Nd4jLong> expectedUpdShape = indices->getShapeAsVector();
|
|
||||||
expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin() + Nd4jLong(1L), inShape.end());
|
|
||||||
|
|
||||||
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ADD OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
std::vector<Nd4jLong> updShape = updates->getShapeAsVector();
|
||||||
}
|
std::vector<Nd4jLong> inShape = input->getShapeAsVector();
|
||||||
|
std::vector<Nd4jLong> expectedUpdShape = indices->getShapeAsVector();
|
||||||
|
expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin() + Nd4jLong(1L), inShape.end());
|
||||||
|
|
||||||
if (!block.isInplace())
|
REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ADD OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str());
|
||||||
output->assign(input);
|
|
||||||
|
|
||||||
helpers::scatter(block.launchContext(), pairwise::Add, *indices, *updates, *output, lock);
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
DECLARE_SYN(ScatterAdd, scatter_add);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_TYPES(scatter_add) {
|
if (!block.isInplace())
|
||||||
getOpDescriptor()
|
output->assign(input);
|
||||||
->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS})
|
|
||||||
->setAllowedInputTypes(1, {ALL_INTS})
|
helpers::scatter(block.launchContext(), pairwise::Add, *indices, *updates, *output, lock);
|
||||||
->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS})
|
|
||||||
->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS});
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DECLARE_SYN(ScatterAdd, scatter_add);
|
||||||
|
|
||||||
|
DECLARE_TYPES(scatter_add) {
|
||||||
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS})
|
||||||
|
->setAllowedInputTypes(1, {ALL_INTS})
|
||||||
|
->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS})
|
||||||
|
->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS});
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
|
@ -57,16 +57,26 @@ namespace nd4j {
|
||||||
auto in = inputShape->at(0);
|
auto in = inputShape->at(0);
|
||||||
int outRank = shape::rank(in) + 1;
|
int outRank = shape::rank(in) + 1;
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto dtype = DataType::BOOL;
|
||||||
Nd4jLong maxInd = input->argMax();
|
Nd4jLong maxInd = input->argMax();
|
||||||
float max = input->e<float>(maxInd);
|
Nd4jLong max = input->e<Nd4jLong>(maxInd);
|
||||||
|
|
||||||
if (block.getIArguments()->size() > 0) {
|
if (block.getIArguments()->size() > 0) {
|
||||||
|
if (block.width() < 2) {
|
||||||
maxInd = INT_ARG(0);
|
maxInd = INT_ARG(0);
|
||||||
if (maxInd < max)
|
if (maxInd < max)
|
||||||
maxInd = static_cast<Nd4jLong>(max);
|
maxInd = static_cast<Nd4jLong>(max);
|
||||||
|
if (block.getIArguments()->size() > 1)
|
||||||
|
dtype = (DataType)INT_ARG(1);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
dtype = (DataType)INT_ARG(0);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
else if (block.width() > 1) {
|
|
||||||
|
if (block.width() > 1) {
|
||||||
auto maxlen = INPUT_VARIABLE(1);
|
auto maxlen = INPUT_VARIABLE(1);
|
||||||
float tmaxlen = maxlen->e<float>(0);
|
Nd4jLong tmaxlen = maxlen->e<Nd4jLong>(0);
|
||||||
if (tmaxlen > max)
|
if (tmaxlen > max)
|
||||||
maxInd = static_cast<Nd4jLong>(tmaxlen);
|
maxInd = static_cast<Nd4jLong>(tmaxlen);
|
||||||
}
|
}
|
||||||
|
@ -80,14 +90,14 @@ namespace nd4j {
|
||||||
outShapeInfo[i + 1] = shape::sizeAt(in, i);
|
outShapeInfo[i + 1] = shape::sizeAt(in, i);
|
||||||
outShapeInfo[outRank] = lastDimension;
|
outShapeInfo[outRank] = lastDimension;
|
||||||
|
|
||||||
ShapeUtils::updateStridesAndType(outShapeInfo, in, shape::order(in));
|
ShapeUtils::updateStridesAndType(outShapeInfo, dtype, shape::order(in));
|
||||||
|
|
||||||
return SHAPELIST(CONSTANT(outShapeInfo));
|
return SHAPELIST(CONSTANT(outShapeInfo));
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_TYPES(sequence_mask) {
|
DECLARE_TYPES(sequence_mask) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
->setAllowedInputTypes({ALL_INTS})
|
||||||
->setAllowedOutputTypes(nd4j::DataType::ANY);
|
->setAllowedOutputTypes(nd4j::DataType::ANY);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,11 +33,11 @@ OP_IMPL(random_shuffle, 1, 1, true) {
|
||||||
const bool isInplace = block.isInplace();
|
const bool isInplace = block.isInplace();
|
||||||
auto output = isInplace ? nullptr : OUTPUT_VARIABLE(0);
|
auto output = isInplace ? nullptr : OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
nd4j::random::RandomBuffer* rng = block.getRNG();
|
// nd4j::random::RandomBuffer* rng = block.getRNG();
|
||||||
|
nd4j::graph::RandomGenerator rng = block.randomGenerator();
|
||||||
|
// REQUIRE_TRUE(rng != nullptr, 0, "RANDOM_SHUFFLE op: RNG should be defined in Graph !");
|
||||||
|
|
||||||
REQUIRE_TRUE(rng != nullptr, 0, "RANDOM_SHUFFLE op: RNG should be defined in Graph !");
|
helpers::randomShuffle(block.launchContext(), *input, *output, rng, isInplace);
|
||||||
|
|
||||||
helpers::randomShuffle(block.launchContext(), *input, *output, *rng, isInplace);
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,17 +31,18 @@ namespace ops {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
CUSTOM_OP_IMPL(gruCell, 6, 4, false, 0, 0) {
|
CUSTOM_OP_IMPL(gruCell, 6, 4, false, 0, 0) {
|
||||||
auto x = INPUT_VARIABLE(0); // input [bS, nIn], nIn - input size
|
|
||||||
auto hLast = INPUT_VARIABLE(1); // previous cell output [bS, nU], that is at previous time step t-1, nU - number of units
|
|
||||||
auto Wru = INPUT_VARIABLE(2); // RU weights - [nIn+nU, 2*nU] - reset and update gates (input/recurrent weights)
|
|
||||||
auto Wc = INPUT_VARIABLE(3); // C weights - [nIn+nU, nU] - cell gate (input/recurrent weights)
|
|
||||||
auto bru = INPUT_VARIABLE(4); // reset and update biases, [2*nU] - reset and update gates
|
|
||||||
auto bc = INPUT_VARIABLE(5); // cell biases, [nU]
|
|
||||||
|
|
||||||
auto r = OUTPUT_VARIABLE(0); // Reset gate output [bS, nU]
|
auto x = INPUT_VARIABLE(0); // input [bS, nIn], nIn - input size
|
||||||
auto u = OUTPUT_VARIABLE(1); // Update gate output [bS, nU]
|
auto hLast = INPUT_VARIABLE(1); // previous cell output [bS, nU], that is at previous time step t-1, nU - number of units
|
||||||
auto c = OUTPUT_VARIABLE(2); // Cell gate output [bS, nU]
|
auto Wru = INPUT_VARIABLE(2); // RU weights - [nIn+nU, 2*nU] - reset and update gates (input/recurrent weights)
|
||||||
auto h = OUTPUT_VARIABLE(3); // current cell output [bS, nU]
|
auto Wc = INPUT_VARIABLE(3); // C weights - [nIn+nU, nU] - cell gate (input/recurrent weights)
|
||||||
|
auto bru = INPUT_VARIABLE(4); // reset and update biases, [2*nU] - reset and update gates
|
||||||
|
auto bc = INPUT_VARIABLE(5); // cell biases, [nU]
|
||||||
|
|
||||||
|
auto r = OUTPUT_VARIABLE(0); // Reset gate output [bS, nU]
|
||||||
|
auto u = OUTPUT_VARIABLE(1); // Update gate output [bS, nU]
|
||||||
|
auto c = OUTPUT_VARIABLE(2); // Cell gate output [bS, nU]
|
||||||
|
auto h = OUTPUT_VARIABLE(3); // current cell output [bS, nU]
|
||||||
|
|
||||||
REQUIRE_TRUE(x->rankOf()==2 && hLast->rankOf()==2, 0, "gruCell: Input ranks must be 2 for inputs 0 and 1 (x, hLast) - got %i, %i", x->rankOf(), hLast->rankOf());
|
REQUIRE_TRUE(x->rankOf()==2 && hLast->rankOf()==2, 0, "gruCell: Input ranks must be 2 for inputs 0 and 1 (x, hLast) - got %i, %i", x->rankOf(), hLast->rankOf());
|
||||||
|
|
||||||
|
@ -118,65 +119,58 @@ DECLARE_SHAPE_FN(gruCell) {
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
CUSTOM_OP_IMPL(gruCell_bp, 6, 5, false, 0, 0) {
|
CUSTOM_OP_IMPL(gruCell_bp, 10, 6, false, 0, 0) {
|
||||||
|
|
||||||
auto x = INPUT_VARIABLE(0); // input [bS x iS]
|
auto x = INPUT_VARIABLE(0); // input [bS x iS]
|
||||||
auto hi = INPUT_VARIABLE(1); // previous cell output [bS x nU]
|
auto hi = INPUT_VARIABLE(1); // previous cell output [bS x nU]
|
||||||
auto Wx = INPUT_VARIABLE(2); // input-to-hidden weights, [iS x 3*nU]
|
auto W = INPUT_VARIABLE(2); // weights, [iS+nU x 2*nU]
|
||||||
auto Wh = INPUT_VARIABLE(3); // hidden-to-hidden weights, [nU x 3*nU]
|
auto Wc = INPUT_VARIABLE(3); // c weights, [iS+nU x nU]
|
||||||
auto b = INPUT_VARIABLE(4); // biases, [3*nU]
|
auto b = INPUT_VARIABLE(4); // biases, [2*nU]
|
||||||
auto dLdh = INPUT_VARIABLE(5); // gradient wrt output, [bS,nU], that is epsilon_next
|
auto bc = INPUT_VARIABLE(5); // biases, [nU]
|
||||||
auto dLdWxi = block.width() > 6 ? INPUT_VARIABLE(6) : nullptr; // gradient wrt Wx at previous time step, [iS, 3*nU]
|
auto dLdr = INPUT_VARIABLE(6); // gradient wrt reset gate, [bS, nU]
|
||||||
auto dLdWhi = block.width() > 7 ? INPUT_VARIABLE(7) : nullptr; // gradient wrt Wh at previous time step, [nU, 3*nU]
|
auto dLdu = INPUT_VARIABLE(7); // gradient wrt update gate, [bS, nU]
|
||||||
auto dLdbi = block.width() > 8 ? INPUT_VARIABLE(8) : nullptr; // gradient wrt b at previous time step, [3*nU]
|
auto dLdc = INPUT_VARIABLE(8); // gradient wrt cell state, [bS, nU]
|
||||||
|
auto dLdh = INPUT_VARIABLE(9); // gradient wrt current cell output, [bS, nU]
|
||||||
|
|
||||||
auto dLdx = OUTPUT_VARIABLE(0); // gradient wrt x, [bS, iS], that is epsilon
|
auto dLdx = OUTPUT_VARIABLE(0); // gradient wrt x, [bS, iS]
|
||||||
auto dLdhi = OUTPUT_VARIABLE(1); // gradient wrt hi, [bS, nU]
|
auto dLdhi = OUTPUT_VARIABLE(1); // gradient wrt hi, [bS, nU]
|
||||||
auto dLdWx = OUTPUT_VARIABLE(2); // gradient wrt Wx, [iS, 3*nU]
|
auto dLdW = OUTPUT_VARIABLE(2); // gradient wrt W, [iS+nU x 2*nU]
|
||||||
auto dLdWh = OUTPUT_VARIABLE(3); // gradient wrt Wh, [nU, 3*nU]
|
auto dLdWc = OUTPUT_VARIABLE(3); // gradient wrt Wc, [iS+nU x nU]
|
||||||
auto dLdb = OUTPUT_VARIABLE(4); // gradient wrt biases, [3*nU]
|
auto dLdb = OUTPUT_VARIABLE(4); // gradient wrt biases, [2*nU]
|
||||||
|
auto dLdbc = OUTPUT_VARIABLE(5); // gradient wrt c biases, [nU]
|
||||||
|
|
||||||
const int rank = x->rankOf(); // = 2
|
const Nd4jLong bS = x->sizeAt(0);
|
||||||
const Nd4jLong bS = x->sizeAt(0);
|
const Nd4jLong iS = x->sizeAt(1);
|
||||||
const Nd4jLong iS = x->sizeAt(1);
|
const Nd4jLong nU = hi->sizeAt(1);
|
||||||
const Nd4jLong nU = hi->sizeAt(1);
|
|
||||||
|
|
||||||
const std::string hiShape = ShapeUtils::shapeAsString(hi);
|
REQUIRE_TRUE(x->rankOf() == 2, 0, "GRU_CELL_BP: rank of input array x must be 2, but got %i instead", x->rankOf());
|
||||||
const std::string hiCorrectShape = ShapeUtils::shapeAsString({bS, nU});
|
|
||||||
const std::string wxShape = ShapeUtils::shapeAsString(Wx);
|
|
||||||
const std::string wxCorrectShape = ShapeUtils::shapeAsString({iS, 3*nU});
|
|
||||||
const std::string whShape = ShapeUtils::shapeAsString(Wh);
|
|
||||||
const std::string whCorrectShape = ShapeUtils::shapeAsString({nU, 3*nU});
|
|
||||||
const std::string bShape = ShapeUtils::shapeAsString(b);
|
|
||||||
const std::string bCorrectShape = ShapeUtils::shapeAsString({3*nU});
|
|
||||||
const std::string dLdhShape = ShapeUtils::shapeAsString(dLdh);
|
|
||||||
const std::string dLdhCorrectShape = ShapeUtils::shapeAsString({bS, nU});
|
|
||||||
|
|
||||||
REQUIRE_TRUE(hiShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of previous cell output array, expected is %s, but got %s instead !", hiCorrectShape.c_str(), hiShape.c_str());
|
const std::string hiShape = ShapeUtils::shapeAsString(hi);
|
||||||
REQUIRE_TRUE(wxShape == wxCorrectShape, 0, "GRU_CELL_BP op: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", wxCorrectShape.c_str(), wxShape.c_str());
|
const std::string hiCorrectShape = ShapeUtils::shapeAsString({bS, nU});
|
||||||
REQUIRE_TRUE(whShape == whCorrectShape, 0, "GRU_CELL_BP op: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", whCorrectShape.c_str(), whShape.c_str());
|
const std::string wShape = ShapeUtils::shapeAsString(W);
|
||||||
REQUIRE_TRUE(bShape == bCorrectShape, 0, "GRU_CELL_BP op: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
|
const std::string wCorrectShape = ShapeUtils::shapeAsString({iS+nU, 2*nU});
|
||||||
REQUIRE_TRUE(dLdhShape == dLdhCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdh array (epsilon_next), expected is %s, but got %s instead !", dLdhCorrectShape.c_str(), dLdhShape.c_str());
|
const std::string wcShape = ShapeUtils::shapeAsString(Wc);
|
||||||
|
const std::string wcCorrectShape = ShapeUtils::shapeAsString({iS+nU, nU});
|
||||||
|
const std::string bShape = ShapeUtils::shapeAsString(b);
|
||||||
|
const std::string bCorrectShape = ShapeUtils::shapeAsString({2*nU});
|
||||||
|
const std::string bcShape = ShapeUtils::shapeAsString(bc);
|
||||||
|
const std::string bcCorrectShape = ShapeUtils::shapeAsString({nU});
|
||||||
|
const std::string dLdrShape = ShapeUtils::shapeAsString(dLdr);
|
||||||
|
const std::string dLduShape = ShapeUtils::shapeAsString(dLdu);
|
||||||
|
const std::string dLdcShape = ShapeUtils::shapeAsString(dLdc);
|
||||||
|
const std::string dLdhShape = ShapeUtils::shapeAsString(dLdh);
|
||||||
|
|
||||||
if(dLdWxi != nullptr) {
|
REQUIRE_TRUE(hiShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of previous cell output array, expected is %s, but got %s instead !", hiCorrectShape.c_str(), hiShape.c_str());
|
||||||
const std::string dLdWxiShape = ShapeUtils::shapeAsString(dLdWxi);
|
REQUIRE_TRUE(wShape == wCorrectShape, 0, "GRU_CELL_BP op: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
|
||||||
const std::string dLdWxiCorrectShape = ShapeUtils::shapeAsString({iS, 3*nU});
|
REQUIRE_TRUE(wcShape == wcCorrectShape, 0, "GRU_CELL_BP op: wrong shape of c weights array, expected is %s, but got %s instead !", wcCorrectShape.c_str(), wcShape.c_str());
|
||||||
REQUIRE_TRUE(dLdWxiShape == dLdWxiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdWxi array (gradient wrt Wx at previous time step), expected is %s, but got %s instead !", dLdWxiCorrectShape.c_str(), dLdWxiShape.c_str());
|
REQUIRE_TRUE(bShape == bCorrectShape, 0, "GRU_CELL_BP op: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
|
||||||
}
|
REQUIRE_TRUE(bcShape == bcCorrectShape, 0, "GRU_CELL_BP op: wrong shape of c biases array, expected is %s, but got %s instead !", bcCorrectShape.c_str(), bcShape.c_str());
|
||||||
|
REQUIRE_TRUE(dLdrShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdr array (gradient wrt reset gate), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLdrShape.c_str());
|
||||||
|
REQUIRE_TRUE(dLduShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdu array (gradient wrt update gate), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLduShape.c_str());
|
||||||
|
REQUIRE_TRUE(dLdcShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdc array (gradient wrt cell state), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLdcShape.c_str());
|
||||||
|
REQUIRE_TRUE(dLdhShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdh array (gradient wrt current cell output), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLdhShape.c_str());
|
||||||
|
|
||||||
if(dLdWhi != nullptr) {
|
helpers::gruCellBP(block.launchContext(), x, hi, W, Wc, b, bc, dLdr, dLdu, dLdc, dLdh, dLdx, dLdhi, dLdW, dLdWc, dLdb, dLdbc);
|
||||||
const std::string dLdWhiShape = ShapeUtils::shapeAsString(dLdWhi);
|
|
||||||
const std::string dLdWhiCorrectShape = ShapeUtils::shapeAsString({nU, 3*nU});
|
|
||||||
REQUIRE_TRUE(dLdWhiShape == dLdWhiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdWhi array (gradient wrt Wh at previous time step), expected is %s, but got %s instead !", dLdWhiCorrectShape.c_str(), dLdWhiShape.c_str());
|
|
||||||
}
|
|
||||||
|
|
||||||
if(dLdbi != nullptr) {
|
|
||||||
const std::string dLdbiShape = ShapeUtils::shapeAsString(dLdbi);
|
|
||||||
const std::string dLdbiCorrectShape = ShapeUtils::shapeAsString({3*nU});
|
|
||||||
REQUIRE_TRUE(dLdbiShape == dLdbiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdbi array (gradient wrt biases at previous time step), expected is %s, but got %s instead !", dLdbiCorrectShape.c_str(), dLdbiShape.c_str());
|
|
||||||
}
|
|
||||||
|
|
||||||
helpers::gruCellBP(block.launchContext(), x, hi, Wx, Wh, b, dLdh, dLdWxi, dLdWhi, dLdbi, dLdx, dLdhi, dLdWx, dLdWh, dLdb);
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -192,60 +186,54 @@ DECLARE_TYPES(gruCell_bp) {
|
||||||
->setAllowedInputTypes(6, {ALL_FLOATS})
|
->setAllowedInputTypes(6, {ALL_FLOATS})
|
||||||
->setAllowedInputTypes(7, {ALL_FLOATS})
|
->setAllowedInputTypes(7, {ALL_FLOATS})
|
||||||
->setAllowedInputTypes(8, {ALL_FLOATS})
|
->setAllowedInputTypes(8, {ALL_FLOATS})
|
||||||
|
->setAllowedInputTypes(9, {ALL_FLOATS})
|
||||||
->setAllowedOutputTypes({ALL_FLOATS});
|
->setAllowedOutputTypes({ALL_FLOATS});
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(gruCell_bp) {
|
DECLARE_SHAPE_FN(gruCell_bp) {
|
||||||
|
|
||||||
auto xShapeInfo = inputShape->at(0); // [bS x iS]
|
auto xShapeInfo = inputShape->at(0); // [bS x iS]
|
||||||
auto hiShapeInfo = inputShape->at(1); // [bS x nU]
|
auto hiShapeInfo = inputShape->at(1); // [bS x nU]
|
||||||
auto wxShapeInfo = inputShape->at(2); // [iS x 3*nU]
|
auto wShapeInfo = inputShape->at(2); // [iS+nU x 2*nU]
|
||||||
auto whShapeInfo = inputShape->at(3); // [nU x 3*nU]
|
auto wcShapeInfo = inputShape->at(3); // [iS+nU x nU]
|
||||||
auto bShapeInfo = inputShape->at(4); // [3*nU]
|
auto bShapeInfo = inputShape->at(4); // [2*nU]
|
||||||
auto dLdhShapeInfo = inputShape->at(5); // [bS x nU]
|
auto bcShapeInfo = inputShape->at(5); // [nU]
|
||||||
|
auto dLdrShapeInfo = inputShape->at(6); // [bS, nU]
|
||||||
|
auto dLduShapeInfo = inputShape->at(7); // [bS, nU]
|
||||||
|
auto dLdcShapeInfo = inputShape->at(8); // [bS, nU]
|
||||||
|
auto dLdhShapeInfo = inputShape->at(9); // [bS, nU]
|
||||||
|
|
||||||
const int rank = xShapeInfo[0]; // = 2
|
const int rank = xShapeInfo[0]; // = 2
|
||||||
const Nd4jLong bS = xShapeInfo[1];
|
const Nd4jLong bS = xShapeInfo[1];
|
||||||
const Nd4jLong iS = xShapeInfo[2];
|
const Nd4jLong iS = xShapeInfo[2];
|
||||||
const Nd4jLong nU = hiShapeInfo[2];
|
const Nd4jLong nU = hiShapeInfo[2];
|
||||||
|
|
||||||
const std::string hiShape = ShapeUtils::shapeAsString(hiShapeInfo);
|
REQUIRE_TRUE(xShapeInfo[0] == 2, 0, "GRU_CELL_BP: rank of input array x must be 2, but got %i instead", xShapeInfo[0]);
|
||||||
const std::string hiCorrectShape = ShapeUtils::shapeAsString({bS, nU});
|
|
||||||
const std::string wxShape = ShapeUtils::shapeAsString(wxShapeInfo);
|
|
||||||
const std::string wxCorrectShape = ShapeUtils::shapeAsString({iS, 3*nU});
|
|
||||||
const std::string whShape = ShapeUtils::shapeAsString(whShapeInfo);
|
|
||||||
const std::string whCorrectShape = ShapeUtils::shapeAsString({nU, 3*nU});
|
|
||||||
const std::string bShape = ShapeUtils::shapeAsString(bShapeInfo);
|
|
||||||
const std::string bCorrectShape = ShapeUtils::shapeAsString({3*nU});
|
|
||||||
const std::string dLdhShape = ShapeUtils::shapeAsString(dLdhShapeInfo);
|
|
||||||
const std::string dLdhCorrectShape = ShapeUtils::shapeAsString({bS, nU});
|
|
||||||
|
|
||||||
REQUIRE_TRUE(hiShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of previous cell output array, expected is %s, but got %s instead !", hiCorrectShape.c_str(), hiShape.c_str());
|
const std::string hiShape = ShapeUtils::shapeAsString(hiShapeInfo);
|
||||||
REQUIRE_TRUE(wxShape == wxCorrectShape, 0, "GRU_CELL_BP op: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", wxCorrectShape.c_str(), wxShape.c_str());
|
const std::string hiCorrectShape = ShapeUtils::shapeAsString({bS, nU});
|
||||||
REQUIRE_TRUE(whShape == whCorrectShape, 0, "GRU_CELL_BP op: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", whCorrectShape.c_str(), whShape.c_str());
|
const std::string wShape = ShapeUtils::shapeAsString(wShapeInfo);
|
||||||
REQUIRE_TRUE(bShape == bCorrectShape, 0, "GRU_CELL_BP op: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
|
const std::string wCorrectShape = ShapeUtils::shapeAsString({iS+nU, 2*nU});
|
||||||
REQUIRE_TRUE(dLdhShape == dLdhCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdh array (epsilon_next), expected is %s, but got %s instead !", dLdhCorrectShape.c_str(), dLdhShape.c_str());
|
const std::string wcShape = ShapeUtils::shapeAsString(wcShapeInfo);
|
||||||
|
const std::string wcCorrectShape = ShapeUtils::shapeAsString({iS+nU, nU});
|
||||||
|
const std::string bShape = ShapeUtils::shapeAsString(bShapeInfo);
|
||||||
|
const std::string bCorrectShape = ShapeUtils::shapeAsString({2*nU});
|
||||||
|
const std::string bcShape = ShapeUtils::shapeAsString(bcShapeInfo);
|
||||||
|
const std::string bcCorrectShape = ShapeUtils::shapeAsString({nU});
|
||||||
|
const std::string dLdrShape = ShapeUtils::shapeAsString(dLdrShapeInfo);
|
||||||
|
const std::string dLduShape = ShapeUtils::shapeAsString(dLduShapeInfo);
|
||||||
|
const std::string dLdcShape = ShapeUtils::shapeAsString(dLdcShapeInfo);
|
||||||
|
const std::string dLdhShape = ShapeUtils::shapeAsString(dLdhShapeInfo);
|
||||||
|
|
||||||
if(block.width() > 6) {
|
REQUIRE_TRUE(hiShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of previous cell output array, expected is %s, but got %s instead !", hiCorrectShape.c_str(), hiShape.c_str());
|
||||||
Nd4jLong* dLdWxiShapeInfo = inputShape->at(6); // [iS x 3*nU]
|
REQUIRE_TRUE(wShape == wCorrectShape, 0, "GRU_CELL_BP op: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
|
||||||
const std::string dLdWxiShape = ShapeUtils::shapeAsString(dLdWxiShapeInfo);
|
REQUIRE_TRUE(wcShape == wcCorrectShape, 0, "GRU_CELL_BP op: wrong shape of c weights array, expected is %s, but got %s instead !", wcCorrectShape.c_str(), wcShape.c_str());
|
||||||
const std::string dLdWxiCorrectShape = ShapeUtils::shapeAsString({iS, 3*nU});
|
REQUIRE_TRUE(bShape == bCorrectShape, 0, "GRU_CELL_BP op: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
|
||||||
REQUIRE_TRUE(dLdWxiShape == dLdWxiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdWxi array (gradient wrt Wx at previous time step), expected is %s, but got %s instead !", dLdWxiCorrectShape.c_str(), dLdWxiShape.c_str());
|
REQUIRE_TRUE(bcShape == bcCorrectShape, 0, "GRU_CELL_BP op: wrong shape of c biases array, expected is %s, but got %s instead !", bcCorrectShape.c_str(), bcShape.c_str());
|
||||||
}
|
REQUIRE_TRUE(dLdrShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdr array (gradient wrt reset gate), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLdrShape.c_str());
|
||||||
|
REQUIRE_TRUE(dLduShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdu array (gradient wrt update gate), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLduShape.c_str());
|
||||||
if(block.width() > 7) {
|
REQUIRE_TRUE(dLdcShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdc array (gradient wrt cell state), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLdcShape.c_str());
|
||||||
Nd4jLong* dLdWhiShapeInfo = inputShape->at(7); // [nU x 3*nU]
|
REQUIRE_TRUE(dLdhShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdh array (gradient wrt current cell output), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLdhShape.c_str());
|
||||||
const std::string dLdWhiShape = ShapeUtils::shapeAsString(dLdWhiShapeInfo);
|
|
||||||
const std::string dLdWhiCorrectShape = ShapeUtils::shapeAsString({nU, 3*nU});
|
|
||||||
REQUIRE_TRUE(dLdWhiShape == dLdWhiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdWhi array (gradient wrt Wh at previous time step), expected is %s, but got %s instead !", dLdWhiCorrectShape.c_str(), dLdWhiShape.c_str());
|
|
||||||
}
|
|
||||||
|
|
||||||
if(block.width() > 8) {
|
|
||||||
Nd4jLong* dLdbiShapeInfo = inputShape->at(8); // [3*nU]
|
|
||||||
const std::string dLdbiShape = ShapeUtils::shapeAsString(dLdbiShapeInfo);
|
|
||||||
const std::string dLdbiCorrectShape = ShapeUtils::shapeAsString({3*nU});
|
|
||||||
REQUIRE_TRUE(dLdbiShape == dLdbiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdbi array (gradient wrt biases at previous time step), expected is %s, but got %s instead !", dLdbiCorrectShape.c_str(), dLdbiShape.c_str());
|
|
||||||
}
|
|
||||||
|
|
||||||
Nd4jLong *dLdxShapeInfo = nullptr;
|
Nd4jLong *dLdxShapeInfo = nullptr;
|
||||||
COPY_SHAPE(xShapeInfo, dLdxShapeInfo);
|
COPY_SHAPE(xShapeInfo, dLdxShapeInfo);
|
||||||
|
@ -253,17 +241,19 @@ DECLARE_SHAPE_FN(gruCell_bp) {
|
||||||
Nd4jLong *dLdhiShapeInfo = nullptr;
|
Nd4jLong *dLdhiShapeInfo = nullptr;
|
||||||
COPY_SHAPE(hiShapeInfo, dLdhiShapeInfo);
|
COPY_SHAPE(hiShapeInfo, dLdhiShapeInfo);
|
||||||
|
|
||||||
Nd4jLong *dLdWxShapeInfo = nullptr;
|
Nd4jLong *dLdWShapeInfo = nullptr;
|
||||||
COPY_SHAPE(wxShapeInfo, dLdWxShapeInfo);
|
COPY_SHAPE(wShapeInfo, dLdWShapeInfo);
|
||||||
|
|
||||||
Nd4jLong *dLdWhShapeInfo = nullptr;
|
Nd4jLong *dLdWcShapeInfo = nullptr;
|
||||||
COPY_SHAPE(whShapeInfo, dLdWhShapeInfo);
|
COPY_SHAPE(wcShapeInfo, dLdWcShapeInfo);
|
||||||
|
|
||||||
Nd4jLong *dLdbShapeInfo = nullptr;
|
Nd4jLong *dLdbShapeInfo = nullptr;
|
||||||
COPY_SHAPE(bShapeInfo, dLdbShapeInfo);
|
COPY_SHAPE(bShapeInfo, dLdbShapeInfo);
|
||||||
|
|
||||||
return SHAPELIST(dLdxShapeInfo, dLdhiShapeInfo, dLdWxShapeInfo, dLdWhShapeInfo, dLdbShapeInfo);
|
Nd4jLong *dLdbcShapeInfo = nullptr;
|
||||||
|
COPY_SHAPE(bcShapeInfo, dLdbcShapeInfo);
|
||||||
|
|
||||||
|
return SHAPELIST(dLdxShapeInfo, dLdhiShapeInfo, dLdWShapeInfo, dLdWcShapeInfo, dLdbShapeInfo, dLdbcShapeInfo);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -71,11 +71,11 @@ CUSTOM_OP_IMPL(static_rnn, 4, 2, false, 0, 0) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_TYPES(static_rnn) {
|
DECLARE_TYPES(static_rnn) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||||
->setAllowedOutputTypes({ALL_FLOATS});
|
->setAllowedOutputTypes({ALL_FLOATS});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(static_rnn) {
|
DECLARE_SHAPE_FN(static_rnn) {
|
||||||
|
|
|
@ -553,33 +553,31 @@ namespace nd4j {
|
||||||
/**
|
/**
|
||||||
* This operation adjusts image hue by delta
|
* This operation adjusts image hue by delta
|
||||||
* Input arrays:
|
* Input arrays:
|
||||||
* 0 - 1D or 3D input array, must have 3 channels.
|
* 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels.
|
||||||
* 1 - optional scalar, delta value
|
|
||||||
*
|
*
|
||||||
* T arguments:
|
* T arguments:
|
||||||
* 0 - optional delta value
|
* 0 - delta value
|
||||||
*
|
*
|
||||||
* Int arguments:
|
* Int arguments:
|
||||||
* 0 - optional argument, isNHWC. false by default.
|
* 0 - optional argument, corresponds to dimension with 3 channels
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_adjust_hue)
|
#if NOT_EXCLUDED(OP_adjust_hue)
|
||||||
DECLARE_CONFIGURABLE_OP(adjust_hue, 1, 1, true, -2, -2);
|
DECLARE_CONFIGURABLE_OP(adjust_hue, 1, 1, true, 1, -2);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This operation adjusts image saturation by delta
|
* This operation adjusts image saturation by delta
|
||||||
* Input arrays:
|
* Input arrays:
|
||||||
* 0 - 1D or 3D input array, must have 3 channels.
|
* 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels.
|
||||||
* 1 - optional scalar, delta value
|
|
||||||
*
|
*
|
||||||
* T arguments:
|
* T arguments:
|
||||||
* 0 - optional delta value
|
* 0 - saturation factor
|
||||||
*
|
*
|
||||||
* Int arguments:
|
* Int arguments:
|
||||||
* 0 - optional argument, isNHWC. false by default.
|
* 0 - optional argument, corresponds to dimension with 3 channels
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_adjust_saturation)
|
#if NOT_EXCLUDED(OP_adjust_saturation)
|
||||||
DECLARE_CONFIGURABLE_OP(adjust_saturation, 1, 1, true, -2, -2);
|
DECLARE_CONFIGURABLE_OP(adjust_saturation, 1, 1, true, 1, -2);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -259,8 +259,8 @@ namespace ops {
|
||||||
* Input arrays:
|
* Input arrays:
|
||||||
* 0: input with shape [batchSize x inSize], batchSize - batch size, inSize - number of features
|
* 0: input with shape [batchSize x inSize], batchSize - batch size, inSize - number of features
|
||||||
* 1: previous cell output [batchSize x numUnits], that is at previous time step t-1
|
* 1: previous cell output [batchSize x numUnits], that is at previous time step t-1
|
||||||
* 2: RU weights - [(nIn+nOut), 2*numUnits] - reset and update gates (input/recurrent weights)
|
* 2: RU weights - [(inSize+numUnits), 2*numUnits] - reset and update gates (input/recurrent weights)
|
||||||
* 3: C weights - [(nIn+nOut), numUnits] - cell gate (input/recurrent weights)
|
* 3: C weights - [(inSize+numUnits), numUnits] - cell gate (input/recurrent weights)
|
||||||
* 4: reset and update biases, [2*numUnits] - reset and update gates
|
* 4: reset and update biases, [2*numUnits] - reset and update gates
|
||||||
* 5: cell biases, [numUnits]
|
* 5: cell biases, [numUnits]
|
||||||
*
|
*
|
||||||
|
@ -275,7 +275,7 @@ namespace ops {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if NOT_EXCLUDED(OP_gruCell)
|
#if NOT_EXCLUDED(OP_gruCell)
|
||||||
DECLARE_CUSTOM_OP(gruCell_bp, 6, 5, false, 0, 0);
|
DECLARE_CUSTOM_OP(gruCell_bp, 10, 6, false, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
|
|
@ -15,120 +15,204 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// @author raver119@gmail.com
|
// @author raver119@gmail.com
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <op_boilerplate.h>
|
#include <op_boilerplate.h>
|
||||||
#include <NDArray.h>
|
#include <NDArray.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
template <typename T>
|
|
||||||
static FORCEINLINE _CUDA_HD void rgb_to_hv(T r, T g, T b, T* h, T* v_min, T* v_max) {
|
|
||||||
T v_mid;
|
void adjustHue(nd4j::LaunchContext* context, const NDArray *input, const NDArray* deltaScalarArr, NDArray *output, const int dimC);
|
||||||
int h_category;
|
|
||||||
// According to the figures in:
|
|
||||||
// https://en.wikipedia.org/wiki/HSL_and_HSV#Hue_and_chroma
|
|
||||||
// For the conditions, we don't care about the case where two components are
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
// equal. It is okay to count it in either side in that case.
|
template <typename T>
|
||||||
if (r < g) {
|
FORCEINLINE _CUDA_HD void rgbToHsv(const T& r, const T& g, const T& b, T& h, T& s, T& v) {
|
||||||
if (b < r) {
|
|
||||||
// b < r < g
|
// h values are in range [0, 360)
|
||||||
*v_max = g;
|
// s and v values are in range [0, 1]
|
||||||
v_mid = r;
|
|
||||||
*v_min = b;
|
const T max = nd4j::math::nd4j_max<T>(r, nd4j::math::nd4j_max<T>(g, b));
|
||||||
h_category = 1;
|
const T min = nd4j::math::nd4j_min<T>(r, nd4j::math::nd4j_min<T>(g, b));
|
||||||
} else if (b > g) {
|
const T c = max - min;
|
||||||
// r < g < b
|
|
||||||
*v_max = b;
|
// calculate h
|
||||||
v_mid = g;
|
if(c == 0) {
|
||||||
*v_min = r;
|
h = 0;
|
||||||
h_category = 3;
|
}
|
||||||
} else {
|
else if(max == r) {
|
||||||
// r < b < g
|
h = 60.f * ((g - b) / c) + (g >= b ? 0 : 360);
|
||||||
*v_max = g;
|
}
|
||||||
v_mid = b;
|
else if(max == g) {
|
||||||
*v_min = r;
|
h = 60.f * ((b - r) / c) + 120;
|
||||||
h_category = 2;
|
}
|
||||||
}
|
else { // max == b
|
||||||
|
h = 60.f * ((r - g) / c) + 240;
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculate s
|
||||||
|
s = max == (T)0 ? (T)0 : c / max;
|
||||||
|
|
||||||
|
// calculate v
|
||||||
|
v = max / 255.f;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename T>
|
||||||
|
FORCEINLINE _CUDA_HD void hsvToRgb(const T& h, const T& s, const T& v, T& r, T& g, T& b) {
|
||||||
|
|
||||||
|
const float sector = h / 60.f;
|
||||||
|
const T c = v * s;
|
||||||
|
|
||||||
|
if(0.f <= sector && sector < 1.f) {
|
||||||
|
r = v;
|
||||||
|
g = v - c * (1 - sector);
|
||||||
|
b = v - c;
|
||||||
|
}
|
||||||
|
else if(1.f <= sector && sector < 2.f) {
|
||||||
|
r = v - c * (sector - 1);
|
||||||
|
g = v;
|
||||||
|
b = v - c;
|
||||||
|
}
|
||||||
|
else if(2.f <= sector && sector < 3.f) {
|
||||||
|
r = v - c;
|
||||||
|
g = v;
|
||||||
|
b = v - c * (3 - sector);
|
||||||
|
}
|
||||||
|
else if(3.f <= sector && sector < 4.f) {
|
||||||
|
r = v - c;
|
||||||
|
g = v - c * (sector - 3);
|
||||||
|
b = v;
|
||||||
|
}
|
||||||
|
else if(4.f <= sector && sector < 5.f) {
|
||||||
|
r = v - c * (5 - sector);
|
||||||
|
g = v - c;
|
||||||
|
b = v;
|
||||||
|
}
|
||||||
|
else { // 5.f <= sector < 6.f
|
||||||
|
r = v;
|
||||||
|
g = v - c;
|
||||||
|
b = v - c * (sector - 5);
|
||||||
|
}
|
||||||
|
|
||||||
|
r *= 255;
|
||||||
|
g *= 255;
|
||||||
|
b *= 255;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*////////////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename T>
|
||||||
|
static FORCEINLINE _CUDA_HD void rgb_to_hv(T r, T g, T b, T* h, T* v_min, T* v_max) {
|
||||||
|
T v_mid;
|
||||||
|
int h_category;
|
||||||
|
// According to the figures in:
|
||||||
|
// https://en.wikipedia.org/wiki/HSL_and_HSV#Hue_and_chroma
|
||||||
|
// For the conditions, we don't care about the case where two components are
|
||||||
|
// equal. It is okay to count it in either side in that case.
|
||||||
|
if (r < g) {
|
||||||
|
if (b < r) {
|
||||||
|
// b < r < g
|
||||||
|
*v_max = g;
|
||||||
|
v_mid = r;
|
||||||
|
*v_min = b;
|
||||||
|
h_category = 1;
|
||||||
|
} else if (b > g) {
|
||||||
|
// r < g < b
|
||||||
|
*v_max = b;
|
||||||
|
v_mid = g;
|
||||||
|
*v_min = r;
|
||||||
|
h_category = 3;
|
||||||
} else {
|
} else {
|
||||||
// g < r
|
// r < b < g
|
||||||
if (b < g) {
|
*v_max = g;
|
||||||
// b < g < r
|
v_mid = b;
|
||||||
*v_max = r;
|
*v_min = r;
|
||||||
v_mid = g;
|
h_category = 2;
|
||||||
*v_min = b;
|
|
||||||
h_category = 0;
|
|
||||||
} else if (b > r) {
|
|
||||||
// g < r < b
|
|
||||||
*v_max = b;
|
|
||||||
v_mid = r;
|
|
||||||
*v_min = g;
|
|
||||||
h_category = 4;
|
|
||||||
} else {
|
|
||||||
// g < b < r
|
|
||||||
*v_max = r;
|
|
||||||
v_mid = b;
|
|
||||||
*v_min = g;
|
|
||||||
h_category = 5;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if (*v_max == *v_min) {
|
} else {
|
||||||
*h = 0;
|
// g < r
|
||||||
return;
|
if (b < g) {
|
||||||
}
|
// b < g < r
|
||||||
auto ratio = (v_mid - *v_min) / (*v_max - *v_min);
|
*v_max = r;
|
||||||
bool increase = ((h_category & 0x1) == 0);
|
v_mid = g;
|
||||||
*h = h_category + (increase ? ratio : (1 - ratio));
|
*v_min = b;
|
||||||
}
|
h_category = 0;
|
||||||
|
} else if (b > r) {
|
||||||
template <typename T>
|
// g < r < b
|
||||||
static FORCEINLINE _CUDA_HD void hv_to_rgb(T h, T v_min, T v_max, T* r, T* g, T* b) {
|
*v_max = b;
|
||||||
int h_category = static_cast<int>(h);
|
v_mid = r;
|
||||||
T ratio = h - (T)h_category;
|
*v_min = g;
|
||||||
bool increase = ((h_category & 0x1) == 0);
|
h_category = 4;
|
||||||
if (!increase)
|
} else {
|
||||||
ratio = 1 - ratio;
|
// g < b < r
|
||||||
|
*v_max = r;
|
||||||
T v_mid = v_min + ratio * (v_max - v_min);
|
v_mid = b;
|
||||||
// According to the figures in:
|
*v_min = g;
|
||||||
// https://en.wikipedia.org/wiki/HSL_and_HSV#Hue_and_chroma
|
h_category = 5;
|
||||||
switch (h_category) {
|
|
||||||
case 0:
|
|
||||||
*r = v_max;
|
|
||||||
*g = v_mid;
|
|
||||||
*b = v_min;
|
|
||||||
break;
|
|
||||||
case 1:
|
|
||||||
*r = v_mid;
|
|
||||||
*g = v_max;
|
|
||||||
*b = v_min;
|
|
||||||
break;
|
|
||||||
case 2:
|
|
||||||
*r = v_min;
|
|
||||||
*g = v_max;
|
|
||||||
*b = v_mid;
|
|
||||||
break;
|
|
||||||
case 3:
|
|
||||||
*r = v_min;
|
|
||||||
*g = v_mid;
|
|
||||||
*b = v_max;
|
|
||||||
break;
|
|
||||||
case 4:
|
|
||||||
*r = v_mid;
|
|
||||||
*g = v_min;
|
|
||||||
*b = v_max;
|
|
||||||
break;
|
|
||||||
case 5:
|
|
||||||
default:
|
|
||||||
*r = v_max;
|
|
||||||
*g = v_min;
|
|
||||||
*b = v_mid;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (*v_max == *v_min) {
|
||||||
|
*h = 0;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto ratio = (v_mid - *v_min) / (*v_max - *v_min);
|
||||||
|
bool increase = ((h_category & 0x1) == 0);
|
||||||
|
*h = h_category + (increase ? ratio : (1 - ratio));
|
||||||
|
}
|
||||||
|
|
||||||
void _adjust_hue(nd4j::LaunchContext * context, NDArray *input, NDArray *output, NDArray *delta, bool isNHWC);
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename T>
|
||||||
|
static FORCEINLINE _CUDA_HD void hv_to_rgb(T h, T v_min, T v_max, T* r, T* g, T* b) {
|
||||||
|
int h_category = static_cast<int>(h);
|
||||||
|
T ratio = h - (T)h_category;
|
||||||
|
bool increase = ((h_category & 0x1) == 0);
|
||||||
|
if (!increase)
|
||||||
|
ratio = 1 - ratio;
|
||||||
|
|
||||||
|
T v_mid = v_min + ratio * (v_max - v_min);
|
||||||
|
// According to the figures in:
|
||||||
|
// https://en.wikipedia.org/wiki/HSL_and_HSV#Hue_and_chroma
|
||||||
|
switch (h_category) {
|
||||||
|
case 0:
|
||||||
|
*r = v_max;
|
||||||
|
*g = v_mid;
|
||||||
|
*b = v_min;
|
||||||
|
break;
|
||||||
|
case 1:
|
||||||
|
*r = v_mid;
|
||||||
|
*g = v_max;
|
||||||
|
*b = v_min;
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
*r = v_min;
|
||||||
|
*g = v_max;
|
||||||
|
*b = v_mid;
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
*r = v_min;
|
||||||
|
*g = v_mid;
|
||||||
|
*b = v_max;
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
*r = v_mid;
|
||||||
|
*g = v_min;
|
||||||
|
*b = v_max;
|
||||||
|
break;
|
||||||
|
case 5:
|
||||||
|
default:
|
||||||
|
*r = v_max;
|
||||||
|
*g = v_min;
|
||||||
|
*b = v_mid;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
*/
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -15,16 +15,21 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// @author raver119@gmail.com
|
// @author raver119@gmail.com
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <op_boilerplate.h>
|
#include <op_boilerplate.h>
|
||||||
#include <templatemath.h>
|
#include <templatemath.h>
|
||||||
#include <NDArray.h>
|
#include <NDArray.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
|
|
||||||
|
void adjustSaturation(nd4j::LaunchContext* context, const NDArray *input, const NDArray* factorScalarArr, NDArray *output, const int dimC);
|
||||||
|
|
||||||
|
/*
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static FORCEINLINE _CUDA_HD void rgb_to_hsv(T r, T g, T b, T* h, T* s, T* v) {
|
static FORCEINLINE _CUDA_HD void rgb_to_hsv(T r, T g, T b, T* h, T* s, T* v) {
|
||||||
T vv = nd4j::math::nd4j_max<T>(r, nd4j::math::nd4j_max<T>(g, b));
|
T vv = nd4j::math::nd4j_max<T>(r, nd4j::math::nd4j_max<T>(g, b));
|
||||||
|
@ -109,8 +114,8 @@ namespace helpers {
|
||||||
*g = gg + m;
|
*g = gg + m;
|
||||||
*b = bb + m;
|
*b = bb + m;
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
void adjust_saturation(nd4j::LaunchContext * context, NDArray *input, NDArray *output, NDArray *delta, bool isNHWC);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -15,107 +15,177 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// @author raver119@gmail.com
|
// @author raver119@gmail.com
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <ops/declarable/helpers/adjust_hue.h>
|
#include <ops/declarable/helpers/adjust_hue.h>
|
||||||
|
#include <helpers/ConstantTadHelper.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static void _adjust_hue_single(nd4j::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC) {
|
|
||||||
// we're 100% sure it's 3
|
|
||||||
const int numChannels = 3;
|
|
||||||
int tuples = array->lengthOf() / numChannels;
|
|
||||||
auto bIn = reinterpret_cast<T *>(array->buffer());
|
|
||||||
auto bOut = reinterpret_cast<T *>(output->buffer());
|
|
||||||
static const int kChannelRange = 6;
|
|
||||||
|
|
||||||
int stridesDim = isNHWC ? 2 : 0;
|
template <typename T>
|
||||||
if (isNHWC) {
|
static void adjustHue_(const NDArray *input, const NDArray* deltaScalarArr, NDArray *output, const int dimC) {
|
||||||
// for NHWC our rgb values are stored one by one
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
|
||||||
for (int e = 0; e < tuples; e++) {
|
|
||||||
auto i = bIn + e * numChannels;
|
|
||||||
auto o = bOut + e * numChannels;
|
|
||||||
|
|
||||||
T h, v_min, v_max;
|
const T delta = deltaScalarArr->e<T>(0);
|
||||||
helpers::rgb_to_hv(i[0], i[1], i[2], &h, &v_min, &v_max);
|
const int rank = input->rankOf();
|
||||||
|
|
||||||
h += delta * kChannelRange;
|
const T* x = input->bufferAsT<T>();
|
||||||
while (h < (T) 0.)
|
T* z = output->bufferAsT<T>();
|
||||||
h += (T) kChannelRange;
|
|
||||||
|
|
||||||
while (h >= (T) kChannelRange)
|
if(dimC == rank - 1 && input->ews() == 1 && output->ews() == 1 && input->ordering() == 'c' && output->ordering() == 'c') {
|
||||||
h -= (T) kChannelRange;
|
|
||||||
|
|
||||||
helpers::hv_to_rgb(h, v_min, v_max, o, o + 1, o + 2);
|
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||||
}
|
for (Nd4jLong i = 0; i < input->lengthOf(); i += 3) {
|
||||||
} else {
|
|
||||||
auto tadsChannelsIn = array->allTensorsAlongDimension({0});
|
|
||||||
auto tadsChannelsOut = output->allTensorsAlongDimension( {0});
|
|
||||||
|
|
||||||
auto bufferR = reinterpret_cast<T *>(tadsChannelsIn->at(0)->buffer());
|
T h, s, v;
|
||||||
auto bufferG = reinterpret_cast<T *>(tadsChannelsIn->at(1)->buffer());
|
|
||||||
auto bufferB = reinterpret_cast<T *>(tadsChannelsIn->at(2)->buffer());
|
|
||||||
|
|
||||||
auto outputR = reinterpret_cast<T *>(tadsChannelsOut->at(0)->buffer());
|
rgbToHsv<T>(x[i], x[i+1], x[i+2], h, s, v);
|
||||||
auto outputG = reinterpret_cast<T *>(tadsChannelsOut->at(1)->buffer());
|
|
||||||
auto outputB = reinterpret_cast<T *>(tadsChannelsOut->at(2)->buffer());
|
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
h += delta * 360;
|
||||||
for (int e = 0; e < tuples; e++) {
|
if(h > 360)
|
||||||
auto _ri = bufferR + e;
|
h -= 360;
|
||||||
auto _gi = bufferG + e;
|
else if(h < 0)
|
||||||
auto _bi = bufferB + e;
|
h += 360;
|
||||||
|
|
||||||
auto _ro = outputR + e;
|
hsvToRgb<T>(h, s, v, z[i], z[i+1], z[i+2]);
|
||||||
auto _go = outputG + e;
|
|
||||||
auto _bo = outputB + e;
|
|
||||||
|
|
||||||
T h, v_min, v_max;
|
|
||||||
helpers::rgb_to_hv(_ri[0], _gi[0], _bi[0], &h, &v_min, &v_max);
|
|
||||||
|
|
||||||
h += delta * kChannelRange;
|
|
||||||
while (h < (T) 0)
|
|
||||||
h += (T) kChannelRange;
|
|
||||||
|
|
||||||
while (h >= (T) kChannelRange)
|
|
||||||
h -= (T) kChannelRange;
|
|
||||||
|
|
||||||
helpers::hv_to_rgb(h, v_min, v_max, _ro, _go, _bo);
|
|
||||||
}
|
|
||||||
|
|
||||||
delete tadsChannelsIn;
|
|
||||||
delete tadsChannelsOut;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
else {
|
||||||
|
|
||||||
void _adjust_hue(nd4j::LaunchContext * context, NDArray *array, NDArray *output, NDArray* delta, bool isNHWC) {
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {dimC});
|
||||||
auto xType = array->dataType();
|
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {dimC});
|
||||||
|
|
||||||
float d = delta->e<float>(0);
|
const Nd4jLong numOfTads = packX.numberOfTads();
|
||||||
if (array->rankOf() == 4) {
|
const Nd4jLong xDimCstride = input->stridesOf()[dimC];
|
||||||
auto tadsIn = array->allTensorsAlongDimension({0});
|
const Nd4jLong zDimCstride = output->stridesOf()[dimC];
|
||||||
auto tadsOut = output->allTensorsAlongDimension({0});
|
|
||||||
int tSize = tadsIn->size();
|
|
||||||
// FIXME: template selector should be moved out of loop
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR
|
|
||||||
for (int e = 0; e < tSize; e++) {
|
|
||||||
BUILD_SINGLE_SELECTOR(xType, _adjust_hue_single, (context, tadsIn->at(e), tadsOut->at(e), d, isNHWC);, FLOAT_TYPES);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||||
|
for(Nd4jLong i = 0; i < numOfTads; ++i) {
|
||||||
|
|
||||||
|
const T* xTad = x + packX.platformOffsets()[i];
|
||||||
|
T* zTad = z + packZ.platformOffsets()[i];
|
||||||
|
|
||||||
|
T h, s, v;
|
||||||
|
|
||||||
|
rgbToHsv<T>(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], h, s, v);
|
||||||
|
|
||||||
|
h += delta * 360;
|
||||||
|
if(h > 360)
|
||||||
|
h -= 360;
|
||||||
|
else if(h < 0)
|
||||||
|
h += 360;
|
||||||
|
|
||||||
|
hsvToRgb<T>(h, s, v, zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]);
|
||||||
|
|
||||||
delete tadsIn;
|
|
||||||
delete tadsOut;
|
|
||||||
} else {
|
|
||||||
BUILD_SINGLE_SELECTOR(xType, _adjust_hue_single, (context, array, output, d, isNHWC);, FLOAT_TYPES);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void adjustHue(nd4j::LaunchContext* context, const NDArray *input, const NDArray* deltaScalarArr, NDArray *output, const int dimC) {
|
||||||
|
|
||||||
|
BUILD_SINGLE_SELECTOR(input->dataType(), adjustHue_, (input, deltaScalarArr, output, dimC), LIBND4J_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
template <typename T>
|
||||||
|
static void adjust_hue_single_(nd4j::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC) {
|
||||||
|
// we're 100% sure it's 3
|
||||||
|
const int numChannels = 3;
|
||||||
|
int tuples = array->lengthOf() / numChannels;
|
||||||
|
auto bIn = reinterpret_cast<T *>(array->buffer());
|
||||||
|
auto bOut = reinterpret_cast<T *>(output->buffer());
|
||||||
|
static const int kChannelRange = 6;
|
||||||
|
|
||||||
|
int stridesDim = isNHWC ? 2 : 0;
|
||||||
|
if (isNHWC) {
|
||||||
|
// for NHWC our rgb values are stored one by one
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||||
|
for (int e = 0; e < tuples; e++) {
|
||||||
|
auto i = bIn + e * numChannels;
|
||||||
|
auto o = bOut + e * numChannels;
|
||||||
|
|
||||||
|
T h, v_min, v_max;
|
||||||
|
helpers::rgb_to_hv(i[0], i[1], i[2], &h, &v_min, &v_max);
|
||||||
|
|
||||||
|
h += delta * kChannelRange;
|
||||||
|
while (h < (T) 0.)
|
||||||
|
h += (T) kChannelRange;
|
||||||
|
|
||||||
|
while (h >= (T) kChannelRange)
|
||||||
|
h -= (T) kChannelRange;
|
||||||
|
|
||||||
|
helpers::hv_to_rgb(h, v_min, v_max, o, o + 1, o + 2);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto tadsChannelsIn = array->allTensorsAlongDimension({0});
|
||||||
|
auto tadsChannelsOut = output->allTensorsAlongDimension( {0});
|
||||||
|
|
||||||
|
auto bufferR = reinterpret_cast<T *>(tadsChannelsIn->at(0)->buffer());
|
||||||
|
auto bufferG = reinterpret_cast<T *>(tadsChannelsIn->at(1)->buffer());
|
||||||
|
auto bufferB = reinterpret_cast<T *>(tadsChannelsIn->at(2)->buffer());
|
||||||
|
|
||||||
|
auto outputR = reinterpret_cast<T *>(tadsChannelsOut->at(0)->buffer());
|
||||||
|
auto outputG = reinterpret_cast<T *>(tadsChannelsOut->at(1)->buffer());
|
||||||
|
auto outputB = reinterpret_cast<T *>(tadsChannelsOut->at(2)->buffer());
|
||||||
|
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||||
|
for (int e = 0; e < tuples; e++) {
|
||||||
|
auto _ri = bufferR + e;
|
||||||
|
auto _gi = bufferG + e;
|
||||||
|
auto _bi = bufferB + e;
|
||||||
|
|
||||||
|
auto _ro = outputR + e;
|
||||||
|
auto _go = outputG + e;
|
||||||
|
auto _bo = outputB + e;
|
||||||
|
|
||||||
|
T h, v_min, v_max;
|
||||||
|
helpers::rgb_to_hv(_ri[0], _gi[0], _bi[0], &h, &v_min, &v_max);
|
||||||
|
|
||||||
|
h += delta * kChannelRange;
|
||||||
|
while (h < (T) 0)
|
||||||
|
h += (T) kChannelRange;
|
||||||
|
|
||||||
|
while (h >= (T) kChannelRange)
|
||||||
|
h -= (T) kChannelRange;
|
||||||
|
|
||||||
|
helpers::hv_to_rgb(h, v_min, v_max, _ro, _go, _bo);
|
||||||
|
}
|
||||||
|
|
||||||
|
delete tadsChannelsIn;
|
||||||
|
delete tadsChannelsOut;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void adjust_hue_(nd4j::LaunchContext * context, NDArray *array, NDArray *output, NDArray* delta, bool isNHWC) {
|
||||||
|
auto xType = array->dataType();
|
||||||
|
|
||||||
|
float d = delta->e<float>(0);
|
||||||
|
if (array->rankOf() == 4) {
|
||||||
|
auto tadsIn = array->allTensorsAlongDimension({0});
|
||||||
|
auto tadsOut = output->allTensorsAlongDimension({0});
|
||||||
|
int tSize = tadsIn->size();
|
||||||
|
// FIXME: template selector should be moved out of loop
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR
|
||||||
|
for (int e = 0; e < tSize; e++) {
|
||||||
|
BUILD_SINGLE_SELECTOR(xType, adjust_hue_single_, (context, tadsIn->at(e), tadsOut->at(e), d, isNHWC);, FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
delete tadsIn;
|
||||||
|
delete tadsOut;
|
||||||
|
} else {
|
||||||
|
BUILD_SINGLE_SELECTOR(xType, adjust_hue_single_, (context, array, output, d, isNHWC);, FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void adjust_hue_single_, (nd4j::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC);, FLOAT_TYPES);
|
||||||
|
*/
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void _adjust_hue_single, (nd4j::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC);, FLOAT_TYPES);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,99 +15,168 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// @author raver119@gmail.com
|
// @author raver119@gmail.com
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <ops/declarable/helpers/adjust_saturation.h>
|
#include <ops/declarable/helpers/adjust_saturation.h>
|
||||||
|
#include <ops/declarable/helpers/adjust_hue.h>
|
||||||
|
#include <helpers/ConstantTadHelper.h>
|
||||||
|
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static void adjust_saturation_single_(nd4j::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC) {
|
static void adjustSaturation_(const NDArray *input, const NDArray* factorScalarArr, NDArray *output, const int dimC) {
|
||||||
// we're 100% sure it's 3
|
|
||||||
const int numChannels = 3;
|
|
||||||
int tuples = array->lengthOf() / numChannels;
|
|
||||||
auto bIn = reinterpret_cast<T *>(array->buffer());
|
|
||||||
auto bOut = reinterpret_cast<T *>(output->buffer());
|
|
||||||
static const int kChannelRange = 6;
|
|
||||||
|
|
||||||
if (isNHWC) {
|
const T factor = factorScalarArr->e<T>(0);
|
||||||
// for NHWC our rgb values are stored one by one
|
const int rank = input->rankOf();
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
|
||||||
for (int e = 0; e < tuples; e++) {
|
|
||||||
auto i = bIn + e * numChannels;
|
|
||||||
auto o = bOut + e * numChannels;
|
|
||||||
|
|
||||||
T h, s, v;
|
const T* x = input->bufferAsT<T>();
|
||||||
// Convert the RGB color to Hue/V-range.
|
T* z = output->bufferAsT<T>();
|
||||||
helpers::rgb_to_hsv(i[0], i[1], i[2], &h, &s, &v);
|
|
||||||
s = nd4j::math::nd4j_min<T>((T) 1.0f, nd4j::math::nd4j_max<T>((T) 0.0f, s * delta));
|
|
||||||
// Convert the hue and v-range back into RGB.
|
|
||||||
helpers::hsv_to_rgb(h, s, v, o, o + 1, o + 2);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
auto tadsChannelsIn = array->allTensorsAlongDimension({0});
|
|
||||||
auto tadsChannelsOut = output->allTensorsAlongDimension({0});
|
|
||||||
|
|
||||||
auto bufferR = reinterpret_cast<T *>(tadsChannelsIn->at(0)->buffer());
|
if(dimC == rank - 1 && input->ews() == 1 && output->ews() == 1 && input->ordering() == 'c' && output->ordering() == 'c') {
|
||||||
auto bufferG = reinterpret_cast<T *>(tadsChannelsIn->at(1)->buffer());
|
|
||||||
auto bufferB = reinterpret_cast<T *>(tadsChannelsIn->at(2)->buffer());
|
|
||||||
|
|
||||||
auto outputR = reinterpret_cast<T *>(tadsChannelsOut->at(0)->buffer());
|
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||||
auto outputG = reinterpret_cast<T *>(tadsChannelsOut->at(1)->buffer());
|
for (Nd4jLong i = 0; i < input->lengthOf(); i += 3) {
|
||||||
auto outputB = reinterpret_cast<T *>(tadsChannelsOut->at(2)->buffer());
|
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
T h, s, v;
|
||||||
for (int e = 0; e < tuples; e++) {
|
|
||||||
auto _ri = bufferR + e;
|
|
||||||
auto _gi = bufferG + e;
|
|
||||||
auto _bi = bufferB + e;
|
|
||||||
|
|
||||||
auto _ro = outputR + e;
|
rgbToHsv<T>(x[i], x[i+1], x[i+2], h, s, v);
|
||||||
auto _go = outputG + e;
|
|
||||||
auto _bo = outputB + e;
|
|
||||||
|
|
||||||
T h, s, v;
|
s *= factor;
|
||||||
// Convert the RGB color to Hue/V-range.
|
if(s > 1.f)
|
||||||
helpers::rgb_to_hsv(_ri[0], _gi[0], _bi[0], &h, &s, &v);
|
s = 1.f;
|
||||||
s = nd4j::math::nd4j_min<T>((T) 1.0f, nd4j::math::nd4j_max<T>((T) 0.0f, s * delta));
|
else if(s < 0.f)
|
||||||
// Convert the hue and v-range back into RGB.
|
s = 0.f;
|
||||||
helpers::hsv_to_rgb(h, s, v, _ro, _go, _bo);
|
|
||||||
}
|
|
||||||
|
|
||||||
delete tadsChannelsIn;
|
hsvToRgb<T>(h, s, v, z[i], z[i+1], z[i+2]);
|
||||||
delete tadsChannelsOut;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
else {
|
||||||
|
|
||||||
void adjust_saturation(nd4j::LaunchContext * context, NDArray *array, NDArray *output, NDArray* delta, bool isNHWC) {
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {dimC});
|
||||||
auto xType = array->dataType();
|
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {dimC});
|
||||||
|
|
||||||
float d = delta->e<float>(0);
|
const Nd4jLong numOfTads = packX.numberOfTads();
|
||||||
if (array->rankOf() == 4) {
|
const Nd4jLong xDimCstride = input->stridesOf()[dimC];
|
||||||
auto tadsIn = array->allTensorsAlongDimension({0});
|
const Nd4jLong zDimCstride = output->stridesOf()[dimC];
|
||||||
auto tadsOut = output->allTensorsAlongDimension({0});
|
|
||||||
int tSize = tadsIn->size();
|
|
||||||
|
|
||||||
// FIXME: template selector should be moved out of loop
|
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||||
PRAGMA_OMP_PARALLEL_FOR
|
for(Nd4jLong i = 0; i < numOfTads; ++i) {
|
||||||
for (int e = 0; e < tSize; e++) {
|
|
||||||
BUILD_SINGLE_SELECTOR(xType, adjust_saturation_single_, (context, tadsIn->at(e), tadsOut->at(e), d, isNHWC);, FLOAT_TYPES);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
const T* xTad = x + packX.platformOffsets()[i];
|
||||||
|
T* zTad = z + packZ.platformOffsets()[i];
|
||||||
|
|
||||||
|
T h, s, v;
|
||||||
|
|
||||||
|
rgbToHsv<T>(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], h, s, v);
|
||||||
|
|
||||||
|
s *= factor;
|
||||||
|
if(s > 1.f)
|
||||||
|
s = 1.f;
|
||||||
|
else if(s < 0.f)
|
||||||
|
s = 0.f;
|
||||||
|
|
||||||
|
hsvToRgb<T>(h, s, v, zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]);
|
||||||
|
|
||||||
delete tadsIn;
|
|
||||||
delete tadsOut;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
BUILD_SINGLE_SELECTOR(xType, adjust_saturation_single_, (context, array, output, d, isNHWC);, FLOAT_TYPES);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void adjust_saturation_single_, (nd4j::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC), FLOAT_TYPES);
|
|
||||||
|
void adjustSaturation(nd4j::LaunchContext* context, const NDArray *input, const NDArray* factorScalarArr, NDArray *output, const int dimC) {
|
||||||
|
|
||||||
|
BUILD_SINGLE_SELECTOR(input->dataType(), adjustSaturation_, (input, factorScalarArr, output, dimC), LIBND4J_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
template <typename T>
|
||||||
|
static void adjust_saturation_single_(nd4j::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC) {
|
||||||
|
// we're 100% sure it's 3
|
||||||
|
const int numChannels = 3;
|
||||||
|
int tuples = array->lengthOf() / numChannels;
|
||||||
|
auto bIn = reinterpret_cast<T *>(array->buffer());
|
||||||
|
auto bOut = reinterpret_cast<T *>(output->buffer());
|
||||||
|
static const int kChannelRange = 6;
|
||||||
|
|
||||||
|
if (isNHWC) {
|
||||||
|
// for NHWC our rgb values are stored one by one
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||||
|
for (int e = 0; e < tuples; e++) {
|
||||||
|
auto i = bIn + e * numChannels;
|
||||||
|
auto o = bOut + e * numChannels;
|
||||||
|
|
||||||
|
T h, s, v;
|
||||||
|
// Convert the RGB color to Hue/V-range.
|
||||||
|
helpers::rgb_to_hsv(i[0], i[1], i[2], &h, &s, &v);
|
||||||
|
s = nd4j::math::nd4j_min<T>((T) 1.0f, nd4j::math::nd4j_max<T>((T) 0.0f, s * delta));
|
||||||
|
// Convert the hue and v-range back into RGB.
|
||||||
|
helpers::hsv_to_rgb(h, s, v, o, o + 1, o + 2);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto tadsChannelsIn = array->allTensorsAlongDimension({0});
|
||||||
|
auto tadsChannelsOut = output->allTensorsAlongDimension({0});
|
||||||
|
|
||||||
|
auto bufferR = reinterpret_cast<T *>(tadsChannelsIn->at(0)->buffer());
|
||||||
|
auto bufferG = reinterpret_cast<T *>(tadsChannelsIn->at(1)->buffer());
|
||||||
|
auto bufferB = reinterpret_cast<T *>(tadsChannelsIn->at(2)->buffer());
|
||||||
|
|
||||||
|
auto outputR = reinterpret_cast<T *>(tadsChannelsOut->at(0)->buffer());
|
||||||
|
auto outputG = reinterpret_cast<T *>(tadsChannelsOut->at(1)->buffer());
|
||||||
|
auto outputB = reinterpret_cast<T *>(tadsChannelsOut->at(2)->buffer());
|
||||||
|
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||||
|
for (int e = 0; e < tuples; e++) {
|
||||||
|
auto _ri = bufferR + e;
|
||||||
|
auto _gi = bufferG + e;
|
||||||
|
auto _bi = bufferB + e;
|
||||||
|
|
||||||
|
auto _ro = outputR + e;
|
||||||
|
auto _go = outputG + e;
|
||||||
|
auto _bo = outputB + e;
|
||||||
|
|
||||||
|
T h, s, v;
|
||||||
|
// Convert the RGB color to Hue/V-range.
|
||||||
|
helpers::rgb_to_hsv(_ri[0], _gi[0], _bi[0], &h, &s, &v);
|
||||||
|
s = nd4j::math::nd4j_min<T>((T) 1.0f, nd4j::math::nd4j_max<T>((T) 0.0f, s * delta));
|
||||||
|
// Convert the hue and v-range back into RGB.
|
||||||
|
helpers::hsv_to_rgb(h, s, v, _ro, _go, _bo);
|
||||||
|
}
|
||||||
|
|
||||||
|
delete tadsChannelsIn;
|
||||||
|
delete tadsChannelsOut;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void adjust_saturation(nd4j::LaunchContext * context, NDArray *array, NDArray *output, NDArray* delta, bool isNHWC) {
|
||||||
|
auto xType = array->dataType();
|
||||||
|
|
||||||
|
float d = delta->e<float>(0);
|
||||||
|
if (array->rankOf() == 4) {
|
||||||
|
auto tadsIn = array->allTensorsAlongDimension({0});
|
||||||
|
auto tadsOut = output->allTensorsAlongDimension({0});
|
||||||
|
int tSize = tadsIn->size();
|
||||||
|
|
||||||
|
// FIXME: template selector should be moved out of loop
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR
|
||||||
|
for (int e = 0; e < tSize; e++) {
|
||||||
|
BUILD_SINGLE_SELECTOR(xType, adjust_saturation_single_, (context, tadsIn->at(e), tadsOut->at(e), d, isNHWC);, FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
delete tadsIn;
|
||||||
|
delete tadsOut;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
BUILD_SINGLE_SELECTOR(xType, adjust_saturation_single_, (context, array, output, d, isNHWC);, FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void adjust_saturation_single_, (nd4j::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC), FLOAT_TYPES);
|
||||||
|
*/
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -59,13 +59,16 @@ namespace helpers {
|
||||||
std::vector<Nd4jLong> dims(reduceShape->lengthOf());
|
std::vector<Nd4jLong> dims(reduceShape->lengthOf());
|
||||||
|
|
||||||
bool fit = true;
|
bool fit = true;
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(firstprivate(fit))
|
||||||
for( int i = 0; fit && (i < dims.size()); i++ ) {
|
for( int i = 0; i < dims.size(); i++ ) {
|
||||||
dims[i] = reduceShape->e<Nd4jLong>(i);
|
if (fit) {
|
||||||
for (int e = 0; fit && (e < input->rankOf()); ++e)
|
dims[i] = reduceShape->e<Nd4jLong>(i);
|
||||||
if (input->sizeAt(e) % dims[i]) {
|
for (int e = 0; e < input->rankOf(); ++e)
|
||||||
fit = false;
|
if (fit)
|
||||||
}
|
if (input->sizeAt(e) % dims[i]) {
|
||||||
|
fit = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// check dims to fit input
|
// check dims to fit input
|
||||||
|
|
|
@ -35,82 +35,88 @@ namespace helpers {
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void gruCell(nd4j::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* Wru, const NDArray* Wc,
|
void gruCell(nd4j::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* W, const NDArray* Wc,
|
||||||
const NDArray* bru, const NDArray* bc,
|
const NDArray* b, const NDArray* bc,
|
||||||
NDArray* r, NDArray* u, NDArray* c, NDArray* h) {
|
NDArray* r, NDArray* u, NDArray* c, NDArray* h) {
|
||||||
|
|
||||||
//Inputs:
|
//Inputs:
|
||||||
// x input [bS, nIn], nIn - input size
|
// x input [bS, iS], iS - input size
|
||||||
// hLast previous cell output [bS, nUn], that is at previous time step t-1, nUn - number of units
|
// hLast previous cell output [bS, nU], that is at previous time step t-1, nU - number of units
|
||||||
// Wru RU weights - [nIn+nUn, 2*nUn] - reset and update gates
|
// W RU weights - [iS+nU, 2*nU] - reset and update gates
|
||||||
// Wc C weights - [nIn+nUn, nUn] - cell gate
|
// Wc C weights - [iS+nU, nU] - cell gate
|
||||||
// bru r and u biases, [2*nUn] - reset and update gates
|
// b r and u biases, [2*nU] - reset and update gates
|
||||||
// bc c biases, [nUn] - cell gate
|
// bc c biases, [nU] - cell gate
|
||||||
|
|
||||||
//Outputs:
|
//Outputs:
|
||||||
// r Reset gate output [bS, nUn]
|
// r Reset gate output [bS, nU]
|
||||||
// u Update gate output [bS, nUn]
|
// u Update gate output [bS, nU]
|
||||||
// c Cell gate output [bS, nUn]
|
// c Cell gate output [bS, nU]
|
||||||
// h current cell output [bS, nUn]
|
// h current cell output [bS, nU]
|
||||||
|
|
||||||
/***************************************************************************************/
|
/***************************************************************************************/
|
||||||
/************************ THIS IS NOT OPTIMAZED CODE ***********************************/
|
/************************ THIS IS NOT OPTIMAZED CODE ***********************************/
|
||||||
/** however it is more math-friendly and convenient for backprop formulas derivation) **/
|
/** however it is more math-friendly and convenient for backprop formulas derivation) **/
|
||||||
|
|
||||||
const int bS = x->sizeAt(0);
|
const int bS = x->sizeAt(0);
|
||||||
const int nIn = x->sizeAt(1);
|
const int iS = x->sizeAt(1);
|
||||||
const int nUn = hLast->sizeAt(1);
|
const int nU = hLast->sizeAt(1);
|
||||||
|
|
||||||
NDArray Wr = (*Wru)({0,nIn, 0,0}); // reset gates weights [nIn, 2*nUn]
|
NDArray Wrx = (*W)({0,iS, 0,nU}); // [iS, nU]
|
||||||
NDArray Wu = (*Wru)({nIn,nIn+nUn, 0,0}); // updates gates weights [nUn, 2*nUn]
|
NDArray Wux = (*W)({0,iS, nU,2*nU}); // [iS, nU]
|
||||||
|
NDArray Wrh = (*W)({iS,iS+nU, 0,nU}); // [nU, nU]
|
||||||
|
NDArray Wuh = (*W)({iS,iS+nU, nU,2*nU}); // [nU, nU]
|
||||||
|
|
||||||
NDArray Wcr = (*Wc)({0,nIn, 0,0}); // reset cell weights [nIn, nUn]
|
NDArray Wcx = (*Wc)({0,iS, 0,0}); // reset cell weights [iS, nU]
|
||||||
NDArray Wcu = (*Wc)({nIn,nIn+nUn, 0,0}); // updates cell weights [nUn, nUn]
|
NDArray Wch = (*Wc)({iS,iS+nU, 0,0}); // updates cell weights [nU, nU]
|
||||||
|
|
||||||
// gates = sigmoid(x*Wr + hLast*Wu + br + bu)
|
NDArray br = (*b)({0, nU}); // [nU]
|
||||||
NDArray gates = mmul(*x, Wr) + mmul(*hLast, Wu) + *bru; // [bS, nIn] * [nIn, 2*nUn] + [bS, nUn] * [nUn, 2*nUn] + [2*nUn] = [bS, 2*nUn]
|
NDArray bu = (*b)({nU, 2*nU}); // [nU]
|
||||||
gates.applyTransform(transform::Sigmoid);
|
|
||||||
|
// × means matrix multipication
|
||||||
|
// * means element-wise product or so called Hadamard product
|
||||||
|
|
||||||
// reset gate
|
// reset gate
|
||||||
r->assign(gates({0,0, 0,nUn})); // [bS, nUn]
|
r->assign(mmul(*x, Wrx) + mmul(*hLast, Wrh) + br); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
|
||||||
|
r->applyTransform(transform::Sigmoid);
|
||||||
|
|
||||||
// update gate
|
// update gate
|
||||||
u->assign(gates({0,0, nUn,2*nUn})); // [bS, nUn]
|
u->assign(mmul(*x, Wux) + mmul(*hLast, Wuh) + bu); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
|
||||||
|
u->applyTransform(transform::Sigmoid);
|
||||||
|
|
||||||
// cell gate c = activation(x*Wcr + (r◦hlast)*Wcu + bc)
|
// cell gate c = activation(x × Wcx + (r * hlast) × Wch + bc)
|
||||||
c->assign(mmul(*x, Wcr) + mmul(*r * *hLast, Wcu) + *bc); // [bS, nIn] * [nIn, nUn] + [bS, nUn] * [nUn, nUn] + [nUn] = [bS, nUn]
|
c->assign(mmul(*x, Wcx) + mmul(*r * *hLast, Wch) + *bc); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
|
||||||
c->applyTransform(transform::Tanh);
|
c->applyTransform(transform::Tanh);
|
||||||
|
|
||||||
|
NDArray temp = 1.f - *c * *c;
|
||||||
|
|
||||||
// cell output
|
// cell output
|
||||||
h->assign(*u * *hLast + (1.f - *u) * *c);
|
h->assign(*u * *hLast + (1.f - *u) * *c);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/***************************************************************************************/
|
/***************************************************************************************/
|
||||||
/********************** THIS MORE OPTIMAZED CODE (except concat ) **********************/
|
/*************** THIS IS MORE OPTIMAZED CODE (should think about concat) ***************/
|
||||||
/***************************************************************************************/
|
/***************************************************************************************/
|
||||||
/*
|
/*
|
||||||
//Concat inputs: x + hLast : [bs, nIn + nUn]
|
//Concat inputs: x + hLast : [bs, iS + nU]
|
||||||
NDArray xhConcat(x->ordering(), {bS, nIn + nUn}, x->dataType(), context); // concat([bs, nIn], [bs, nUn]) -> [bs, nIn + nUn]
|
NDArray xhConcat(x->ordering(), {bS, iS + nU}, x->dataType(), context); // concat([bs, iS], [bs, nU]) -> [bs, iS + nU]
|
||||||
helpers::concat(context, {const_cast<NDArray*>(x), const_cast<NDArray*>(hLast)}, xhConcat, {1});
|
helpers::concat(context, {const_cast<NDArray*>(x), const_cast<NDArray*>(hLast)}, xhConcat, {1});
|
||||||
|
|
||||||
//mmul for reset and update gates: (x * weight_ux + hLast * weight_xr + b_u)
|
//mmul for reset and update gates: (x × weight_ux + hLast × weight_xr + b_u)
|
||||||
auto m = mmul(xhConcat, *Wru) + *bru ; // [bs, nIn+nUn] * [nIn+nUn, 2*nUn] = [bs, 2*nUn]
|
auto m = mmul(xhConcat, *W) + *b ; // [bs, iS+nU] * [iS+nU, 2*nU] = [bs, 2*nU]
|
||||||
// m += *bru;
|
// m += *bru;
|
||||||
|
|
||||||
sigmoidInplace(m); //sigmoid(rz) and sigmoid(uz)
|
m.applyTransform(transform::Sigmoid); //sigmoid(rz) and sigmoid(uz)
|
||||||
|
|
||||||
r->assign(m({0,0, 0, nUn}));
|
r->assign(m({0,0, 0, nU}));
|
||||||
u->assign(m({0,0, nUn, 2*nUn}));
|
u->assign(m({0,0, nU, 2*nU}));
|
||||||
|
|
||||||
// hLast = hLast * r
|
// hLast = hLast * r
|
||||||
xhConcat({0,0, nIn, nIn+nUn}) *= *r;
|
xhConcat({0,0, iS, iS+nU}) *= *r;
|
||||||
|
|
||||||
//c = tanh(x * weight_cx + (hLast .* r) * weight_cr + b_c)
|
//c = tanh(x × weight_cx + (hLast * r) × weight_cr + b_c)
|
||||||
MmulHelper::mmul(&xhConcat, Wc, c, 1.0, 0.0); //c = 1.0 * xhConcat * Wc + 0.0 * c
|
MmulHelper::mmul(&xhConcat, Wc, c, 1.0, 0.0); //c = 1.0 * xhConcat * Wc + 0.0 * c
|
||||||
*c += *bc;
|
*c += *bc;
|
||||||
tanhInplace(*c);
|
c->applyTransform(transform::Tanh);
|
||||||
|
|
||||||
//Output: h = (1-u).*c + u .* hPrev
|
//Output: h = (1-u).*c + u .* hPrev
|
||||||
//auto hResult = (*u) * (*hLast) + (1.0f - *u) * (*c); const_cast<NDArray*>(h)->assign(&hResult);
|
//auto hResult = (*u) * (*hLast) + (1.0f - *u) * (*c); const_cast<NDArray*>(h)->assign(&hResult);
|
||||||
|
@ -122,19 +128,19 @@ void gruCell(nd4j::LaunchContext * context, const NDArray* x, const NDArray* hLa
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void gruTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h) {
|
void gruTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h) {
|
||||||
|
|
||||||
// x input [time, bS, iS]
|
// x input [time, bS, iS]
|
||||||
// h0 initial cell output (at time step = 0) [bS, nUn]
|
// hLast initial cell output (at time step = 0) [bS, nU]
|
||||||
// Wx input-to-hidden weights, [iS, 3*nUn]
|
// Wx input-to-hidden weights, [iS, 3*nU]
|
||||||
// Wh hidden-to-hidden weights, [nUn, 3*nUn]
|
// Wh hidden-to-hidden weights, [nU, 3*nU]
|
||||||
// b biases, [3*nUn]
|
// b biases, [3*nU]
|
||||||
|
|
||||||
// h is cell outputs at each time step [time, bS, nUn]
|
// h is cell outputs at each time step [time, bS, nU]
|
||||||
|
|
||||||
const int time = x->sizeAt(0);
|
const int time = x->sizeAt(0);
|
||||||
|
|
||||||
NDArray ht_1(*h0);
|
NDArray ht_1(*hLast);
|
||||||
|
|
||||||
// loop through time steps
|
// loop through time steps
|
||||||
for (int t = 0; t < time; ++t) {
|
for (int t = 0; t < time; ++t) {
|
||||||
|
@ -142,111 +148,214 @@ void gruTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray*
|
||||||
auto xt = (*x)({t,t+1, 0,0, 0,0});
|
auto xt = (*x)({t,t+1, 0,0, 0,0});
|
||||||
auto ht = (*h)({t,t+1, 0,0, 0,0});
|
auto ht = (*h)({t,t+1, 0,0, 0,0});
|
||||||
|
|
||||||
//helpers::gruCell(&xt, &ht_1, Wx, Wh, b, &ht);
|
// helpers::gruCell(&xt, &ht_1, Wx, Wh, b, &ht);
|
||||||
//ht_1.assign(ht);
|
// ht_1.assign(ht);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void gruCellBP(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh, const NDArray* dLdWx0,
|
void gruCellBP(nd4j::LaunchContext* context,
|
||||||
const NDArray* dLdWh0, const NDArray* dLdb0, NDArray* dLdx, NDArray* dLdh0, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb) {
|
const NDArray* x, const NDArray* hLast,
|
||||||
|
const NDArray* W, const NDArray* Wc, const NDArray* b, const NDArray* bc,
|
||||||
|
const NDArray* dLdr, const NDArray* dLdu, const NDArray* dLdc, const NDArray* dLdh,
|
||||||
|
NDArray* dLdx, NDArray* dLdhLast,
|
||||||
|
NDArray* dLdW, NDArray* dLdWc,
|
||||||
|
NDArray* dLdb, NDArray* dLdbc) {
|
||||||
|
|
||||||
// x input [bS, iS]
|
//Inputs:
|
||||||
// h0 previous cell output [bS, nUn], that is at previous time step t-1
|
// x input [bS, iS]
|
||||||
// Wx input-to-hidden weights, [iS, 3*nUn]
|
// hLast previous cell output [bS, nU], that is at previous time step t-1
|
||||||
// Wh hidden-to-hidden weights, [nUn, 3*nUn]
|
// W weights - [iS+nU, 2*nU] - reset and update gates
|
||||||
// b biases, [3*nUn]
|
// Wc C weights - [iS+nU, nU] - cell gate
|
||||||
// dLdh gradient wrt output, [bS,nUn], that is epsilon_next
|
// b r and u biases, [2*nU] - reset and update gates
|
||||||
// dLdWx0 gradient wrt Wx at previous time step, [iS, 3*nUn]
|
// bc c biases, [nU] - cell gate
|
||||||
// dLdWh0 gradient wrt Wh at previous time step, [nUn, 3*nUn]
|
// dLdr gradient wrt reset gate, [bS, nU]
|
||||||
// dLdb0 gradient wrt b at previous time step, [3*nUn]
|
// dLdu gradient wrt update gate, [bS, nU]
|
||||||
|
// dLdc gradient wrt cell state, [bS, nU]
|
||||||
|
// dLdh gradient wrt current cell output, [bS, nU]
|
||||||
|
|
||||||
// dLdx gradient wrt x, [bS, iS], that is epsilon
|
//Outputs:
|
||||||
// dLdh0 gradient wrt h0, [bS, nUn]
|
// dLdx gradient wrt x, [bS, iS],
|
||||||
// dLdWx gradient wrt Wx, [iS, 3*nUn]
|
// dLdhLast gradient wrt hLast, [bS, nU]
|
||||||
// dLdWh gradient wrt Wh, [nUn, 3*nUn]
|
// dLdW gradient wrt W, [iS+nU, 2*nU]
|
||||||
// dLdb gradient wrt b at previous time step, [3*nUn]
|
// dLdWc gradient wrt Wc, [iS+nU, nU]
|
||||||
|
// dLdb gradient wrt bru [2*nU]
|
||||||
|
// dLdbc gradient wrt bc [nU]
|
||||||
|
|
||||||
// h is current cell output [bS, nUn], that is at current time step t
|
// * means element-wise product or so called Hadamard product
|
||||||
|
// × means matrix multiplication
|
||||||
|
|
||||||
|
/************************************************************************************************/
|
||||||
|
/******************************* THIS IS NOT OPTIMAZED CODE *************************************/
|
||||||
|
/*** aim is to have math-readable code in order to keep track of backprop formulas derivation ***/
|
||||||
|
|
||||||
|
const int bS = x->sizeAt(0);
|
||||||
|
const int iS = x->sizeAt(1);
|
||||||
|
const int nU = hLast->sizeAt(1);
|
||||||
|
|
||||||
|
NDArray xT = x->transpose(); // [iS, bS]
|
||||||
|
NDArray hLastT = hLast->transpose(); // [nU, bS]
|
||||||
|
|
||||||
|
NDArray Wrx = (*W)({0,iS, 0,nU}); // [iS, nU]
|
||||||
|
NDArray Wux = (*W)({0,iS, nU,2*nU}); // [iS, nU]
|
||||||
|
NDArray Wrh = (*W)({iS,iS+nU, 0,nU}); // [nU, nU]
|
||||||
|
NDArray Wuh = (*W)({iS,iS+nU, nU,2*nU}); // [nU, nU]
|
||||||
|
|
||||||
|
NDArray Wcx = (*Wc)({0,iS, 0,0}); // reset cell weights [iS, nU]
|
||||||
|
NDArray Wch = (*Wc)({iS,iS+nU, 0,0}); // updates cell weights [nU, nU]
|
||||||
|
|
||||||
|
NDArray br = (*b)({0, nU}); // [nU]
|
||||||
|
NDArray bu = (*b)({nU, 2*nU}); // [nU]
|
||||||
|
|
||||||
|
NDArray WrxT = Wrx.transpose(); // [nU, iS]
|
||||||
|
NDArray WuxT = Wux.transpose(); // [nU, iS]
|
||||||
|
NDArray WrhT = Wrh.transpose(); // [nU, nU]
|
||||||
|
NDArray WuhT = Wuh.transpose(); // [nU, nU]
|
||||||
|
|
||||||
|
NDArray WcxT = Wcx.transpose(); // [nU, iS]
|
||||||
|
NDArray WchT = Wch.transpose(); // [nU, nU]
|
||||||
|
|
||||||
|
NDArray dLdWrx = (*dLdW)({0,iS, 0,nU}); // [iS, nU]
|
||||||
|
NDArray dLdWux = (*dLdW)({0,iS, nU,2*nU}); // [iS, nU]
|
||||||
|
NDArray dLdWrh = (*dLdW)({iS,iS+nU, 0,nU}); // [nU, nU]
|
||||||
|
NDArray dLdWuh = (*dLdW)({iS,iS+nU, nU,2*nU}); // [nU, nU]
|
||||||
|
|
||||||
|
NDArray dLdWcx = (*dLdWc)({0,iS, 0,0}); // [iS, nU]
|
||||||
|
NDArray dLdWch = (*dLdWc)({iS,iS+nU, 0,0}); // [nU, nU]
|
||||||
|
|
||||||
|
NDArray dLdbr = (*dLdb)({0, nU}); // [nU]
|
||||||
|
NDArray dLdbu = (*dLdb)({nU, 2*nU}); // [nU]
|
||||||
|
|
||||||
const int nUn = h0->sizeAt(1);
|
|
||||||
|
|
||||||
// ***** feed forward step ***** //
|
// ***** feed forward step ***** //
|
||||||
// gates = sigmoid(x*Wx + h0*Wh + b)
|
|
||||||
auto gates = sigmoid(mmul(*x, (*Wx)({0,0, 0,2*nUn})) + mmul(*h0, (*Wh)({0,0, 0,2*nUn})) + (*b)({0,2*nUn})); // [bS, 2*nUn] + [bS, 2*nUn] + [1, 2*nUn] = [bS, 2*nUn]
|
|
||||||
// reset gate
|
// reset gate
|
||||||
auto r = gates({0,0, 0, nUn}); // [bS, nUn]
|
NDArray r = mmul(*x, Wrx) + mmul(*hLast, Wrh) + br; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
|
||||||
|
r.applyTransform(transform::Sigmoid);
|
||||||
|
|
||||||
// update gate
|
// update gate
|
||||||
auto u = gates({0,0, nUn, 2*nUn}); // [bS, nUn]
|
NDArray u = mmul(*x, Wux) + mmul(*hLast, Wuh) + bu; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
|
||||||
// ◦ means element-wise product or so called Hadamard product
|
u.applyTransform(transform::Sigmoid);
|
||||||
// n = tanh(x*Wx + (r◦h0)*Wh + b)
|
|
||||||
auto n = tanh(mmul(*x, (*Wx)({0,0, 2*nUn,3*nUn})) + mmul((*h0)*r, (*Wh)({0,0, 2*nUn,3*nUn})) + (*b)({2*nUn,3*nUn})); // [bS, nUn]
|
// cell gate c = activation(x×Wcx + (r*hlast)×Wcu + bc)
|
||||||
|
NDArray c = mmul(*x, Wcx) + mmul(r * *hLast, Wch) + *bc; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
|
||||||
|
c.applyTransform(transform::Tanh);
|
||||||
|
|
||||||
|
// h = (1 - u) * c + u * hPrev
|
||||||
|
|
||||||
|
|
||||||
// ***** back prop step ***** //
|
// ***** back prop step ***** //
|
||||||
auto Wxr = (*Wx)({0,0, 0, nUn});
|
|
||||||
auto Wxu = (*Wx)({0,0, nUn, 2*nUn});
|
|
||||||
auto Wxn = (*Wx)({0,0, 2*nUn,3*nUn});
|
|
||||||
auto Whr = (*Wh)({0,0, 0, nUn});
|
|
||||||
auto Whu = (*Wh)({0,0, nUn, 2*nUn});
|
|
||||||
auto Whn = (*Wh)({0,0, 2*nUn,3*nUn});
|
|
||||||
auto WxrT = Wxr.transpose();
|
|
||||||
auto WxuT = Wxu.transpose();
|
|
||||||
auto WxnT = Wxn.transpose();
|
|
||||||
auto WhrT = Whr.transpose();
|
|
||||||
auto WhuT = Whu.transpose();
|
|
||||||
auto WhnT = Whn.transpose();
|
|
||||||
auto xT = x->transpose();
|
|
||||||
auto h0T = h0->transpose();
|
|
||||||
|
|
||||||
auto dLdWxr = (*dLdWx)({0,0, 0, nUn});
|
// notations:
|
||||||
auto dLdWxu = (*dLdWx)({0,0, nUn, 2*nUn});
|
// Zr = x × Wrx + hLast × Wrh + br
|
||||||
auto dLdWxn = (*dLdWx)({0,0, 2*nUn,3*nUn});
|
// Zu = x × Wux + hLast × Wuh + bu
|
||||||
|
// Sr = sigmoid(Zr)
|
||||||
|
// Su = sigmoid(Zu)
|
||||||
|
// Zc = x × Wcx + (r * hlast) × Wch + bc
|
||||||
|
|
||||||
auto dLdWhr = (*dLdWh)({0,0, 0, nUn});
|
|
||||||
auto dLdWhu = (*dLdWh)({0,0, nUn, 2*nUn});
|
|
||||||
auto dLdWhn = (*dLdWh)({0,0, 2*nUn,3*nUn});
|
|
||||||
|
|
||||||
auto dLdbr = (*dLdb)({0, nUn});
|
// dLdx = dLdh * dhdx = dLdh * (dhdu * dudx + dhdc * dcdx) = (dLdh * dhdu) * dudx + (dLdh * dhdc) * dcdx = dLdu * dudx + dLdc * dcdx
|
||||||
auto dLdbu = (*dLdb)({nUn, 2*nUn});
|
// = dLdx_u + dLdx_c
|
||||||
auto dLdbn = (*dLdb)({2*nUn,3*nUn});
|
// dLdx_u = dLdu * dudx = dLdu * dudZu * dZudx = |dZudx = ... × WuxT| = (dLdu * dudZu) × WuxT
|
||||||
|
// dLdx_c = dLdc * dcdx = dLdc * dcdZc * (dZcdx + dZcdr * drdx) = dLdc * dcdZc * dZcdx + dLdc * dcdZc * dZcdr * drdx = dLdx_c0 + dLdx_c1
|
||||||
|
// dLdx_c0 = dLdc * dcdZc * dZcdx = |dZcdx = ... × WcxT| = (dLdc * dcdZc) × WcxT
|
||||||
|
// dZcdr = (... * hLast) × WchT
|
||||||
|
// dLdc * dcdZc * dZcdr = dLdr = (dLdc * dcdZc * hLast) × WchT
|
||||||
|
// drdx = drdZr * dZrdx
|
||||||
|
// dZrdx = ... × WrxT
|
||||||
|
// dLdx_c1 = dLdc * dcdZc * dZcdr * drdx = dLdr * drdx = (dLdr * drdZr) × WrxT
|
||||||
|
// finally dLdx = dLdx_u + dLdx_c0 + dLdx_c1 = (dLdu * dudZu) × WuxT + (dLdc * dcdZc) × WcxT + (dLdr * drdZr) × WrxT
|
||||||
|
|
||||||
auto dhdu = *h0 - n; // [bS, nUn]
|
|
||||||
auto dhdn = 1.f - u; // [bS, nUn]
|
|
||||||
auto dSigdu = u * (1.f - u); // [bS, nUn]
|
|
||||||
auto dSigdr = r * (1.f - r); // [bS, nUn]
|
|
||||||
auto dActdn = 1.f - n * n; // [bS, nUn]
|
|
||||||
auto dndr = mmul(dActdn * (*h0), WhnT);
|
|
||||||
auto drdh0 = mmul(dSigdr, WhrT);
|
|
||||||
|
|
||||||
auto dLdn = (*dLdh) * dhdn;
|
// dLdhLast = dLdh * (dhdhLast + dhdu * dudhLast + dhdc * dcdhLast) = dLdh * dhdhLast + dLdu * dudhLast + dLdc * dcdhLast
|
||||||
auto dLdu = (*dLdh) * dhdu;
|
// = dLdhLast_h + dLdhLast_u + dLdhLast_c
|
||||||
auto dLdr = dLdn * dndr;
|
// dLdhLast_h = dLdh * dhdhLas = dLdh * u
|
||||||
|
// dLdhLast_u = dLdu * dudhLast = |dudhLast = dudZu * dZudhLast , dZudhLast = ... × WuhT| = (dLdu * dudZu) × WuhT
|
||||||
|
// dLdhLast_c = dLdc * dcdhLast = dLdc * (dcdZc * dZcdhLast + dcdZc * dZcdr * drdhLast) =
|
||||||
|
// = dLdc * dcdZc * dZcdhLast + dLdc * dcdZc * dZcdr * drdhLast =
|
||||||
|
// = dLdc * dcdZc * dZcdhLast + dLdr * drdhLast = dLdhLast_c0 + dLdhLast_c1
|
||||||
|
// dLdhLast_c0 = dLdc * dcdZc * dZcdhLast = |dZcdhLast = (... * r) × WchT| = (dLdc * dcdZc * r) × WchT
|
||||||
|
// dLdhLast_c1 = dLdr * drdhLast = |drdhLast = drdZr * dZrdhLast, dZrdhLast = ... × WrhT| = (dLdr * drdZr) × WrhT
|
||||||
|
// finally dLdhLast = dLdhLast_h + dLdhLast_u + dLdhLast_c0 + dLdhLast_c1 =
|
||||||
|
// = dLdh * u + (dLdu * dudZu) × WuhT + (dLdc * dcdZc * r) × WchT + (dLdr * drdZr) × WrhT
|
||||||
|
|
||||||
dLdx->assign( mmul(dLdu * dSigdu, WxuT) + mmul(dLdr * dSigdr, WxrT) + mmul(dLdn * dActdn, WxnT) ); // [bS,iS]
|
|
||||||
dLdh0->assign( mmul(dLdu * dSigdu, WhuT) + mmul(dLdn * dActdn * (r + drdh0), WhnT) + (*dLdh)*u ); // [bS,nUn]
|
|
||||||
|
|
||||||
dLdWxr.assign( mmul(xT, dSigdr * dLdr) ); // [iS,nUn]
|
// dLdWrx = dLdh * dhdWrx = (dLdh * dhdc) * dcdWrx = dLdc * dcdZc * dZcdWrx = dLdc * dcdZc * dZcdr * drdWrx =
|
||||||
dLdWhr.assign( mmul(h0T, dSigdr * dLdr) ); // [nUn,nUn]
|
// = dLdc * dcdZc * dZcdr * drdZr * dZrdWrx = dLdr * drdZr * dZrdWrx
|
||||||
|
// dZrdWrx = xT × ...
|
||||||
|
// finally dLdWrx = xT × (dLdr * drdZr)
|
||||||
|
|
||||||
dLdWxu.assign( mmul(xT, dSigdu * dLdu) ); // [iS,nUn]
|
|
||||||
dLdWhu.assign( mmul(h0T, dSigdu * dLdu) ); // [nUn,nUn]
|
|
||||||
|
|
||||||
dLdWxn.assign( mmul(xT, dActdn * dLdn) ); // [iS,nUn]
|
// dLdWrh = dLdh * dhdWrh = (dLdh * dhdc) * dcdWrh = dLdc * dcdZc * dZcdWrh = dLdc * dcdZc * dZcdr * drdWrh =
|
||||||
dLdWhn.assign( mmul((r*(*h0)).transpose(), dActdn * dLdn) ); // [nUn,nUn]
|
// = dLdc * dcdZc * dZcdr * drdZr * dZrdWrh = dLdr * drdZr * dZrdWrh
|
||||||
|
// dZrdWrh = hLastT × ...
|
||||||
|
// finally dLdWrh = hLastT × (dLdr * drdZr)
|
||||||
|
|
||||||
dLdbr.assign( (dSigdr * dLdr).reduceAlongDims(reduce::Sum, {0})); // [nUn]
|
|
||||||
dLdbu.assign( (dSigdu * dLdu).reduceAlongDims(reduce::Sum, {0})); // [nUn]
|
|
||||||
dLdbn.assign( (dActdn * dLdn).reduceAlongDims(reduce::Sum, {0})); // [nUn]
|
|
||||||
|
|
||||||
if(dLdWx0 != nullptr)
|
// dLdWux = dLdh * dhdWux = (dLdh * dhdu) * dudWux = dLdu * dudZu * dZudWux
|
||||||
*dLdWx += *dLdWx0;
|
// dZudWux = xT × ...
|
||||||
|
// dLdu * dudZu * dZudWux = xT × (dLdu * dudZu)
|
||||||
|
|
||||||
if(dLdWh0 != nullptr)
|
|
||||||
*dLdWh += *dLdWh0;
|
|
||||||
|
|
||||||
if(dLdb0 != nullptr)
|
// dLdWuh = dLdh * dhdWuh = (dLdh * dhdu) * dudWuh = dLdh * dhdu * dudZu * dZudWuh = dLdu * dudZu * dZudWuh
|
||||||
*dLdb += *dLdb0;
|
// dZudWuh = hLastT × ...
|
||||||
|
// finally dLdWuh = hLastT × (dLdu * dudZu)
|
||||||
|
|
||||||
|
|
||||||
|
// dLdWcx = dLdh * dhdWcx = dLdh * dhdc * dcdWcx = (dLdh * dhdc) * dcdZc * dZcdWcx = dLdc * dcdZc * dZcdWcx
|
||||||
|
// dZcdWcx = xT × ...
|
||||||
|
// finally dLdWcx = xT × (dLdc * dcdZc)
|
||||||
|
|
||||||
|
|
||||||
|
// dLdWch = dLdh * dhdWch = dLdh * dhdc * dcdWch = (dLdh * dhdc) * dcdZc * dZcdWch = dLdc * dcdZc * dZcdWch
|
||||||
|
// dZcdWch = (r*hLast)^T × ...
|
||||||
|
// finally dLdWch = (r*hLast)^T × (dLdc * dcdZc)
|
||||||
|
|
||||||
|
|
||||||
|
// dLdbr = dLdh * dhdbr = (dLdh * dhdc) * dcdbr = dLdc * dcdbr = dLdc * dcdZc * dZcdbr = dLdc * dcdZc * dZcdr * drdbr =
|
||||||
|
// = dLdr * drdZr * dZrdbr
|
||||||
|
// dZrdbr = 1
|
||||||
|
// finally dLdbr = dLdr * drdZr
|
||||||
|
|
||||||
|
|
||||||
|
// dLdbu = dLdh * dhdbu = (dLdh * dhdu) * dudbu = dLdu * dudZu * dZudbu
|
||||||
|
// dZudbu = 1
|
||||||
|
// finally dLdbu = dLdu * dudZu
|
||||||
|
|
||||||
|
|
||||||
|
// dLdbc = dLdh * dhdbc = (dLdh * dhdc) * dcdbc = dLdc * dcdZc * dZcdbc
|
||||||
|
// dZcdbc = 1
|
||||||
|
// finally dLdbc = dLdc * dcdZc
|
||||||
|
|
||||||
|
NDArray dhdc = 1.f - u; // [bS, nU]
|
||||||
|
NDArray dhdu = *hLast - c; // [bS, nU]
|
||||||
|
NDArray dudZu = u * dhdc; // [bS, nU]
|
||||||
|
NDArray drdZr = r * (1.f - r); // [bS, nU]
|
||||||
|
NDArray dcdZc = 1.f - c * c; // [bS, nU]
|
||||||
|
NDArray dLdZc = *dLdc * dcdZc; // [bS, nU]
|
||||||
|
NDArray dLdZu = *dLdu * dudZu; // [bS, nU]
|
||||||
|
NDArray dLdZr = *dLdr * drdZr; // [bS, nU]
|
||||||
|
|
||||||
|
// NDArray dLdc = *dLdh * dhdc; // [bS, nU]
|
||||||
|
// NDArray dLdu = *dLdh * dhdu; // [bS, nU]
|
||||||
|
// NDArray dLdr = mmul(dLdc * dcdZc * *hLast, WchT); // [bS, nU]
|
||||||
|
|
||||||
|
dLdx->assign(mmul(dLdZu, WuxT) + mmul(dLdZc, WcxT) + mmul(dLdZr, WrxT)); // [bS, iS]
|
||||||
|
|
||||||
|
dLdhLast->assign(*dLdh * u + mmul(dLdZu, WuhT) + mmul(dLdZc * r, WchT) + mmul(dLdZr, WrhT)); // [bS, nU]
|
||||||
|
|
||||||
|
dLdWrx.assign(mmul(xT, dLdZr)); // [iS, bS] × [bS, nU] = [iS, nU]
|
||||||
|
dLdWrh.assign(mmul(hLastT, dLdZr)); // [nU, bS] × [bS, nU] = [nU, nU]
|
||||||
|
dLdWux.assign(mmul(xT, dLdZu)); // [iS, bS] × [bS, nU] = [iS, nU]
|
||||||
|
dLdWuh.assign(mmul(hLastT, dLdZu)); // [nU, bS] × [bS, nU] = [nU, nU]
|
||||||
|
|
||||||
|
dLdWcx.assign(mmul(xT, dLdZc)); // [iS, bS] × [bS, nU] = [iS, nU]
|
||||||
|
dLdWch.assign(mmul((r * *hLast).transpose(), dLdZc)); // [nU, bS] × [bS, nU] = [nU, nU]
|
||||||
|
|
||||||
|
dLdbr.assign(dLdZr.reduceAlongDims(reduce::Sum, {0})); // [nU]
|
||||||
|
dLdbu.assign(dLdZu.reduceAlongDims(reduce::Sum, {0})); // [nU]
|
||||||
|
|
||||||
|
dLdbc->assign(dLdZc.reduceAlongDims(reduce::Sum, {0})); // [nU]
|
||||||
}
|
}
|
||||||
|
|
||||||
// //////////////////////////////////////////////////////////////////////////
|
// //////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -255,34 +364,34 @@ void gruCellBP(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h
|
||||||
// void gruTimeLoopBP(const std::vector<NDArray<T>*>& inArrs, const std::vector<NDArray<T>*>& outArrs) {
|
// void gruTimeLoopBP(const std::vector<NDArray<T>*>& inArrs, const std::vector<NDArray<T>*>& outArrs) {
|
||||||
|
|
||||||
// NDArray<T>* x = inArrs[0]; // input [time, bS, iS]
|
// NDArray<T>* x = inArrs[0]; // input [time, bS, iS]
|
||||||
// NDArray<T>* hi = inArrs[1]; // previous/initial cell output [bS, nUn], that is at previous time step t-1
|
// NDArray<T>* hi = inArrs[1]; // previous/initial cell output [bS, nU], that is at previous time step t-1
|
||||||
// NDArray<T>* Wx = inArrs[2]; // input-to-hidden weights, [iS, 3*nUn]
|
// NDArray<T>* Wx = inArrs[2]; // input-to-hidden weights, [iS, 3*nU]
|
||||||
// NDArray<T>* Wh = inArrs[3]; // hidden-to-hidden weights, [nUn, 3*nUn]
|
// NDArray<T>* Wh = inArrs[3]; // hidden-to-hidden weights, [nU, 3*nU]
|
||||||
// NDArray<T>* b = inArrs[4]; // biases, [3*nUn]
|
// NDArray<T>* b = inArrs[4]; // biases, [3*nU]
|
||||||
// NDArray<T>* dLdh = inArrs[5]; // gradient wrt output, [time, bS, nUn], that is epsilon_next
|
// NDArray<T>* dLdh = inArrs[5]; // gradient wrt output, [time, bS, nU], that is epsilon_next
|
||||||
|
|
||||||
// NDArray<T>* dLdx = outArrs[0]; // gradient wrt x, [time, bS, iS], that is epsilon
|
// NDArray<T>* dLdx = outArrs[0]; // gradient wrt x, [time, bS, iS], that is epsilon
|
||||||
// NDArray<T>* dLdhi = outArrs[1]; // gradient wrt hi, [bS, nUn]
|
// NDArray<T>* dLdhi = outArrs[1]; // gradient wrt hi, [bS, nU]
|
||||||
// NDArray<T>* dLdWx = outArrs[2]; // gradient wrt Wx, [iS, 3*nUn]
|
// NDArray<T>* dLdWx = outArrs[2]; // gradient wrt Wx, [iS, 3*nU]
|
||||||
// NDArray<T>* dLdWh = outArrs[3]; // gradient wrt Wh, [nUn, 3*nUn]
|
// NDArray<T>* dLdWh = outArrs[3]; // gradient wrt Wh, [nU, 3*nU]
|
||||||
// NDArray<T>* dLdb = outArrs[4]; // gradient wrt b, [3*nUn]
|
// NDArray<T>* dLdb = outArrs[4]; // gradient wrt b, [3*nU]
|
||||||
|
|
||||||
// const Nd4jLong time = x->sizeAt(0);
|
// const Nd4jLong time = x->sizeAt(0);
|
||||||
// const Nd4jLong bS = x->sizeAt(1);
|
// const Nd4jLong bS = x->sizeAt(1);
|
||||||
// const Nd4jLong iS = x->sizeAt(2);
|
// const Nd4jLong iS = x->sizeAt(2);
|
||||||
// const Nd4jLong nUn = hi->sizeAt(1);
|
// const Nd4jLong nU = hi->sizeAt(1);
|
||||||
|
|
||||||
// NDArray<T> h(hi->ordering(), {time, bS, nUn}); // feed forward output
|
// NDArray<T> h(hi->ordering(), {time, bS, nU}); // feed forward output
|
||||||
|
|
||||||
// // first step, time = 0, feed forward
|
// // first step, time = 0, feed forward
|
||||||
// NDArray<T> x0 = (*x)({{0,1}, {}, {}});
|
// NDArray<T> x0 = (*x)({{0,1}, {}, {}});
|
||||||
// NDArray<T> h0 = h({{0,1}, {}, {}});
|
// NDArray<T> hLast = h({{0,1}, {}, {}});
|
||||||
// helpers::gruCell<T>({&x0, hi, Wx, Wh, b}, &h0);
|
// helpers::gruCell<T>({&x0, hi, Wx, Wh, b}, &hLast);
|
||||||
|
|
||||||
// // first step, time = 0, back prop
|
// // first step, time = 0, back prop
|
||||||
// NDArray<T> dLdx0 = (*dLdx)({{0,1}, {}, {}});
|
// NDArray<T> dLdx0 = (*dLdx)({{0,1}, {}, {}});
|
||||||
// NDArray<T> dLdh0 = (*dLdh)({{0,1}, {}, {}});
|
// NDArray<T> dLdhLast = (*dLdh)({{0,1}, {}, {}});
|
||||||
// helpers::gruCellBP<T>({&x0, hi, Wx, Wh, b, &dLdh0, nullptr, nullptr, nullptr}, {&dLdx0, dLdhi, dLdWx, dLdWh, dLdb});
|
// helpers::gruCellBP<T>({&x0, hi, Wx, Wh, b, &dLdhLast, nullptr, nullptr, nullptr}, {&dLdx0, dLdhi, dLdWx, dLdWh, dLdb});
|
||||||
|
|
||||||
// // loop through the rest time steps
|
// // loop through the rest time steps
|
||||||
// for (Nd4jLong t = time-1; t > 0; --t) {
|
// for (Nd4jLong t = time-1; t > 0; --t) {
|
||||||
|
@ -310,4 +419,3 @@ void gruCellBP(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,8 @@
|
||||||
|
|
||||||
#include <ops/declarable/helpers/image_suppression.h>
|
#include <ops/declarable/helpers/image_suppression.h>
|
||||||
//#include <blas/NDArray.h>
|
//#include <blas/NDArray.h>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -28,9 +30,8 @@ namespace helpers {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static void nonMaxSuppressionV2_(NDArray* boxes, NDArray* scales, int maxSize, double threshold, NDArray* output) {
|
static void nonMaxSuppressionV2_(NDArray* boxes, NDArray* scales, int maxSize, double threshold, NDArray* output) {
|
||||||
std::vector<Nd4jLong> indices(scales->lengthOf());
|
std::vector<Nd4jLong> indices(scales->lengthOf());
|
||||||
|
std::iota(indices.begin(), indices.end(), 0);
|
||||||
|
|
||||||
for (size_t i = 0; i < indices.size(); ++i)
|
|
||||||
indices[i] = i;
|
|
||||||
std::sort(indices.begin(), indices.end(), [scales](int i, int j) {return scales->e<T>(i) > scales->e<T>(j);});
|
std::sort(indices.begin(), indices.end(), [scales](int i, int j) {return scales->e<T>(i) > scales->e<T>(j);});
|
||||||
|
|
||||||
// std::vector<int> selected(output->lengthOf());
|
// std::vector<int> selected(output->lengthOf());
|
||||||
|
@ -62,13 +63,15 @@ namespace helpers {
|
||||||
};
|
};
|
||||||
// int numSelected = 0;
|
// int numSelected = 0;
|
||||||
int numBoxes = boxes->sizeAt(0);
|
int numBoxes = boxes->sizeAt(0);
|
||||||
|
int numSelected = 0;
|
||||||
|
|
||||||
for (int i = 0, numSelected = 0; i < numBoxes && numSelected < output->lengthOf(); ++i) {
|
for (int i = 0; i < numBoxes; ++i) {
|
||||||
bool shouldSelect = true;
|
bool shouldSelect = numSelected < output->lengthOf();
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR //_ARGS(firstprivate(numSelected))
|
||||||
for (int j = numSelected - 1; j >= 0; --j) {
|
for (int j = numSelected - 1; j >= 0; --j) {
|
||||||
|
if (shouldSelect)
|
||||||
if (needToSuppressWithThreshold(*boxes, indices[i], indices[selectedIndices[j]], T(threshold))) {
|
if (needToSuppressWithThreshold(*boxes, indices[i], indices[selectedIndices[j]], T(threshold))) {
|
||||||
shouldSelect = false;
|
shouldSelect = false;
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (shouldSelect) {
|
if (shouldSelect) {
|
||||||
|
|
|
@ -24,20 +24,20 @@ namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
|
|
||||||
template <typename T>
|
template <typename I, typename B>
|
||||||
static void sequenceMask_(NDArray* input, NDArray* output, int maxIndex) {
|
static void sequenceMask_(NDArray* input, NDArray* output, int maxIndex) {
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_COLLAPSE(2)
|
PRAGMA_OMP_PARALLEL_FOR_SIMD_COLLAPSE(2)
|
||||||
for (Nd4jLong i = 0; i < maxIndex; i++)
|
for (Nd4jLong i = 0; i < maxIndex; i++)
|
||||||
for(Nd4jLong k = 0; k < input->lengthOf(); k++)
|
for(Nd4jLong k = 0; k < input->lengthOf(); k++)
|
||||||
if (i < input->e<int>(k))
|
if (i < input->t<I>(k))
|
||||||
output->p<T>(k * maxIndex + i, T(1.0f));
|
output->t<B>(k * maxIndex + i) = B(true); //, T(1.0f));
|
||||||
}
|
}
|
||||||
|
|
||||||
void sequenceMask(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int maxIndex) {
|
void sequenceMask(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int maxIndex) {
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), sequenceMask_, (input, output, maxIndex), LIBND4J_TYPES);
|
BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sequenceMask_, (input, output, maxIndex), INTEGER_TYPES, BOOL_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void sequenceMask_, (NDArray* input, NDArray* output, int maxIndex), LIBND4J_TYPES);
|
BUILD_DOUBLE_TEMPLATE(template void sequenceMask_, (NDArray* input, NDArray* output, int maxIndex), INTEGER_TYPES, BOOL_TYPES);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -27,6 +27,7 @@
|
||||||
#include <helpers/TAD.h>
|
#include <helpers/TAD.h>
|
||||||
#include <helpers/ConstantTadHelper.h>
|
#include <helpers/ConstantTadHelper.h>
|
||||||
#include <Loops.h>
|
#include <Loops.h>
|
||||||
|
#include <graph/RandomGenerator.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -81,7 +82,7 @@ static void trace_(const NDArray& input, NDArray& output) {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void randomShuffle_(NDArray& input, NDArray& output, nd4j::random::RandomBuffer& rng, const bool isInplace) {
|
void randomShuffle_(NDArray& input, NDArray& output, nd4j::graph::RandomGenerator& rng, const bool isInplace) {
|
||||||
|
|
||||||
// check edge cases first
|
// check edge cases first
|
||||||
int temp;
|
int temp;
|
||||||
|
@ -95,16 +96,16 @@ void randomShuffle_(NDArray& input, NDArray& output, nd4j::random::RandomBuffer&
|
||||||
|
|
||||||
// apply Fisher-Yates shuffle
|
// apply Fisher-Yates shuffle
|
||||||
if(isInplace) {
|
if(isInplace) {
|
||||||
PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->tadThreshold())
|
//PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->tadThreshold())
|
||||||
for(int i = firstDim-1; i > 0; --i) {
|
for(int i = firstDim-1; i > 0; --i) {
|
||||||
int r = rng.nextInt(0, i);
|
int r = rng.relativeInt(i) % i;
|
||||||
if(i == r)
|
if(i == r)
|
||||||
continue;
|
continue;
|
||||||
T _e0 = input.e<T>(i);
|
T t0 = input.t<T>(i);
|
||||||
T _e1 = input.e<T>(r);
|
T t1 = input.t<T>(r);
|
||||||
//math::nd4j_swap<T>(input(i), input(r));
|
//math::nd4j_swap<T>(input(i), input(r));
|
||||||
input.p<T>(i, _e1);
|
input.t<T>(i) = t1;
|
||||||
input.p<T>(r, _e0);
|
input.t<T>(r) = t0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
@ -113,12 +114,12 @@ void randomShuffle_(NDArray& input, NDArray& output, nd4j::random::RandomBuffer&
|
||||||
output.p<T>(Nd4jLong(0), input.e<T>(0));
|
output.p<T>(Nd4jLong(0), input.e<T>(0));
|
||||||
PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->tadThreshold())
|
PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->tadThreshold())
|
||||||
for(int i = firstDim-1; i > 0; --i) {
|
for(int i = firstDim-1; i > 0; --i) {
|
||||||
int r = rng.nextInt(0, i);
|
int r = rng.relativeInt(i) % i;
|
||||||
output.p(i, input.e<T>(indices[r]));
|
output.t<T>(i) = input.t<T>(indices[r]);
|
||||||
if(i == r)
|
if(i == r)
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
output.p(r, input.e<T>(indices[i]));
|
output.t<T>(r) = input.t<T>(indices[i]);
|
||||||
math::nd4j_swap<int>(indices[i], indices[r]);
|
math::nd4j_swap<int>(indices[i], indices[r]);
|
||||||
}
|
}
|
||||||
rng.rewindH(firstDim-1);
|
rng.rewindH(firstDim-1);
|
||||||
|
@ -132,9 +133,10 @@ void randomShuffle_(NDArray& input, NDArray& output, nd4j::random::RandomBuffer&
|
||||||
|
|
||||||
// apply Fisher-Yates shuffle
|
// apply Fisher-Yates shuffle
|
||||||
if(isInplace) {
|
if(isInplace) {
|
||||||
PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->elementwiseThreshold())
|
//PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->elementwiseThreshold())
|
||||||
for(int i = firstDim-1; i > 0; --i) {
|
for(int i = firstDim - 1; i > 0; --i) {
|
||||||
int r = rng.nextInt(0, i);
|
int r = rng.relativeInt(i) % i;
|
||||||
|
|
||||||
if(i == r)
|
if(i == r)
|
||||||
continue;
|
continue;
|
||||||
subArrsListIn->at(i)->swapUnsafe(*subArrsListIn->at(r));
|
subArrsListIn->at(i)->swapUnsafe(*subArrsListIn->at(r));
|
||||||
|
@ -146,9 +148,9 @@ void randomShuffle_(NDArray& input, NDArray& output, nd4j::random::RandomBuffer&
|
||||||
std::vector<int> indices(firstDim);
|
std::vector<int> indices(firstDim);
|
||||||
std::iota(indices.begin(), indices.end(), 0);
|
std::iota(indices.begin(), indices.end(), 0);
|
||||||
bool isZeroShuffled = false;
|
bool isZeroShuffled = false;
|
||||||
PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->tadThreshold())
|
//PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->tadThreshold())
|
||||||
for(int i = firstDim-1; i > 0; --i) {
|
for(int i = firstDim - 1; i > 0; --i) {
|
||||||
int r = rng.nextInt(0, i);
|
int r = rng.relativeInt(i) % i;
|
||||||
subArrsListOut->at(i)->assign(subArrsListIn->at(indices[r]));
|
subArrsListOut->at(i)->assign(subArrsListIn->at(indices[r]));
|
||||||
if(r == 0)
|
if(r == 0)
|
||||||
isZeroShuffled = true;
|
isZeroShuffled = true;
|
||||||
|
@ -167,11 +169,11 @@ void randomShuffle_(NDArray& input, NDArray& output, nd4j::random::RandomBuffer&
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void randomShuffle(nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::random::RandomBuffer& rng, const bool isInplace) {
|
void randomShuffle(nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::graph::RandomGenerator& rng, const bool isInplace) {
|
||||||
BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (input, output, rng, isInplace), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (input, output, rng, isInplace), LIBND4J_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void randomShuffle_, (NDArray& input, NDArray& output, nd4j::random::RandomBuffer& rng, const bool isInplace), LIBND4J_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void randomShuffle_, (NDArray& input, NDArray& output, nd4j::graph::RandomGenerator& rng, const bool isInplace), LIBND4J_TYPES);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -15,130 +15,209 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// @author raver119@gmail.com
|
// @author raver119@gmail.com
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <ops/declarable/helpers/adjust_hue.h>
|
#include <ops/declarable/helpers/adjust_hue.h>
|
||||||
#include <helpers/ConstantTadHelper.h>
|
#include <helpers/ConstantTadHelper.h>
|
||||||
|
#include <PointersManager.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static void _CUDA_G adjustHueSingleNHWCKernel(void *xBuffer, Nd4jLong *xShapeInfo, void *zBuffer, Nd4jLong *zShapeInfo, Nd4jLong tuples, float delta) {
|
|
||||||
int numChannels = 3;
|
|
||||||
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
|
|
||||||
|
|
||||||
auto bIn = reinterpret_cast<T*>(xBuffer);
|
///////////////////////////////////////////////////////////////////
|
||||||
auto bOut = reinterpret_cast<T*>(zBuffer);
|
template <typename T>
|
||||||
static const int kChannelRange = 6;
|
static void _CUDA_G adjustHueCuda(const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets,
|
||||||
|
void* vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets,
|
||||||
|
const Nd4jLong numOfTads, const T delta, const int dimC) {
|
||||||
|
|
||||||
for (Nd4jLong e = tid; e < tuples; e += blockDim.x * gridDim.x) {
|
const T* x = reinterpret_cast<const T*>(vx);
|
||||||
auto i = bIn + e * numChannels;
|
T* z = reinterpret_cast<T*>(vz);
|
||||||
auto o = bOut + e * numChannels;
|
|
||||||
|
|
||||||
T h, v_min, v_max;
|
__shared__ int rank;
|
||||||
helpers::rgb_to_hv(i[0], i[1], i[2], &h, &v_min, &v_max);
|
__shared__ Nd4jLong xDimCstride, zDimCstride;
|
||||||
|
|
||||||
h += delta * kChannelRange;
|
if (threadIdx.x == 0) {
|
||||||
while (h < (T) 0.)
|
rank = shape::rank(xShapeInfo);
|
||||||
h += (T) kChannelRange;
|
xDimCstride = shape::stride(xShapeInfo)[dimC];
|
||||||
|
zDimCstride = shape::stride(zShapeInfo)[dimC];
|
||||||
while (h >= (T) kChannelRange)
|
|
||||||
h -= (T) kChannelRange;
|
|
||||||
|
|
||||||
helpers::hv_to_rgb(h, v_min, v_max, o, o + 1, o + 2);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
__syncthreads();
|
||||||
static void _CUDA_G adjustHueSingleNCHWKernel(void *xBuffer, Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets, void *zBuffer, Nd4jLong *zTadShapeInfo, Nd4jLong *zOffsets, Nd4jLong tadLength, Nd4jLong tuples, float delta) {
|
|
||||||
int numChannels = 3;
|
|
||||||
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
|
|
||||||
static const int kChannelRange = 6;
|
|
||||||
|
|
||||||
auto bufferR = reinterpret_cast<T *>(xBuffer) + xOffsets[0];
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
auto bufferG = reinterpret_cast<T *>(xBuffer) + xOffsets[1];
|
|
||||||
auto bufferB = reinterpret_cast<T *>(xBuffer) + xOffsets[2];
|
|
||||||
|
|
||||||
auto outputR = reinterpret_cast<T *>(zBuffer) + zOffsets[0];
|
for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) {
|
||||||
auto outputG = reinterpret_cast<T *>(zBuffer) + zOffsets[1];
|
|
||||||
auto outputB = reinterpret_cast<T *>(zBuffer) + zOffsets[2];
|
|
||||||
|
|
||||||
|
const T* xTad = x + xTadOffsets[i];
|
||||||
|
T* zTad = z + zTadOffsets[i];
|
||||||
|
|
||||||
for (Nd4jLong e = tid; e < tuples; e += blockDim.x * gridDim.x) {
|
T h, s, v;
|
||||||
auto _ri = bufferR + shape::getIndexOffset(e, xTadShapeInfo, tadLength);;
|
|
||||||
auto _gi = bufferG + shape::getIndexOffset(e, xTadShapeInfo, tadLength);;
|
|
||||||
auto _bi = bufferB + shape::getIndexOffset(e, xTadShapeInfo, tadLength);;
|
|
||||||
|
|
||||||
auto _ro = outputR + shape::getIndexOffset(e, xTadShapeInfo, tadLength);;
|
rgbToHsv<T>(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], h, s, v);
|
||||||
auto _go = outputG + shape::getIndexOffset(e, xTadShapeInfo, tadLength);;
|
|
||||||
auto _bo = outputB + shape::getIndexOffset(e, xTadShapeInfo, tadLength);;
|
|
||||||
|
|
||||||
T h, v_min, v_max;
|
h += delta * 360;
|
||||||
helpers::rgb_to_hv(_ri[0], _gi[0], _bi[0], &h, &v_min, &v_max);
|
if(h > 360)
|
||||||
|
h -= 360;
|
||||||
|
else if(h < 0)
|
||||||
|
h += 360;
|
||||||
|
|
||||||
h += delta * kChannelRange;
|
hsvToRgb<T>(h, s, v, zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]);
|
||||||
while (h < (T) 0)
|
|
||||||
h += (T) kChannelRange;
|
|
||||||
|
|
||||||
while (h >= (T) kChannelRange)
|
|
||||||
h -= (T) kChannelRange;
|
|
||||||
|
|
||||||
helpers::hv_to_rgb(h, v_min, v_max, _ro, _go, _bo);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
///////////////////////////////////////////////////////////////////
|
||||||
static void _adjust_hue_single(nd4j::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC) {
|
template<typename T>
|
||||||
// numChannels is always 3
|
static _CUDA_H void adjustHueCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream,
|
||||||
auto tuples = array->lengthOf() / 3;
|
const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets,
|
||||||
if (isNHWC) {
|
void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets,
|
||||||
adjustHueSingleNHWCKernel<T><<<256, 256, 1024, *context->getCudaStream()>>>(array->specialBuffer(), array->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), tuples, delta);
|
const Nd4jLong numOfTads, const NDArray* deltaScalarArr, const int dimC) {
|
||||||
} else {
|
|
||||||
// TODO: check this one
|
|
||||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(array->getShapeInfo(), {1, 2});
|
|
||||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {1, 2});
|
|
||||||
|
|
||||||
auto tadLength = shape::length(packX.primaryShapeInfo());
|
adjustHueCuda<T><<<blocksPerGrid, threadsPerBlock, 256, *stream>>>(vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, deltaScalarArr->e<T>(0), dimC);
|
||||||
|
}
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void adjustHueCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, const NDArray* deltaScalarArr, const int dimC), LIBND4J_TYPES);
|
||||||
|
|
||||||
adjustHueSingleNCHWKernel<T><<<256, 256, 1024, *context->getCudaStream()>>>(array->specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), output->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), tadLength, tuples, delta);
|
////////////////////////////////////////////////////////////////////////
|
||||||
}
|
void adjustHue(nd4j::LaunchContext* context, const NDArray *input, const NDArray* deltaScalarArr, NDArray *output, const int dimC) {
|
||||||
|
|
||||||
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {dimC});
|
||||||
|
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {dimC});
|
||||||
|
|
||||||
|
const Nd4jLong numOfTads = packX.numberOfTads();
|
||||||
|
|
||||||
|
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||||
|
const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
|
||||||
|
PointersManager manager(context, "adjustHue");
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({output}, {input, deltaScalarArr});
|
||||||
|
BUILD_SINGLE_SELECTOR(input->dataType(), adjustHueCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformOffsets(), numOfTads, deltaScalarArr, dimC), LIBND4J_TYPES);
|
||||||
|
NDArray::registerSpecialUse({output}, {input, deltaScalarArr});
|
||||||
|
|
||||||
|
manager.synchronize();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/*
|
||||||
|
template <typename T>
|
||||||
|
static void _CUDA_G adjustHueSingleNHWCKernel(void *xBuffer, Nd4jLong *xShapeInfo, void *zBuffer, Nd4jLong *zShapeInfo, Nd4jLong tuples, float delta) {
|
||||||
|
int numChannels = 3;
|
||||||
|
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
|
||||||
|
auto bIn = reinterpret_cast<T*>(xBuffer);
|
||||||
|
auto bOut = reinterpret_cast<T*>(zBuffer);
|
||||||
|
static const int kChannelRange = 6;
|
||||||
|
|
||||||
|
for (Nd4jLong e = tid; e < tuples; e += blockDim.x * gridDim.x) {
|
||||||
|
auto i = bIn + e * numChannels;
|
||||||
|
auto o = bOut + e * numChannels;
|
||||||
|
|
||||||
|
T h, v_min, v_max;
|
||||||
|
helpers::rgb_to_hv(i[0], i[1], i[2], &h, &v_min, &v_max);
|
||||||
|
|
||||||
|
h += delta * kChannelRange;
|
||||||
|
while (h < (T) 0.)
|
||||||
|
h += (T) kChannelRange;
|
||||||
|
|
||||||
|
while (h >= (T) kChannelRange)
|
||||||
|
h -= (T) kChannelRange;
|
||||||
|
|
||||||
|
helpers::hv_to_rgb(h, v_min, v_max, o, o + 1, o + 2);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static void _CUDA_G adjustHueSingleNCHWKernel(void *xBuffer, Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets, void *zBuffer, Nd4jLong *zTadShapeInfo, Nd4jLong *zOffsets, Nd4jLong tadLength, Nd4jLong tuples, float delta) {
|
||||||
|
int numChannels = 3;
|
||||||
|
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
static const int kChannelRange = 6;
|
||||||
|
|
||||||
|
auto bufferR = reinterpret_cast<T *>(xBuffer) + xOffsets[0];
|
||||||
|
auto bufferG = reinterpret_cast<T *>(xBuffer) + xOffsets[1];
|
||||||
|
auto bufferB = reinterpret_cast<T *>(xBuffer) + xOffsets[2];
|
||||||
|
|
||||||
|
auto outputR = reinterpret_cast<T *>(zBuffer) + zOffsets[0];
|
||||||
|
auto outputG = reinterpret_cast<T *>(zBuffer) + zOffsets[1];
|
||||||
|
auto outputB = reinterpret_cast<T *>(zBuffer) + zOffsets[2];
|
||||||
|
|
||||||
|
|
||||||
template <typename T>
|
for (Nd4jLong e = tid; e < tuples; e += blockDim.x * gridDim.x) {
|
||||||
static void _adjust_hue_batch(nd4j::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC) {
|
auto _ri = bufferR + shape::getIndexOffset(e, xTadShapeInfo, tadLength);;
|
||||||
auto xType = array->dataType();
|
auto _gi = bufferG + shape::getIndexOffset(e, xTadShapeInfo, tadLength);;
|
||||||
|
auto _bi = bufferB + shape::getIndexOffset(e, xTadShapeInfo, tadLength);;
|
||||||
|
|
||||||
// numChannels is always 3
|
auto _ro = outputR + shape::getIndexOffset(e, xTadShapeInfo, tadLength);;
|
||||||
auto tuples = array->lengthOf() / 3;
|
auto _go = outputG + shape::getIndexOffset(e, xTadShapeInfo, tadLength);;
|
||||||
|
auto _bo = outputB + shape::getIndexOffset(e, xTadShapeInfo, tadLength);;
|
||||||
|
|
||||||
if (isNHWC) {
|
T h, v_min, v_max;
|
||||||
// in case of nhwc batch, we don't really care about examples: it's still bunch of RGB values
|
helpers::rgb_to_hv(_ri[0], _gi[0], _bi[0], &h, &v_min, &v_max);
|
||||||
BUILD_SINGLE_SELECTOR(xType, _adjust_hue_single, (context, array, output, delta, isNHWC);, FLOAT_TYPES);
|
|
||||||
} else {
|
|
||||||
// TODO: check this one
|
|
||||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(array->getShapeInfo(), {0, 2, 3});
|
|
||||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {0, 2, 3});
|
|
||||||
|
|
||||||
auto tadLength = shape::length(packX.primaryShapeInfo());
|
h += delta * kChannelRange;
|
||||||
|
while (h < (T) 0)
|
||||||
|
h += (T) kChannelRange;
|
||||||
|
|
||||||
adjustHueSingleNCHWKernel<T><<<256, 256, 1024, *context->getCudaStream()>>>(array->specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), output->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), tadLength, tuples, delta);
|
while (h >= (T) kChannelRange)
|
||||||
}
|
h -= (T) kChannelRange;
|
||||||
|
|
||||||
|
helpers::hv_to_rgb(h, v_min, v_max, _ro, _go, _bo);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void _adjust_hue(nd4j::LaunchContext * context, NDArray *array, NDArray *output, NDArray* delta, bool isNHWC) {
|
template <typename T>
|
||||||
auto xType = array->dataType();
|
static void _adjust_hue_single(nd4j::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC) {
|
||||||
|
// numChannels is always 3
|
||||||
|
auto tuples = array->lengthOf() / 3;
|
||||||
|
if (isNHWC) {
|
||||||
|
adjustHueSingleNHWCKernel<T><<<256, 256, 1024, *context->getCudaStream()>>>(array->specialBuffer(), array->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), tuples, delta);
|
||||||
|
} else {
|
||||||
|
// TODO: check this one
|
||||||
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(array->getShapeInfo(), {1, 2});
|
||||||
|
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {1, 2});
|
||||||
|
|
||||||
float d = delta->e<float>(0);
|
auto tadLength = shape::length(packX.primaryShapeInfo());
|
||||||
if (array->rankOf() == 4) {
|
|
||||||
} else {
|
adjustHueSingleNCHWKernel<T><<<256, 256, 1024, *context->getCudaStream()>>>(array->specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), output->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), tadLength, tuples, delta);
|
||||||
BUILD_SINGLE_SELECTOR(xType, _adjust_hue_single, (context, array, output, d, isNHWC);, FLOAT_TYPES);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static void _adjust_hue_batch(nd4j::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC) {
|
||||||
|
auto xType = array->dataType();
|
||||||
|
|
||||||
|
// numChannels is always 3
|
||||||
|
auto tuples = array->lengthOf() / 3;
|
||||||
|
|
||||||
|
if (isNHWC) {
|
||||||
|
// in case of nhwc batch, we don't really care about examples: it's still bunch of RGB values
|
||||||
|
BUILD_SINGLE_SELECTOR(xType, _adjust_hue_single, (context, array, output, delta, isNHWC);, FLOAT_TYPES);
|
||||||
|
} else {
|
||||||
|
// TODO: check this one
|
||||||
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(array->getShapeInfo(), {0, 2, 3});
|
||||||
|
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {0, 2, 3});
|
||||||
|
|
||||||
|
auto tadLength = shape::length(packX.primaryShapeInfo());
|
||||||
|
|
||||||
|
adjustHueSingleNCHWKernel<T><<<256, 256, 1024, *context->getCudaStream()>>>(array->specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), output->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), tadLength, tuples, delta);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void _adjust_hue(nd4j::LaunchContext * context, NDArray *array, NDArray *output, NDArray* delta, bool isNHWC) {
|
||||||
|
auto xType = array->dataType();
|
||||||
|
|
||||||
|
float d = delta->e<float>(0);
|
||||||
|
if (array->rankOf() == 4) {
|
||||||
|
BUILD_SINGLE_SELECTOR(xType, _adjust_hue_batch, (context, array, output, d, isNHWC);, FLOAT_TYPES);
|
||||||
|
} else {
|
||||||
|
BUILD_SINGLE_SELECTOR(xType, _adjust_hue_single, (context, array, output, d, isNHWC);, FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
*/
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,121 +15,198 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// @author raver119@gmail.com
|
// @author raver119@gmail.com
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <ops/declarable/helpers/adjust_saturation.h>
|
#include <ops/declarable/helpers/adjust_saturation.h>
|
||||||
|
#include <ops/declarable/helpers/adjust_hue.h>
|
||||||
#include <helpers/ConstantTadHelper.h>
|
#include <helpers/ConstantTadHelper.h>
|
||||||
|
#include <PointersManager.h>
|
||||||
|
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static void _CUDA_G adjustSaturationSingleNHWCKernel(void *xBuffer, Nd4jLong *xShapeInfo, void *zBuffer, Nd4jLong *zShapeInfo, Nd4jLong tuples, float delta) {
|
|
||||||
int numChannels = 3;
|
|
||||||
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
|
|
||||||
|
|
||||||
auto bIn = reinterpret_cast<T*>(xBuffer);
|
///////////////////////////////////////////////////////////////////
|
||||||
auto bOut = reinterpret_cast<T*>(zBuffer);
|
template <typename T>
|
||||||
static const int kChannelRange = 6;
|
static void _CUDA_G adjustSaturationCuda(const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets,
|
||||||
|
void* vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets,
|
||||||
|
const Nd4jLong numOfTads, const T factor, const int dimC) {
|
||||||
|
|
||||||
for (Nd4jLong e = tid; e < tuples; e += blockDim.x * gridDim.x) {
|
const T* x = reinterpret_cast<const T*>(vx);
|
||||||
auto i = bIn + e * numChannels;
|
T* z = reinterpret_cast<T*>(vz);
|
||||||
auto o = bOut + e * numChannels;
|
|
||||||
|
|
||||||
T h, s, v;
|
__shared__ int rank;
|
||||||
// Convert the RGB color to Hue/V-range.
|
__shared__ Nd4jLong xDimCstride, zDimCstride;
|
||||||
helpers::rgb_to_hsv(i[0], i[1], i[2], &h, &s, &v);
|
|
||||||
s = nd4j::math::nd4j_min<T>((T) 1.0f, nd4j::math::nd4j_max<T>((T) 0.0f, s * delta));
|
|
||||||
|
|
||||||
// Convert the hue and v-range back into RGB.
|
if (threadIdx.x == 0) {
|
||||||
helpers::hsv_to_rgb(h, s, v, o, o + 1, o + 2);
|
rank = shape::rank(xShapeInfo);
|
||||||
}
|
xDimCstride = shape::stride(xShapeInfo)[dimC];
|
||||||
|
zDimCstride = shape::stride(zShapeInfo)[dimC];
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
__syncthreads();
|
||||||
static void _CUDA_G adjustSaturationSingleNCHWKernel(void *xBuffer, Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets, void *zBuffer, Nd4jLong *zTadShapeInfo, Nd4jLong *zOffsets, Nd4jLong tadLength, Nd4jLong tuples, float delta) {
|
|
||||||
int numChannels = 3;
|
|
||||||
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
|
|
||||||
static const int kChannelRange = 6;
|
|
||||||
|
|
||||||
auto bufferR = reinterpret_cast<T *>(xBuffer) + xOffsets[0];
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
auto bufferG = reinterpret_cast<T *>(xBuffer) + xOffsets[1];
|
|
||||||
auto bufferB = reinterpret_cast<T *>(xBuffer) + xOffsets[2];
|
|
||||||
|
|
||||||
auto outputR = reinterpret_cast<T *>(zBuffer) + zOffsets[0];
|
for (Nd4jLong i = tid; i < numOfTads; i += gridDim.x * blockDim.x) {
|
||||||
auto outputG = reinterpret_cast<T *>(zBuffer) + zOffsets[1];
|
|
||||||
auto outputB = reinterpret_cast<T *>(zBuffer) + zOffsets[2];
|
|
||||||
|
|
||||||
for (Nd4jLong e = tid; e < tuples; e += blockDim.x * gridDim.x) {
|
const T* xTad = x + xTadOffsets[i];
|
||||||
auto _ri = bufferR + shape::getIndexOffset(e, xTadShapeInfo, tadLength);
|
T* zTad = z + zTadOffsets[i];
|
||||||
auto _gi = bufferG + shape::getIndexOffset(e, xTadShapeInfo, tadLength);
|
|
||||||
auto _bi = bufferB + shape::getIndexOffset(e, xTadShapeInfo, tadLength);
|
|
||||||
|
|
||||||
auto _ro = outputR + shape::getIndexOffset(e, xTadShapeInfo, tadLength);
|
T h, s, v;
|
||||||
auto _go = outputG + shape::getIndexOffset(e, xTadShapeInfo, tadLength);
|
|
||||||
auto _bo = outputB + shape::getIndexOffset(e, xTadShapeInfo, tadLength);
|
|
||||||
|
|
||||||
T h, s, v;
|
rgbToHsv<T>(xTad[0], xTad[xDimCstride], xTad[2 * xDimCstride], h, s, v);
|
||||||
// Convert the RGB color to Hue/V-range.
|
|
||||||
helpers::rgb_to_hsv(_ri[0], _gi[0], _bi[0], &h, &s, &v);
|
s *= factor;
|
||||||
s = nd4j::math::nd4j_min<T>((T) 1.0f, nd4j::math::nd4j_max<T>((T) 0.0f, s * delta));
|
if(s > 1.f)
|
||||||
// Convert the hue and v-range back into RGB.
|
s = 1.f;
|
||||||
helpers::hsv_to_rgb(h, s, v, _ro, _go, _bo);
|
else if(s < 0.f)
|
||||||
}
|
s = 0.f;
|
||||||
|
|
||||||
|
hsvToRgb<T>(h, s, v, zTad[0], zTad[zDimCstride], zTad[2 * zDimCstride]);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
///////////////////////////////////////////////////////////////////
|
||||||
static void _adjust_saturation_single(nd4j::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC) {
|
template<typename T>
|
||||||
// numChannels is always 3
|
static _CUDA_H void adjustSaturationCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream,
|
||||||
auto tuples = array->lengthOf() / 3;
|
const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets,
|
||||||
|
void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets,
|
||||||
|
const Nd4jLong numOfTads, const NDArray* factorScalarArr, const int dimC) {
|
||||||
|
|
||||||
if (isNHWC) {
|
adjustSaturationCuda<T><<<blocksPerGrid, threadsPerBlock, 256, *stream>>>(vx, xShapeInfo, xTadOffsets, vz, zShapeInfo, zTadOffsets, numOfTads, factorScalarArr->e<T>(0), dimC);
|
||||||
adjustSaturationSingleNHWCKernel<T><<<256, 256, 1024, *context->getCudaStream()>>>(array->specialBuffer(), array->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), tuples, delta);
|
}
|
||||||
} else {
|
BUILD_SINGLE_TEMPLATE(template void adjustSaturationCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets, const Nd4jLong numOfTads, const NDArray* factorScalarArr, const int dimC), LIBND4J_TYPES);
|
||||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(array->getShapeInfo(), {1, 2});
|
|
||||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {1, 2});
|
|
||||||
|
|
||||||
auto tadLength = shape::length(packX.primaryShapeInfo());
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
void adjustSaturation(nd4j::LaunchContext* context, const NDArray *input, const NDArray* factorScalarArr, NDArray *output, const int dimC) {
|
||||||
|
|
||||||
adjustSaturationSingleNCHWKernel<T><<<256, 256, 1024, *context->getCudaStream()>>>(array->specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), output->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), tadLength, tuples, delta);
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {dimC});
|
||||||
}
|
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {dimC});
|
||||||
|
|
||||||
|
const Nd4jLong numOfTads = packX.numberOfTads();
|
||||||
|
|
||||||
|
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||||
|
const int blocksPerGrid = (numOfTads + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
|
||||||
|
PointersManager manager(context, "adjustSaturation");
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({output}, {input, factorScalarArr});
|
||||||
|
BUILD_SINGLE_SELECTOR(input->dataType(), adjustSaturationCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), packX.platformOffsets(), output->specialBuffer(), output->specialShapeInfo(), packZ.platformOffsets(), numOfTads, factorScalarArr, dimC), LIBND4J_TYPES);
|
||||||
|
NDArray::registerSpecialUse({output}, {input, factorScalarArr});
|
||||||
|
|
||||||
|
manager.synchronize();
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
template <typename T>
|
||||||
|
static void _CUDA_G adjustSaturationSingleNHWCKernel(void *xBuffer, Nd4jLong *xShapeInfo, void *zBuffer, Nd4jLong *zShapeInfo, Nd4jLong tuples, float delta) {
|
||||||
|
int numChannels = 3;
|
||||||
|
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
|
||||||
|
auto bIn = reinterpret_cast<T*>(xBuffer);
|
||||||
|
auto bOut = reinterpret_cast<T*>(zBuffer);
|
||||||
|
static const int kChannelRange = 6;
|
||||||
|
|
||||||
|
for (Nd4jLong e = tid; e < tuples; e += blockDim.x * gridDim.x) {
|
||||||
|
auto i = bIn + e * numChannels;
|
||||||
|
auto o = bOut + e * numChannels;
|
||||||
|
|
||||||
|
T h, s, v;
|
||||||
|
// Convert the RGB color to Hue/V-range.
|
||||||
|
helpers::rgb_to_hsv(i[0], i[1], i[2], &h, &s, &v);
|
||||||
|
s = nd4j::math::nd4j_min<T>((T) 1.0f, nd4j::math::nd4j_max<T>((T) 0.0f, s * delta));
|
||||||
|
|
||||||
|
// Convert the hue and v-range back into RGB.
|
||||||
|
helpers::hsv_to_rgb(h, s, v, o, o + 1, o + 2);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static void _adjust_saturation_batch(nd4j::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC) {
|
static void _CUDA_G adjustSaturationSingleNCHWKernel(void *xBuffer, Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets, void *zBuffer, Nd4jLong *zTadShapeInfo, Nd4jLong *zOffsets, Nd4jLong tadLength, Nd4jLong tuples, float delta) {
|
||||||
auto xType = array->dataType();
|
int numChannels = 3;
|
||||||
|
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
static const int kChannelRange = 6;
|
||||||
|
|
||||||
// numChannels is always 3
|
auto bufferR = reinterpret_cast<T *>(xBuffer) + xOffsets[0];
|
||||||
auto tuples = array->lengthOf() / 3;
|
auto bufferG = reinterpret_cast<T *>(xBuffer) + xOffsets[1];
|
||||||
|
auto bufferB = reinterpret_cast<T *>(xBuffer) + xOffsets[2];
|
||||||
|
|
||||||
if (isNHWC) {
|
auto outputR = reinterpret_cast<T *>(zBuffer) + zOffsets[0];
|
||||||
// in case of nhwc batch, we don't really care about examples: it's still bunch of RGB values
|
auto outputG = reinterpret_cast<T *>(zBuffer) + zOffsets[1];
|
||||||
BUILD_SINGLE_SELECTOR(xType, _adjust_saturation_single, (context, array, output, delta, isNHWC);, FLOAT_TYPES);
|
auto outputB = reinterpret_cast<T *>(zBuffer) + zOffsets[2];
|
||||||
} else {
|
|
||||||
// TODO: check this one
|
|
||||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(array->getShapeInfo(), {0, 2, 3});
|
|
||||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {0, 2, 3});
|
|
||||||
|
|
||||||
auto tadLength = shape::length(packX.primaryShapeInfo());
|
for (Nd4jLong e = tid; e < tuples; e += blockDim.x * gridDim.x) {
|
||||||
|
auto _ri = bufferR + shape::getIndexOffset(e, xTadShapeInfo, tadLength);
|
||||||
|
auto _gi = bufferG + shape::getIndexOffset(e, xTadShapeInfo, tadLength);
|
||||||
|
auto _bi = bufferB + shape::getIndexOffset(e, xTadShapeInfo, tadLength);
|
||||||
|
|
||||||
adjustSaturationSingleNCHWKernel<T><<<256, 256, 1024, *context->getCudaStream()>>>(array->specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), output->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), tadLength, tuples, delta);
|
auto _ro = outputR + shape::getIndexOffset(e, xTadShapeInfo, tadLength);
|
||||||
}
|
auto _go = outputG + shape::getIndexOffset(e, xTadShapeInfo, tadLength);
|
||||||
|
auto _bo = outputB + shape::getIndexOffset(e, xTadShapeInfo, tadLength);
|
||||||
|
|
||||||
|
T h, s, v;
|
||||||
|
// Convert the RGB color to Hue/V-range.
|
||||||
|
helpers::rgb_to_hsv(_ri[0], _gi[0], _bi[0], &h, &s, &v);
|
||||||
|
s = nd4j::math::nd4j_min<T>((T) 1.0f, nd4j::math::nd4j_max<T>((T) 0.0f, s * delta));
|
||||||
|
// Convert the hue and v-range back into RGB.
|
||||||
|
helpers::hsv_to_rgb(h, s, v, _ro, _go, _bo);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void adjust_saturation(nd4j::LaunchContext * context, NDArray *array, NDArray *output, NDArray* delta, bool isNHWC) {
|
template <typename T>
|
||||||
auto xType = array->dataType();
|
static void _adjust_saturation_single(nd4j::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC) {
|
||||||
|
// numChannels is always 3
|
||||||
|
auto tuples = array->lengthOf() / 3;
|
||||||
|
|
||||||
float d = delta->e<float>(0);
|
if (isNHWC) {
|
||||||
if (array->rankOf() == 4) {
|
adjustSaturationSingleNHWCKernel<T><<<256, 256, 1024, *context->getCudaStream()>>>(array->specialBuffer(), array->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), tuples, delta);
|
||||||
BUILD_SINGLE_SELECTOR(xType, _adjust_saturation_batch, (context, array, output, d, isNHWC);, FLOAT_TYPES);
|
} else {
|
||||||
} else {
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(array->getShapeInfo(), {1, 2});
|
||||||
BUILD_SINGLE_SELECTOR(xType, _adjust_saturation_single, (context, array, output, d, isNHWC);, FLOAT_TYPES);
|
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {1, 2});
|
||||||
}
|
|
||||||
|
auto tadLength = shape::length(packX.primaryShapeInfo());
|
||||||
|
|
||||||
|
adjustSaturationSingleNCHWKernel<T><<<256, 256, 1024, *context->getCudaStream()>>>(array->specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), output->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), tadLength, tuples, delta);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static void _adjust_saturation_batch(nd4j::LaunchContext * context, NDArray *array, NDArray *output, float delta, bool isNHWC) {
|
||||||
|
auto xType = array->dataType();
|
||||||
|
|
||||||
|
// numChannels is always 3
|
||||||
|
auto tuples = array->lengthOf() / 3;
|
||||||
|
|
||||||
|
if (isNHWC) {
|
||||||
|
// in case of nhwc batch, we don't really care about examples: it's still bunch of RGB values
|
||||||
|
BUILD_SINGLE_SELECTOR(xType, _adjust_saturation_single, (context, array, output, delta, isNHWC);, FLOAT_TYPES);
|
||||||
|
} else {
|
||||||
|
// TODO: check this one
|
||||||
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(array->getShapeInfo(), {0, 2, 3});
|
||||||
|
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {0, 2, 3});
|
||||||
|
|
||||||
|
auto tadLength = shape::length(packX.primaryShapeInfo());
|
||||||
|
|
||||||
|
adjustSaturationSingleNCHWKernel<T><<<256, 256, 1024, *context->getCudaStream()>>>(array->specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), output->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), tadLength, tuples, delta);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void adjust_saturation(nd4j::LaunchContext * context, NDArray *array, NDArray *output, NDArray* delta, bool isNHWC) {
|
||||||
|
auto xType = array->dataType();
|
||||||
|
|
||||||
|
float d = delta->e<float>(0);
|
||||||
|
if (array->rankOf() == 4) {
|
||||||
|
BUILD_SINGLE_SELECTOR(xType, _adjust_saturation_batch, (context, array, output, d, isNHWC);, FLOAT_TYPES);
|
||||||
|
} else {
|
||||||
|
BUILD_SINGLE_SELECTOR(xType, _adjust_saturation_single, (context, array, output, d, isNHWC);, FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,20 +22,99 @@
|
||||||
#include <NativeOps.h>
|
#include <NativeOps.h>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <cuda_exception.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static void dropoutSimple(NDArray const* input, NDArray* output, double probValue, int seed) {
|
static __global__ void dropoutSimpleKernel(void const* inputBuf, Nd4jLong const* inputShape, void* outputBuf, Nd4jLong* outputShape, double probVal, int inLen, nd4j::graph::RandomGenerator* nodeRng) {
|
||||||
|
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
auto step = blockDim.x * gridDim.x;
|
||||||
|
__shared__ T const* input;
|
||||||
|
__shared__ T* output;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
input = reinterpret_cast<T const*>(inputBuf);
|
||||||
|
output = reinterpret_cast<T*>(outputBuf);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (Nd4jLong e = 0; e < inLen; ++e) {
|
||||||
|
T val = nodeRng->relativeT(e, T(0.f), T(1.f));
|
||||||
|
|
||||||
|
if (double(val) < probVal)
|
||||||
|
output[shape::getIndexOffset(e, outputShape, inLen)] = T(input[shape::getIndexOffset(e, inputShape, inLen)] / probVal);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
BUILD_SINGLE_TEMPLATE(template void dropoutSimple, (NDArray const* input, NDArray* output, double probValue, int seed), FLOAT_TYPES);
|
|
||||||
|
template <typename T>
|
||||||
|
static void dropoutSimple(nd4j::LaunchContext* context, NDArray const* input, NDArray* output, double probValue, int seed) {
|
||||||
|
nd4j::graph::RandomGenerator nodeRng(3019L, seed);
|
||||||
|
int inLen = input->lengthOf();
|
||||||
|
nd4j::graph::RandomGenerator* dRandom;
|
||||||
|
auto stream = context->getCudaStream();
|
||||||
|
NDArray::prepareSpecialUse({output}, {input});
|
||||||
|
|
||||||
|
auto err = cudaMalloc(&dRandom, sizeof(nd4j::graph::RandomGenerator));
|
||||||
|
if (err) {
|
||||||
|
throw cuda_exception::build("helpers::dropoutSimple: Cannot allocate device memory for random generator.", err);
|
||||||
|
}
|
||||||
|
err = cudaMemcpy(dRandom, &nodeRng, sizeof(nd4j::graph::RandomGenerator), cudaMemcpyHostToDevice);
|
||||||
|
if (err) {
|
||||||
|
throw cuda_exception::build("helpers::dropoutSimple: Cannot set up device memory for random generator.", err);
|
||||||
|
}
|
||||||
|
|
||||||
|
dropoutSimpleKernel<T><<<128, 256, 1024, *stream>>>(input->getSpecialBuffer(), input->getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), probValue, inLen, dRandom);
|
||||||
|
err = cudaFree(dRandom);
|
||||||
|
if (err) {
|
||||||
|
throw cuda_exception::build("helpers::dropoutSimple: Cannot deallocate device memory for random generator.", err);
|
||||||
|
}
|
||||||
|
NDArray::registerSpecialUse({output}, {input});
|
||||||
|
}
|
||||||
|
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void dropoutSimple, (nd4j::LaunchContext* context, NDArray const* input, NDArray* output, double probValue, int seed), FLOAT_TYPES);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
int _dropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue) {
|
int _dropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue) {
|
||||||
|
|
||||||
|
if (reduceShape == nullptr){
|
||||||
|
dropoutSimple<T>(context.launchContext(), input, output, probValue, seed);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
REQUIRE_TRUE(reduceShape->lengthOf() <= input->rankOf(), 0, "dropout: Noise shape should be fittable to input");
|
||||||
|
|
||||||
|
std::vector<Nd4jLong> dims(reduceShape->lengthOf());
|
||||||
|
reduceShape->syncToHost(); // to ensure that follows are actual
|
||||||
|
bool fit = true;
|
||||||
|
// PRAGMA_OMP_PARALLEL_FOR_ARGS(firstprivate(fit))
|
||||||
|
for( int i = 0; i < dims.size(); i++ ) {
|
||||||
|
if (fit) {
|
||||||
|
dims[i] = reduceShape->e<Nd4jLong>(i);
|
||||||
|
for (int e = 0; e < input->rankOf(); ++e)
|
||||||
|
if (fit)
|
||||||
|
if (input->sizeAt(e) % dims[i]) {
|
||||||
|
fit = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check dims to fit input
|
||||||
|
REQUIRE_TRUE(fit, 0, "dropout: Noise shape should fit to input rank.");
|
||||||
|
std::unique_ptr<NDArray> chunk(new NDArray('c', dims, output->dataType(), context.launchContext()));
|
||||||
|
chunk->assign(1.f);
|
||||||
|
//chunk->applyRandom<randomOps::DropOutInverted<T>>(rng, nullptr, chunk.get(), &probValue);
|
||||||
|
//NativeOpExecutioner::execRandom(random::DropOutInverted, rng, chunk->buffer(), chunk->shapeInfo(), chunk->buffer(), chunk->shapeInfo(), &prob);
|
||||||
|
dropoutSimple<T>(context.launchContext(), chunk.get(), chunk.get(), probValue, seed);
|
||||||
|
// broadcast chunk to full matrix
|
||||||
|
std::unique_ptr<NDArray> dropOutMultiplier(new NDArray(*input));
|
||||||
|
dropOutMultiplier->assign(1.f);
|
||||||
|
|
||||||
|
*dropOutMultiplier += *chunk;
|
||||||
|
|
||||||
|
output->assign(*input * *dropOutMultiplier); //input->applyPairwiseTransform(pairwise::Multiply, dropOutMultiplier.get(), output, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -48,14 +127,121 @@ namespace helpers {
|
||||||
BUILD_SINGLE_TEMPLATE(template int _dropOutFunctor, (graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue);, FLOAT_TYPES);
|
BUILD_SINGLE_TEMPLATE(template int _dropOutFunctor, (graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue);, FLOAT_TYPES);
|
||||||
|
|
||||||
/////////////////////////////////// backrpopagations ///////////////////////////////////////////////
|
/////////////////////////////////// backrpopagations ///////////////////////////////////////////////
|
||||||
|
template <typename T>
|
||||||
|
static __global__ void dropoutBPKernel(void* outputBuf, Nd4jLong* outputShape, void* gradOutBuf, Nd4jLong* gradOutShape, double probValue) {
|
||||||
|
__shared__ T* output;
|
||||||
|
__shared__ T* input;
|
||||||
|
__shared__ int len;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
len = shape::length(outputShape);
|
||||||
|
output = reinterpret_cast<T*>(outputBuf);
|
||||||
|
input = reinterpret_cast<T*>(gradOutBuf);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
auto step = blockDim.x * gridDim.x;
|
||||||
|
|
||||||
|
for (int e = tid; e < len; e += step) {
|
||||||
|
if (output[shape::getIndexOffset(e, outputShape, len)] != T(0.))
|
||||||
|
output[shape::getIndexOffset(e, outputShape, len)] = T(input[shape::getIndexOffset(e, gradOutShape, len)] / probValue);
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static int dropOutFunctorBP_(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue) {
|
static int dropOutFunctorBP_(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue) {
|
||||||
return Status::OK();
|
int res = dropOutFunctor(context, input, output, reduceShape, seed, probValue);
|
||||||
|
auto stream = context.launchContext()->getCudaStream();
|
||||||
|
|
||||||
|
if (ND4J_STATUS_OK == res)
|
||||||
|
dropoutBPKernel<T><<<128, 256, 1024, *stream>>>(output->specialBuffer(), output->specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), probValue);
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static __global__ void alphaDropoutSimpleKernel(void const* inputBuf, Nd4jLong const* inputShape, void* outputBuf, Nd4jLong* outputShape, double probValue, double alpha, double alpha1, double beta, int inLen, nd4j::graph::RandomGenerator* nodeRng) {
|
||||||
|
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
auto step = blockDim.x * gridDim.x;
|
||||||
|
__shared__ T const* input;
|
||||||
|
__shared__ T* output;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
input = reinterpret_cast<T const*>(inputBuf);
|
||||||
|
output = reinterpret_cast<T*>(outputBuf);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto e = tid; e < inLen; e += step) {
|
||||||
|
T val = nodeRng->relativeT(e, T(0.f), T(1.f));
|
||||||
|
T xVal = input[shape::getIndexOffset(e, inputShape, inLen)];
|
||||||
|
output[shape::getIndexOffset(e, outputShape, inLen)] = (val >= T(probValue) ? T(alpha * beta + alpha1) : T(alpha * (double)xVal + alpha1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
template <typename T>
|
||||||
|
static void alphaDropoutSimple(nd4j::LaunchContext* context, NDArray const* input, NDArray* output, int seed, double probValue, double alpha, double alpha1, double beta) {
|
||||||
|
nd4j::graph::RandomGenerator nodeRng(3019L, seed), *dRandom;
|
||||||
|
auto stream = context->getCudaStream();
|
||||||
|
auto err = cudaMalloc(&dRandom, sizeof(nd4j::graph::RandomGenerator));
|
||||||
|
NDArray::prepareSpecialUse({output}, {input});
|
||||||
|
if (err) {
|
||||||
|
throw cuda_exception::build("helpers::alphaDropoutSimple: Cannot allocate device memory for random generator.", err);
|
||||||
|
}
|
||||||
|
err = cudaMemcpy(dRandom, &nodeRng, sizeof(nd4j::graph::RandomGenerator), cudaMemcpyHostToDevice);
|
||||||
|
if (err) {
|
||||||
|
throw cuda_exception::build("helpers::alphaDropoutSimple: Cannot set up device memory for random generator.", err);
|
||||||
|
}
|
||||||
|
|
||||||
|
alphaDropoutSimpleKernel<T><<<128, 256, 1024, *stream>>>(input->getSpecialBuffer(), input->getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), probValue, alpha, alpha1, beta, output->lengthOf(), dRandom);
|
||||||
|
|
||||||
|
err = cudaFree(dRandom);
|
||||||
|
if (err) {
|
||||||
|
throw cuda_exception::build("helpers::alphaDropoutSimple: Cannot deallocate device memory for random generator.", err);
|
||||||
|
}
|
||||||
|
NDArray::registerSpecialUse({output}, {input});
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static int alphaDropOutFunctor_(graph::Context& context, NDArray* input, NDArray* output,
|
static int alphaDropOutFunctor_(graph::Context& context, NDArray* input, NDArray* output,
|
||||||
NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) {
|
NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) {
|
||||||
|
|
||||||
|
if (reduceShape == nullptr){
|
||||||
|
alphaDropoutSimple<T>(context.launchContext(), input, output, seed, probValue, alpha, alpha1, beta);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
REQUIRE_TRUE(reduceShape->lengthOf() <= input->rankOf(), 0, "dropout: Noise shape should be fittable to input");
|
||||||
|
|
||||||
|
std::vector<Nd4jLong> dims(reduceShape->lengthOf());
|
||||||
|
reduceShape->syncToHost(); // to ensure that follows are actual
|
||||||
|
bool fit = true;
|
||||||
|
// PRAGMA_OMP_PARALLEL_FOR_ARGS(firstprivate(fit))
|
||||||
|
for( int i = 0; i < dims.size(); i++ ) {
|
||||||
|
if (fit) {
|
||||||
|
dims[i] = reduceShape->e<Nd4jLong>(i);
|
||||||
|
for (int e = 0; e < input->rankOf(); ++e)
|
||||||
|
if (fit)
|
||||||
|
if (input->sizeAt(e) % dims[i]) {
|
||||||
|
fit = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check dims to fit input
|
||||||
|
REQUIRE_TRUE(fit, 0, "alpha_dropout: Noise shape should fit to input rank.");
|
||||||
|
std::unique_ptr<NDArray> chunk(new NDArray('c', dims, output->dataType(), context.launchContext()));
|
||||||
|
chunk->assign(1.f);
|
||||||
|
//chunk->applyRandom<randomOps::DropOutInverted<T>>(rng, nullptr, chunk.get(), &probValue);
|
||||||
|
//NativeOpExecutioner::execRandom(random::DropOutInverted, rng, chunk->buffer(), chunk->shapeInfo(), chunk->buffer(), chunk->shapeInfo(), &prob);
|
||||||
|
alphaDropoutSimple<T>(context.launchContext(), chunk.get(), chunk.get(), seed, probValue, alpha, alpha1, beta);
|
||||||
|
// broadcast chunk to full matrix
|
||||||
|
std::unique_ptr<NDArray> dropOutMultiplier(new NDArray(*input));
|
||||||
|
dropOutMultiplier->assign(1.f);
|
||||||
|
|
||||||
|
*dropOutMultiplier += *chunk;
|
||||||
|
|
||||||
|
output->assign(*input * *dropOutMultiplier); //input->applyPairwiseTransform(pairwise::Multiply, dropOutMultiplier.get(), output, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -63,7 +249,12 @@ namespace helpers {
|
||||||
int alphaDropOutFunctorBP_(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output,
|
int alphaDropOutFunctorBP_(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output,
|
||||||
NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) {
|
NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) {
|
||||||
|
|
||||||
return Status::OK();
|
int res = alphaDropOutFunctor(context, input, output, reduceShape, seed, probValue, alpha, alpha1, beta);
|
||||||
|
if (res == ND4J_STATUS_OK) {
|
||||||
|
(*output) *= alpha;
|
||||||
|
(*output) *= (*gradOut); //->applyPairwiseTransform<transform::Multiply>(gradOut, output, nullptr);
|
||||||
|
}
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
int dropOutFunctorBP(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue) {
|
int dropOutFunctorBP(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue) {
|
||||||
|
|
|
@ -35,58 +35,88 @@ namespace helpers {
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void gruCell(nd4j::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* Wru, const NDArray* Wc,
|
void gruCell(nd4j::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* W, const NDArray* Wc,
|
||||||
const NDArray* bru, const NDArray* bc,
|
const NDArray* b, const NDArray* bc,
|
||||||
NDArray* r, NDArray* u, NDArray* c, NDArray* h) {
|
NDArray* r, NDArray* u, NDArray* c, NDArray* h) {
|
||||||
|
|
||||||
//Inputs:
|
//Inputs:
|
||||||
// x input [bS x inSize]
|
// x input [bS, iS], iS - input size
|
||||||
// hLast previous cell output [bS x numUnits], that is at previous time step t-1
|
// hLast previous cell output [bS, nU], that is at previous time step t-1, nU - number of units
|
||||||
// Wru RU weights - [bS, 2*numUnits] - reset and update gates
|
// W RU weights - [iS+nU, 2*nU] - reset and update gates
|
||||||
// Wc C weights - [bS, numUnits] - cell gate
|
// Wc C weights - [iS+nU, nU] - cell gate
|
||||||
// bru r and u biases, [2*numUnits] - reset and update gates
|
// b r and u biases, [2*nU] - reset and update gates
|
||||||
// bc c biases, [numUnits] - cell gate
|
// bc c biases, [nU] - cell gate
|
||||||
|
|
||||||
//Outputs:
|
//Outputs:
|
||||||
// r Reset gate output [bS, numUnits]
|
// r Reset gate output [bS, nU]
|
||||||
// u Update gate output [bS, numUnits]
|
// u Update gate output [bS, nU]
|
||||||
// c Cell gate output [bS, numUnits]
|
// c Cell gate output [bS, nU]
|
||||||
// h current cell output [bS, numUnits]
|
// h current cell output [bS, nU]
|
||||||
|
|
||||||
const int nIn = x->sizeAt(1);
|
/***************************************************************************************/
|
||||||
const int nU = hLast->sizeAt(1); // number of units
|
/************************ THIS IS NOT OPTIMAZED CODE ***********************************/
|
||||||
|
/** however it is more math-friendly and convenient for backprop formulas derivation) **/
|
||||||
|
|
||||||
//Concat inputs: [x, yt-1]: concat([bs,nIn],[bs,nOut]) -> [bs, (nIn+nOut)]
|
const int bS = x->sizeAt(0);
|
||||||
nd4j::ops::concat concatOp;
|
const int iS = x->sizeAt(1);
|
||||||
std::vector<NDArray*> inputs;
|
const int nU = hLast->sizeAt(1);
|
||||||
std::vector<double> targs;
|
|
||||||
std::vector<Nd4jLong> iargs({1}); //Axis = 1
|
|
||||||
std::vector<bool> bargs;
|
|
||||||
inputs.emplace_back(const_cast<NDArray*>(x));
|
|
||||||
inputs.emplace_back(const_cast<NDArray*>(hLast));
|
|
||||||
|
|
||||||
auto result = concatOp.execute(inputs, targs, iargs, bargs);
|
NDArray Wrx = (*W)({0,iS, 0,nU}); // [iS, nU]
|
||||||
auto concatOut = result->at(0);
|
NDArray Wux = (*W)({0,iS, nU,2*nU}); // [iS, nU]
|
||||||
|
NDArray Wrh = (*W)({iS,iS+nU, 0,nU}); // [nU, nU]
|
||||||
|
NDArray Wuh = (*W)({iS,iS+nU, nU,2*nU}); // [nU, nU]
|
||||||
|
|
||||||
//mmul/z for reset and update gates: (x * weight_ux + hLast * weight_xr + b_u)
|
NDArray Wcx = (*Wc)({0,iS, 0,0}); // reset cell weights [iS, nU]
|
||||||
auto m = mmul(*concatOut, *Wru); //mmul: [bs, (nIn+numUnits)]* [(inSize+numUnits), 2*numUnits] = [bs, 4*numUnits]
|
NDArray Wch = (*Wc)({iS,iS+nU, 0,0}); // updates cell weights [nU, nU]
|
||||||
m += (*bru);
|
|
||||||
|
|
||||||
sigmoidInplace(m); //sigmoid(rz) and sigmoid(uz)
|
NDArray br = (*b)({0, nU}); // [nU]
|
||||||
auto mr = m({0,0, 0, nU});
|
NDArray bu = (*b)({nU, 2*nU}); // [nU]
|
||||||
auto mu = m({0,0, nU, 2*nU});
|
|
||||||
|
|
||||||
r->assign(&mr);
|
// × means matrix multipication
|
||||||
u->assign(&mu);
|
// * means element-wise product or so called Hadamard product
|
||||||
|
|
||||||
//Concatenated inputs: [x, yt-1 .* r]
|
// reset gate
|
||||||
auto yr = (*concatOut)({0,0, nIn, nIn+nU});
|
r->assign(mmul(*x, Wrx) + mmul(*hLast, Wrh) + br); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
|
||||||
yr *= (*r);
|
r->applyTransform(transform::Sigmoid);
|
||||||
|
|
||||||
//c = tanh(x * weight_cx + (hLast .* r) * weight_cr + b_c)
|
// update gate
|
||||||
MmulHelper::mmul(concatOut, const_cast<NDArray*>(Wc), c, 1.0, 0.0); //c = 1.0 * concatOut * Wc + 0.0 * c
|
u->assign(mmul(*x, Wux) + mmul(*hLast, Wuh) + bu); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
|
||||||
|
u->applyTransform(transform::Sigmoid);
|
||||||
|
|
||||||
|
// cell gate c = activation(x × Wcx + (r * hlast) × Wch + bc)
|
||||||
|
c->assign(mmul(*x, Wcx) + mmul(*r * *hLast, Wch) + *bc); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
|
||||||
|
c->applyTransform(transform::Tanh);
|
||||||
|
|
||||||
|
NDArray temp = 1.f - *c * *c;
|
||||||
|
|
||||||
|
// cell output
|
||||||
|
h->assign(*u * *hLast + (1.f - *u) * *c);
|
||||||
|
|
||||||
|
|
||||||
|
/***************************************************************************************/
|
||||||
|
/*************** THIS IS MORE OPTIMAZED CODE (should think about concat) ***************/
|
||||||
|
/***************************************************************************************/
|
||||||
|
/*
|
||||||
|
//Concat inputs: x + hLast : [bs, iS + nU]
|
||||||
|
NDArray xhConcat(x->ordering(), {bS, iS + nU}, x->dataType(), context); // concat([bs, iS], [bs, nU]) -> [bs, iS + nU]
|
||||||
|
helpers::concat(context, {const_cast<NDArray*>(x), const_cast<NDArray*>(hLast)}, xhConcat, {1});
|
||||||
|
|
||||||
|
//mmul for reset and update gates: (x × weight_ux + hLast × weight_xr + b_u)
|
||||||
|
auto m = mmul(xhConcat, *W) + *b ; // [bs, iS+nU] * [iS+nU, 2*nU] = [bs, 2*nU]
|
||||||
|
// m += *bru;
|
||||||
|
|
||||||
|
m.applyTransform(transform::Sigmoid); //sigmoid(rz) and sigmoid(uz)
|
||||||
|
|
||||||
|
r->assign(m({0,0, 0, nU}));
|
||||||
|
u->assign(m({0,0, nU, 2*nU}));
|
||||||
|
|
||||||
|
// hLast = hLast * r
|
||||||
|
xhConcat({0,0, iS, iS+nU}) *= *r;
|
||||||
|
|
||||||
|
//c = tanh(x × weight_cx + (hLast * r) × weight_cr + b_c)
|
||||||
|
MmulHelper::mmul(&xhConcat, Wc, c, 1.0, 0.0); //c = 1.0 * xhConcat * Wc + 0.0 * c
|
||||||
*c += *bc;
|
*c += *bc;
|
||||||
tanhInplace(*c);
|
c->applyTransform(transform::Tanh);
|
||||||
|
|
||||||
//Output: h = (1-u).*c + u .* hPrev
|
//Output: h = (1-u).*c + u .* hPrev
|
||||||
//auto hResult = (*u) * (*hLast) + (1.0f - *u) * (*c); const_cast<NDArray*>(h)->assign(&hResult);
|
//auto hResult = (*u) * (*hLast) + (1.0f - *u) * (*c); const_cast<NDArray*>(h)->assign(&hResult);
|
||||||
|
@ -94,115 +124,238 @@ void gruCell(nd4j::LaunchContext * context, const NDArray* x, const NDArray* hLa
|
||||||
auto temp = (1.0f - *u);
|
auto temp = (1.0f - *u);
|
||||||
temp *= (*c);
|
temp *= (*c);
|
||||||
(*h) += temp;
|
(*h) += temp;
|
||||||
|
*/
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void gruTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h) {
|
void gruTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h) {
|
||||||
|
|
||||||
|
// x input [time, bS, iS]
|
||||||
|
// hLast initial cell output (at time step = 0) [bS, nU]
|
||||||
|
// Wx input-to-hidden weights, [iS, 3*nU]
|
||||||
|
// Wh hidden-to-hidden weights, [nU, 3*nU]
|
||||||
|
// b biases, [3*nU]
|
||||||
|
|
||||||
|
// h is cell outputs at each time step [time, bS, nU]
|
||||||
|
|
||||||
|
const int time = x->sizeAt(0);
|
||||||
|
|
||||||
|
NDArray ht_1(*hLast);
|
||||||
|
|
||||||
|
// loop through time steps
|
||||||
|
for (int t = 0; t < time; ++t) {
|
||||||
|
|
||||||
|
auto xt = (*x)({t,t+1, 0,0, 0,0});
|
||||||
|
auto ht = (*h)({t,t+1, 0,0, 0,0});
|
||||||
|
|
||||||
|
// helpers::gruCell(&xt, &ht_1, Wx, Wh, b, &ht);
|
||||||
|
// ht_1.assign(ht);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void gruCellBP(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh, const NDArray* dLdWx0,
|
void gruCellBP(nd4j::LaunchContext* context,
|
||||||
const NDArray* dLdWh0, const NDArray* dLdb0, NDArray* dLdx, NDArray* dLdh0, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb) {
|
const NDArray* x, const NDArray* hLast,
|
||||||
|
const NDArray* W, const NDArray* Wc, const NDArray* b, const NDArray* bc,
|
||||||
|
const NDArray* dLdr, const NDArray* dLdu, const NDArray* dLdc, const NDArray* dLdh,
|
||||||
|
NDArray* dLdx, NDArray* dLdhLast,
|
||||||
|
NDArray* dLdW, NDArray* dLdWc,
|
||||||
|
NDArray* dLdb, NDArray* dLdbc) {
|
||||||
|
|
||||||
// x input [bS, iS]
|
//Inputs:
|
||||||
// h0 previous cell output [bS, nU], that is at previous time step t-1
|
// x input [bS, iS]
|
||||||
// Wx input-to-hidden weights, [iS, 3*nU]
|
// hLast previous cell output [bS, nU], that is at previous time step t-1
|
||||||
// Wh hidden-to-hidden weights, [nU, 3*nU]
|
// W weights - [iS+nU, 2*nU] - reset and update gates
|
||||||
// b biases, [3*nU]
|
// Wc C weights - [iS+nU, nU] - cell gate
|
||||||
// dLdh gradient wrt output, [bS,nU], that is epsilon_next
|
// b r and u biases, [2*nU] - reset and update gates
|
||||||
// dLdWx0 gradient wrt Wx at previous time step, [iS, 3*nU]
|
// bc c biases, [nU] - cell gate
|
||||||
// dLdWh0 gradient wrt Wh at previous time step, [nU, 3*nU]
|
// dLdr gradient wrt reset gate, [bS, nU]
|
||||||
// dLdb0 gradient wrt b at previous time step, [3*nU]
|
// dLdu gradient wrt update gate, [bS, nU]
|
||||||
|
// dLdc gradient wrt cell state, [bS, nU]
|
||||||
|
// dLdh gradient wrt current cell output, [bS, nU]
|
||||||
|
|
||||||
// dLdx gradient wrt x, [bS, iS], that is epsilon
|
//Outputs:
|
||||||
// dLdh0 gradient wrt h0, [bS, nU]
|
// dLdx gradient wrt x, [bS, iS],
|
||||||
// dLdWx gradient wrt Wx, [iS, 3*nU]
|
// dLdhLast gradient wrt hLast, [bS, nU]
|
||||||
// dLdWh gradient wrt Wh, [nU, 3*nU]
|
// dLdW gradient wrt W, [iS+nU, 2*nU]
|
||||||
// dLdb gradient wrt b at previous time step, [3*nU]
|
// dLdWc gradient wrt Wc, [iS+nU, nU]
|
||||||
|
// dLdb gradient wrt bru [2*nU]
|
||||||
|
// dLdbc gradient wrt bc [nU]
|
||||||
|
|
||||||
// h is current cell output [bS, nU], that is at current time step t
|
// * means element-wise product or so called Hadamard product
|
||||||
|
// × means matrix multiplication
|
||||||
|
|
||||||
|
/************************************************************************************************/
|
||||||
|
/******************************* THIS IS NOT OPTIMAZED CODE *************************************/
|
||||||
|
/*** aim is to have math-readable code in order to keep track of backprop formulas derivation ***/
|
||||||
|
|
||||||
|
const int bS = x->sizeAt(0);
|
||||||
|
const int iS = x->sizeAt(1);
|
||||||
|
const int nU = hLast->sizeAt(1);
|
||||||
|
|
||||||
|
NDArray xT = x->transpose(); // [iS, bS]
|
||||||
|
NDArray hLastT = hLast->transpose(); // [nU, bS]
|
||||||
|
|
||||||
|
NDArray Wrx = (*W)({0,iS, 0,nU}); // [iS, nU]
|
||||||
|
NDArray Wux = (*W)({0,iS, nU,2*nU}); // [iS, nU]
|
||||||
|
NDArray Wrh = (*W)({iS,iS+nU, 0,nU}); // [nU, nU]
|
||||||
|
NDArray Wuh = (*W)({iS,iS+nU, nU,2*nU}); // [nU, nU]
|
||||||
|
|
||||||
|
NDArray Wcx = (*Wc)({0,iS, 0,0}); // reset cell weights [iS, nU]
|
||||||
|
NDArray Wch = (*Wc)({iS,iS+nU, 0,0}); // updates cell weights [nU, nU]
|
||||||
|
|
||||||
|
NDArray br = (*b)({0, nU}); // [nU]
|
||||||
|
NDArray bu = (*b)({nU, 2*nU}); // [nU]
|
||||||
|
|
||||||
|
NDArray WrxT = Wrx.transpose(); // [nU, iS]
|
||||||
|
NDArray WuxT = Wux.transpose(); // [nU, iS]
|
||||||
|
NDArray WrhT = Wrh.transpose(); // [nU, nU]
|
||||||
|
NDArray WuhT = Wuh.transpose(); // [nU, nU]
|
||||||
|
|
||||||
|
NDArray WcxT = Wcx.transpose(); // [nU, iS]
|
||||||
|
NDArray WchT = Wch.transpose(); // [nU, nU]
|
||||||
|
|
||||||
|
NDArray dLdWrx = (*dLdW)({0,iS, 0,nU}); // [iS, nU]
|
||||||
|
NDArray dLdWux = (*dLdW)({0,iS, nU,2*nU}); // [iS, nU]
|
||||||
|
NDArray dLdWrh = (*dLdW)({iS,iS+nU, 0,nU}); // [nU, nU]
|
||||||
|
NDArray dLdWuh = (*dLdW)({iS,iS+nU, nU,2*nU}); // [nU, nU]
|
||||||
|
|
||||||
|
NDArray dLdWcx = (*dLdWc)({0,iS, 0,0}); // [iS, nU]
|
||||||
|
NDArray dLdWch = (*dLdWc)({iS,iS+nU, 0,0}); // [nU, nU]
|
||||||
|
|
||||||
|
NDArray dLdbr = (*dLdb)({0, nU}); // [nU]
|
||||||
|
NDArray dLdbu = (*dLdb)({nU, 2*nU}); // [nU]
|
||||||
|
|
||||||
const int nU = h0->sizeAt(1);
|
|
||||||
|
|
||||||
// ***** feed forward step ***** //
|
// ***** feed forward step ***** //
|
||||||
// gates = sigmoid(x*Wx + h0*Wh + b)
|
|
||||||
auto gates = sigmoid(mmul(*x, (*Wx)({0,0, 0,2*nU})) + mmul(*h0, (*Wh)({0,0, 0,2*nU})) + (*b)({0,2*nU})); // [bS, 2*nU] + [bS, 2*nU] + [1, 2*nU] = [bS, 2*nU]
|
|
||||||
// reset gate
|
// reset gate
|
||||||
auto r = gates({0,0, 0, nU}); // [bS, nU]
|
NDArray r = mmul(*x, Wrx) + mmul(*hLast, Wrh) + br; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
|
||||||
|
r.applyTransform(transform::Sigmoid);
|
||||||
|
|
||||||
// update gate
|
// update gate
|
||||||
auto u = gates({0,0, nU, 2*nU}); // [bS, nU]
|
NDArray u = mmul(*x, Wux) + mmul(*hLast, Wuh) + bu; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
|
||||||
// ◦ means element-wise product or so called Hadamard product
|
u.applyTransform(transform::Sigmoid);
|
||||||
// n = tanh(x*Wx + (r◦h0)*Wh + b)
|
|
||||||
auto n = tanh(mmul(*x, (*Wx)({0,0, 2*nU,3*nU})) + mmul((*h0)*r, (*Wh)({0,0, 2*nU,3*nU})) + (*b)({2*nU,3*nU})); // [bS, nU]
|
// cell gate c = activation(x×Wcx + (r*hlast)×Wcu + bc)
|
||||||
|
NDArray c = mmul(*x, Wcx) + mmul(r * *hLast, Wch) + *bc; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
|
||||||
|
c.applyTransform(transform::Tanh);
|
||||||
|
|
||||||
|
// h = (1 - u) * c + u * hPrev
|
||||||
|
|
||||||
|
|
||||||
// ***** back prop step ***** //
|
// ***** back prop step ***** //
|
||||||
auto Wxr = (*Wx)({0,0, 0, nU});
|
|
||||||
auto Wxu = (*Wx)({0,0, nU, 2*nU});
|
|
||||||
auto Wxn = (*Wx)({0,0, 2*nU,3*nU});
|
|
||||||
auto Whr = (*Wh)({0,0, 0, nU});
|
|
||||||
auto Whu = (*Wh)({0,0, nU, 2*nU});
|
|
||||||
auto Whn = (*Wh)({0,0, 2*nU,3*nU});
|
|
||||||
auto WxrT = Wxr.transpose();
|
|
||||||
auto WxuT = Wxu.transpose();
|
|
||||||
auto WxnT = Wxn.transpose();
|
|
||||||
auto WhrT = Whr.transpose();
|
|
||||||
auto WhuT = Whu.transpose();
|
|
||||||
auto WhnT = Whn.transpose();
|
|
||||||
auto xT = x->transpose();
|
|
||||||
auto h0T = h0->transpose();
|
|
||||||
|
|
||||||
auto dLdWxr = (*dLdWx)({0,0, 0, nU});
|
// notations:
|
||||||
auto dLdWxu = (*dLdWx)({0,0, nU, 2*nU});
|
// Zr = x × Wrx + hLast × Wrh + br
|
||||||
auto dLdWxn = (*dLdWx)({0,0, 2*nU,3*nU});
|
// Zu = x × Wux + hLast × Wuh + bu
|
||||||
|
// Sr = sigmoid(Zr)
|
||||||
|
// Su = sigmoid(Zu)
|
||||||
|
// Zc = x × Wcx + (r * hlast) × Wch + bc
|
||||||
|
|
||||||
auto dLdWhr = (*dLdWh)({0,0, 0, nU});
|
|
||||||
auto dLdWhu = (*dLdWh)({0,0, nU, 2*nU});
|
|
||||||
auto dLdWhn = (*dLdWh)({0,0, 2*nU,3*nU});
|
|
||||||
|
|
||||||
auto dLdbr = (*dLdb)({0, nU});
|
// dLdx = dLdh * dhdx = dLdh * (dhdu * dudx + dhdc * dcdx) = (dLdh * dhdu) * dudx + (dLdh * dhdc) * dcdx = dLdu * dudx + dLdc * dcdx
|
||||||
auto dLdbu = (*dLdb)({nU, 2*nU});
|
// = dLdx_u + dLdx_c
|
||||||
auto dLdbn = (*dLdb)({2*nU,3*nU});
|
// dLdx_u = dLdu * dudx = dLdu * dudZu * dZudx = |dZudx = ... × WuxT| = (dLdu * dudZu) × WuxT
|
||||||
|
// dLdx_c = dLdc * dcdx = dLdc * dcdZc * (dZcdx + dZcdr * drdx) = dLdc * dcdZc * dZcdx + dLdc * dcdZc * dZcdr * drdx = dLdx_c0 + dLdx_c1
|
||||||
|
// dLdx_c0 = dLdc * dcdZc * dZcdx = |dZcdx = ... × WcxT| = (dLdc * dcdZc) × WcxT
|
||||||
|
// dZcdr = (... * hLast) × WchT
|
||||||
|
// dLdc * dcdZc * dZcdr = dLdr = (dLdc * dcdZc * hLast) × WchT
|
||||||
|
// drdx = drdZr * dZrdx
|
||||||
|
// dZrdx = ... × WrxT
|
||||||
|
// dLdx_c1 = dLdc * dcdZc * dZcdr * drdx = dLdr * drdx = (dLdr * drdZr) × WrxT
|
||||||
|
// finally dLdx = dLdx_u + dLdx_c0 + dLdx_c1 = (dLdu * dudZu) × WuxT + (dLdc * dcdZc) × WcxT + (dLdr * drdZr) × WrxT
|
||||||
|
|
||||||
auto dhdu = *h0 - n; // [bS, nU]
|
|
||||||
auto dhdn = 1.f - u; // [bS, nU]
|
|
||||||
auto dSigdu = u * (1.f - u); // [bS, nU]
|
|
||||||
auto dSigdr = r * (1.f - r); // [bS, nU]
|
|
||||||
auto dActdn = 1.f - n * n; // [bS, nU]
|
|
||||||
auto dndr = mmul(dActdn * (*h0), WhnT);
|
|
||||||
auto drdh0 = mmul(dSigdr, WhrT);
|
|
||||||
|
|
||||||
auto dLdn = (*dLdh) * dhdn;
|
// dLdhLast = dLdh * (dhdhLast + dhdu * dudhLast + dhdc * dcdhLast) = dLdh * dhdhLast + dLdu * dudhLast + dLdc * dcdhLast
|
||||||
auto dLdu = (*dLdh) * dhdu;
|
// = dLdhLast_h + dLdhLast_u + dLdhLast_c
|
||||||
auto dLdr = dLdn * dndr;
|
// dLdhLast_h = dLdh * dhdhLas = dLdh * u
|
||||||
|
// dLdhLast_u = dLdu * dudhLast = |dudhLast = dudZu * dZudhLast , dZudhLast = ... × WuhT| = (dLdu * dudZu) × WuhT
|
||||||
|
// dLdhLast_c = dLdc * dcdhLast = dLdc * (dcdZc * dZcdhLast + dcdZc * dZcdr * drdhLast) =
|
||||||
|
// = dLdc * dcdZc * dZcdhLast + dLdc * dcdZc * dZcdr * drdhLast =
|
||||||
|
// = dLdc * dcdZc * dZcdhLast + dLdr * drdhLast = dLdhLast_c0 + dLdhLast_c1
|
||||||
|
// dLdhLast_c0 = dLdc * dcdZc * dZcdhLast = |dZcdhLast = (... * r) × WchT| = (dLdc * dcdZc * r) × WchT
|
||||||
|
// dLdhLast_c1 = dLdr * drdhLast = |drdhLast = drdZr * dZrdhLast, dZrdhLast = ... × WrhT| = (dLdr * drdZr) × WrhT
|
||||||
|
// finally dLdhLast = dLdhLast_h + dLdhLast_u + dLdhLast_c0 + dLdhLast_c1 =
|
||||||
|
// = dLdh * u + (dLdu * dudZu) × WuhT + (dLdc * dcdZc * r) × WchT + (dLdr * drdZr) × WrhT
|
||||||
|
|
||||||
dLdx->assign( mmul(dLdu * dSigdu, WxuT) + mmul(dLdr * dSigdr, WxrT) + mmul(dLdn * dActdn, WxnT) ); // [bS,iS]
|
|
||||||
dLdh0->assign( mmul(dLdu * dSigdu, WhuT) + mmul(dLdn * dActdn * (r + drdh0), WhnT) + (*dLdh)*u ); // [bS,nU]
|
|
||||||
|
|
||||||
dLdWxr.assign( mmul(xT, dSigdr * dLdr) ); // [iS,nU]
|
// dLdWrx = dLdh * dhdWrx = (dLdh * dhdc) * dcdWrx = dLdc * dcdZc * dZcdWrx = dLdc * dcdZc * dZcdr * drdWrx =
|
||||||
dLdWhr.assign( mmul(h0T, dSigdr * dLdr) ); // [nU,nU]
|
// = dLdc * dcdZc * dZcdr * drdZr * dZrdWrx = dLdr * drdZr * dZrdWrx
|
||||||
|
// dZrdWrx = xT × ...
|
||||||
|
// finally dLdWrx = xT × (dLdr * drdZr)
|
||||||
|
|
||||||
dLdWxu.assign( mmul(xT, dSigdu * dLdu) ); // [iS,nU]
|
|
||||||
dLdWhu.assign( mmul(h0T, dSigdu * dLdu) ); // [nU,nU]
|
|
||||||
|
|
||||||
dLdWxn.assign( mmul(xT, dActdn * dLdn) ); // [iS,nU]
|
// dLdWrh = dLdh * dhdWrh = (dLdh * dhdc) * dcdWrh = dLdc * dcdZc * dZcdWrh = dLdc * dcdZc * dZcdr * drdWrh =
|
||||||
dLdWhn.assign( mmul((r*(*h0)).transpose(), dActdn * dLdn) ); // [nU,nU]
|
// = dLdc * dcdZc * dZcdr * drdZr * dZrdWrh = dLdr * drdZr * dZrdWrh
|
||||||
|
// dZrdWrh = hLastT × ...
|
||||||
|
// finally dLdWrh = hLastT × (dLdr * drdZr)
|
||||||
|
|
||||||
dLdbr.assign( (dSigdr * dLdr).reduceAlongDims(reduce::Sum, {0})); // [nU]
|
|
||||||
dLdbu.assign( (dSigdu * dLdu).reduceAlongDims(reduce::Sum, {0})); // [nU]
|
|
||||||
dLdbn.assign( (dActdn * dLdn).reduceAlongDims(reduce::Sum, {0})); // [nU]
|
|
||||||
|
|
||||||
if(dLdWx0 != nullptr)
|
// dLdWux = dLdh * dhdWux = (dLdh * dhdu) * dudWux = dLdu * dudZu * dZudWux
|
||||||
*dLdWx += *dLdWx0;
|
// dZudWux = xT × ...
|
||||||
|
// dLdu * dudZu * dZudWux = xT × (dLdu * dudZu)
|
||||||
|
|
||||||
if(dLdWh0 != nullptr)
|
|
||||||
*dLdWh += *dLdWh0;
|
|
||||||
|
|
||||||
if(dLdb0 != nullptr)
|
// dLdWuh = dLdh * dhdWuh = (dLdh * dhdu) * dudWuh = dLdh * dhdu * dudZu * dZudWuh = dLdu * dudZu * dZudWuh
|
||||||
*dLdb += *dLdb0;
|
// dZudWuh = hLastT × ...
|
||||||
|
// finally dLdWuh = hLastT × (dLdu * dudZu)
|
||||||
|
|
||||||
|
|
||||||
|
// dLdWcx = dLdh * dhdWcx = dLdh * dhdc * dcdWcx = (dLdh * dhdc) * dcdZc * dZcdWcx = dLdc * dcdZc * dZcdWcx
|
||||||
|
// dZcdWcx = xT × ...
|
||||||
|
// finally dLdWcx = xT × (dLdc * dcdZc)
|
||||||
|
|
||||||
|
|
||||||
|
// dLdWch = dLdh * dhdWch = dLdh * dhdc * dcdWch = (dLdh * dhdc) * dcdZc * dZcdWch = dLdc * dcdZc * dZcdWch
|
||||||
|
// dZcdWch = (r*hLast)^T × ...
|
||||||
|
// finally dLdWch = (r*hLast)^T × (dLdc * dcdZc)
|
||||||
|
|
||||||
|
|
||||||
|
// dLdbr = dLdh * dhdbr = (dLdh * dhdc) * dcdbr = dLdc * dcdbr = dLdc * dcdZc * dZcdbr = dLdc * dcdZc * dZcdr * drdbr =
|
||||||
|
// = dLdr * drdZr * dZrdbr
|
||||||
|
// dZrdbr = 1
|
||||||
|
// finally dLdbr = dLdr * drdZr
|
||||||
|
|
||||||
|
|
||||||
|
// dLdbu = dLdh * dhdbu = (dLdh * dhdu) * dudbu = dLdu * dudZu * dZudbu
|
||||||
|
// dZudbu = 1
|
||||||
|
// finally dLdbu = dLdu * dudZu
|
||||||
|
|
||||||
|
|
||||||
|
// dLdbc = dLdh * dhdbc = (dLdh * dhdc) * dcdbc = dLdc * dcdZc * dZcdbc
|
||||||
|
// dZcdbc = 1
|
||||||
|
// finally dLdbc = dLdc * dcdZc
|
||||||
|
|
||||||
|
NDArray dhdc = 1.f - u; // [bS, nU]
|
||||||
|
NDArray dhdu = *hLast - c; // [bS, nU]
|
||||||
|
NDArray dudZu = u * dhdc; // [bS, nU]
|
||||||
|
NDArray drdZr = r * (1.f - r); // [bS, nU]
|
||||||
|
NDArray dcdZc = 1.f - c * c; // [bS, nU]
|
||||||
|
NDArray dLdZc = *dLdc * dcdZc; // [bS, nU]
|
||||||
|
NDArray dLdZu = *dLdu * dudZu; // [bS, nU]
|
||||||
|
NDArray dLdZr = *dLdr * drdZr; // [bS, nU]
|
||||||
|
|
||||||
|
// NDArray dLdc = *dLdh * dhdc; // [bS, nU]
|
||||||
|
// NDArray dLdu = *dLdh * dhdu; // [bS, nU]
|
||||||
|
// NDArray dLdr = mmul(dLdc * dcdZc * *hLast, WchT); // [bS, nU]
|
||||||
|
|
||||||
|
dLdx->assign(mmul(dLdZu, WuxT) + mmul(dLdZc, WcxT) + mmul(dLdZr, WrxT)); // [bS, iS]
|
||||||
|
|
||||||
|
dLdhLast->assign(*dLdh * u + mmul(dLdZu, WuhT) + mmul(dLdZc * r, WchT) + mmul(dLdZr, WrhT)); // [bS, nU]
|
||||||
|
|
||||||
|
dLdWrx.assign(mmul(xT, dLdZr)); // [iS, bS] × [bS, nU] = [iS, nU]
|
||||||
|
dLdWrh.assign(mmul(hLastT, dLdZr)); // [nU, bS] × [bS, nU] = [nU, nU]
|
||||||
|
dLdWux.assign(mmul(xT, dLdZu)); // [iS, bS] × [bS, nU] = [iS, nU]
|
||||||
|
dLdWuh.assign(mmul(hLastT, dLdZu)); // [nU, bS] × [bS, nU] = [nU, nU]
|
||||||
|
|
||||||
|
dLdWcx.assign(mmul(xT, dLdZc)); // [iS, bS] × [bS, nU] = [iS, nU]
|
||||||
|
dLdWch.assign(mmul((r * *hLast).transpose(), dLdZc)); // [nU, bS] × [bS, nU] = [nU, nU]
|
||||||
|
|
||||||
|
dLdbr.assign(dLdZr.reduceAlongDims(reduce::Sum, {0})); // [nU]
|
||||||
|
dLdbu.assign(dLdZu.reduceAlongDims(reduce::Sum, {0})); // [nU]
|
||||||
|
|
||||||
|
dLdbc->assign(dLdZc.reduceAlongDims(reduce::Sum, {0})); // [nU]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -20,12 +20,111 @@
|
||||||
|
|
||||||
#include <ops/declarable/helpers/hashcode.h>
|
#include <ops/declarable/helpers/hashcode.h>
|
||||||
|
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
void hashCode(LaunchContext *context, NDArray &array, NDArray &result) {
|
template <typename T>
|
||||||
|
static __global__ void splitBufferToChuncks(T* buffer, Nd4jLong* tempBuffer, Nd4jLong numBlocks, Nd4jLong blockSize, Nd4jLong length) {
|
||||||
|
|
||||||
|
for (int b = blockIdx.x; b < numBlocks; b += gridDim.x) {
|
||||||
|
auto blockBuffer = buffer + b * numBlocks;
|
||||||
|
|
||||||
|
Nd4jLong r = 1;
|
||||||
|
for (int e = threadIdx.x; e < blockSize && e + (b * numBlocks) < length; e += blockDim.x) {
|
||||||
|
auto v = longBytes<T>(blockBuffer[e]);
|
||||||
|
r = 31 * r + v;
|
||||||
|
}
|
||||||
|
|
||||||
|
tempBuffer[b] = r;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static __global__ void internalHash(Nd4jLong* tempBuffer, Nd4jLong* tempResult, Nd4jLong numBlocks, Nd4jLong blockSize, Nd4jLong lastLength) {
|
||||||
|
|
||||||
|
for (int b = blockIdx.x; b < numBlocks; b += gridDim.x) {
|
||||||
|
auto blockBuffer = tempBuffer + b * numBlocks;
|
||||||
|
|
||||||
|
Nd4jLong r = 1;
|
||||||
|
for (int e = threadIdx.x; e < blockSize && e + (b * numBlocks) < lastLength; e += blockDim.x) {
|
||||||
|
auto v = longBytes<T>(blockBuffer[e]);
|
||||||
|
r = 31 * r + v;
|
||||||
|
}
|
||||||
|
|
||||||
|
tempResult[b] = r;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static __global__ void lastStep(Nd4jLong* resultBuf, Nd4jLong* tempBufferA, Nd4jLong* tempResult, Nd4jLong length, Nd4jLong blockSize) {
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
|
||||||
|
if (length <= blockSize)
|
||||||
|
*resultBuf = *tempBufferA;
|
||||||
|
else
|
||||||
|
*resultBuf = *tempResult;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void hashCode_(LaunchContext *context, NDArray &array, NDArray &result) {
|
||||||
|
auto blockSize = 32;
|
||||||
|
auto stream = context->getCudaStream();
|
||||||
|
array.syncToDevice();
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&result}, {&array});
|
||||||
|
auto length = array.lengthOf();
|
||||||
|
int numBlocks = length / blockSize + ((length % blockSize == 0) ? 0 : 1);
|
||||||
|
auto tempA = NDArrayFactory::create<Nd4jLong>('c', {numBlocks}, context);
|
||||||
|
auto tempB = NDArrayFactory::create<Nd4jLong>('c', { numBlocks / blockSize + 1}, context);
|
||||||
|
|
||||||
|
auto buffer = reinterpret_cast<T*>(array.specialBuffer()); //bufferAsT<T>();
|
||||||
|
auto tempBufferA = reinterpret_cast<Nd4jLong*>(tempA.specialBuffer()); //bufferAsT<Nd4jLong>();
|
||||||
|
auto tempBufferB = reinterpret_cast<Nd4jLong*>(tempB.specialBuffer()); //bufferAsT<Nd4jLong>();
|
||||||
|
|
||||||
|
// default buffer is the first one, because it might be the last one in case of small arrays (< blockSize)
|
||||||
|
auto tempBuffer = tempBufferA;
|
||||||
|
auto tempResult = tempBufferB;
|
||||||
|
|
||||||
|
// we divide array into 32 element chunks, and store intermediate results once
|
||||||
|
splitBufferToChuncks<T><<<numBlocks, length, 1024, *stream>>>(buffer, tempBuffer, numBlocks, blockSize, length);
|
||||||
|
|
||||||
|
// we replace pointer with intermediate one, and repeat only one chunk left
|
||||||
|
int iterationCount = 0;
|
||||||
|
while (numBlocks > 1) {
|
||||||
|
int lastLength = numBlocks;
|
||||||
|
numBlocks = lastLength / blockSize + ((lastLength % blockSize == 0) ? 0 : 1);
|
||||||
|
|
||||||
|
|
||||||
|
internalHash<Nd4jLong><<<numBlocks, lastLength, 1024, *stream>>>(tempBuffer, tempResult, numBlocks, blockSize, lastLength);
|
||||||
|
|
||||||
|
|
||||||
|
iterationCount++;
|
||||||
|
// swapping buffers
|
||||||
|
if (iterationCount % 2 == 0) {
|
||||||
|
tempBuffer = tempBufferA;
|
||||||
|
tempResult = tempBufferB;
|
||||||
|
} else {
|
||||||
|
tempBuffer = tempBufferB;
|
||||||
|
tempResult = tempBufferA;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//lastStep<Nd4jLong><<<1,1,128, *stream>>>(result.specialBuffer(), tempBufferA, tempResult, length, blockSize);
|
||||||
|
tempA.syncToHost();
|
||||||
|
tempB.syncToHost();
|
||||||
|
result.assign((length <= blockSize?tempA.e(0) : tempB.e(0)));
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({&result}, {&array});
|
||||||
|
}
|
||||||
|
|
||||||
|
void hashCode(LaunchContext *context, NDArray &array, NDArray &result) {
|
||||||
|
BUILD_SINGLE_SELECTOR(array.dataType(), hashCode_, (context, array, result), LIBND4J_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void hashCode_, (LaunchContext* context, NDArray& array, NDArray& result), LIBND4J_TYPES);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,6 +20,8 @@
|
||||||
|
|
||||||
#include <ops/declarable/helpers/image_suppression.h>
|
#include <ops/declarable/helpers/image_suppression.h>
|
||||||
#include <NDArrayFactory.h>
|
#include <NDArrayFactory.h>
|
||||||
|
#include <NativeOps.h>
|
||||||
|
#include <cuda_exception.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -35,15 +37,16 @@ namespace helpers {
|
||||||
Nd4jLong next1[] = {nextIndex, 1};
|
Nd4jLong next1[] = {nextIndex, 1};
|
||||||
Nd4jLong next2[] = {nextIndex, 2};
|
Nd4jLong next2[] = {nextIndex, 2};
|
||||||
Nd4jLong next3[] = {nextIndex, 3};
|
Nd4jLong next3[] = {nextIndex, 3};
|
||||||
|
Nd4jLong* shapeOf = shape::shapeOf(boxesShape);
|
||||||
T minYPrev = nd4j::math::nd4j_min(boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), previous0, 2)], boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), previous2, 2)]);
|
Nd4jLong* strideOf = shape::stride(boxesShape);
|
||||||
T minXPrev = nd4j::math::nd4j_min(boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), previous1, 2)], boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), previous3, 2)]);
|
T minYPrev = nd4j::math::nd4j_min(boxes[shape::getOffset(0, shapeOf, strideOf, previous0, 2)], boxes[shape::getOffset(0, shapeOf, strideOf, previous2, 2)]);
|
||||||
T maxYPrev = nd4j::math::nd4j_max(boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), previous0, 2)], boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), previous2, 2)]);
|
T minXPrev = nd4j::math::nd4j_min(boxes[shape::getOffset(0, shapeOf, strideOf, previous1, 2)], boxes[shape::getOffset(0, shapeOf, strideOf, previous3, 2)]);
|
||||||
T maxXPrev = nd4j::math::nd4j_max(boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), previous1, 2)], boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), previous3, 2)]);
|
T maxYPrev = nd4j::math::nd4j_max(boxes[shape::getOffset(0, shapeOf, strideOf, previous0, 2)], boxes[shape::getOffset(0, shapeOf, strideOf, previous2, 2)]);
|
||||||
T minYNext = nd4j::math::nd4j_min(boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), next0, 2)], boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), next2, 2)]);
|
T maxXPrev = nd4j::math::nd4j_max(boxes[shape::getOffset(0, shapeOf, strideOf, previous1, 2)], boxes[shape::getOffset(0, shapeOf, strideOf, previous3, 2)]);
|
||||||
T minXNext = nd4j::math::nd4j_min(boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), next1, 2)], boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), next3, 2)]);
|
T minYNext = nd4j::math::nd4j_min(boxes[shape::getOffset(0, shapeOf, strideOf, next0, 2)], boxes[shape::getOffset(0, shapeOf, strideOf, next2, 2)]);
|
||||||
T maxYNext = nd4j::math::nd4j_max(boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), next0, 2)], boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), next2, 2)]);
|
T minXNext = nd4j::math::nd4j_min(boxes[shape::getOffset(0, shapeOf, strideOf, next1, 2)], boxes[shape::getOffset(0, shapeOf, strideOf, next3, 2)]);
|
||||||
T maxXNext = nd4j::math::nd4j_max(boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), next1, 2)], boxes[shape::getOffset(0, shape::shapeOf(boxesShape), shape::stride(boxesShape), next3, 2)]);
|
T maxYNext = nd4j::math::nd4j_max(boxes[shape::getOffset(0, shapeOf, strideOf, next0, 2)], boxes[shape::getOffset(0, shapeOf, strideOf, next2, 2)]);
|
||||||
|
T maxXNext = nd4j::math::nd4j_max(boxes[shape::getOffset(0, shapeOf, strideOf, next1, 2)], boxes[shape::getOffset(0, shapeOf, strideOf, next3, 2)]);
|
||||||
|
|
||||||
T areaPrev = (maxYPrev - minYPrev) * (maxXPrev - minXPrev);
|
T areaPrev = (maxYPrev - minYPrev) * (maxXPrev - minXPrev);
|
||||||
T areaNext = (maxYNext - minYNext) * (maxXNext - minXNext);
|
T areaNext = (maxYNext - minYNext) * (maxXNext - minXNext);
|
||||||
|
@ -62,149 +65,101 @@ namespace helpers {
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, typename I>
|
template <typename T, typename I>
|
||||||
static __global__ void nonMaxSuppressionKernel(T* boxes, Nd4jLong* boxesShape, I* indices, int* selectedIndices, Nd4jLong numBoxes, I* output, Nd4jLong* outputShape, T threshold) {
|
static __global__ void shouldSelectKernel(T* boxesBuf, Nd4jLong* boxesShape, I* indexBuf, I* selectedIndicesData, double threshold, int numSelected, int i, bool* shouldSelect) {
|
||||||
__shared__ Nd4jLong outputLen;
|
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
auto step = gridDim.x * blockDim.x;
|
||||||
|
__shared__ bool shouldSelectShared;
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
outputLen = shape::length(outputShape);
|
shouldSelectShared = shouldSelect[0];
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
for (int j = numSelected - 1 - tid; j >= 0; j -= step) {
|
||||||
auto numSelected = blockIdx.x;
|
if (shouldSelectShared) {
|
||||||
auto start = blockIdx.x * blockDim.x + threadIdx.x;
|
if (needToSuppressWithThreshold(boxesBuf, boxesShape, indexBuf[i],
|
||||||
auto step = blockDim.x * gridDim.x;
|
indexBuf[selectedIndicesData[j]], T(threshold)))
|
||||||
// for (int numSelected = blockIdx.x; numSelected < outputLen; numSelected += gridDim.x) {
|
shouldSelectShared = false;
|
||||||
for (int i = start; i < numBoxes; i += step) {
|
}
|
||||||
bool shouldSelect = true;
|
}
|
||||||
for (int j = numSelected - 1; shouldSelect && j >= 0; --j) {
|
__syncthreads();
|
||||||
if (needToSuppressWithThreshold<T>(boxes, boxesShape, indices[i], indices[selectedIndices[j]], threshold)) {
|
if (threadIdx.x == 0) {
|
||||||
shouldSelect = false;
|
*shouldSelect = shouldSelectShared;
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (shouldSelect) {
|
|
||||||
auto zPos = shape::getIndexOffset(numSelected, outputShape, outputLen);
|
|
||||||
output[zPos] = indices[i];
|
|
||||||
selectedIndices[numSelected] = i;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
template <typename I>
|
||||||
|
|
||||||
template <typename T, typename I>
|
static __global__ void copyIndices(void* indices, void* indicesLong, Nd4jLong len) {
|
||||||
static __global__ void sortIndices(I* indices, Nd4jLong* indexShape, T* scores, Nd4jLong* scoreShape) {
|
__shared__ I* indexBuf;
|
||||||
__shared__ Nd4jLong len;
|
__shared__ Nd4jLong* srcBuf;
|
||||||
// __shared__ Nd4jLong* sortedPart;
|
|
||||||
// __shared__ Nd4jLong part;
|
|
||||||
// __shared__ Nd4jLong partSize;
|
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
// blocksPerArr = (gridDim.x + numOfArrs - 1) / numOfArrs; // ceil
|
indexBuf = reinterpret_cast<I*>(indices);
|
||||||
// part = blockIdx.x / blocksPerArr;
|
srcBuf = reinterpret_cast<Nd4jLong*>(indicesLong);
|
||||||
|
|
||||||
len = shape::length(indexShape);
|
|
||||||
// __shared__ Nd4jLong* shmem = shared[];
|
|
||||||
// sortedPart = shmem;
|
|
||||||
}
|
}
|
||||||
|
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
auto step = blockDim.x * gridDim.x;
|
||||||
|
|
||||||
for (int m = 0; m < len; m++) {
|
for (auto i = tid; i < len; i += step)
|
||||||
if (m % 2 == 0) {
|
indexBuf[i] = (I)srcBuf[i];
|
||||||
for (int tid = threadIdx.x; tid < len; tid += blockDim.x) {
|
|
||||||
auto top = 2 * tid + 1;
|
|
||||||
if (top < len) {
|
|
||||||
auto t0 = shape::getIndexOffset(top - 1, indexShape, len);
|
|
||||||
auto t1 = shape::getIndexOffset(top, indexShape, len);
|
|
||||||
auto z0 = shape::getIndexOffset(top - 1, scoreShape, len);
|
|
||||||
auto z1 = shape::getIndexOffset(top, scoreShape, len);
|
|
||||||
|
|
||||||
if (scores[t0] < scores[t1]) {
|
|
||||||
// swap indices first
|
|
||||||
Nd4jLong di0 = indices[t0];
|
|
||||||
indices[t0] = indices[t1];
|
|
||||||
indices[t1] = di0;
|
|
||||||
|
|
||||||
//swap scores next
|
|
||||||
// T dz0 = scores[z0];
|
|
||||||
// scores[z0] = scores[z1];
|
|
||||||
// scores[z1] = dz0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (int tid = threadIdx.x; tid < len; tid += blockDim.x) {
|
|
||||||
auto top = 2 * tid + 2;
|
|
||||||
if (top < len) {
|
|
||||||
auto t0 = shape::getIndexOffset(top - 1, indexShape, len);
|
|
||||||
auto t1 = shape::getIndexOffset(top, indexShape, len);
|
|
||||||
auto z0 = shape::getIndexOffset(top - 1, scoreShape, len);
|
|
||||||
auto z1 = shape::getIndexOffset(top, scoreShape, len);
|
|
||||||
|
|
||||||
if (scores[t0] < scores[t1]) {
|
|
||||||
// swap indices first
|
|
||||||
Nd4jLong di0 = indices[t0];
|
|
||||||
indices[t0] = indices[t1];
|
|
||||||
indices[t1] = di0;
|
|
||||||
|
|
||||||
//swap scores next
|
|
||||||
// T dz0 = scores[z0];
|
|
||||||
// scores[z0] = scores[z1];
|
|
||||||
// scores[z1] = dz0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename I>
|
template <typename T, typename I>
|
||||||
static void nonMaxSuppressionV2_(nd4j::LaunchContext* context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, NDArray* output) {
|
static void nonMaxSuppressionV2_(nd4j::LaunchContext* context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, NDArray* output) {
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
NDArray::prepareSpecialUse({output}, {boxes, scales});
|
NDArray::prepareSpecialUse({output}, {boxes, scales});
|
||||||
NDArray* indices = NDArrayFactory::create_<I>('c', {scales->lengthOf()}); // - 1, scales->lengthOf()); //, scales->getContext());
|
std::unique_ptr<NDArray> indices(NDArrayFactory::create_<I>('c', {scales->lengthOf()})); // - 1, scales->lengthOf()); //, scales->getContext());
|
||||||
indices->linspace(0);
|
indices->linspace(0);
|
||||||
|
indices->syncToDevice(); // linspace only on CPU, so sync to Device as well
|
||||||
|
|
||||||
NDArray scores(*scales);
|
NDArray scores(*scales);
|
||||||
indices->syncToHost(); //linspace(0);
|
NativeOps nativeOps;
|
||||||
I* indexBuf = reinterpret_cast<I*>(indices->specialBuffer());
|
|
||||||
T* scoreBuf = reinterpret_cast<T*>(scores.specialBuffer());
|
Nd4jPointer extras[2] = {nullptr, stream};
|
||||||
sortIndices<T, I><<<1, 32, 128, *stream>>>(indexBuf, indices->specialShapeInfo(), scoreBuf, scores.specialShapeInfo());
|
|
||||||
|
nativeOps.sortByValue(extras, indices->buffer(), indices->shapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), scores.buffer(), scores.shapeInfo(), scores.specialBuffer(), scores.specialShapeInfo(), true);
|
||||||
// TO DO: sort indices using scales as value row
|
// TO DO: sort indices using scales as value row
|
||||||
//std::sort(indices.begin(), indices.end(), [scales](int i, int j) {return scales->e<T>(i) > scales->e<T>(j);});
|
//std::sort(indices.begin(), indices.end(), [scales](int i, int j) {return scales->e<T>(i) > scales->e<T>(j);});
|
||||||
indices->tickWriteDevice();
|
I* indexBuf = reinterpret_cast<I*>(indices->specialBuffer());
|
||||||
indices->syncToHost();
|
|
||||||
indices->printIndexedBuffer("AFTERSORT OUTPUT");
|
|
||||||
NDArray selected = NDArrayFactory::create<int>({output->lengthOf()});
|
|
||||||
|
|
||||||
NDArray selectedIndices = NDArrayFactory::create<int>({output->lengthOf()});
|
NDArray selectedIndices = NDArrayFactory::create<I>('c', {output->lengthOf()});
|
||||||
int numSelected = 0;
|
int numSelected = 0;
|
||||||
int numBoxes = boxes->sizeAt(0);
|
int numBoxes = boxes->sizeAt(0);
|
||||||
T* boxesBuf = reinterpret_cast<T*>(boxes->specialBuffer());
|
T* boxesBuf = reinterpret_cast<T*>(boxes->specialBuffer());
|
||||||
// Nd4jLong* indicesData = reinterpret_cast<Nd4jLong*>(indices->specialBuffer());
|
|
||||||
// int* selectedData = reinterpret_cast<int*>(selected.specialBuffer());
|
I* selectedIndicesData = reinterpret_cast<I*>(selectedIndices.specialBuffer());
|
||||||
int* selectedIndicesData = reinterpret_cast<int*>(selectedIndices.specialBuffer());
|
|
||||||
I* outputBuf = reinterpret_cast<I*>(output->specialBuffer());
|
I* outputBuf = reinterpret_cast<I*>(output->specialBuffer());
|
||||||
nonMaxSuppressionKernel<T, I><<<output->lengthOf(), 512, 1024, *stream>>>(boxesBuf, boxes->specialShapeInfo(), indexBuf, selectedIndicesData, numBoxes, outputBuf, output->specialShapeInfo(), T(threshold));
|
|
||||||
NDArray::registerSpecialUse({output}, {boxes, scales});
|
bool* shouldSelectD;
|
||||||
// for (int i = 0; i < boxes->sizeAt(0); ++i) {
|
auto err = cudaMalloc(&shouldSelectD, sizeof(bool));
|
||||||
// if (selected.size() >= output->lengthOf()) break;
|
if (err) {
|
||||||
// bool shouldSelect = true;
|
throw cuda_exception::build("helpers::nonMaxSuppressionV2: Cannot allocate memory for bool flag", err);
|
||||||
// // Overlapping boxes are likely to have similar scores,
|
}
|
||||||
// // therefore we iterate through the selected boxes backwards.
|
for (I i = 0; i < boxes->sizeAt(0); ++i) {
|
||||||
// for (int j = numSelected - 1; j >= 0; --j) {
|
bool shouldSelect = numSelected < output->lengthOf();
|
||||||
// if (needToSuppressWithThreshold(*boxes, indices[i], indices[selectedIndices[j]], T(threshold)) {
|
if (shouldSelect) {
|
||||||
// shouldSelect = false;
|
err = cudaMemcpy(shouldSelectD, &shouldSelect, sizeof(bool), cudaMemcpyHostToDevice);
|
||||||
// break;
|
if (err) {
|
||||||
// }
|
throw cuda_exception::build("helpers::nonMaxSuppressionV2: Cannot set up bool flag to device", err);
|
||||||
// }
|
}
|
||||||
// if (shouldSelect) {
|
|
||||||
// selected.push_back(indices[i]);
|
shouldSelectKernel<T> <<< 128, 256, 1024, *stream >>>
|
||||||
// selectedIndices[numSelected++] = i;
|
(boxesBuf, boxes->specialShapeInfo(), indexBuf, selectedIndicesData, threshold, numSelected, i, shouldSelectD);
|
||||||
// }
|
err = cudaMemcpy(&shouldSelect, shouldSelectD, sizeof(bool), cudaMemcpyDeviceToHost);
|
||||||
// }
|
if (err) {
|
||||||
// for (size_t e = 0; e < selected.size(); ++e)
|
throw cuda_exception::build("helpers::nonMaxSuppressionV2: Cannot set up bool flag to host", err);
|
||||||
// output->p<int>(e, selected[e]);
|
}
|
||||||
//
|
}
|
||||||
delete indices;
|
|
||||||
|
if (shouldSelect) {
|
||||||
|
cudaMemcpy(reinterpret_cast<I*>(output->specialBuffer()) + numSelected, indexBuf + i, sizeof(I), cudaMemcpyDeviceToDevice);
|
||||||
|
cudaMemcpy(selectedIndicesData + numSelected, &i, sizeof(I), cudaMemcpyHostToDevice);
|
||||||
|
numSelected++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = cudaFree(shouldSelectD);
|
||||||
|
if (err) {
|
||||||
|
throw cuda_exception::build("helpers::nonMaxSuppressionV2: Cannot deallocate memory for bool flag", err);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void nonMaxSuppressionV2(nd4j::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, NDArray* output) {
|
void nonMaxSuppressionV2(nd4j::LaunchContext * context, NDArray* boxes, NDArray* scales, int maxSize, double threshold, NDArray* output) {
|
||||||
|
|
|
@ -32,24 +32,24 @@ namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
|
|
||||||
template <typename T>
|
// template <typename T>
|
||||||
static __device__ void swapRows_(T* matrix, Nd4jLong* shape, int theFirst, int theSecond, Nd4jLong N) {
|
// static __device__ void swapRows_(T* matrix, Nd4jLong* shape, int theFirst, int theSecond, Nd4jLong N) {
|
||||||
if (theFirst != theSecond) {
|
// if (theFirst != theSecond) {
|
||||||
auto start = threadIdx.x + blockIdx.x * blockDim.x;
|
// auto start = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
auto step = blockDim.x * gridDim.x;
|
// auto step = blockDim.x * gridDim.x;
|
||||||
for (auto i = start; i < N; i += step) {
|
// for (auto i = start; i < N; i += step) {
|
||||||
Nd4jLong iCoord1[] = {theFirst, i};
|
// Nd4jLong iCoord1[] = {theFirst, i};
|
||||||
Nd4jLong iCoord2[] = {theSecond, i};
|
// Nd4jLong iCoord2[] = {theSecond, i};
|
||||||
auto iIndex1 = shape::getOffset(0, shape::shapeOf(shape), shape::stride(shape), iCoord1, 2);
|
// auto iIndex1 = shape::getOffset(0, shape::shapeOf(shape), shape::stride(shape), iCoord1, 2);
|
||||||
auto iIndex2 = shape::getOffset(0, shape::shapeOf(shape), shape::stride(shape), iCoord2, 2);
|
// auto iIndex2 = shape::getOffset(0, shape::shapeOf(shape), shape::stride(shape), iCoord2, 2);
|
||||||
//atomicExch(&matrix[iIndex1], matrix[iIndex2]);
|
// //atomicExch(&matrix[iIndex1], matrix[iIndex2]);
|
||||||
T e0 = matrix[iIndex1];
|
// T e0 = matrix[iIndex1];
|
||||||
T e1 = matrix[iIndex2];
|
// T e1 = matrix[iIndex2];
|
||||||
matrix[iIndex1] = e0;
|
// matrix[iIndex1] = e0;
|
||||||
matrix[iIndex2] = e1;
|
// matrix[iIndex2] = e1;
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
// BUILD_SINGLE_TEMPLATE(template void swapRows_, (NDArray* matrix, int theFirst, int theSecond), FLOAT_TYPES);
|
// BUILD_SINGLE_TEMPLATE(template void swapRows_, (NDArray* matrix, int theFirst, int theSecond), FLOAT_TYPES);
|
||||||
//
|
//
|
||||||
// void swapRows(NDArray* matrix, int theFirst, int theSecond) {
|
// void swapRows(NDArray* matrix, int theFirst, int theSecond) {
|
||||||
|
@ -71,9 +71,14 @@ namespace helpers {
|
||||||
|
|
||||||
for (int i = start + 1; i < n; i += step) {
|
for (int i = start + 1; i < n; i += step) {
|
||||||
Nd4jLong pos[] = {i, i - 1};
|
Nd4jLong pos[] = {i, i - 1};
|
||||||
|
Nd4jLong posX[] = {i, i};
|
||||||
|
Nd4jLong posY[] = {i - 1, i - 1};
|
||||||
auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2);
|
auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2);
|
||||||
|
auto dxIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posX, 2);
|
||||||
|
auto dyIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posY, 2);
|
||||||
auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), pos, 2);
|
auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), pos, 2);
|
||||||
inverted[zIndex] = -input[xIndex];
|
inverted[zIndex] = -input[xIndex] / (input[dxIndex] * input[dyIndex]);
|
||||||
|
// math::atomics::nd4j_atomicAdd(&inverted[zIndex], - input[xIndex] * inverted[iIndex] / input[dIndex]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -91,10 +96,11 @@ namespace helpers {
|
||||||
auto start = threadIdx.x + blockIdx.x * blockDim.x;
|
auto start = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
auto step = blockDim.x * gridDim.x;
|
auto step = blockDim.x * gridDim.x;
|
||||||
|
|
||||||
for (int i = start + 1; i < n; i += step) {
|
for (int i = start; i < n; i += step) {
|
||||||
Nd4jLong pos[] = {i, i};
|
Nd4jLong pos[] = {i, i};
|
||||||
auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2);
|
auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2);
|
||||||
auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), pos, 2);
|
auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), pos, 2);
|
||||||
|
// math::atomics::nd4j_atomicDiv(&inverted[zIndex], input[xIndex]);
|
||||||
inverted[zIndex] /= input[xIndex];
|
inverted[zIndex] /= input[xIndex];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -113,16 +119,16 @@ namespace helpers {
|
||||||
auto start = threadIdx.x + blockIdx.x * blockDim.x;
|
auto start = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
auto step = blockDim.x * gridDim.x;
|
auto step = blockDim.x * gridDim.x;
|
||||||
|
|
||||||
for (int i = start + 1; i < n - 1; i += step) {
|
for (int i = start; i < n - 1; i += step) {
|
||||||
Nd4jLong pos[] = {i, i + 1};
|
Nd4jLong pos[] = {i, i + 1};
|
||||||
Nd4jLong posY[] = {i, i};
|
//Nd4jLong posY[] = {i, i};
|
||||||
Nd4jLong posX[] = {i + 1, i + 1};
|
Nd4jLong posX[] = {i + 1, i + 1};
|
||||||
auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2);
|
auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2);
|
||||||
auto yIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2);
|
// auto yIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posY, 2);
|
||||||
// auto yIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2);
|
// auto yIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2);
|
||||||
auto iIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posX, 2);
|
auto iIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posX, 2);
|
||||||
auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), pos, 2);
|
auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), pos, 2);
|
||||||
inverted[zIndex] -= input[xIndex] * inverted[iIndex] / input[yIndex];
|
math::atomics::nd4j_atomicAdd(&inverted[zIndex], - input[xIndex] * inverted[iIndex]); // / input[yIndex]);
|
||||||
//inputMatrix->t<T>(i, i + 1) * invertedMatrix->t<T>(i + 1, i + 1) / inputMatrix->t<T>(i, i)
|
//inputMatrix->t<T>(i, i + 1) * invertedMatrix->t<T>(i + 1, i + 1) / inputMatrix->t<T>(i, i)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -142,16 +148,18 @@ namespace helpers {
|
||||||
// auto step = blockDim.x * gridDim.x;
|
// auto step = blockDim.x * gridDim.x;
|
||||||
|
|
||||||
for (int i = blockIdx.x + 2; i < n; i += gridDim.x) {
|
for (int i = blockIdx.x + 2; i < n; i += gridDim.x) {
|
||||||
for (int j = i - 2; j > -1; --j)
|
for (int j = i - 2; j >= 0; --j)
|
||||||
for (int k = threadIdx.x; k < i; k+= blockDim.x) {
|
for (int k = threadIdx.x; k < i; k += blockDim.x) {
|
||||||
Nd4jLong posZ[] = {i, j};
|
Nd4jLong posZ[] = {i, j};
|
||||||
Nd4jLong posX[] = {k, j};
|
Nd4jLong posY[] = {k, j};
|
||||||
Nd4jLong posY[] = {i, k};
|
Nd4jLong posX[] = {i, k};
|
||||||
|
Nd4jLong posD[] = {i, i};
|
||||||
|
|
||||||
auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posX, 2);
|
auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posX, 2);
|
||||||
auto yIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posY, 2);
|
auto yIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posY, 2);
|
||||||
|
auto dIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posD, 2);
|
||||||
auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posZ, 2);
|
auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posZ, 2);
|
||||||
inverted[zIndex] -= inverted[yIndex] * input[xIndex];
|
math::atomics::nd4j_atomicAdd(&inverted[zIndex], - inverted[yIndex] * input[xIndex] / input[dIndex]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -176,13 +184,13 @@ namespace helpers {
|
||||||
Nd4jLong posZ[] = {i, j};
|
Nd4jLong posZ[] = {i, j};
|
||||||
Nd4jLong posY[] = {k, j};
|
Nd4jLong posY[] = {k, j};
|
||||||
Nd4jLong posX[] = {i, k};
|
Nd4jLong posX[] = {i, k};
|
||||||
Nd4jLong posD[] = {i, i};
|
// Nd4jLong posD[] = {i, i};
|
||||||
|
|
||||||
auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posX, 2);
|
auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posX, 2);
|
||||||
auto yIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posY, 2);
|
auto yIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posY, 2);
|
||||||
auto dIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posD, 2);
|
// auto dIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posD, 2);
|
||||||
auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posZ, 2);
|
auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posZ, 2);
|
||||||
inverted[zIndex] -= inverted[yIndex] * input[xIndex] / input[dIndex];
|
math::atomics::nd4j_atomicAdd(&inverted[zIndex], - inverted[yIndex] * input[xIndex]);// / input[dIndex]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -196,14 +204,18 @@ namespace helpers {
|
||||||
LaunchContext* context = inputMatrix->getContext();
|
LaunchContext* context = inputMatrix->getContext();
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
|
|
||||||
|
// invert main diagonal
|
||||||
|
upvertKernel<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||||
|
// invert the second diagonal
|
||||||
invertKernelLow<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
invertKernelLow<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||||
|
// invertKernelLow<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||||
invertLowKernel<T><<<n, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
invertLowKernel<T><<<n, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void invertLowerMatrix_, (NDArray* inputMatrix, NDArray* invertedMatrix);, FLOAT_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void invertLowerMatrix_, (NDArray* inputMatrix, NDArray* invertedMatrix);, FLOAT_NATIVE);
|
||||||
|
|
||||||
void invertLowerMatrix(NDArray* inputMatrix, NDArray* invertedMatrix) {
|
void invertLowerMatrix(NDArray* inputMatrix, NDArray* invertedMatrix) {
|
||||||
BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertLowerMatrix_, (inputMatrix, invertedMatrix), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertLowerMatrix_, (inputMatrix, invertedMatrix), FLOAT_NATIVE);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@ -215,58 +227,58 @@ namespace helpers {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
upvertKernel<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
//upvertKernel<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||||
upvertKernelUp<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
upvertKernelUp<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||||
invertUpKernel<T><<<n, n, 256, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
invertUpKernel<T><<<n, n, 256, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void invertUpperMatrix_, (NDArray* inputMatrix, NDArray* invertedMatrix);, FLOAT_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void invertUpperMatrix_, (NDArray* inputMatrix, NDArray* invertedMatrix);, FLOAT_NATIVE);
|
||||||
|
|
||||||
void invertUpperMatrix(NDArray* inputMatrix, NDArray* invertedMatrix) {
|
void invertUpperMatrix(NDArray* inputMatrix, NDArray* invertedMatrix) {
|
||||||
BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertUpperMatrix_, (inputMatrix, invertedMatrix), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertUpperMatrix_, (inputMatrix, invertedMatrix), FLOAT_NATIVE);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
// template <typename T>
|
||||||
static __global__ void lupKernel(T* compound, Nd4jLong* compoundShape, T* permutation, Nd4jLong* permutationShape, Nd4jLong rowNum) {
|
// static __global__ void lupKernel(T* compound, Nd4jLong* compoundShape, T* permutation, Nd4jLong* permutationShape, Nd4jLong rowNum) {
|
||||||
int swapCount = 0;
|
// int swapCount = 0;
|
||||||
for(int i = blockIdx.x; i < rowNum; i += gridDim.x ) {
|
// for(int i = blockIdx.x; i < rowNum; i += gridDim.x ) {
|
||||||
auto pivotValue = T(0.0);
|
// auto pivotValue = T(0.0);
|
||||||
auto pivot = -1;
|
// auto pivot = -1;
|
||||||
|
//
|
||||||
for(int rowCounter = i; rowCounter < rowNum; rowCounter++ ) {
|
// for(int rowCounter = i; rowCounter < rowNum; rowCounter++ ) {
|
||||||
Nd4jLong rowCoord[] = {rowCounter, i};
|
// Nd4jLong rowCoord[] = {rowCounter, i};
|
||||||
auto rowPos = shape::getOffset(0, shape::shapeOf(compoundShape), shape::stride(compoundShape), rowCoord, 2);
|
// auto rowPos = shape::getOffset(0, shape::shapeOf(compoundShape), shape::stride(compoundShape), rowCoord, 2);
|
||||||
if(nd4j::math::nd4j_abs(compound[rowPos]) > pivotValue ) {
|
// if(nd4j::math::nd4j_abs(compound[rowPos]) > pivotValue ) {
|
||||||
pivotValue = nd4j::math::nd4j_abs(compound[rowPos]);
|
// pivotValue = nd4j::math::nd4j_abs(compound[rowPos]);
|
||||||
pivot = rowCounter;
|
// pivot = rowCounter;
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
//
|
||||||
if( pivotValue != T(0.0) ) {
|
// if( pivotValue != T(0.0) ) {
|
||||||
swapRows_<T>(compound, compoundShape, pivot, i, rowNum);
|
// swapRows_<T>(compound, compoundShape, pivot, i, rowNum);
|
||||||
swapRows_<T>(permutation, permutationShape, pivot, i, rowNum);
|
// swapRows_<T>(permutation, permutationShape, pivot, i, rowNum);
|
||||||
if (pivot != i)
|
// if (pivot != i)
|
||||||
swapCount++;
|
// swapCount++;
|
||||||
|
//
|
||||||
for( int j = i + 1; j < rowNum; j++ ) {
|
// for( int j = i + 1; j < rowNum; j++ ) {
|
||||||
Nd4jLong posJIbuf[] = {j, i};
|
// Nd4jLong posJIbuf[] = {j, i};
|
||||||
Nd4jLong posIIbuf[] = {i, i};
|
// Nd4jLong posIIbuf[] = {i, i};
|
||||||
auto posJI = shape::getOffset(0, shape::shapeOf(compoundShape), shape::stride(compoundShape), posJIbuf, 2);
|
// auto posJI = shape::getOffset(0, shape::shapeOf(compoundShape), shape::stride(compoundShape), posJIbuf, 2);
|
||||||
auto posII = shape::getOffset(0, shape::shapeOf(compoundShape), shape::stride(compoundShape), posIIbuf, 2);
|
// auto posII = shape::getOffset(0, shape::shapeOf(compoundShape), shape::stride(compoundShape), posIIbuf, 2);
|
||||||
|
//
|
||||||
compound[posJI] /= compound[posII];
|
// compound[posJI] /= compound[posII];
|
||||||
for( int k = i + 1; k < rowNum; k++ ) {
|
// for( int k = i + 1; k < rowNum; k++ ) {
|
||||||
Nd4jLong posJKbuf[] = {j, k};
|
// Nd4jLong posJKbuf[] = {j, k};
|
||||||
Nd4jLong posIKbuf[] = {i, k};
|
// Nd4jLong posIKbuf[] = {i, k};
|
||||||
auto posJK = shape::getOffset(0, shape::shapeOf(compoundShape), shape::stride(compoundShape), posJKbuf, 2);
|
// auto posJK = shape::getOffset(0, shape::shapeOf(compoundShape), shape::stride(compoundShape), posJKbuf, 2);
|
||||||
auto posIK = shape::getOffset(0, shape::shapeOf(compoundShape), shape::stride(compoundShape), posIKbuf, 2);
|
// auto posIK = shape::getOffset(0, shape::shapeOf(compoundShape), shape::stride(compoundShape), posIKbuf, 2);
|
||||||
T arg = compound[posJI] * compound[posIK];
|
// T arg = compound[posJI] * compound[posIK];
|
||||||
compound[posJK] -= arg;
|
// compound[posJK] -= arg;
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
template <typename T, typename F>
|
template <typename T, typename F>
|
||||||
static __global__ void determinantKernel(T* compound, T* result, Nd4jLong len) {
|
static __global__ void determinantKernel(T* compound, T* result, Nd4jLong len) {
|
||||||
|
@ -332,6 +344,30 @@ namespace helpers {
|
||||||
matrix[j] = (F)inputBuf[xIndex];
|
matrix[j] = (F)inputBuf[xIndex];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T, typename F>
|
||||||
|
static __global__ void returnMatrix(void* output, Nd4jLong* outputShape, void* input, Nd4jLong* inputShape, Nd4jLong pos, Nd4jLong rowLen) {
|
||||||
|
__shared__ F* matrix;
|
||||||
|
__shared__ T* outputBuf;
|
||||||
|
__shared__ Nd4jLong outputLen;
|
||||||
|
__shared__ Nd4jLong n2;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
matrix = reinterpret_cast<F*>(input);
|
||||||
|
outputBuf = reinterpret_cast<T*>(output);
|
||||||
|
outputLen = shape::length(inputShape);
|
||||||
|
n2 = rowLen * rowLen;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
auto start = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
auto step = blockDim.x * gridDim.x;
|
||||||
|
|
||||||
|
for (int k = pos + start, j = start; j < n2; k += step, j += step) {
|
||||||
|
auto zIndex = shape::getIndexOffset(k, outputShape, outputLen);
|
||||||
|
outputBuf[zIndex] = (T)matrix[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename F>
|
template <typename F>
|
||||||
static __global__ void fillUpPermutation(void* output, Nd4jLong* shape, int* source, int rowNum) {
|
static __global__ void fillUpPermutation(void* output, Nd4jLong* shape, int* source, int rowNum) {
|
||||||
__shared__ F* permutation;
|
__shared__ F* permutation;
|
||||||
|
@ -462,7 +498,7 @@ namespace helpers {
|
||||||
d_work,
|
d_work,
|
||||||
permutationBuf,
|
permutationBuf,
|
||||||
d_info);
|
d_info);
|
||||||
fillUpPermutation<float><<<n, n, 128, *stream>>>(permutation->specialBuffer(), permutation->specialShapeInfo(), permutationBuf, n);
|
fillUpPermutation<T><<<n, n, 128, *stream>>>(permutation->specialBuffer(), permutation->specialShapeInfo(), permutationBuf, n);
|
||||||
permutation->tickWriteDevice();
|
permutation->tickWriteDevice();
|
||||||
}
|
}
|
||||||
err = cudaFree(d_work);
|
err = cudaFree(d_work);
|
||||||
|
@ -483,7 +519,7 @@ namespace helpers {
|
||||||
// NDArray::registerSpecialUse({input}, {input});
|
// NDArray::registerSpecialUse({input}, {input});
|
||||||
input->tickWriteDevice();
|
input->tickWriteDevice();
|
||||||
}
|
}
|
||||||
BUILD_SINGLE_TEMPLATE(template void lup_, (LaunchContext* context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void lup_, (LaunchContext* context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_NATIVE);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static int determinant_(nd4j::LaunchContext* context, NDArray* input, NDArray* output) {
|
static int determinant_(nd4j::LaunchContext* context, NDArray* input, NDArray* output) {
|
||||||
|
@ -504,32 +540,32 @@ namespace helpers {
|
||||||
output->assign(1.f);
|
output->assign(1.f);
|
||||||
for (int e = 0; e < output->lengthOf(); e++) {
|
for (int e = 0; e < output->lengthOf(); e++) {
|
||||||
Nd4jLong pos = e * n2;
|
Nd4jLong pos = e * n2;
|
||||||
if (matrix.dataType() == input->dataType())
|
// if (matrix.dataType() == input->dataType())
|
||||||
fillMatrix<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
fillMatrix<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
||||||
else
|
// else
|
||||||
fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
// fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
||||||
|
|
||||||
if (matrix.dataType() == input->dataType())
|
// if (matrix.dataType() == input->dataType())
|
||||||
lup_<T>(context, &matrix, nullptr, nullptr);
|
lup_<T>(context, &matrix, nullptr, nullptr);
|
||||||
else
|
// else
|
||||||
lup_<float>(context, &matrix, nullptr, nullptr);
|
// lup_<float>(context, &matrix, nullptr, nullptr);
|
||||||
auto offset = shape::getIndexOffset(e, output->shapeInfo(), output->lengthOf());
|
auto offset = shape::getIndexOffset(e, output->shapeInfo(), output->lengthOf());
|
||||||
auto inputBuf = reinterpret_cast<T*>(matrix.specialBuffer());
|
auto inputBuf = reinterpret_cast<T*>(matrix.specialBuffer());
|
||||||
auto outputBuf = reinterpret_cast<T*>(output->specialBuffer()) + offset;
|
auto outputBuf = reinterpret_cast<T*>(output->specialBuffer()) + offset;
|
||||||
if (matrix.dataType() == input->dataType())
|
// if (matrix.dataType() == input->dataType())
|
||||||
determinantKernel<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n);
|
determinantKernel<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n);
|
||||||
else
|
// else
|
||||||
determinantKernel<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n);
|
// determinantKernel<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n);
|
||||||
}
|
}
|
||||||
NDArray::registerSpecialUse({output}, {input});
|
NDArray::registerSpecialUse({output}, {input});
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template int determinant_, (nd4j::LaunchContext* context, NDArray* input, NDArray* output), FLOAT_TYPES);
|
BUILD_SINGLE_TEMPLATE(template int determinant_, (nd4j::LaunchContext* context, NDArray* input, NDArray* output), FLOAT_NATIVE);
|
||||||
|
|
||||||
int determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
int determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (context, input, output), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (context, input, output), FLOAT_NATIVE);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@ -552,22 +588,22 @@ namespace helpers {
|
||||||
output->assign(1.f);
|
output->assign(1.f);
|
||||||
for (int e = 0; e < output->lengthOf(); e++) {
|
for (int e = 0; e < output->lengthOf(); e++) {
|
||||||
Nd4jLong pos = e * n2;
|
Nd4jLong pos = e * n2;
|
||||||
if (matrix.dataType() == input->dataType())
|
// if (matrix.dataType() == input->dataType())
|
||||||
fillMatrix<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
fillMatrix<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
||||||
else
|
// else
|
||||||
fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
// fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
||||||
|
|
||||||
if (matrix.dataType() == input->dataType())
|
// if (matrix.dataType() == input->dataType())
|
||||||
lup_<T>(context, &matrix, nullptr, nullptr);
|
lup_<T>(context, &matrix, nullptr, nullptr);
|
||||||
else
|
// else
|
||||||
lup_<float>(context, &matrix, nullptr, nullptr);
|
// lup_<float>(context, &matrix, nullptr, nullptr);
|
||||||
auto offset = shape::getIndexOffset(e, output->shapeInfo(), output->lengthOf());
|
auto offset = shape::getIndexOffset(e, output->shapeInfo(), output->lengthOf());
|
||||||
auto inputBuf = reinterpret_cast<T*>(matrix.specialBuffer());
|
auto inputBuf = reinterpret_cast<T*>(matrix.specialBuffer());
|
||||||
auto outputBuf = reinterpret_cast<T*>(output->specialBuffer()) + offset;
|
auto outputBuf = reinterpret_cast<T*>(output->specialBuffer()) + offset;
|
||||||
if (matrix.dataType() == input->dataType())
|
// if (matrix.dataType() == input->dataType())
|
||||||
determinantLogKernel<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n);
|
determinantLogKernel<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n);
|
||||||
else
|
// else
|
||||||
determinantLogKernel<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n);
|
// determinantLogKernel<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n);
|
||||||
}
|
}
|
||||||
NDArray::registerSpecialUse({output}, {input});
|
NDArray::registerSpecialUse({output}, {input});
|
||||||
|
|
||||||
|
@ -576,10 +612,10 @@ namespace helpers {
|
||||||
return ND4J_STATUS_OK;
|
return ND4J_STATUS_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template int logAbsDeterminant_, (LaunchContext* context, NDArray* input, NDArray* output), FLOAT_TYPES);
|
BUILD_SINGLE_TEMPLATE(template int logAbsDeterminant_, (LaunchContext* context, NDArray* input, NDArray* output), FLOAT_NATIVE);
|
||||||
|
|
||||||
int logAbsDeterminant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
int logAbsDeterminant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (context, input, output), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (context, input, output), FLOAT_NATIVE);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@ -597,10 +633,12 @@ namespace helpers {
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
xShapeOf = shape::shapeOf(lowerShape);
|
xShapeOf = shape::shapeOf(lowerShape);
|
||||||
yShapeOf = shape::shapeOf(upperShape);
|
|
||||||
zShapeOf = shape::shapeOf(matrixShape);
|
|
||||||
xStrideOf = shape::stride(lowerShape);
|
xStrideOf = shape::stride(lowerShape);
|
||||||
|
|
||||||
|
yShapeOf = shape::shapeOf(upperShape);
|
||||||
yStrideOf = shape::stride(upperShape);
|
yStrideOf = shape::stride(upperShape);
|
||||||
|
|
||||||
|
zShapeOf = shape::shapeOf(matrixShape);
|
||||||
zStrideOf = shape::stride(matrixShape);
|
zStrideOf = shape::stride(matrixShape);
|
||||||
lowerMatrix = reinterpret_cast<T*>(lowerBuf);
|
lowerMatrix = reinterpret_cast<T*>(lowerBuf);
|
||||||
upperMatrix = reinterpret_cast<T*>(upperBuf);
|
upperMatrix = reinterpret_cast<T*>(upperBuf);
|
||||||
|
@ -610,15 +648,16 @@ namespace helpers {
|
||||||
|
|
||||||
for (int k = blockIdx.x; k < n; k += gridDim.x) { // and then put all values under main diagonal on to it
|
for (int k = blockIdx.x; k < n; k += gridDim.x) { // and then put all values under main diagonal on to it
|
||||||
for (int j = threadIdx.x; j < n; j += blockDim.x) {
|
for (int j = threadIdx.x; j < n; j += blockDim.x) {
|
||||||
Nd4jLong posX[] = {j, k};
|
Nd4jLong posX[] = {k, j};
|
||||||
|
Nd4jLong posD[] = {j, j};
|
||||||
auto xPos = shape::getOffset(0, xShapeOf, xStrideOf, posX, 2);
|
auto xPos = shape::getOffset(0, xShapeOf, xStrideOf, posX, 2);
|
||||||
auto yPos = shape::getOffset(0, yShapeOf, yStrideOf, posX, 2);
|
auto yPos = shape::getOffset(0, yShapeOf, yStrideOf, posX, 2);
|
||||||
auto pos = shape::getOffset(0, zShapeOf, zStrideOf, posX, 2);
|
auto iPos = shape::getOffset(0, zShapeOf, zStrideOf, posX, 2);
|
||||||
if (k <= j)
|
auto dPos = shape::getOffset(0, zShapeOf, zStrideOf, posD, 2);
|
||||||
lowerMatrix[xPos] = matrix[pos];//(k, j);
|
if (k >= j)
|
||||||
|
lowerMatrix[xPos] = matrix[iPos];//(k, j);
|
||||||
else
|
else
|
||||||
upperMatrix[yPos] = matrix[pos]; //k, j);
|
upperMatrix[yPos] = matrix[iPos]; //k, j);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -639,38 +678,26 @@ namespace helpers {
|
||||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {output->rankOf() - 2, output->rankOf() - 1});
|
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {output->rankOf() - 2, output->rankOf() - 1});
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
|
|
||||||
|
// PRAGMA_OMP_PARALLEL_FOR
|
||||||
for (auto i = 0LL; i < packX.numberOfTads(); i++) {
|
for (auto i = 0LL; i < packX.numberOfTads(); i++) {
|
||||||
fillMatrix<T, float><<<1, n2, 128, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), i * n2, n);
|
fillMatrix<T, T><<<1, n2, 128, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), i * n2, n);
|
||||||
permutation.assign(0.f);
|
|
||||||
lup_<float>(context, &matrix, &compound, &permutation);
|
|
||||||
matrix.tickWriteDevice();
|
matrix.tickWriteDevice();
|
||||||
permutation.tickWriteDevice();
|
compound.assign(matrix);
|
||||||
permutation.printIndexedBuffer("PERMUTE");
|
lup_<T>(context, &compound, nullptr, nullptr);
|
||||||
lower.setIdentity(); // set up U to identity matrix
|
fillLowerUpperKernel<T><<<n, n, 128>>>(lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), compound.specialBuffer(), compound.specialShapeInfo(), n);
|
||||||
upper.setIdentity();
|
matrix.assign(0);
|
||||||
fillLowerUpperKernel<float><<<1, n2, 128>>>(lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), matrix.specialBuffer(), matrix.specialShapeInfo(), n);
|
invertUpperMatrix(&upper, &matrix); // U^{-1}
|
||||||
lower.tickWriteDevice();
|
compound.assign(0);
|
||||||
upper.tickWriteDevice();
|
invertLowerMatrix(&lower, &compound); // L{-1}
|
||||||
invertUpperMatrix(&upper, &matrix);
|
|
||||||
invertLowerMatrix(&lower, &upper);
|
|
||||||
lower.tickWriteDevice();
|
|
||||||
upper.tickWriteDevice();
|
|
||||||
lower.printIndexedBuffer("LOWER");
|
|
||||||
upper.printIndexedBuffer("UPPER");
|
|
||||||
|
|
||||||
nd4j::MmulHelper::mmul(&matrix, &upper, &compound, 1.0, 0.0);
|
nd4j::MmulHelper::mmul(&matrix, &compound, &upper, 1.0, 0.0);
|
||||||
nd4j::MmulHelper::mmul(&compound, &permutation, &matrix, 1.0, 0.0);
|
returnMatrix<T, T><<<1, n2, 128, *stream>>>(output->specialBuffer(), output->specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), i * n2, n);
|
||||||
// for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) {
|
|
||||||
// output->t<T>(k) = matrix.template t<T>(row++);
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
int inverse(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
int inverse(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (context, input, output), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (context, input, output), FLOAT_NATIVE);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool checkCholeskyInput(nd4j::LaunchContext * context, NDArray const* input) {
|
bool checkCholeskyInput(nd4j::LaunchContext * context, NDArray const* input) {
|
||||||
|
@ -803,7 +830,7 @@ namespace helpers {
|
||||||
return cholesky_(context, input, output, inplace);
|
return cholesky_(context, input, output, inplace);
|
||||||
}
|
}
|
||||||
// BUILD_SINGLE_TEMPLATE(template int cholesky_, (LaunchContext* context, NDArray* input, NDArray* output, bool inplace), FLOAT_TYPES);
|
// BUILD_SINGLE_TEMPLATE(template int cholesky_, (LaunchContext* context, NDArray* input, NDArray* output, bool inplace), FLOAT_TYPES);
|
||||||
BUILD_SINGLE_TEMPLATE(template int inverse_, (nd4j::LaunchContext* context, NDArray* input, NDArray* output), FLOAT_TYPES);
|
BUILD_SINGLE_TEMPLATE(template int inverse_, (nd4j::LaunchContext* context, NDArray* input, NDArray* output), FLOAT_NATIVE);
|
||||||
|
|
||||||
__global__ void logDetKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong batchNum, Nd4jLong* tadShape, Nd4jLong* tadOffsets, void* outputBuf, Nd4jLong* outputShape) {
|
__global__ void logDetKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong batchNum, Nd4jLong* tadShape, Nd4jLong* tadOffsets, void* outputBuf, Nd4jLong* outputShape) {
|
||||||
__shared__ double* output;
|
__shared__ double* output;
|
||||||
|
|
|
@ -143,7 +143,7 @@ namespace helpers {
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static void _reverseSequence(nd4j::LaunchContext * context, const NDArray* input, const NDArray* seqLengths, NDArray* output, int seqDim, const int batchDim){
|
static void reverseSequence_(nd4j::LaunchContext * context, const NDArray* input, const NDArray* seqLengths, NDArray* output, int seqDim, const int batchDim){
|
||||||
int posOfNonUnityDim = -1;
|
int posOfNonUnityDim = -1;
|
||||||
seqLengths->syncToHost();
|
seqLengths->syncToHost();
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
|
@ -193,7 +193,7 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
|
|
||||||
void reverseSequence(nd4j::LaunchContext * context, const NDArray* input, const NDArray* seqLengths, NDArray* output, int seqDim, const int batchDim) {
|
void reverseSequence(nd4j::LaunchContext * context, const NDArray* input, const NDArray* seqLengths, NDArray* output, int seqDim, const int batchDim) {
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), _reverseSequence, (context, input, seqLengths, output, seqDim, batchDim), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR(input->dataType(), reverseSequence_, (context, input, seqLengths, output, seqDim, batchDim), LIBND4J_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
|
|
@ -391,7 +391,7 @@ static void scatterCudaLauncher(const int blocksPerGrid, const int threadsPerBlo
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
void scatter(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock) {
|
void scatter(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock) {
|
||||||
|
|
||||||
PointersManager manager(context, "scatterND");
|
PointersManager manager(context, "scatter");
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({&output}, {&updates, &indices});
|
NDArray::prepareSpecialUse({&output}, {&updates, &indices});
|
||||||
|
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,427 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author GS <sgazeos@gmail.com>
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/segment.h>
|
||||||
|
#include <ops/declarable/helpers/segment_common.h>
|
||||||
|
|
||||||
|
#include <NDArrayFactory.h>
|
||||||
|
#include <helpers/ShapeUtils.h>
|
||||||
|
#include <helpers/TAD.h>
|
||||||
|
#include <exceptions/cuda_exception.h>
|
||||||
|
#include <PointersManager.h>
|
||||||
|
#include <ConstantTadHelper.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
// Segment ops linear kernels
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
|
||||||
|
template<typename T, typename I>
|
||||||
|
static __global__ void
|
||||||
|
segmentMaxLinearKernel(void *input, Nd4jLong *inputShape, int *starts, int *lengths, Nd4jLong numOfClasses,
|
||||||
|
void *output, Nd4jLong *outputShape) {
|
||||||
|
__shared__
|
||||||
|
T *val;
|
||||||
|
__shared__
|
||||||
|
Nd4jLong xLen, zLen, segment, zIndex;
|
||||||
|
__shared__
|
||||||
|
T *x;
|
||||||
|
__shared__
|
||||||
|
T *z;
|
||||||
|
__shared__ int threadsPerSegment, start, finish;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses;
|
||||||
|
segment = blockIdx.x / threadsPerSegment;
|
||||||
|
x = reinterpret_cast<T *>(input);
|
||||||
|
z = reinterpret_cast<T *>(output);
|
||||||
|
extern __shared__ unsigned char shmem[];
|
||||||
|
val = reinterpret_cast<T *>(shmem);
|
||||||
|
xLen = shape::length(inputShape);
|
||||||
|
zLen = shape::length(outputShape);
|
||||||
|
|
||||||
|
if (segment < numOfClasses) {
|
||||||
|
zIndex = shape::getIndexOffset(segment, outputShape, zLen);
|
||||||
|
start = starts[segment];
|
||||||
|
finish = start + lengths[segment];
|
||||||
|
z[zIndex] = x[shape::getIndexOffset(start, inputShape, xLen)];
|
||||||
|
val[segment] = z[zIndex];
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (auto e = start + threadIdx.x + 1; e < finish; e += blockDim.x) {
|
||||||
|
auto xIndex = shape::getIndexOffset(e, inputShape, xLen);
|
||||||
|
nd4j::math::atomics::nd4j_atomicMax(&z[zIndex], x[xIndex]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
|
||||||
|
template<typename T, typename I>
|
||||||
|
static __global__ void
|
||||||
|
unsortedSegmentMaxLinearKernel(void *input, Nd4jLong *inputShape, void *indices, Nd4jLong *indicesShape,
|
||||||
|
int *starts, int *lengths, Nd4jLong numOfClasses, void *output,
|
||||||
|
Nd4jLong *outputShape) {
|
||||||
|
__shared__
|
||||||
|
T *val;
|
||||||
|
__shared__
|
||||||
|
Nd4jLong xLen, zLen, segment, zIndex;
|
||||||
|
__shared__
|
||||||
|
T *x;
|
||||||
|
__shared__
|
||||||
|
T *z;
|
||||||
|
__shared__
|
||||||
|
I *y; //int threadsPerSegment, start, finish;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
segment = blockIdx.x;
|
||||||
|
x = reinterpret_cast<T *>(input);
|
||||||
|
z = reinterpret_cast<T *>(output);
|
||||||
|
y = reinterpret_cast<I *>(indices);
|
||||||
|
xLen = shape::length(inputShape);
|
||||||
|
zLen = shape::length(outputShape);
|
||||||
|
|
||||||
|
zIndex = shape::getIndexOffset(segment, outputShape, zLen);
|
||||||
|
//start = starts[segment];
|
||||||
|
//finish = start + lengths[segment];
|
||||||
|
if (lengths[segment] > 0)
|
||||||
|
z[zIndex] = x[shape::getIndexOffset(starts[segment], inputShape, xLen)];
|
||||||
|
else
|
||||||
|
z[zIndex] = -DataTypeUtils::max<T>();
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
if (lengths[segment] > 0)
|
||||||
|
for (auto e = threadIdx.x + 1; e < xLen; e += blockDim.x) {
|
||||||
|
auto xIndex = shape::getIndexOffset(e, inputShape, xLen);
|
||||||
|
auto yIndex = shape::getIndexOffset(e, indicesShape, xLen);
|
||||||
|
if (y[yIndex] == segment) {
|
||||||
|
nd4j::math::atomics::nd4j_atomicMax(&z[zIndex], x[xIndex]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
template <typename T, typename I>
|
||||||
|
static __global__ void segmentMaxTadKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong* inputTads,
|
||||||
|
Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf,
|
||||||
|
Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets, T filler = 0) {
|
||||||
|
|
||||||
|
__shared__ T* val;
|
||||||
|
__shared__ Nd4jLong len, segment, zIndex, total;
|
||||||
|
__shared__ T* z;
|
||||||
|
__shared__ int start, finish;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
segment = indices[blockIdx.x]; // / threadsPerSegment;
|
||||||
|
z = reinterpret_cast<T*>(outputBuf) + outputTadOffsets[segment];
|
||||||
|
len = shape::length(inputTads);
|
||||||
|
|
||||||
|
start = starts[segment];
|
||||||
|
finish = start + lengths[segment];
|
||||||
|
total = shape::sizeAt(inputShape, 0);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
auto idx = blockIdx.x;
|
||||||
|
if (blockIdx.x <= total) {
|
||||||
|
auto x = reinterpret_cast<T *>(inputBuf) + inputTadOffsets[idx];
|
||||||
|
if (blockIdx.x == start) {
|
||||||
|
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
||||||
|
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
||||||
|
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
||||||
|
z[zIndex] = x[xIndex];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
||||||
|
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
||||||
|
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
||||||
|
nd4j::math::atomics::nd4j_atomicMax(&z[zIndex], x[xIndex]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
|
||||||
|
template <typename T, typename I>
|
||||||
|
static void segmentMaxFunctor_(LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) {
|
||||||
|
//int numClasses = output->sizeAt(0);
|
||||||
|
// if input is a vector: (as if in doc sample)
|
||||||
|
//Nd4jLong idx = indices->e<Nd4jLong>(0);
|
||||||
|
auto stream = context->getCudaStream();
|
||||||
|
indices->syncToHost();
|
||||||
|
Nd4jLong numOfClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
|
||||||
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
|
||||||
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses});
|
||||||
|
|
||||||
|
classesRangesBegs.assign(indices->lengthOf());
|
||||||
|
classesRangesLens.assign(0);
|
||||||
|
dim3 dims(256, 512, 256);
|
||||||
|
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
|
||||||
|
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
|
||||||
|
fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens);
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({output}, {input, indices, &classesRangesBegs, &classesRangesLens});
|
||||||
|
|
||||||
|
if (input->isVector()) {
|
||||||
|
|
||||||
|
segmentMaxLinearKernel<T,I><<<numOfClasses, input->lengthOf(), numOfClasses * 32 + 32, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||||
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||||
|
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||||
|
Nd4jLong* inputTads = packX.specialShapeInfo();
|
||||||
|
Nd4jLong* inputTadOffsets = packX.specialOffsets();
|
||||||
|
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||||
|
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||||
|
segmentMaxTadKernel<T,I><<<packX.numberOfTads(), 512, 2048, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
|
||||||
|
}
|
||||||
|
NDArray::registerSpecialUse({output}, {input, indices, &classesRangesBegs, &classesRangesLens});
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
void segmentMaxFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) {
|
||||||
|
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentMaxFunctor_, (context, input, indices, output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||||
|
}
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template void segmentMaxFunctor_, (LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
|
||||||
|
template <typename T, typename I>
|
||||||
|
static void unsortedSegmentMaxFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
|
auto stream = context->getCudaStream();
|
||||||
|
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
|
||||||
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses});
|
||||||
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
|
||||||
|
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
|
||||||
|
// classes.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), &row, &classes);
|
||||||
|
classesRangesBegs.assign(indices->lengthOf());
|
||||||
|
classesRangesLens.assign(0);
|
||||||
|
dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32);
|
||||||
|
// int* classesBuf = reinterpret_cast<int*>(classes.specialBuffer());
|
||||||
|
fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens);
|
||||||
|
int* begins = reinterpret_cast<int*>(classesRangesBegs.getSpecialBuffer());
|
||||||
|
int* lengths = reinterpret_cast<int*>(classesRangesLens.getSpecialBuffer());
|
||||||
|
|
||||||
|
if (input->isVector()) {
|
||||||
|
unsortedSegmentMaxLinearKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||||
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||||
|
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||||
|
Nd4jLong* inputTads = packX.specialShapeInfo();
|
||||||
|
Nd4jLong* inputTadOffsets = packX.specialOffsets();
|
||||||
|
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||||
|
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||||
|
dims.x = input->sizeAt(0);
|
||||||
|
output->assign(-DataTypeUtils::max<T>());
|
||||||
|
segmentMaxTadKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
void unsortedSegmentMaxFunctor(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
|
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMaxFunctor_, (context, input, indices, numOfClasses, output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template void unsortedSegmentMaxFunctor_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
// segment max
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
template <typename T, typename I>
|
||||||
|
static __global__ void segmentMaxBPLinearKernel(void* inputBuf, Nd4jLong* inputShape, void* forwardOutput,
|
||||||
|
Nd4jLong* forwardShape, void* eps, Nd4jLong* epsShape, void* indicesBuf, Nd4jLong* indicesShape,
|
||||||
|
void* outputBuf, Nd4jLong* outputShape) {
|
||||||
|
__shared__ T* x;
|
||||||
|
__shared__ T* gradIn;
|
||||||
|
__shared__ T* gradOut;
|
||||||
|
__shared__ I* y;
|
||||||
|
__shared__ T* z;
|
||||||
|
__shared__ Nd4jLong xLen, gradLen;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
xLen = shape::length(inputShape);
|
||||||
|
x = reinterpret_cast<T*>(inputBuf);
|
||||||
|
y = reinterpret_cast<I*>(indicesBuf);
|
||||||
|
z = reinterpret_cast<T*>(outputBuf);
|
||||||
|
gradIn = reinterpret_cast<T*>(forwardOutput);
|
||||||
|
gradOut = reinterpret_cast<T*>(eps);
|
||||||
|
gradLen = shape::length(epsShape);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto start = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
auto step = gridDim.x * blockDim.x;
|
||||||
|
|
||||||
|
for (auto e = start; e < xLen; e += step) {
|
||||||
|
|
||||||
|
auto zOffset = shape::getIndexOffset(e, outputShape, xLen);
|
||||||
|
auto xOffset = shape::getIndexOffset(e, inputShape, xLen);
|
||||||
|
auto yOffset = shape::getIndexOffset(e, indicesShape, xLen);
|
||||||
|
auto classIndex = y[yOffset];
|
||||||
|
auto gradOffsetI = shape::getIndexOffset(classIndex, forwardShape, gradLen);
|
||||||
|
auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape, gradLen);
|
||||||
|
|
||||||
|
if (nd4j::math::nd4j_abs(gradIn[gradOffsetI] - x[xOffset]) <= T(1.e-6)) {
|
||||||
|
z[zOffset] = gradOut[gradOffsetO];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
template <typename T, typename I>
|
||||||
|
static __global__ void segmentMaxBPTadKernel(void* inputBuf, Nd4jLong* inputShape, void* forwardOutput,
|
||||||
|
Nd4jLong* forwardShape, void* eps, Nd4jLong* epsShape, void* indicesBuf, Nd4jLong* indicesShape,
|
||||||
|
void* outputBuf, Nd4jLong* outputShape,Nd4jLong* inputTad,
|
||||||
|
Nd4jLong* inputOffsets, Nd4jLong* gradInTad, Nd4jLong* gradInOffsets,
|
||||||
|
Nd4jLong* gradOutTad, Nd4jLong* gradOutOffsets, Nd4jLong* outTad,
|
||||||
|
Nd4jLong* outOffsets) {
|
||||||
|
__shared__ T* x;
|
||||||
|
__shared__ T* gradIn;
|
||||||
|
__shared__ T* gradOut;
|
||||||
|
__shared__ I* y;
|
||||||
|
__shared__ T* z;
|
||||||
|
__shared__ Nd4jLong xLen, yLen, gradLen, currentLen;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
xLen = shape::length(inputShape);
|
||||||
|
x = reinterpret_cast<T*>(inputBuf);
|
||||||
|
y = reinterpret_cast<I*>(indicesBuf);
|
||||||
|
z = reinterpret_cast<T*>(outputBuf);
|
||||||
|
yLen = shape::length(indicesShape);
|
||||||
|
gradOut = reinterpret_cast<T*>(eps);
|
||||||
|
gradIn = reinterpret_cast<T*>(forwardOutput);
|
||||||
|
gradLen = shape::length(epsShape);
|
||||||
|
currentLen = shape::length(outTad);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto i = blockIdx.x; i < yLen; i += gridDim.x) {
|
||||||
|
auto yIndex = shape::getIndexOffset(i, indicesShape, yLen);
|
||||||
|
auto segment = y[yIndex];
|
||||||
|
T* current = x + inputOffsets[i];
|
||||||
|
T* currentOut = z + outOffsets[i];
|
||||||
|
T* in = gradIn + gradInOffsets[segment];
|
||||||
|
T* outGrad = gradOut + gradOutOffsets[segment];
|
||||||
|
|
||||||
|
for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) {
|
||||||
|
if (nd4j::math::nd4j_abs(in[e] - current[e]) <= T(1.e-6))
|
||||||
|
currentOut[e] = outGrad[e];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
template <typename T, typename I>
|
||||||
|
int segmentMaxFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
|
||||||
|
//int numOfClasses = gradOut->sizeAt(0);
|
||||||
|
// if input is a vector: (as if in doc sample)
|
||||||
|
auto stream = context->getCudaStream();
|
||||||
|
NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), DataTypeUtils::fromT<T>(), context);//->shapeInfo(), context);
|
||||||
|
segmentMaxFunctor_<T, I>(context, input, indices, &tempRes);
|
||||||
|
NDArray::prepareSpecialUse({output}, {input, indices, gradOut, &tempRes});
|
||||||
|
if (input->isVector()) {
|
||||||
|
Nd4jLong loop_size = input->lengthOf();
|
||||||
|
auto numOfClasses = gradOut->lengthOf(); //indices->e<Nd4jLong>(loop_size - 1);
|
||||||
|
segmentMaxBPLinearKernel<T,I><<<1 + gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
|
||||||
|
tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
|
||||||
|
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||||
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||||
|
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||||
|
auto packGradIn = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.getShapeInfo(), dimensions);
|
||||||
|
auto packGradOut = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->getShapeInfo(), dimensions);
|
||||||
|
Nd4jLong* inputTads = packX.specialShapeInfo();
|
||||||
|
Nd4jLong* inputTadOffsets = packX.specialOffsets();
|
||||||
|
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||||
|
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||||
|
Nd4jLong* gradInTads = packGradIn.specialShapeInfo();
|
||||||
|
Nd4jLong* gradInTadOffsets = packGradIn.specialOffsets();
|
||||||
|
Nd4jLong* gradOutTads = packGradOut.specialShapeInfo();
|
||||||
|
Nd4jLong* gradOutTadOffsets = packGradOut.specialOffsets();
|
||||||
|
|
||||||
|
segmentMaxBPTadKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
|
||||||
|
tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
|
||||||
|
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(),
|
||||||
|
inputTads, inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, gradOutTadOffsets,
|
||||||
|
outputTads, outputTadOffsets);
|
||||||
|
}
|
||||||
|
NDArray::registerSpecialUse({output}, {input, indices, gradOut, &tempRes});
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
int segmentMaxFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
|
||||||
|
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentMaxFunctorBP_, (context, input,
|
||||||
|
indices, gradOut, output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template int segmentMaxFunctorBP_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||||
|
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
template <typename T, typename I>
|
||||||
|
static int unsortedSegmentMaxFunctorBP_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
|
//int numOfClasses = gradOut->sizeAt(0);
|
||||||
|
// if input is a vector: (as if in doc sample)
|
||||||
|
auto stream = context->getCudaStream();
|
||||||
|
NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), DataTypeUtils::fromT<T>(), context);//->shapeInfo(), context);
|
||||||
|
unsortedSegmentMaxFunctor_<T, I>(context, input, indices, numOfClasses, &tempRes);
|
||||||
|
NDArray::prepareSpecialUse({output}, {input, indices, gradOut, &tempRes});
|
||||||
|
if (input->isVector()) {
|
||||||
|
Nd4jLong loop_size = input->lengthOf();
|
||||||
|
auto numOfClasses = gradOut->lengthOf(); //indices->e<Nd4jLong>(loop_size - 1);
|
||||||
|
segmentMaxBPLinearKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
|
||||||
|
tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
|
||||||
|
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||||
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||||
|
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||||
|
auto packGradIn = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.getShapeInfo(), dimensions);
|
||||||
|
auto packGradOut = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->getShapeInfo(), dimensions);
|
||||||
|
Nd4jLong* inputTads = packX.specialShapeInfo();
|
||||||
|
Nd4jLong* inputTadOffsets = packX.specialOffsets();
|
||||||
|
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||||
|
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||||
|
Nd4jLong* gradInTads = packGradIn.specialShapeInfo();
|
||||||
|
Nd4jLong* gradInTadOffsets = packGradIn.specialOffsets();
|
||||||
|
Nd4jLong* gradOutTads = packGradOut.specialShapeInfo();
|
||||||
|
Nd4jLong* gradOutTadOffsets = packGradOut.specialOffsets();
|
||||||
|
|
||||||
|
segmentMaxBPTadKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
|
||||||
|
tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
|
||||||
|
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(),
|
||||||
|
inputTads, inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, gradOutTadOffsets,
|
||||||
|
outputTads, outputTadOffsets);
|
||||||
|
}
|
||||||
|
NDArray::registerSpecialUse({output}, {input, indices, gradOut, &tempRes});
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
int unsortedSegmentMaxFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
|
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMaxFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template int unsortedSegmentMaxFunctorBP_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,414 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author GS <sgazeos@gmail.com>
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/segment.h>
|
||||||
|
#include <ops/declarable/helpers/segment_common.h>
|
||||||
|
#include <NDArrayFactory.h>
|
||||||
|
#include <helpers/ShapeUtils.h>
|
||||||
|
#include <helpers/TAD.h>
|
||||||
|
#include <exceptions/cuda_exception.h>
|
||||||
|
#include <PointersManager.h>
|
||||||
|
#include <ConstantTadHelper.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
// Segment ops linear kernels
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
template <typename T, typename I>
|
||||||
|
static __global__ void segmentMeanLinearKernel(void* input, Nd4jLong* inputShape, int* starts, int* lengths, Nd4jLong numOfClasses, void* output, Nd4jLong* outputShape) {
|
||||||
|
__shared__ T* val;
|
||||||
|
__shared__ Nd4jLong xLen, zLen, segment, zIndex;
|
||||||
|
__shared__ T* x;
|
||||||
|
__shared__ T* z;
|
||||||
|
__shared__ int threadsPerSegment, start, finish;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses;
|
||||||
|
segment = blockIdx.x / threadsPerSegment;
|
||||||
|
x = reinterpret_cast<T*>(input);
|
||||||
|
z = reinterpret_cast<T*>(output);
|
||||||
|
// extern __shared__ unsigned char shmem[];
|
||||||
|
// val = reinterpret_cast<T*>(shmem);
|
||||||
|
xLen = shape::length(inputShape);
|
||||||
|
zLen = shape::length(outputShape);
|
||||||
|
|
||||||
|
//[zIndex] =
|
||||||
|
if (segment < numOfClasses) {
|
||||||
|
zIndex = shape::getIndexOffset(segment, outputShape, zLen);
|
||||||
|
start = starts[segment];
|
||||||
|
finish = start + lengths[segment];
|
||||||
|
//val[segment] = ;
|
||||||
|
z[zIndex] = T(x[shape::getIndexOffset(start, inputShape, xLen)] / lengths[segment]);
|
||||||
|
// val[segment] = z[zIndex];
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (auto e = start + threadIdx.x + 1; e < finish; e += blockDim.x) {
|
||||||
|
auto xIndex = shape::getIndexOffset(e, inputShape, xLen);
|
||||||
|
if (lengths[segment])
|
||||||
|
nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], T(x[xIndex] / lengths[segment]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
template <typename T, typename I>
|
||||||
|
static __global__ void unsortedSegmentMeanLinearKernel(void* input, Nd4jLong* inputShape, void* indices, Nd4jLong* indicesShape, int* starts, int* lengths, Nd4jLong numOfClasses, void* output, Nd4jLong* outputShape) {
|
||||||
|
__shared__ T* val;
|
||||||
|
__shared__ Nd4jLong xLen, zLen, segment, zIndex;
|
||||||
|
__shared__ T* x;
|
||||||
|
__shared__ T* z;
|
||||||
|
__shared__ I* y; //int threadsPerSegment, start, finish;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
// threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses;
|
||||||
|
segment = blockIdx.x;// / threadsPerSegment;
|
||||||
|
x = reinterpret_cast<T*>(input);
|
||||||
|
z = reinterpret_cast<T*>(output);
|
||||||
|
y = reinterpret_cast<I*>(indices);
|
||||||
|
// extern __shared__ unsigned char shmem[];
|
||||||
|
// val = reinterpret_cast<T*>(shmem);
|
||||||
|
xLen = shape::length(inputShape);
|
||||||
|
zLen = shape::length(outputShape);
|
||||||
|
|
||||||
|
// if (segment < numOfClasses) {
|
||||||
|
zIndex = shape::getIndexOffset(segment, outputShape, zLen);
|
||||||
|
//start = starts[segment];
|
||||||
|
//finish = start + lengths[segment];
|
||||||
|
if (lengths[segment] > 0)
|
||||||
|
z[zIndex] = T(x[shape::getIndexOffset(starts[segment], inputShape, xLen)] / T(lengths[segment]));
|
||||||
|
else
|
||||||
|
z[zIndex] = 0; //DataTypeUtils::max<T>();
|
||||||
|
// val[segment] = z[zIndex];
|
||||||
|
// }
|
||||||
|
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
if (lengths[segment] > 0)
|
||||||
|
for (auto e = threadIdx.x; e < xLen; e += blockDim.x) {
|
||||||
|
auto xIndex = shape::getIndexOffset(e, inputShape, xLen);
|
||||||
|
auto yIndex = shape::getIndexOffset(e, indicesShape, xLen);
|
||||||
|
if (y[yIndex] == segment && e != starts[segment]) {
|
||||||
|
nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], T(x[xIndex]/T(lengths[segment])));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
// SegmentMean kernel
|
||||||
|
template <typename T, typename I>
|
||||||
|
static __global__ void segmentMeanTadKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong* inputTads, Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets) {
|
||||||
|
__shared__ T* val;
|
||||||
|
__shared__ Nd4jLong len, segment, zIndex, total;
|
||||||
|
__shared__ T* z;
|
||||||
|
__shared__ int threadsPerSegment, start, finish;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
segment = indices[blockIdx.x]; // / threadsPerSegment;
|
||||||
|
z = reinterpret_cast<T*>(outputBuf) + outputTadOffsets[segment];
|
||||||
|
len = shape::length(inputTads);
|
||||||
|
start = starts[segment];
|
||||||
|
finish = start + lengths[segment];
|
||||||
|
total = shape::sizeAt(inputShape, 0);
|
||||||
|
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
auto idx = blockIdx.x;
|
||||||
|
if (blockIdx.x <= total) {
|
||||||
|
auto x = reinterpret_cast<T *>(inputBuf) + inputTadOffsets[idx];
|
||||||
|
if (blockIdx.x == start) {
|
||||||
|
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
||||||
|
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
||||||
|
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
||||||
|
z[zIndex] = T(x[xIndex]/lengths[segment]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
||||||
|
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
||||||
|
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
||||||
|
if (lengths[segment])
|
||||||
|
nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], T(x[xIndex]/lengths[segment]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
// segmen mean
|
||||||
|
template <typename T, typename I>
|
||||||
|
static void segmentMeanFunctor_(LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) {
|
||||||
|
auto stream = context->getCudaStream();
|
||||||
|
Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
|
||||||
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses});
|
||||||
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses});
|
||||||
|
|
||||||
|
classesRangesBegs.assign(indices->lengthOf());
|
||||||
|
classesRangesLens.assign(0);
|
||||||
|
|
||||||
|
dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32);
|
||||||
|
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
|
||||||
|
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
|
||||||
|
fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens);
|
||||||
|
|
||||||
|
if (input->isVector()) {
|
||||||
|
segmentMeanLinearKernel<T,I><<<numClasses, input->lengthOf(), numClasses * 32 + 32, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||||
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||||
|
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||||
|
Nd4jLong* inputTads = packX.specialShapeInfo();
|
||||||
|
Nd4jLong* inputTadOffsets = packX.specialOffsets();
|
||||||
|
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||||
|
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||||
|
segmentMeanTadKernel<T,I><<<input->sizeAt(0), 512, 2048, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
void segmentMeanFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) {
|
||||||
|
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), segmentMeanFunctor_, (context, input, indices, output), FLOAT_TYPES, INTEGER_TYPES);
|
||||||
|
}
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template void segmentMeanFunctor_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output), FLOAT_TYPES, INTEGER_TYPES);
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
template <typename T, typename I>
|
||||||
|
static void unsortedSegmentMeanFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
|
auto stream = context->getCudaStream();
|
||||||
|
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
|
||||||
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses});
|
||||||
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
|
||||||
|
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
|
||||||
|
// classes.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), &row, &classes);
|
||||||
|
classesRangesBegs.assign(indices->lengthOf());
|
||||||
|
classesRangesLens.assign(0);
|
||||||
|
dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32);
|
||||||
|
// int* classesBuf = reinterpret_cast<int*>(classes.specialBuffer());
|
||||||
|
fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens);
|
||||||
|
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
|
||||||
|
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
|
||||||
|
|
||||||
|
if (input->isVector()) {
|
||||||
|
unsortedSegmentMeanLinearKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
output->assign(0);
|
||||||
|
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||||
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||||
|
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||||
|
Nd4jLong* inputTads = packX.specialShapeInfo();
|
||||||
|
Nd4jLong* inputTadOffsets = packX.specialOffsets();
|
||||||
|
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||||
|
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||||
|
dims.x = input->sizeAt(0);
|
||||||
|
segmentMeanTadKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
void unsortedSegmentMeanFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
|
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMeanFunctor_, (context, input, indices, numOfClasses, output),
|
||||||
|
FLOAT_TYPES, INTEGER_TYPES);
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template void unsortedSegmentMeanFunctor_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), FLOAT_TYPES, INTEGER_TYPES);
|
||||||
|
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
template <typename T, typename I>
|
||||||
|
static __global__ void segmentMeanBPLinearKernel(void* inputBuf, Nd4jLong* inputShape, void* eps, Nd4jLong* epsShape, void* indicesBuf, Nd4jLong* indicesShape,
|
||||||
|
int* lengths, void* outputBuf, Nd4jLong* outputShape) {
|
||||||
|
__shared__ T* x;
|
||||||
|
__shared__ T* gradIn;
|
||||||
|
__shared__ T* gradOut;
|
||||||
|
__shared__ I* y;
|
||||||
|
__shared__ T* z;
|
||||||
|
__shared__ Nd4jLong xLen, gradLen;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
xLen = shape::length(inputShape);
|
||||||
|
x = reinterpret_cast<T*>(inputBuf);
|
||||||
|
y = reinterpret_cast<I*>(indicesBuf);
|
||||||
|
z = reinterpret_cast<T*>(outputBuf);
|
||||||
|
gradOut = reinterpret_cast<T*>(eps);
|
||||||
|
gradLen = shape::length(epsShape);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto start = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
auto step = gridDim.x * blockDim.x;
|
||||||
|
|
||||||
|
for (auto e = start; e < xLen; e += step) {
|
||||||
|
|
||||||
|
auto zOffset = shape::getIndexOffset(e, outputShape, xLen);
|
||||||
|
auto xOffset = shape::getIndexOffset(e, inputShape, xLen);
|
||||||
|
auto yOffset = shape::getIndexOffset(e, indicesShape, xLen);
|
||||||
|
auto classIndex = y[yOffset];
|
||||||
|
auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape, gradLen);
|
||||||
|
|
||||||
|
z[zOffset] = T(gradOut[gradOffsetO] / float(lengths[classIndex]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
template <typename T, typename I>
|
||||||
|
static __global__ void segmentMeanBPTadKernel(void* inputBuf, Nd4jLong* inputShape, void* eps, Nd4jLong* epsShape,
|
||||||
|
void* indicesBuf, Nd4jLong* indicesShape, int* lengths, void* outputBuf, Nd4jLong* outputShape,Nd4jLong* inputTad,
|
||||||
|
Nd4jLong* inputOffsets, Nd4jLong* gradOutTad, Nd4jLong* gradOutOffsets, Nd4jLong* outTad, Nd4jLong* outOffsets) {
|
||||||
|
__shared__ T* x;
|
||||||
|
__shared__ T* gradOut;
|
||||||
|
__shared__ I* y;
|
||||||
|
__shared__ T* z;
|
||||||
|
__shared__ Nd4jLong xLen, yLen, gradLen, currentLen;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
xLen = shape::length(inputShape);
|
||||||
|
x = reinterpret_cast<T*>(inputBuf);
|
||||||
|
y = reinterpret_cast<I*>(indicesBuf);
|
||||||
|
z = reinterpret_cast<T*>(outputBuf);
|
||||||
|
yLen = shape::length(indicesShape);
|
||||||
|
gradOut = reinterpret_cast<T*>(eps);
|
||||||
|
gradLen = shape::length(epsShape);
|
||||||
|
currentLen = shape::length(outTad);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (auto i = blockIdx.x; i < yLen; i += gridDim.x) {
|
||||||
|
// auto yIndex = shape::getIndexOffset(i, indicesShape, yLen);
|
||||||
|
auto segment = y[i]; //yIndex];
|
||||||
|
T* currentOut = z + outOffsets[i];
|
||||||
|
T* outGrad = gradOut + gradOutOffsets[segment];
|
||||||
|
|
||||||
|
for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) {
|
||||||
|
auto zIndex = shape::getIndexOffset(e, outTad, currentLen);
|
||||||
|
auto gradIndex = shape::getIndexOffset(e, gradOutTad, gradLen);
|
||||||
|
if (lengths[segment] > 0)
|
||||||
|
currentOut[zIndex] = T(outGrad[gradIndex] / float(lengths[segment]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
// backrop for mean
|
||||||
|
template <typename T, typename I>
|
||||||
|
int segmentMeanFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
|
||||||
|
auto stream = context->getCudaStream();
|
||||||
|
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
|
||||||
|
auto numClasses = indices->e<int>(indices->lengthOf() - 1) + 1;
|
||||||
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses});
|
||||||
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses});
|
||||||
|
|
||||||
|
classesRangesBegs.assign(indices->lengthOf());
|
||||||
|
classesRangesLens.assign(0);
|
||||||
|
dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32);
|
||||||
|
fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens);
|
||||||
|
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
|
||||||
|
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
|
||||||
|
|
||||||
|
if (input->isVector()) {
|
||||||
|
Nd4jLong loop_size = input->lengthOf();
|
||||||
|
auto numOfClasses = gradOut->lengthOf(); //indices->e<Nd4jLong>(loop_size - 1);
|
||||||
|
segmentMeanBPLinearKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(),
|
||||||
|
input->specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
|
||||||
|
indices->specialBuffer(), indices->specialShapeInfo(), lengths, output->specialBuffer(), output->specialShapeInfo());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||||
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||||
|
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||||
|
// auto packGradIn = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.getShapeInfo(), dimensions);
|
||||||
|
auto packGradOut = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->getShapeInfo(), dimensions);
|
||||||
|
Nd4jLong* inputTads = packX.specialShapeInfo();
|
||||||
|
Nd4jLong* inputTadOffsets = packX.specialOffsets();
|
||||||
|
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||||
|
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||||
|
Nd4jLong* gradOutTads = packGradOut.specialShapeInfo();
|
||||||
|
Nd4jLong* gradOutTadOffsets = packGradOut.specialOffsets();
|
||||||
|
|
||||||
|
segmentMeanBPTadKernel<T,I><<<indices->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
|
||||||
|
gradOut->specialBuffer(), gradOut->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), lengths,
|
||||||
|
output->specialBuffer(), output->specialShapeInfo(), inputTads, inputTadOffsets, gradOutTads, gradOutTadOffsets,
|
||||||
|
outputTads, outputTadOffsets);
|
||||||
|
}
|
||||||
|
NDArray::registerSpecialUse({output}, {input, indices, gradOut});
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
// segmen mean bp main
|
||||||
|
int segmentMeanFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
|
||||||
|
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentMeanFunctorBP_, (context, input,
|
||||||
|
indices, gradOut, output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template int segmentMeanFunctorBP_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output), FLOAT_TYPES, INTEGER_TYPES);
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
|
||||||
|
template <typename T, typename I>
|
||||||
|
static int unsortedSegmentMeanFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
|
auto stream = context->getCudaStream();
|
||||||
|
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
|
||||||
|
auto numClasses = indices->e<int>(indices->lengthOf() - 1) + 1;
|
||||||
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses});
|
||||||
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses});
|
||||||
|
|
||||||
|
classesRangesBegs.assign(indices->lengthOf());
|
||||||
|
classesRangesLens.assign(0);
|
||||||
|
dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32);
|
||||||
|
fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens);
|
||||||
|
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
|
||||||
|
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
|
||||||
|
|
||||||
|
if (input->isVector()) {
|
||||||
|
Nd4jLong loop_size = input->lengthOf();
|
||||||
|
auto numOfClasses = gradOut->lengthOf(); //indices->e<Nd4jLong>(loop_size - 1);
|
||||||
|
segmentMeanBPLinearKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(),
|
||||||
|
input->specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
|
||||||
|
indices->specialBuffer(), indices->specialShapeInfo(), lengths, output->specialBuffer(), output->specialShapeInfo());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||||
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||||
|
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||||
|
// auto packGradIn = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.getShapeInfo(), dimensions);
|
||||||
|
auto packGradOut = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->getShapeInfo(), dimensions);
|
||||||
|
Nd4jLong* inputTads = packX.specialShapeInfo();
|
||||||
|
Nd4jLong* inputTadOffsets = packX.specialOffsets();
|
||||||
|
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||||
|
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||||
|
Nd4jLong* gradOutTads = packGradOut.specialShapeInfo();
|
||||||
|
Nd4jLong* gradOutTadOffsets = packGradOut.specialOffsets();
|
||||||
|
|
||||||
|
segmentMeanBPTadKernel<T,I><<<indices->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
|
||||||
|
gradOut->specialBuffer(), gradOut->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), lengths,
|
||||||
|
output->specialBuffer(), output->specialShapeInfo(), inputTads, inputTadOffsets, gradOutTads, gradOutTadOffsets,
|
||||||
|
outputTads, outputTadOffsets);
|
||||||
|
}
|
||||||
|
NDArray::registerSpecialUse({output}, {input, indices, gradOut});
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
int unsortedSegmentMeanFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
|
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMeanFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INTEGER_TYPES);
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template int unsortedSegmentMeanFunctorBP_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), FLOAT_TYPES, INTEGER_TYPES);
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,423 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author GS <sgazeos@gmail.com>
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/segment.h>
|
||||||
|
#include <ops/declarable/helpers/segment_common.h>
|
||||||
|
#include <NDArrayFactory.h>
|
||||||
|
#include <helpers/ShapeUtils.h>
|
||||||
|
#include <helpers/TAD.h>
|
||||||
|
#include <exceptions/cuda_exception.h>
|
||||||
|
#include <PointersManager.h>
|
||||||
|
#include <ConstantTadHelper.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
// Segment ops linear kernels
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
|
||||||
|
template<typename T, typename I>
|
||||||
|
static __global__ void
|
||||||
|
segmentMinLinearKernel(void *input, Nd4jLong *inputShape, int *starts, int *lengths, Nd4jLong numOfClasses,
|
||||||
|
void *output, Nd4jLong *outputShape) {
|
||||||
|
__shared__
|
||||||
|
T *val;
|
||||||
|
__shared__
|
||||||
|
Nd4jLong xLen, zLen, segment, zIndex;
|
||||||
|
__shared__
|
||||||
|
T *x;
|
||||||
|
__shared__
|
||||||
|
T *z;
|
||||||
|
__shared__ int threadsPerSegment, start, finish;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses;
|
||||||
|
segment = blockIdx.x / threadsPerSegment;
|
||||||
|
x = reinterpret_cast<T *>(input);
|
||||||
|
z = reinterpret_cast<T *>(output);
|
||||||
|
extern __shared__ unsigned char shmem[];
|
||||||
|
val = reinterpret_cast<T *>(shmem);
|
||||||
|
xLen = shape::length(inputShape);
|
||||||
|
zLen = shape::length(outputShape);
|
||||||
|
|
||||||
|
if (segment < numOfClasses) {
|
||||||
|
zIndex = shape::getIndexOffset(segment, outputShape, zLen);
|
||||||
|
start = starts[segment];
|
||||||
|
finish = start + lengths[segment];
|
||||||
|
z[zIndex] = x[shape::getIndexOffset(start, inputShape, xLen)];
|
||||||
|
val[segment] = z[zIndex];
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (auto e = start + threadIdx.x + 1; e < finish; e += blockDim.x) {
|
||||||
|
auto xIndex = shape::getIndexOffset(e, inputShape, xLen);
|
||||||
|
nd4j::math::atomics::nd4j_atomicMin(&z[zIndex], x[xIndex]);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
|
||||||
|
template<typename T, typename I>
|
||||||
|
static __global__ void
|
||||||
|
unsortedSegmentMinLinearKernel(void *input, Nd4jLong *inputShape, void *indices, Nd4jLong *indicesShape,
|
||||||
|
int *starts, int *lengths, Nd4jLong numOfClasses, void *output,
|
||||||
|
Nd4jLong *outputShape) {
|
||||||
|
__shared__
|
||||||
|
T *val;
|
||||||
|
__shared__
|
||||||
|
Nd4jLong xLen, zLen, segment, zIndex;
|
||||||
|
__shared__
|
||||||
|
T *x;
|
||||||
|
__shared__
|
||||||
|
T *z;
|
||||||
|
__shared__
|
||||||
|
I *y; //int threadsPerSegment, start, finish;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
segment = blockIdx.x;
|
||||||
|
x = reinterpret_cast<T *>(input);
|
||||||
|
z = reinterpret_cast<T *>(output);
|
||||||
|
y = reinterpret_cast<I *>(indices);
|
||||||
|
xLen = shape::length(inputShape);
|
||||||
|
zLen = shape::length(outputShape);
|
||||||
|
|
||||||
|
zIndex = shape::getIndexOffset(segment, outputShape, zLen);
|
||||||
|
if (lengths[segment] > 0)
|
||||||
|
z[zIndex] = x[shape::getIndexOffset(starts[segment], inputShape, xLen)];
|
||||||
|
else
|
||||||
|
z[zIndex] = DataTypeUtils::max<T>();
|
||||||
|
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
if (lengths[segment] > 0)
|
||||||
|
for (auto e = threadIdx.x + 1; e < xLen; e += blockDim.x) {
|
||||||
|
auto xIndex = shape::getIndexOffset(e, inputShape, xLen);
|
||||||
|
auto yIndex = shape::getIndexOffset(e, indicesShape, xLen);
|
||||||
|
if (y[yIndex] == segment) {
|
||||||
|
nd4j::math::atomics::nd4j_atomicMin(&z[zIndex], x[xIndex]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
// SegmentMin kernel
|
||||||
|
template <typename T, typename I>
|
||||||
|
static __global__ void segmentMinTadKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong* inputTads, Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets) {
|
||||||
|
__shared__ T* val;
|
||||||
|
__shared__ Nd4jLong len, segment, zIndex, total;
|
||||||
|
__shared__ T* z;
|
||||||
|
__shared__ int threadsPerSegment, start, finish;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
segment = indices[blockIdx.x]; // / threadsPerSegment;
|
||||||
|
z = reinterpret_cast<T*>(outputBuf) + outputTadOffsets[segment];
|
||||||
|
len = shape::length(inputTads);
|
||||||
|
start = starts[segment];
|
||||||
|
finish = start + lengths[segment];
|
||||||
|
total = shape::sizeAt(inputShape, 0);
|
||||||
|
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
auto idx = blockIdx.x;
|
||||||
|
if (blockIdx.x <= total) {
|
||||||
|
auto x = reinterpret_cast<T *>(inputBuf) + inputTadOffsets[idx];
|
||||||
|
if (blockIdx.x == start) {
|
||||||
|
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
||||||
|
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
||||||
|
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
||||||
|
z[zIndex] = x[xIndex];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
||||||
|
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
||||||
|
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
||||||
|
nd4j::math::atomics::nd4j_atomicMin(&z[zIndex], x[xIndex]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
// segmen min
|
||||||
|
template <typename T, typename I>
|
||||||
|
static void segmentMinFunctor_(LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) {
|
||||||
|
auto stream = context->getCudaStream();
|
||||||
|
Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
|
||||||
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses});
|
||||||
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses});
|
||||||
|
|
||||||
|
classesRangesBegs.assign(indices->lengthOf());
|
||||||
|
classesRangesLens.assign(0);
|
||||||
|
|
||||||
|
fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens);
|
||||||
|
NDArray::prepareSpecialUse({output}, {input, indices, &classesRangesBegs, &classesRangesLens});
|
||||||
|
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
|
||||||
|
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
|
||||||
|
if (input->isVector()) {
|
||||||
|
segmentMinLinearKernel<T,I><<<numClasses, input->lengthOf(), numClasses * 32 + 32, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||||
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||||
|
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||||
|
Nd4jLong* inputTads = packX.specialShapeInfo();
|
||||||
|
Nd4jLong* inputTadOffsets = packX.specialOffsets();
|
||||||
|
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||||
|
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||||
|
segmentMinTadKernel<T,I><<<input->sizeAt(0), 512, 2048, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
|
||||||
|
|
||||||
|
}
|
||||||
|
NDArray::registerSpecialUse({output}, {input, indices, &classesRangesBegs, &classesRangesLens});
|
||||||
|
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
void segmentMinFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) {
|
||||||
|
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentMinFunctor_, (context, input, indices, output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||||
|
}
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template void segmentMinFunctor_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||||
|
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
|
||||||
|
template <typename T, typename I>
|
||||||
|
static void unsortedSegmentMinFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
|
auto stream = context->getCudaStream();
|
||||||
|
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
|
||||||
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses});
|
||||||
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
|
||||||
|
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
|
||||||
|
// classes.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), &row, &classes);
|
||||||
|
classesRangesBegs.assign(indices->lengthOf());
|
||||||
|
classesRangesLens.assign(0);
|
||||||
|
dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32);
|
||||||
|
// int* classesBuf = reinterpret_cast<int*>(classes.specialBuffer());
|
||||||
|
fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens);
|
||||||
|
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
|
||||||
|
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
|
||||||
|
NDArray::prepareSpecialUse({output}, {input, indices});
|
||||||
|
if (input->isVector()) {
|
||||||
|
unsortedSegmentMinLinearKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
output->assign(DataTypeUtils::max<T>());
|
||||||
|
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||||
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||||
|
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||||
|
Nd4jLong* inputTads = packX.specialShapeInfo();
|
||||||
|
Nd4jLong* inputTadOffsets = packX.specialOffsets();
|
||||||
|
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||||
|
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||||
|
dims.x = input->sizeAt(0);
|
||||||
|
segmentMinTadKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
|
||||||
|
}
|
||||||
|
NDArray::registerSpecialUse({output}, {input, indices});
|
||||||
|
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
void unsortedSegmentMinFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
|
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentMinFunctor_, (context, input, indices, numOfClasses, output),
|
||||||
|
NUMERIC_TYPES, INTEGER_TYPES);
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template void unsortedSegmentMinFunctor_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||||
|
template <typename T, typename I>
|
||||||
|
static __global__ void segmentMinBPLinearKernel(void* inputBuf, Nd4jLong* inputShape, void* forwardOutput,
|
||||||
|
Nd4jLong* forwardShape, void* eps, Nd4jLong* epsShape, void* indicesBuf, Nd4jLong* indicesShape,
|
||||||
|
void* outputBuf, Nd4jLong* outputShape) {
|
||||||
|
__shared__ T* x;
|
||||||
|
__shared__ T* gradIn;
|
||||||
|
__shared__ T* gradOut;
|
||||||
|
__shared__ I* y;
|
||||||
|
__shared__ T* z;
|
||||||
|
__shared__ Nd4jLong xLen, gradLen;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
xLen = shape::length(inputShape);
|
||||||
|
x = reinterpret_cast<T*>(inputBuf);
|
||||||
|
y = reinterpret_cast<I*>(indicesBuf);
|
||||||
|
z = reinterpret_cast<T*>(outputBuf);
|
||||||
|
gradIn = reinterpret_cast<T*>(forwardOutput);
|
||||||
|
gradOut = reinterpret_cast<T*>(eps);
|
||||||
|
gradLen = shape::length(epsShape);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto start = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
auto step = gridDim.x * blockDim.x;
|
||||||
|
|
||||||
|
for (auto e = start; e < xLen; e += step) {
|
||||||
|
|
||||||
|
auto zOffset = shape::getIndexOffset(e, outputShape, xLen);
|
||||||
|
auto xOffset = shape::getIndexOffset(e, inputShape, xLen);
|
||||||
|
auto yOffset = shape::getIndexOffset(e, indicesShape, xLen);
|
||||||
|
auto classIndex = y[yOffset];
|
||||||
|
auto gradOffsetI = shape::getIndexOffset(classIndex, forwardShape, gradLen);
|
||||||
|
auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape, gradLen);
|
||||||
|
|
||||||
|
if (nd4j::math::nd4j_abs(gradIn[gradOffsetI] - x[xOffset]) <= T(1.e-6)) {
|
||||||
|
z[zOffset] = gradOut[gradOffsetO];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
template <typename T, typename I>
|
||||||
|
static __global__ void segmentMinBPTadKernel(void* inputBuf, Nd4jLong* inputShape, void* forwardOutput,
|
||||||
|
Nd4jLong* forwardShape, void* eps, Nd4jLong* epsShape, void* indicesBuf, Nd4jLong* indicesShape,
|
||||||
|
void* outputBuf, Nd4jLong* outputShape,Nd4jLong* inputTad,
|
||||||
|
Nd4jLong* inputOffsets, Nd4jLong* gradInTad, Nd4jLong* gradInOffsets,
|
||||||
|
Nd4jLong* gradOutTad, Nd4jLong* gradOutOffsets, Nd4jLong* outTad,
|
||||||
|
Nd4jLong* outOffsets) {
|
||||||
|
__shared__ T* x;
|
||||||
|
__shared__ T* gradIn;
|
||||||
|
__shared__ T* gradOut;
|
||||||
|
__shared__ I* y;
|
||||||
|
__shared__ T* z;
|
||||||
|
__shared__ Nd4jLong xLen, yLen, gradLen, currentLen;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
xLen = shape::length(inputShape);
|
||||||
|
x = reinterpret_cast<T*>(inputBuf);
|
||||||
|
y = reinterpret_cast<I*>(indicesBuf);
|
||||||
|
z = reinterpret_cast<T*>(outputBuf);
|
||||||
|
yLen = shape::length(indicesShape);
|
||||||
|
gradOut = reinterpret_cast<T*>(eps);
|
||||||
|
gradIn = reinterpret_cast<T*>(forwardOutput);
|
||||||
|
gradLen = shape::length(epsShape);
|
||||||
|
currentLen = shape::length(outTad);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto i = blockIdx.x; i < yLen; i += gridDim.x) {
|
||||||
|
auto yIndex = shape::getIndexOffset(i, indicesShape, yLen);
|
||||||
|
auto segment = y[yIndex];
|
||||||
|
T* current = x + inputOffsets[i];
|
||||||
|
T* currentOut = z + outOffsets[i];
|
||||||
|
T* in = gradIn + gradInOffsets[segment];
|
||||||
|
T* outGrad = gradOut + gradOutOffsets[segment];
|
||||||
|
|
||||||
|
for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) {
|
||||||
|
if (nd4j::math::nd4j_abs(in[e] - current[e]) <= T(1.e-6))
|
||||||
|
currentOut[e] = outGrad[e];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
template <typename T, typename I>
|
||||||
|
int segmentMinFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
|
||||||
|
//int numOfClasses = gradOut->sizeAt(0);
|
||||||
|
// if input is a vector: (as if in doc sample)
|
||||||
|
auto stream = context->getCudaStream();
|
||||||
|
NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), DataTypeUtils::fromT<T>(), context);//->shapeInfo(), context);
|
||||||
|
segmentMinFunctor_<T, I>(context, input, indices, &tempRes);
|
||||||
|
NDArray::prepareSpecialUse({output}, {input, indices, gradOut, &tempRes});
|
||||||
|
if (input->isVector()) {
|
||||||
|
Nd4jLong loop_size = input->lengthOf();
|
||||||
|
auto numOfClasses = gradOut->lengthOf(); //indices->e<Nd4jLong>(loop_size - 1);
|
||||||
|
|
||||||
|
segmentMinBPLinearKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
|
||||||
|
tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
|
||||||
|
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||||
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||||
|
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||||
|
auto packGradIn = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.getShapeInfo(), dimensions);
|
||||||
|
auto packGradOut = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->getShapeInfo(), dimensions);
|
||||||
|
Nd4jLong* inputTads = packX.specialShapeInfo();
|
||||||
|
Nd4jLong* inputTadOffsets = packX.specialOffsets();
|
||||||
|
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||||
|
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||||
|
Nd4jLong* gradInTads = packGradIn.specialShapeInfo();
|
||||||
|
Nd4jLong* gradInTadOffsets = packGradIn.specialOffsets();
|
||||||
|
Nd4jLong* gradOutTads = packGradOut.specialShapeInfo();
|
||||||
|
Nd4jLong* gradOutTadOffsets = packGradOut.specialOffsets();
|
||||||
|
|
||||||
|
segmentMinBPTadKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
|
||||||
|
tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
|
||||||
|
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(),
|
||||||
|
inputTads, inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, gradOutTadOffsets,
|
||||||
|
outputTads, outputTadOffsets);
|
||||||
|
}
|
||||||
|
NDArray::registerSpecialUse({output}, {input, indices, gradOut, &tempRes});
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
// segmen min
|
||||||
|
int segmentMinFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
|
||||||
|
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentMinFunctorBP_, (context, input,
|
||||||
|
indices, gradOut, output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||||
|
}
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template int segmentMinFunctorBP_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
|
||||||
|
template <typename T, typename I>
|
||||||
|
static int unsortedSegmentMinFunctorBP_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
|
//int numOfClasses = gradOut->sizeAt(0);
|
||||||
|
// if input is a vector: (as if in doc sample)
|
||||||
|
auto stream = context->getCudaStream();
|
||||||
|
NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), DataTypeUtils::fromT<T>(), context);//->shapeInfo(), context);
|
||||||
|
unsortedSegmentMinFunctor_<T, I>(context, input, indices, numOfClasses, &tempRes);
|
||||||
|
NDArray::prepareSpecialUse({output}, {input, indices, gradOut, &tempRes});
|
||||||
|
if (input->isVector()) {
|
||||||
|
Nd4jLong loop_size = input->lengthOf();
|
||||||
|
auto numOfClasses = gradOut->lengthOf(); //indices->e<Nd4jLong>(loop_size - 1);
|
||||||
|
segmentMinBPLinearKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
|
||||||
|
tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
|
||||||
|
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||||
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||||
|
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||||
|
auto packGradIn = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.getShapeInfo(), dimensions);
|
||||||
|
auto packGradOut = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->getShapeInfo(), dimensions);
|
||||||
|
Nd4jLong* inputTads = packX.specialShapeInfo();
|
||||||
|
Nd4jLong* inputTadOffsets = packX.specialOffsets();
|
||||||
|
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||||
|
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||||
|
Nd4jLong* gradInTads = packGradIn.specialShapeInfo();
|
||||||
|
Nd4jLong* gradInTadOffsets = packGradIn.specialOffsets();
|
||||||
|
Nd4jLong* gradOutTads = packGradOut.specialShapeInfo();
|
||||||
|
Nd4jLong* gradOutTadOffsets = packGradOut.specialOffsets();
|
||||||
|
|
||||||
|
segmentMinBPTadKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
|
||||||
|
tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
|
||||||
|
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(),
|
||||||
|
inputTads, inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, gradOutTadOffsets,
|
||||||
|
outputTads, outputTadOffsets);
|
||||||
|
}
|
||||||
|
NDArray::registerSpecialUse({output}, {input, indices, gradOut, &tempRes});
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
int unsortedSegmentMinFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
|
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentMinFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template int unsortedSegmentMinFunctorBP_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,419 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author GS <sgazeos@gmail.com>
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/segment.h>
|
||||||
|
#include <ops/declarable/helpers/segment_common.h>
|
||||||
|
#include <NDArrayFactory.h>
|
||||||
|
#include <helpers/ShapeUtils.h>
|
||||||
|
#include <helpers/TAD.h>
|
||||||
|
#include <exceptions/cuda_exception.h>
|
||||||
|
#include <PointersManager.h>
|
||||||
|
#include <ConstantTadHelper.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
// Segment Prod ops linear kernels
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
|
||||||
|
template <typename T, typename I>
|
||||||
|
static __global__ void segmentProdLinearKernel(void* input, Nd4jLong* inputShape, int* starts, int* lengths, Nd4jLong numOfClasses, void* output, Nd4jLong* outputShape) {
|
||||||
|
__shared__ T* val;
|
||||||
|
__shared__ Nd4jLong xLen, zLen, segment, zIndex;
|
||||||
|
__shared__ T* x;
|
||||||
|
__shared__ T* z;
|
||||||
|
__shared__ int threadsPerSegment, start, finish;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses;
|
||||||
|
segment = blockIdx.x / threadsPerSegment;
|
||||||
|
x = reinterpret_cast<T*>(input);
|
||||||
|
z = reinterpret_cast<T*>(output);
|
||||||
|
extern __shared__ unsigned char shmem[];
|
||||||
|
val = reinterpret_cast<T*>(shmem);
|
||||||
|
xLen = shape::length(inputShape);
|
||||||
|
zLen = shape::length(outputShape);
|
||||||
|
|
||||||
|
if (segment < numOfClasses) {
|
||||||
|
zIndex = shape::getIndexOffset(segment, outputShape, zLen);
|
||||||
|
start = starts[segment];
|
||||||
|
finish = start + lengths[segment];
|
||||||
|
//val[segment] = ;
|
||||||
|
z[zIndex] = x[shape::getIndexOffset(start, inputShape, xLen)];
|
||||||
|
val[segment] = z[zIndex];
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
// auto tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
// auto step = blockDim.x * gridDim.x;
|
||||||
|
|
||||||
|
for (auto e = start + threadIdx.x + 1; e < finish; e += blockDim.x) {
|
||||||
|
auto xIndex = shape::getIndexOffset(e, inputShape, xLen);
|
||||||
|
nd4j::math::atomics::nd4j_atomicMul(&val[segment], x[xIndex]);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
z[zIndex] = val[segment];
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
template <typename T, typename I>
|
||||||
|
static __global__ void unsortedSegmentProdLinearKernel(void* input, Nd4jLong* inputShape, void* indices, Nd4jLong* indicesShape, int* starts, int* lengths, Nd4jLong numOfClasses, void* output, Nd4jLong* outputShape) {
|
||||||
|
__shared__ T* val;
|
||||||
|
__shared__ Nd4jLong xLen, zLen, segment, zIndex;
|
||||||
|
__shared__ T* x;
|
||||||
|
__shared__ T* z;
|
||||||
|
__shared__ I* y; //int threadsPerSegment, start, finish;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
// threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses;
|
||||||
|
segment = blockIdx.x;// / threadsPerSegment;
|
||||||
|
x = reinterpret_cast<T*>(input);
|
||||||
|
z = reinterpret_cast<T*>(output);
|
||||||
|
y = reinterpret_cast<I*>(indices);
|
||||||
|
// extern __shared__ unsigned char shmem[];
|
||||||
|
// val = reinterpret_cast<T*>(shmem);
|
||||||
|
xLen = shape::length(inputShape);
|
||||||
|
zLen = shape::length(outputShape);
|
||||||
|
|
||||||
|
// if (segment < numOfClasses) {
|
||||||
|
zIndex = shape::getIndexOffset(segment, outputShape, zLen);
|
||||||
|
//start = starts[segment];
|
||||||
|
//finish = start + lengths[segment];
|
||||||
|
if (lengths[segment] > 0)
|
||||||
|
z[zIndex] = x[shape::getIndexOffset(starts[segment], inputShape, xLen)];
|
||||||
|
else
|
||||||
|
z[zIndex] = 0; //DataTypeUtils::max<T>();
|
||||||
|
// val[segment] = z[zIndex];
|
||||||
|
// }
|
||||||
|
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
if (lengths[segment] > 0)
|
||||||
|
for (auto e = threadIdx.x; e < xLen; e += blockDim.x) {
|
||||||
|
auto xIndex = shape::getIndexOffset(e, inputShape, xLen);
|
||||||
|
auto yIndex = shape::getIndexOffset(e, indicesShape, xLen);
|
||||||
|
if (y[yIndex] == segment && e != starts[segment]) {
|
||||||
|
nd4j::math::atomics::nd4j_atomicMul(&z[zIndex], x[xIndex]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
// SegmentProd kernel
|
||||||
|
template <typename T, typename I>
|
||||||
|
static __global__ void segmentProdTadKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong* inputTads, Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets) {
|
||||||
|
__shared__ T* val;
|
||||||
|
__shared__ Nd4jLong len, segment, zIndex, total;
|
||||||
|
__shared__ T* z;
|
||||||
|
__shared__ int threadsPerSegment, start, finish;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
segment = indices[blockIdx.x]; // / threadsPerSegment;
|
||||||
|
z = reinterpret_cast<T*>(outputBuf) + outputTadOffsets[segment];
|
||||||
|
len = shape::length(inputTads);
|
||||||
|
start = starts[segment];
|
||||||
|
finish = start + lengths[segment];
|
||||||
|
total = shape::sizeAt(inputShape, 0);
|
||||||
|
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
auto idx = blockIdx.x;
|
||||||
|
if (blockIdx.x <= total) {
|
||||||
|
auto x = reinterpret_cast<T *>(inputBuf) + inputTadOffsets[idx];
|
||||||
|
if (blockIdx.x == start) {
|
||||||
|
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
||||||
|
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
||||||
|
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
||||||
|
z[zIndex] = x[xIndex];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
||||||
|
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
||||||
|
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
||||||
|
nd4j::math::atomics::nd4j_atomicMul(&z[zIndex], x[xIndex]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
|
||||||
|
template <typename T, typename I>
|
||||||
|
static void segmentProdFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) {
|
||||||
|
auto stream = context->getCudaStream();
|
||||||
|
Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
|
||||||
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses});
|
||||||
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses});
|
||||||
|
|
||||||
|
classesRangesBegs.assign(indices->lengthOf());
|
||||||
|
classesRangesLens.assign(0);
|
||||||
|
|
||||||
|
dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32);
|
||||||
|
fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens);
|
||||||
|
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
|
||||||
|
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
|
||||||
|
|
||||||
|
if (input->isVector()) {
|
||||||
|
segmentProdLinearKernel<T,I><<<numClasses, input->lengthOf(), numClasses * 32 + 32, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||||
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||||
|
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||||
|
Nd4jLong* inputTads = packX.specialShapeInfo();
|
||||||
|
Nd4jLong* inputTadOffsets = packX.specialOffsets();
|
||||||
|
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||||
|
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||||
|
segmentProdTadKernel<T,I><<<input->sizeAt(0), 512, 2048, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
void segmentProdFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) {
|
||||||
|
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), segmentProdFunctor_, (context, input, indices, output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||||
|
}
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template void segmentProdFunctor_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output), FLOAT_TYPES, INTEGER_TYPES);
|
||||||
|
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
template <typename T, typename I>
|
||||||
|
static void unsortedSegmentProdFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
|
auto stream = context->getCudaStream();
|
||||||
|
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
|
||||||
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses});
|
||||||
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
|
||||||
|
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
|
||||||
|
// classes.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), &row, &classes);
|
||||||
|
classesRangesBegs.assign(indices->lengthOf());
|
||||||
|
classesRangesLens.assign(0);
|
||||||
|
dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32);
|
||||||
|
// int* classesBuf = reinterpret_cast<int*>(classes.specialBuffer());
|
||||||
|
fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens);
|
||||||
|
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
|
||||||
|
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
|
||||||
|
|
||||||
|
if (input->isVector()) {
|
||||||
|
unsortedSegmentProdLinearKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
output->assign(1);
|
||||||
|
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||||
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||||
|
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||||
|
Nd4jLong* inputTads = packX.specialShapeInfo();
|
||||||
|
Nd4jLong* inputTadOffsets = packX.specialOffsets();
|
||||||
|
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||||
|
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||||
|
dims.x = input->sizeAt(0);
|
||||||
|
segmentProdTadKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
void unsortedSegmentProdFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
|
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentProdFunctor_, (context, input, indices, numOfClasses, output),
|
||||||
|
FLOAT_TYPES, INTEGER_TYPES);
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template void unsortedSegmentProdFunctor_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), FLOAT_TYPES, INTEGER_TYPES);
|
||||||
|
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
template <typename T, typename I>
|
||||||
|
static __global__ void segmentProdBPLinearKernel(void* inputBuf, Nd4jLong* inputShape, void* forwardOutput,
|
||||||
|
Nd4jLong* forwardShape, void* eps, Nd4jLong* epsShape, void* indicesBuf, Nd4jLong* indicesShape,
|
||||||
|
void* outputBuf, Nd4jLong* outputShape) {
|
||||||
|
__shared__ T* x;
|
||||||
|
__shared__ T* gradIn;
|
||||||
|
__shared__ T* gradOut;
|
||||||
|
__shared__ I* y;
|
||||||
|
__shared__ T* z;
|
||||||
|
__shared__ Nd4jLong xLen, gradLen;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
xLen = shape::length(inputShape);
|
||||||
|
x = reinterpret_cast<T*>(inputBuf);
|
||||||
|
y = reinterpret_cast<I*>(indicesBuf);
|
||||||
|
z = reinterpret_cast<T*>(outputBuf);
|
||||||
|
gradIn = reinterpret_cast<T*>(forwardOutput);
|
||||||
|
gradOut = reinterpret_cast<T*>(eps);
|
||||||
|
gradLen = shape::length(epsShape);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto start = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
auto step = gridDim.x * blockDim.x;
|
||||||
|
|
||||||
|
for (auto e = start; e < xLen; e += step) {
|
||||||
|
|
||||||
|
auto zOffset = shape::getIndexOffset(e, outputShape, xLen);
|
||||||
|
auto xOffset = shape::getIndexOffset(e, inputShape, xLen);
|
||||||
|
auto yOffset = shape::getIndexOffset(e, indicesShape, xLen);
|
||||||
|
auto classIndex = y[yOffset];
|
||||||
|
auto gradOffsetI = shape::getIndexOffset(classIndex, forwardShape, gradLen);
|
||||||
|
auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape, gradLen);
|
||||||
|
|
||||||
|
z[zOffset] = gradOut[gradOffsetO] * gradIn[gradOffsetI] / x[xOffset];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
template <typename T, typename I>
|
||||||
|
static __global__ void segmentProdBPTadKernel(void* inputBuf, Nd4jLong* inputShape, void* forwardOutput,
|
||||||
|
Nd4jLong* forwardShape, void* eps, Nd4jLong* epsShape, void* indicesBuf, Nd4jLong* indicesShape,
|
||||||
|
void* outputBuf, Nd4jLong* outputShape,Nd4jLong* inputTad,
|
||||||
|
Nd4jLong* inputOffsets, Nd4jLong* gradInTad, Nd4jLong* gradInOffsets,
|
||||||
|
Nd4jLong* gradOutTad, Nd4jLong* gradOutOffsets, Nd4jLong* outTad,
|
||||||
|
Nd4jLong* outOffsets) {
|
||||||
|
__shared__ T* x;
|
||||||
|
__shared__ T* gradIn;
|
||||||
|
__shared__ T* gradOut;
|
||||||
|
__shared__ I* y;
|
||||||
|
__shared__ T* z;
|
||||||
|
__shared__ Nd4jLong xLen, yLen, gradLen, currentLen;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
xLen = shape::length(inputShape);
|
||||||
|
x = reinterpret_cast<T*>(inputBuf);
|
||||||
|
y = reinterpret_cast<I*>(indicesBuf);
|
||||||
|
z = reinterpret_cast<T*>(outputBuf);
|
||||||
|
yLen = shape::length(indicesShape);
|
||||||
|
gradOut = reinterpret_cast<T*>(eps);
|
||||||
|
gradIn = reinterpret_cast<T*>(forwardOutput);
|
||||||
|
gradLen = shape::length(epsShape);
|
||||||
|
currentLen = shape::length(outTad);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto i = blockIdx.x; i < yLen; i += gridDim.x) {
|
||||||
|
auto yIndex = shape::getIndexOffset(i, indicesShape, yLen);
|
||||||
|
auto segment = y[yIndex];
|
||||||
|
T* current = x + inputOffsets[i];
|
||||||
|
T* currentOut = z + outOffsets[i];
|
||||||
|
T* in = gradIn + gradInOffsets[segment];
|
||||||
|
T* outGrad = gradOut + gradOutOffsets[segment];
|
||||||
|
|
||||||
|
for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) {
|
||||||
|
currentOut[e] = outGrad[e] * in[e] / current[e];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
template <typename T, typename I>
|
||||||
|
int segmentProdFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
|
||||||
|
auto stream = context->getCudaStream();
|
||||||
|
NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), DataTypeUtils::fromT<T>(), context);//->shapeInfo(), context);
|
||||||
|
segmentProdFunctor_<T, I>(context, input, indices, &tempRes);
|
||||||
|
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
|
||||||
|
if (input->isVector()) {
|
||||||
|
Nd4jLong loopSize = input->lengthOf();
|
||||||
|
auto numOfClasses = gradOut->lengthOf(); //indices->e<Nd4jLong>(loop_size - 1);
|
||||||
|
segmentProdBPLinearKernel<T,I><<<gradOut->lengthOf(), loopSize, 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
|
||||||
|
tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
|
||||||
|
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||||
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||||
|
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||||
|
auto packGradIn = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.getShapeInfo(), dimensions);
|
||||||
|
auto packGradOut = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->getShapeInfo(), dimensions);
|
||||||
|
Nd4jLong* inputTads = packX.specialShapeInfo();
|
||||||
|
Nd4jLong* inputTadOffsets = packX.specialOffsets();
|
||||||
|
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||||
|
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||||
|
Nd4jLong* gradInTads = packGradIn.specialShapeInfo();
|
||||||
|
Nd4jLong* gradInTadOffsets = packGradIn.specialOffsets();
|
||||||
|
Nd4jLong* gradOutTads = packGradOut.specialShapeInfo();
|
||||||
|
Nd4jLong* gradOutTadOffsets = packGradOut.specialOffsets();
|
||||||
|
|
||||||
|
segmentProdBPTadKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
|
||||||
|
tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
|
||||||
|
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(),
|
||||||
|
inputTads, inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, gradOutTadOffsets,
|
||||||
|
outputTads, outputTadOffsets);
|
||||||
|
}
|
||||||
|
NDArray::registerSpecialUse({output}, {input, indices, gradOut});
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
|
||||||
|
int segmentProdFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
|
||||||
|
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentProdFunctorBP_, (context, input,
|
||||||
|
indices, gradOut, output), FLOAT_TYPES, INTEGER_TYPES);
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template int segmentProdFunctorBP_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output), FLOAT_TYPES, INTEGER_TYPES);
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
|
||||||
|
template <typename T, typename I>
|
||||||
|
static int unsortedSegmentProdFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
|
auto stream = context->getCudaStream();
|
||||||
|
NDArray tempRes(gradOut->ordering(), gradOut->getShapeAsVector(), DataTypeUtils::fromT<T>(), context);//->shapeInfo(), context);
|
||||||
|
unsortedSegmentProdFunctor_<T, I>(context, input, indices, numOfClasses, &tempRes);
|
||||||
|
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
|
||||||
|
if (input->isVector()) {
|
||||||
|
Nd4jLong loopSize = input->lengthOf();
|
||||||
|
auto numOfClasses = gradOut->lengthOf(); //indices->e<Nd4jLong>(loop_size - 1);
|
||||||
|
segmentProdBPLinearKernel<T,I><<<gradOut->lengthOf(), loopSize, 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
|
||||||
|
tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
|
||||||
|
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||||
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||||
|
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||||
|
auto packGradIn = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.getShapeInfo(), dimensions);
|
||||||
|
auto packGradOut = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->getShapeInfo(), dimensions);
|
||||||
|
Nd4jLong* inputTads = packX.specialShapeInfo();
|
||||||
|
Nd4jLong* inputTadOffsets = packX.specialOffsets();
|
||||||
|
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||||
|
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||||
|
Nd4jLong* gradInTads = packGradIn.specialShapeInfo();
|
||||||
|
Nd4jLong* gradInTadOffsets = packGradIn.specialOffsets();
|
||||||
|
Nd4jLong* gradOutTads = packGradOut.specialShapeInfo();
|
||||||
|
Nd4jLong* gradOutTadOffsets = packGradOut.specialOffsets();
|
||||||
|
|
||||||
|
segmentProdBPTadKernel<T,I><<<indices->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
|
||||||
|
tempRes.specialBuffer(), tempRes.specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
|
||||||
|
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(),
|
||||||
|
inputTads, inputTadOffsets, gradInTads, gradInTadOffsets, gradOutTads, gradOutTadOffsets,
|
||||||
|
outputTads, outputTadOffsets);
|
||||||
|
}
|
||||||
|
NDArray::registerSpecialUse({output}, {input, indices, gradOut});
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
int unsortedSegmentProdFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
|
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentProdFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INTEGER_TYPES);
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template int unsortedSegmentProdFunctorBP_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), FLOAT_TYPES, INTEGER_TYPES);
|
||||||
|
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,280 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author GS <sgazeos@gmail.com>
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/segment.h>
|
||||||
|
#include <ops/declarable/helpers/segment_common.h>
|
||||||
|
#include <NDArrayFactory.h>
|
||||||
|
#include <helpers/ShapeUtils.h>
|
||||||
|
#include <helpers/TAD.h>
|
||||||
|
#include <exceptions/cuda_exception.h>
|
||||||
|
#include <PointersManager.h>
|
||||||
|
#include <ConstantTadHelper.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
template <typename T, typename I>
|
||||||
|
static __global__ void unsortedSegmentSqrtNLinearKernel(void* input, Nd4jLong* inputShape, void* indices, Nd4jLong* indicesShape, int* starts, int* lengths, Nd4jLong numOfClasses, void* output, Nd4jLong* outputShape) {
|
||||||
|
__shared__ T* val;
|
||||||
|
__shared__ Nd4jLong xLen, zLen, segment, zIndex;
|
||||||
|
__shared__ T* x;
|
||||||
|
__shared__ T* z;
|
||||||
|
__shared__ I* y; //int threadsPerSegment, start, finish;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
// threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses;
|
||||||
|
segment = blockIdx.x;// / threadsPerSegment;
|
||||||
|
x = reinterpret_cast<T*>(input);
|
||||||
|
z = reinterpret_cast<T*>(output);
|
||||||
|
y = reinterpret_cast<I*>(indices);
|
||||||
|
// extern __shared__ unsigned char shmem[];
|
||||||
|
// val = reinterpret_cast<T*>(shmem);
|
||||||
|
xLen = shape::length(inputShape);
|
||||||
|
zLen = shape::length(outputShape);
|
||||||
|
|
||||||
|
// if (segment < numOfClasses) {
|
||||||
|
zIndex = shape::getIndexOffset(segment, outputShape, zLen);
|
||||||
|
//start = starts[segment];
|
||||||
|
//finish = start + lengths[segment];
|
||||||
|
if (lengths[segment] > 0)
|
||||||
|
z[zIndex] = x[shape::getIndexOffset(starts[segment], inputShape, xLen)] / nd4j::math::nd4j_sqrt<int, T>(lengths[segment]);
|
||||||
|
else
|
||||||
|
z[zIndex] = 0; //DataTypeUtils::max<T>();
|
||||||
|
// val[segment] = z[zIndex];
|
||||||
|
// }
|
||||||
|
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
if (lengths[segment] > 0)
|
||||||
|
for (auto e = threadIdx.x + 1; e < xLen; e += blockDim.x) {
|
||||||
|
auto xIndex = shape::getIndexOffset(e, inputShape, xLen);
|
||||||
|
auto yIndex = shape::getIndexOffset(e, indicesShape, xLen);
|
||||||
|
if (y[yIndex] == segment && e != starts[segment]) {
|
||||||
|
nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex] / nd4j::math::nd4j_sqrt<int, T>(lengths[segment]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
// SegmentSqrtN kernel
|
||||||
|
template <typename T, typename I>
|
||||||
|
static __global__ void segmentSqrtNTadKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong* inputTads, Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets) {
|
||||||
|
__shared__ T* val;
|
||||||
|
__shared__ Nd4jLong len, segment, zIndex, total;
|
||||||
|
__shared__ T* z;
|
||||||
|
__shared__ int threadsPerSegment, start, finish;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
segment = indices[blockIdx.x]; // / threadsPerSegment;
|
||||||
|
z = reinterpret_cast<T*>(outputBuf) + outputTadOffsets[segment];
|
||||||
|
len = shape::length(inputTads);
|
||||||
|
start = starts[segment];
|
||||||
|
finish = start + lengths[segment];
|
||||||
|
total = shape::sizeAt(inputShape, 0);
|
||||||
|
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
auto idx = blockIdx.x;
|
||||||
|
if (blockIdx.x <= total) {
|
||||||
|
auto x = reinterpret_cast<T *>(inputBuf) + inputTadOffsets[idx];
|
||||||
|
if (blockIdx.x == start) {
|
||||||
|
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
||||||
|
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
||||||
|
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
||||||
|
z[zIndex] = x[xIndex] / nd4j::math::nd4j_sqrt<int, T>(lengths[segment]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
||||||
|
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
||||||
|
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
||||||
|
nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex] / nd4j::math::nd4j_sqrt<int, T>(lengths[segment]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
template <typename T, typename I>
|
||||||
|
static void unsortedSegmentSqrtNFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
|
auto stream = context->getCudaStream();
|
||||||
|
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
|
||||||
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses});
|
||||||
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
|
||||||
|
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
|
||||||
|
// classes.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), &row, &classes);
|
||||||
|
classesRangesBegs.assign(indices->lengthOf());
|
||||||
|
classesRangesLens.assign(0);
|
||||||
|
dim3 dims(numOfClasses, indices->lengthOf(), numOfClasses * 32 + 32);
|
||||||
|
// int* classesBuf = reinterpret_cast<int*>(classes.specialBuffer());
|
||||||
|
fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens);
|
||||||
|
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
|
||||||
|
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
|
||||||
|
|
||||||
|
if (input->isVector()) {
|
||||||
|
unsortedSegmentSqrtNLinearKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
output->assign(0);
|
||||||
|
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||||
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||||
|
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||||
|
Nd4jLong* inputTads = packX.specialShapeInfo();
|
||||||
|
Nd4jLong* inputTadOffsets = packX.specialOffsets();
|
||||||
|
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||||
|
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||||
|
dims.x = input->sizeAt(0);
|
||||||
|
segmentSqrtNTadKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
void unsortedSegmentSqrtNFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
|
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentSqrtNFunctor_, (context, input, indices, numOfClasses, output),
|
||||||
|
FLOAT_TYPES, INTEGER_TYPES);
|
||||||
|
}
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template void unsortedSegmentSqrtNFunctor_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), FLOAT_TYPES, INTEGER_TYPES);
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
template <typename T, typename I>
|
||||||
|
static __global__ void segmentSqrtNBPLinearKernel(void* inputBuf, Nd4jLong* inputShape, void* eps, Nd4jLong* epsShape, void* indicesBuf, Nd4jLong* indicesShape,
|
||||||
|
int* lengths, void* outputBuf, Nd4jLong* outputShape) {
|
||||||
|
__shared__ T* x;
|
||||||
|
__shared__ T* gradIn;
|
||||||
|
__shared__ T* gradOut;
|
||||||
|
__shared__ I* y;
|
||||||
|
__shared__ T* z;
|
||||||
|
__shared__ Nd4jLong xLen, gradLen;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
xLen = shape::length(inputShape);
|
||||||
|
x = reinterpret_cast<T*>(inputBuf);
|
||||||
|
y = reinterpret_cast<I*>(indicesBuf);
|
||||||
|
z = reinterpret_cast<T*>(outputBuf);
|
||||||
|
gradOut = reinterpret_cast<T*>(eps);
|
||||||
|
gradLen = shape::length(epsShape);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto start = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
auto step = gridDim.x * blockDim.x;
|
||||||
|
|
||||||
|
for (auto e = start; e < xLen; e += step) {
|
||||||
|
|
||||||
|
auto zOffset = shape::getIndexOffset(e, outputShape, xLen);
|
||||||
|
auto xOffset = shape::getIndexOffset(e, inputShape, xLen);
|
||||||
|
auto yOffset = shape::getIndexOffset(e, indicesShape, xLen);
|
||||||
|
auto classIndex = y[yOffset];
|
||||||
|
auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape, gradLen);
|
||||||
|
|
||||||
|
z[zOffset] = T(gradOut[gradOffsetO] / math::nd4j_sqrt<int, float>(lengths[classIndex]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
|
||||||
|
template <typename T, typename I>
|
||||||
|
static __global__ void segmentSqrtNBPTadKernel(void* inputBuf, Nd4jLong* inputShape, void* eps, Nd4jLong* epsShape,
|
||||||
|
void* indicesBuf, Nd4jLong* indicesShape, int* lengths, void* outputBuf, Nd4jLong* outputShape,Nd4jLong* inputTad,
|
||||||
|
Nd4jLong* inputOffsets, Nd4jLong* gradOutTad, Nd4jLong* gradOutOffsets, Nd4jLong* outTad, Nd4jLong* outOffsets) {
|
||||||
|
__shared__ T* x;
|
||||||
|
__shared__ T* gradOut;
|
||||||
|
__shared__ I* y;
|
||||||
|
__shared__ T* z;
|
||||||
|
__shared__ Nd4jLong xLen, yLen, gradLen, currentLen;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
xLen = shape::length(inputShape);
|
||||||
|
x = reinterpret_cast<T*>(inputBuf);
|
||||||
|
y = reinterpret_cast<I*>(indicesBuf);
|
||||||
|
z = reinterpret_cast<T*>(outputBuf);
|
||||||
|
yLen = shape::length(indicesShape);
|
||||||
|
gradOut = reinterpret_cast<T*>(eps);
|
||||||
|
gradLen = shape::length(epsShape);
|
||||||
|
currentLen = shape::length(outTad);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (auto i = blockIdx.x; i < yLen; i += gridDim.x) {
|
||||||
|
// auto yIndex = shape::getIndexOffset(i, indicesShape, yLen);
|
||||||
|
auto segment = y[i]; //yIndex];
|
||||||
|
T* currentOut = z + outOffsets[i];
|
||||||
|
T* outGrad = gradOut + gradOutOffsets[segment];
|
||||||
|
|
||||||
|
for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) {
|
||||||
|
auto zIndex = shape::getIndexOffset(e, outTad, currentLen);
|
||||||
|
auto gradIndex = shape::getIndexOffset(e, gradOutTad, gradLen);
|
||||||
|
if (lengths[segment] > 0)
|
||||||
|
currentOut[zIndex] = T(outGrad[gradIndex] / math::nd4j_sqrt<int, float>(lengths[segment]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
|
||||||
|
template <typename T, typename I>
|
||||||
|
static int unsortedSegmentSqrtNFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
|
auto stream = context->getCudaStream();
|
||||||
|
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
|
||||||
|
auto numClasses = indices->e<int>(indices->lengthOf() - 1) + 1;
|
||||||
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses});
|
||||||
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses});
|
||||||
|
|
||||||
|
classesRangesBegs.assign(indices->lengthOf());
|
||||||
|
classesRangesLens.assign(0);
|
||||||
|
dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32);
|
||||||
|
fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens);
|
||||||
|
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
|
||||||
|
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
|
||||||
|
|
||||||
|
if (input->isVector()) {
|
||||||
|
Nd4jLong loop_size = input->lengthOf();
|
||||||
|
auto numOfClasses = gradOut->lengthOf(); //indices->e<Nd4jLong>(loop_size - 1);
|
||||||
|
segmentSqrtNBPLinearKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(),
|
||||||
|
input->specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
|
||||||
|
indices->specialBuffer(), indices->specialShapeInfo(), lengths, output->specialBuffer(), output->specialShapeInfo());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||||
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||||
|
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||||
|
// auto packGradIn = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempRes.getShapeInfo(), dimensions);
|
||||||
|
auto packGradOut = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->getShapeInfo(), dimensions);
|
||||||
|
Nd4jLong* inputTads = packX.specialShapeInfo();
|
||||||
|
Nd4jLong* inputTadOffsets = packX.specialOffsets();
|
||||||
|
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||||
|
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||||
|
Nd4jLong* gradOutTads = packGradOut.specialShapeInfo();
|
||||||
|
Nd4jLong* gradOutTadOffsets = packGradOut.specialOffsets();
|
||||||
|
|
||||||
|
segmentSqrtNBPTadKernel<T,I><<<indices->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
|
||||||
|
gradOut->specialBuffer(), gradOut->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), lengths,
|
||||||
|
output->specialBuffer(), output->specialShapeInfo(), inputTads, inputTadOffsets, gradOutTads, gradOutTadOffsets,
|
||||||
|
outputTads, outputTadOffsets);
|
||||||
|
}
|
||||||
|
NDArray::registerSpecialUse({output}, {input, indices, gradOut});
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
int unsortedSegmentSqrtNFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
|
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentSqrtNFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), FLOAT_TYPES, INTEGER_TYPES);
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template int unsortedSegmentSqrtNFunctorBP_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), FLOAT_TYPES, INTEGER_TYPES);
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,393 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author GS <sgazeos@gmail.com>
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/segment.h>
|
||||||
|
#include <ops/declarable/helpers/segment_common.h>
|
||||||
|
#include <NDArrayFactory.h>
|
||||||
|
#include <helpers/ShapeUtils.h>
|
||||||
|
#include <helpers/TAD.h>
|
||||||
|
#include <exceptions/cuda_exception.h>
|
||||||
|
#include <PointersManager.h>
|
||||||
|
#include <ConstantTadHelper.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
// Segment ops linear kernels
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
template<typename T, typename I>
|
||||||
|
static __global__ void
|
||||||
|
segmentSumLinearKernel(void *input, Nd4jLong *inputShape, int *starts, int *lengths, Nd4jLong numOfClasses,
|
||||||
|
void *output, Nd4jLong *outputShape) {
|
||||||
|
__shared__
|
||||||
|
T *val;
|
||||||
|
__shared__
|
||||||
|
Nd4jLong xLen, zLen, segment, zIndex;
|
||||||
|
__shared__
|
||||||
|
T *x;
|
||||||
|
__shared__
|
||||||
|
T *z;
|
||||||
|
__shared__ int threadsPerSegment, start, finish;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
threadsPerSegment = (gridDim.x + numOfClasses - 1) / numOfClasses;
|
||||||
|
segment = blockIdx.x / threadsPerSegment;
|
||||||
|
x = reinterpret_cast<T *>(input);
|
||||||
|
z = reinterpret_cast<T *>(output);
|
||||||
|
|
||||||
|
xLen = shape::length(inputShape);
|
||||||
|
zLen = shape::length(outputShape);
|
||||||
|
|
||||||
|
|
||||||
|
if (segment < numOfClasses) {
|
||||||
|
zIndex = shape::getIndexOffset(segment, outputShape, zLen);
|
||||||
|
start = starts[segment];
|
||||||
|
finish = start + lengths[segment];
|
||||||
|
//val[segment] = ;
|
||||||
|
z[zIndex] = x[shape::getIndexOffset(start, inputShape, xLen)];
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (auto e = start + threadIdx.x + 1; e < finish; e += blockDim.x) {
|
||||||
|
auto xIndex = shape::getIndexOffset(e, inputShape, xLen);
|
||||||
|
nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
|
||||||
|
template<typename T, typename I>
|
||||||
|
static __global__ void
|
||||||
|
unsortedSegmentSumLinearKernel(void *input, Nd4jLong *inputShape, void *indices, Nd4jLong *indicesShape,
|
||||||
|
int *starts, int *lengths, Nd4jLong numOfClasses, void *output,
|
||||||
|
Nd4jLong *outputShape) {
|
||||||
|
__shared__
|
||||||
|
T *val;
|
||||||
|
__shared__
|
||||||
|
Nd4jLong xLen, zLen, segment, zIndex;
|
||||||
|
__shared__
|
||||||
|
T *x;
|
||||||
|
__shared__
|
||||||
|
T *z;
|
||||||
|
__shared__
|
||||||
|
I *y; //int threadsPerSegment, start, finish;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
segment = blockIdx.x;
|
||||||
|
x = reinterpret_cast<T *>(input);
|
||||||
|
z = reinterpret_cast<T *>(output);
|
||||||
|
y = reinterpret_cast<I *>(indices);
|
||||||
|
xLen = shape::length(inputShape);
|
||||||
|
zLen = shape::length(outputShape);
|
||||||
|
|
||||||
|
zIndex = shape::getIndexOffset(segment, outputShape, zLen);
|
||||||
|
if (lengths[segment] > 0)
|
||||||
|
z[zIndex] = x[shape::getIndexOffset(starts[segment], inputShape, xLen)];
|
||||||
|
else
|
||||||
|
z[zIndex] = 0; //DataTypeUtils::max<T>();
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
if (lengths[segment] > 0)
|
||||||
|
for (auto e = threadIdx.x; e < xLen; e += blockDim.x) {
|
||||||
|
auto xIndex = shape::getIndexOffset(e, inputShape, xLen);
|
||||||
|
auto yIndex = shape::getIndexOffset(e, indicesShape, xLen);
|
||||||
|
if (y[yIndex] == segment && e != starts[segment]) {
|
||||||
|
nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
// SegmentSum kernel
|
||||||
|
template <typename T, typename I>
|
||||||
|
static __global__ void segmentSumTadKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong* inputTads, Nd4jLong* inputTadOffsets, I* indices, int* starts, int* lengths, Nd4jLong numOfClasses, void* outputBuf, Nd4jLong* outputShape, Nd4jLong* outputTads, Nd4jLong* outputTadOffsets) {
|
||||||
|
__shared__ T* val;
|
||||||
|
__shared__ Nd4jLong len, segment, zIndex, total;
|
||||||
|
__shared__ T* z;
|
||||||
|
__shared__ int threadsPerSegment, start, finish;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
segment = indices[blockIdx.x]; // / threadsPerSegment;
|
||||||
|
z = reinterpret_cast<T*>(outputBuf) + outputTadOffsets[segment];
|
||||||
|
len = shape::length(inputTads);
|
||||||
|
start = starts[segment];
|
||||||
|
finish = start + lengths[segment];
|
||||||
|
total = shape::sizeAt(inputShape, 0);
|
||||||
|
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
auto idx = blockIdx.x;
|
||||||
|
if (blockIdx.x <= total) {
|
||||||
|
auto x = reinterpret_cast<T *>(inputBuf) + inputTadOffsets[idx];
|
||||||
|
if (blockIdx.x == start) {
|
||||||
|
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
||||||
|
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
||||||
|
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
||||||
|
z[zIndex] = x[xIndex];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
for (auto e = threadIdx.x; e < len; e += blockDim.x) {
|
||||||
|
auto xIndex = shape::getIndexOffset(e, inputTads, len);
|
||||||
|
auto zIndex = shape::getIndexOffset(e, outputTads, len);
|
||||||
|
if (lengths[segment])
|
||||||
|
nd4j::math::atomics::nd4j_atomicAdd(&z[zIndex], x[xIndex]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
|
||||||
|
template <typename T, typename I>
|
||||||
|
static void segmentSumFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output) {
|
||||||
|
auto stream = context->getCudaStream();
|
||||||
|
Nd4jLong numClasses = indices->e<Nd4jLong>(indices->lengthOf() - 1) + 1;
|
||||||
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numClasses});
|
||||||
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numClasses});
|
||||||
|
|
||||||
|
classesRangesBegs.assign(indices->lengthOf());
|
||||||
|
classesRangesLens.assign(0);
|
||||||
|
|
||||||
|
dim3 dims(numClasses, indices->lengthOf(), numClasses * 32 + 32);
|
||||||
|
fillUpSegments(indices, numClasses, classesRangesBegs, classesRangesLens);
|
||||||
|
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
|
||||||
|
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
|
||||||
|
|
||||||
|
if (input->isVector()) {
|
||||||
|
segmentSumLinearKernel<T,I><<<numClasses, input->lengthOf(), numClasses * 32 + 32, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||||
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||||
|
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||||
|
Nd4jLong* inputTads = packX.specialShapeInfo();
|
||||||
|
Nd4jLong* inputTadOffsets = packX.specialOffsets();
|
||||||
|
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||||
|
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||||
|
segmentSumTadKernel<T,I><<<input->sizeAt(0), 512, 2048, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
void segmentSumFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* output) {
|
||||||
|
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), segmentSumFunctor_, (context, input, indices, output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||||
|
}
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template void segmentSumFunctor_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
template <typename T, typename I>
|
||||||
|
static void unsortedSegmentSumFunctor_(nd4j::LaunchContext* context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
|
auto stream = context->getCudaStream();
|
||||||
|
// NDArray classes = NDArrayFactory::create<int>('c', {numOfClasses, 2});
|
||||||
|
NDArray classesRangesBegs = NDArrayFactory::create<int>('c', {numOfClasses});
|
||||||
|
NDArray classesRangesLens = NDArrayFactory::create<int>('c', {numOfClasses});
|
||||||
|
// NDArray row = NDArrayFactory::create<int>('c', {1, 2}, {(int)indices->lengthOf(), (int)0});
|
||||||
|
// classes.applyTrueBroadcast(nd4j::BroadcastOpsTuple::Assign(), &row, &classes);
|
||||||
|
classesRangesBegs.assign(indices->lengthOf());
|
||||||
|
classesRangesLens.assign(0);
|
||||||
|
dim3 dims(numOfClasses, indices->lengthOf(), (numOfClasses + 1) * 64);
|
||||||
|
// int* classesBuf = reinterpret_cast<int*>(classes.specialBuffer());
|
||||||
|
fillUpSegments(indices, numOfClasses, classesRangesBegs, classesRangesLens);
|
||||||
|
int* begins = reinterpret_cast<int*>(classesRangesBegs.specialBuffer());
|
||||||
|
int* lengths = reinterpret_cast<int*>(classesRangesLens.specialBuffer());
|
||||||
|
|
||||||
|
if (input->isVector()) {
|
||||||
|
unsortedSegmentSumLinearKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
output->assign(0);
|
||||||
|
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||||
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||||
|
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||||
|
Nd4jLong* inputTads = packX.specialShapeInfo();
|
||||||
|
Nd4jLong* inputTadOffsets = packX.specialOffsets();
|
||||||
|
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||||
|
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||||
|
dims.x = input->sizeAt(0);
|
||||||
|
segmentSumTadKernel<T,I><<<dims.x, dims.y, dims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), inputTads, inputTadOffsets, reinterpret_cast<I*>(indices->specialBuffer()), begins, lengths, numOfClasses, output->specialBuffer(), output->specialShapeInfo(), outputTads, outputTadOffsets);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
void unsortedSegmentSumFunctor(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
|
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), unsortedSegmentSumFunctor_, (context, input, indices, numOfClasses, output),
|
||||||
|
NUMERIC_TYPES, INTEGER_TYPES);
|
||||||
|
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template void unsortedSegmentSumFunctor_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||||
|
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
// Backpropagate ops
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
// Sorted sum backpropagate
|
||||||
|
template <typename T, typename I>
|
||||||
|
static __global__ void segmentSumBPLinearKernel(void* inputBuf, Nd4jLong* inputShape, void* eps, Nd4jLong* epsShape,
|
||||||
|
void* indicesBuf, Nd4jLong* indicesShape, void* outputBuf, Nd4jLong* outputShape) {
|
||||||
|
__shared__ T* x;
|
||||||
|
__shared__ T* gradIn;
|
||||||
|
__shared__ T* gradOut;
|
||||||
|
__shared__ I* y;
|
||||||
|
__shared__ T* z;
|
||||||
|
__shared__ Nd4jLong xLen, gradLen;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
xLen = shape::length(inputShape);
|
||||||
|
x = reinterpret_cast<T*>(inputBuf);
|
||||||
|
y = reinterpret_cast<I*>(indicesBuf);
|
||||||
|
z = reinterpret_cast<T*>(outputBuf);
|
||||||
|
gradOut = reinterpret_cast<T*>(eps);
|
||||||
|
gradLen = shape::length(epsShape);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto start = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
auto step = gridDim.x * blockDim.x;
|
||||||
|
|
||||||
|
for (auto e = start; e < xLen; e += step) {
|
||||||
|
|
||||||
|
auto zOffset = shape::getIndexOffset(e, outputShape, xLen);
|
||||||
|
auto xOffset = shape::getIndexOffset(e, inputShape, xLen);
|
||||||
|
auto yOffset = shape::getIndexOffset(e, indicesShape, xLen);
|
||||||
|
auto classIndex = y[yOffset];
|
||||||
|
auto gradOffsetO = shape::getIndexOffset(classIndex, epsShape, gradLen);
|
||||||
|
|
||||||
|
z[zOffset] = gradOut[gradOffsetO];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
template <typename T, typename I>
|
||||||
|
static __global__ void segmentSumBPTadKernel(void* inputBuf, Nd4jLong* inputShape, void* eps, Nd4jLong* epsShape,
|
||||||
|
void* indicesBuf, Nd4jLong* indicesShape, void* outputBuf, Nd4jLong* outputShape, Nd4jLong* inputTad,
|
||||||
|
Nd4jLong* inputOffsets, Nd4jLong* gradOutTad, Nd4jLong* gradOutOffsets, Nd4jLong* outTad, Nd4jLong* outOffsets) {
|
||||||
|
__shared__ T* x;
|
||||||
|
__shared__ T* gradOut;
|
||||||
|
__shared__ I* y;
|
||||||
|
__shared__ T* z;
|
||||||
|
__shared__ Nd4jLong xLen, yLen, gradLen, currentLen;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
xLen = shape::length(inputShape);
|
||||||
|
x = reinterpret_cast<T*>(inputBuf);
|
||||||
|
y = reinterpret_cast<I*>(indicesBuf);
|
||||||
|
z = reinterpret_cast<T*>(outputBuf);
|
||||||
|
yLen = shape::length(indicesShape);
|
||||||
|
gradOut = reinterpret_cast<T*>(eps);
|
||||||
|
gradLen = shape::length(epsShape);
|
||||||
|
currentLen = shape::length(outTad);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto i = blockIdx.x; i < yLen; i += gridDim.x) {
|
||||||
|
auto yIndex = shape::getIndexOffset(i, indicesShape, yLen);
|
||||||
|
auto segment = y[yIndex];
|
||||||
|
T* currentOut = z + outOffsets[i];
|
||||||
|
T* outGrad = gradOut + gradOutOffsets[segment];
|
||||||
|
|
||||||
|
for (auto e = threadIdx.x; e < currentLen; e += blockDim.x) {
|
||||||
|
currentOut[e] = outGrad[e];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
template <typename T, typename I>
|
||||||
|
int segmentSumFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
|
||||||
|
auto stream = context->getCudaStream();
|
||||||
|
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
|
||||||
|
if (input->isVector()) {
|
||||||
|
Nd4jLong loop_size = input->lengthOf();
|
||||||
|
auto numOfClasses = gradOut->lengthOf(); //indices->e<Nd4jLong>(loop_size - 1);
|
||||||
|
segmentSumBPLinearKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(),
|
||||||
|
input->specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
|
||||||
|
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||||
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||||
|
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||||
|
auto packGradOut = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->getShapeInfo(), dimensions);
|
||||||
|
Nd4jLong* inputTads = packX.specialShapeInfo();
|
||||||
|
Nd4jLong* inputTadOffsets = packX.specialOffsets();
|
||||||
|
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||||
|
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||||
|
Nd4jLong* gradOutTads = packGradOut.specialShapeInfo();
|
||||||
|
Nd4jLong* gradOutTadOffsets = packGradOut.specialOffsets();
|
||||||
|
|
||||||
|
segmentSumBPTadKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
|
||||||
|
gradOut->specialBuffer(), gradOut->specialShapeInfo(),
|
||||||
|
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(),
|
||||||
|
inputTads, inputTadOffsets, gradOutTads, gradOutTadOffsets,
|
||||||
|
outputTads, outputTadOffsets);
|
||||||
|
}
|
||||||
|
NDArray::registerSpecialUse({output}, {input, indices, gradOut});
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
|
||||||
|
int segmentSumFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
|
||||||
|
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return segmentSumFunctorBP_, (context, input,
|
||||||
|
indices, gradOut, output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||||
|
}
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template int segmentSumFunctorBP_, (nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
|
||||||
|
template <typename T, typename I>
|
||||||
|
static int unsortedSegmentSumFunctorBP_(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
|
auto stream = context->getCudaStream();
|
||||||
|
NDArray::prepareSpecialUse({output}, {input, indices, gradOut});
|
||||||
|
if (input->isVector()) {
|
||||||
|
Nd4jLong loop_size = input->lengthOf();
|
||||||
|
auto numOfClasses = gradOut->lengthOf(); //indices->e<Nd4jLong>(loop_size - 1);
|
||||||
|
segmentSumBPLinearKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(),
|
||||||
|
input->specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(),
|
||||||
|
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
|
||||||
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||||
|
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||||
|
auto packGradOut = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(gradOut->getShapeInfo(), dimensions);
|
||||||
|
Nd4jLong* inputTads = packX.specialShapeInfo();
|
||||||
|
Nd4jLong* inputTadOffsets = packX.specialOffsets();
|
||||||
|
Nd4jLong* outputTads = packZ.specialShapeInfo();
|
||||||
|
Nd4jLong* outputTadOffsets = packZ.specialOffsets();
|
||||||
|
Nd4jLong* gradOutTads = packGradOut.specialShapeInfo();
|
||||||
|
Nd4jLong* gradOutTadOffsets = packGradOut.specialOffsets();
|
||||||
|
|
||||||
|
segmentSumBPTadKernel<T,I><<<gradOut->lengthOf(), input->lengthOf(), 256, *stream>>>(input->specialBuffer(), input->specialShapeInfo(),
|
||||||
|
gradOut->specialBuffer(), gradOut->specialShapeInfo(),
|
||||||
|
indices->specialBuffer(), indices->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(),
|
||||||
|
inputTads, inputTadOffsets, gradOutTads, gradOutTadOffsets,
|
||||||
|
outputTads, outputTadOffsets);
|
||||||
|
}
|
||||||
|
NDArray::registerSpecialUse({output}, {input, indices, gradOut});
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
int unsortedSegmentSumFunctorBP(nd4j::LaunchContext* context , NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
|
||||||
|
BUILD_DOUBLE_SELECTOR(output->dataType(), indices->dataType(), return unsortedSegmentSumFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||||
|
}
|
||||||
|
// -------------------------------------------------------------------------------------------------------------- //
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template int unsortedSegmentSumFunctorBP_, (nd4j::LaunchContext* context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES, INTEGER_TYPES);
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -24,16 +24,40 @@ namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
|
|
||||||
template <typename T>
|
template <typename I, typename B>
|
||||||
static void sequenceMask_(NDArray* input, NDArray* output, int maxIndex) {
|
static __global__ void sequenceMaskKernel(void* inputBuf, Nd4jLong* inputShape, void* outputBuf, Nd4jLong* outputShape, int maxIndex) {
|
||||||
//
|
|
||||||
|
__shared__ I* input;
|
||||||
|
__shared__ B* output;
|
||||||
|
__shared__ Nd4jLong inputLen, outputLen;
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
input = reinterpret_cast<I*>(inputBuf);
|
||||||
|
output = reinterpret_cast<B*>(outputBuf);
|
||||||
|
inputLen = shape::length(inputShape);
|
||||||
|
outputLen = shape::length(outputShape);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto i = blockIdx.x; i < maxIndex; i += gridDim.x)
|
||||||
|
for(auto k = threadIdx.x; k < inputLen; k += blockDim.x)
|
||||||
|
if (i < input[shape::getIndexOffset(k, inputShape, inputLen)])
|
||||||
|
output[shape::getIndexOffset(k * maxIndex + i, outputShape, outputLen)] = B(true);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename I, typename B>
|
||||||
|
static void sequenceMask_(LaunchContext* context, NDArray* input, NDArray* output, int maxIndex) {
|
||||||
|
dim3 launchDims(maxIndex, input->lengthOf(), 128);
|
||||||
|
NDArray::prepareSpecialUse({output}, {input});
|
||||||
|
auto stream = context->getCudaStream();
|
||||||
|
sequenceMaskKernel<I, B><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), maxIndex);
|
||||||
|
NDArray::registerSpecialUse({output}, {input});
|
||||||
}
|
}
|
||||||
|
|
||||||
void sequenceMask(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int maxIndex) {
|
void sequenceMask(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int maxIndex) {
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), sequenceMask_, (input, output, maxIndex), LIBND4J_TYPES);
|
BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sequenceMask_, (context, input, output, maxIndex), INTEGER_TYPES, BOOL_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void sequenceMask_, (NDArray* input, NDArray* output, int maxIndex), LIBND4J_TYPES);
|
BUILD_DOUBLE_TEMPLATE(template void sequenceMask_, (nd4j::LaunchContext* context, NDArray* input, NDArray* output, int maxIndex), INTEGER_TYPES, BOOL_TYPES);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -456,6 +456,102 @@ void tileBP(nd4j::LaunchContext * context, const NDArray& gradO /*input*/, NDArr
|
||||||
manager.synchronize();
|
manager.synchronize();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
__global__ static void scatterUpdateCuda(const int opCode, const int numOfInd,
|
||||||
|
void* vx, const Nd4jLong *xShapeInfo, const Nd4jLong *xOffsets,
|
||||||
|
void* vy, const Nd4jLong *yShapeInfo, const Nd4jLong *yOffsets,
|
||||||
|
const int* indexes) {
|
||||||
|
|
||||||
|
__shared__ T *x, *y;
|
||||||
|
__shared__ Nd4jLong arrLenX, arrLenY;
|
||||||
|
|
||||||
|
for (int e = 0; e < numOfInd; e++ ) {
|
||||||
|
|
||||||
|
const auto xIndex = indexes[e];
|
||||||
|
const bool isOwner = xIndex < gridDim.x ? blockIdx.x == xIndex : blockIdx.x == xIndex % gridDim.x;
|
||||||
|
|
||||||
|
if (!isOwner)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
x = reinterpret_cast<T*>(vx) + xOffsets[xIndex];
|
||||||
|
y = reinterpret_cast<T*>(vy) + yOffsets[e];
|
||||||
|
arrLenX = shape::length(xShapeInfo);
|
||||||
|
arrLenY = shape::length(yShapeInfo);
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
if (arrLenX != arrLenY)
|
||||||
|
return;
|
||||||
|
|
||||||
|
for (Nd4jLong i = threadIdx.x; i < arrLenX; i += blockDim.x) {
|
||||||
|
|
||||||
|
const auto xOffset = shape::getIndexOffset(i, xShapeInfo, arrLenX);
|
||||||
|
const auto yOffset = shape::getIndexOffset(i, yShapeInfo, arrLenY);
|
||||||
|
|
||||||
|
switch (opCode) {
|
||||||
|
case 0:
|
||||||
|
x[xOffset] += y[yOffset];
|
||||||
|
break;
|
||||||
|
case 1:
|
||||||
|
x[xOffset] -= y[yOffset];
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
x[xOffset] *= y[yOffset];
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
x[xOffset] /= y[yOffset];
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
x[xOffset] = y[yOffset] - x[xOffset];
|
||||||
|
break;
|
||||||
|
case 5:
|
||||||
|
x[xOffset] = y[yOffset] / x[xOffset];
|
||||||
|
break;
|
||||||
|
case 6:
|
||||||
|
x[xOffset] = y[yOffset];
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
__host__ static void scatterUpdateCudaLauncher(const cudaStream_t* stream, const int opCode, const int numOfInd, void* vx, const Nd4jLong *xShapeInfo, const Nd4jLong *xOffsets, void* vy, const Nd4jLong *yShapeInfo, const Nd4jLong *yOffsets, const int* indexes) {
|
||||||
|
|
||||||
|
scatterUpdateCuda<T><<<512, 256, MAX_NUM_THREADS, *stream>>>(opCode, numOfInd, vx, xShapeInfo, xOffsets, vy, yShapeInfo, yOffsets, indexes);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
void scatterUpdate(nd4j::LaunchContext* context, NDArray& input, NDArray& updates, const std::vector<int>* intArgs) {
|
||||||
|
|
||||||
|
const int opCode = (*intArgs)[0];
|
||||||
|
const int numOfDims = (*intArgs)[1];
|
||||||
|
const int numOfInd = (*intArgs)[2 + numOfDims];
|
||||||
|
|
||||||
|
std::vector<int> tadDimensions(numOfDims);
|
||||||
|
for (int e = 2; e < 2 + numOfDims; e++)
|
||||||
|
tadDimensions[e-2] = (*intArgs)[e];
|
||||||
|
|
||||||
|
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), tadDimensions);
|
||||||
|
auto packY = ConstantTadHelper::getInstance()->tadForDimensions(updates.getShapeInfo(), tadDimensions);
|
||||||
|
|
||||||
|
NDArray indices(const_cast<int*>(intArgs->data()) + numOfDims + 3, 'c', {numOfInd}, nd4j::DataType::INT32, context);
|
||||||
|
|
||||||
|
PointersManager manager(context, "scatterUpdate");
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&input}, {&input, &updates, &indices});
|
||||||
|
BUILD_SINGLE_SELECTOR(input.dataType(), scatterUpdateCudaLauncher, (context->getCudaStream(), opCode, numOfInd, input.specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), updates.specialBuffer(), packY.platformShapeInfo(), packY.platformOffsets(), reinterpret_cast<int*>(indices.getSpecialBuffer())), LIBND4J_TYPES);
|
||||||
|
NDArray::registerSpecialUse({&input}, {&input, &updates, &indices});
|
||||||
|
|
||||||
|
manager.synchronize();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -466,17 +562,140 @@ void tileBP(nd4j::LaunchContext * context, const NDArray& gradO /*input*/, NDArr
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static __global__ void swapShuffleKernel(T* input, Nd4jLong* shape, Nd4jLong firstDim, Nd4jLong len, nd4j::graph::RandomGenerator* rng) {
|
||||||
|
auto tid = blockIdx.x * blockDim.x;
|
||||||
|
auto step = blockDim.x * gridDim.x;
|
||||||
|
|
||||||
|
for (int i = firstDim - 1 - tid - threadIdx.x; i > 0; i -= step) {
|
||||||
|
int r = rng->relativeInt(i) % i;
|
||||||
|
if (i != r) {
|
||||||
|
T e0 = input[shape::getIndexOffset(i, shape, len)];
|
||||||
|
T e1 = input[shape::getIndexOffset(r, shape, len)];
|
||||||
|
//math::nd4j_swap<T>(input(i), input(r));
|
||||||
|
input[shape::getIndexOffset(i, shape, len)] = e1;
|
||||||
|
input[shape::getIndexOffset(r, shape, len)] = e0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
template <typename T>
|
||||||
|
static __global__ void fillShuffleKernel(T* input, Nd4jLong* inputShape, T* output, Nd4jLong* outputShape, Nd4jLong firstDim, Nd4jLong len, int* indices, nd4j::graph::RandomGenerator* rng) {
|
||||||
|
|
||||||
|
// PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->tadThreshold())
|
||||||
|
auto tid = blockIdx.x * blockDim.x;
|
||||||
|
auto step = blockDim.x * gridDim.x;
|
||||||
|
|
||||||
|
for(int i = firstDim - 1 - tid - threadIdx.x; i > 0; i -= step) {
|
||||||
|
int r = rng->relativeInt(i) % i;
|
||||||
|
output[shape::getIndexOffset(i, outputShape, len)] = input[shape::getIndexOffset(indices[r], inputShape, len)];
|
||||||
|
if(i != r) {
|
||||||
|
output[shape::getIndexOffset(r, outputShape, len)] = input[shape::getIndexOffset(indices[i], inputShape, len)];
|
||||||
|
// output.p(r, input.e<T>(indices[i]));
|
||||||
|
// math::nd4j_swap<int>(indices[i], indices[r]);
|
||||||
|
atomicExch(&indices[i], indices[r]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void randomShuffle_(nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::random::RandomBuffer& rng, const bool isInplace) {
|
void randomShuffle_(nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::graph::RandomGenerator& rng, const bool isInplace) {
|
||||||
|
|
||||||
|
// check edge cases first
|
||||||
|
int temp;
|
||||||
|
const int firstDim = input.sizeAt(0);
|
||||||
|
auto stream = context->getCudaStream();
|
||||||
|
NDArray::prepareSpecialUse({&output}, {&input});
|
||||||
|
if(input.lengthOf() == 1 || firstDim == 1) {
|
||||||
|
if(!isInplace)
|
||||||
|
output.assign(input);
|
||||||
|
}
|
||||||
|
else if (input.isVector() || shape::isLikeVector(input.getShapeInfo(), temp)) {
|
||||||
|
|
||||||
|
// apply Fisher-Yates shuffle
|
||||||
|
nd4j::graph::RandomGenerator* dRandom = nullptr;
|
||||||
|
cudaMalloc(&dRandom, sizeof(nd4j::graph::RandomGenerator));
|
||||||
|
cudaMemcpy(dRandom, &rng, sizeof(nd4j::graph::RandomGenerator), cudaMemcpyHostToDevice);
|
||||||
|
T* inputBuf = reinterpret_cast<T*>(input.specialBuffer());
|
||||||
|
if(isInplace) {
|
||||||
|
swapShuffleKernel<T><<<128, 256, 1024, *stream>>>(inputBuf, input.specialShapeInfo(), firstDim, input.lengthOf(), dRandom);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
std::vector<int> indices(firstDim);
|
||||||
|
std::iota(indices.begin(), indices.end(), 0);
|
||||||
|
cudaMemcpy(output.specialBuffer(), input.specialBuffer(), sizeof(T), cudaMemcpyDeviceToDevice);
|
||||||
|
//output.p<T>(Nd4jLong(0), input.e<T>(0));
|
||||||
|
PointersManager pointersManager(context, "helper::randomShuffle_");
|
||||||
|
int* indicesDev = reinterpret_cast<int*>(pointersManager.replicatePointer(indices.data(), indices.size() * sizeof(int)));
|
||||||
|
T* outputBuf = reinterpret_cast<T*>(output.specialBuffer());
|
||||||
|
fillShuffleKernel<T><<<128, 256, 1024, *stream>>>(inputBuf, input.specialShapeInfo(), outputBuf, output.specialShapeInfo(), firstDim, input.lengthOf(), indicesDev, dRandom);
|
||||||
|
pointersManager.synchronize();
|
||||||
|
}
|
||||||
|
// rng.rewindH(firstDim - 1);
|
||||||
|
cudaFree(dRandom);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
|
||||||
|
// evaluate sub-arrays list of input array through all dimensions excluding first one
|
||||||
|
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input.rankOf(), {0});
|
||||||
|
auto subArrsListIn = input.allTensorsAlongDimension(dimensions);
|
||||||
|
|
||||||
|
// apply Fisher-Yates shuffle
|
||||||
|
if(isInplace) {
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->elementwiseThreshold())
|
||||||
|
for(int i = firstDim - 1; i > 0; --i) {
|
||||||
|
int r = rng.relativeInt(i) % i;
|
||||||
|
|
||||||
|
if(i != r)
|
||||||
|
subArrsListIn->at(i)->swapUnsafe(*subArrsListIn->at(r));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
// evaluate sub-arrays list of output array through all dimensions excluding first one
|
||||||
|
auto subArrsListOut = output.allTensorsAlongDimension(dimensions);
|
||||||
|
std::vector<int> indices(firstDim);
|
||||||
|
std::iota(indices.begin(), indices.end(), 0);
|
||||||
|
bool isZeroShuffled = false;
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance()->tadThreshold())
|
||||||
|
for(int i = firstDim - 1; i > 0; --i) {
|
||||||
|
int r = rng.relativeInt(i) % i;
|
||||||
|
subArrsListOut->at(i)->assign(subArrsListIn->at(indices[r]));
|
||||||
|
if(r == 0)
|
||||||
|
isZeroShuffled = true;
|
||||||
|
|
||||||
|
if(i != r) {
|
||||||
|
subArrsListOut->at(r)->assign(subArrsListIn->at(indices[i]));
|
||||||
|
math::nd4j_swap<int>(indices[i], indices[r]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if(!isZeroShuffled)
|
||||||
|
subArrsListOut->at(0)->assign(subArrsListIn->at(0));
|
||||||
|
delete subArrsListOut;
|
||||||
|
}
|
||||||
|
rng.rewindH(firstDim-1);
|
||||||
|
delete subArrsListIn;
|
||||||
|
}
|
||||||
|
NDArray::registerSpecialUse({&output}, {&input});
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void randomShuffle(nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::random::RandomBuffer& rng, const bool isInplace) {
|
void randomShuffle(nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::graph::RandomGenerator& rng, const bool isInplace) {
|
||||||
BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (context, input, output, rng, isInplace), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (context, input, output, rng, isInplace), LIBND4J_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void randomShuffle_, (nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::random::RandomBuffer& rng, const bool isInplace), LIBND4J_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void randomShuffle_, (nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::graph::RandomGenerator& rng, const bool isInplace), LIBND4J_TYPES);
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
template<typename T>
|
template<typename T>
|
||||||
|
@ -498,11 +717,6 @@ void eye(nd4j::LaunchContext * context, NDArray& output) {
|
||||||
output.setIdentity();
|
output.setIdentity();
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
|
||||||
void scatterUpdate(nd4j::LaunchContext * context, NDArray& operand, NDArray& updates, const std::vector<int>* intArgs) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
template <typename T, typename Z>
|
template <typename T, typename Z>
|
||||||
static __global__ void global_mergeMaxIndex_(void **inArrs, void **inShapes, const int numArrays, void *voutput, Nd4jLong *outputShape, Nd4jLong length) {
|
static __global__ void global_mergeMaxIndex_(void **inArrs, void **inShapes, const int numArrays, void *voutput, Nd4jLong *outputShape, Nd4jLong length) {
|
||||||
|
|
|
@ -33,27 +33,7 @@ namespace helpers {
|
||||||
|
|
||||||
void gruTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h);
|
void gruTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h);
|
||||||
|
|
||||||
void gruCellBP(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh, const NDArray* dLdWx0,
|
void gruCellBP(nd4j::LaunchContext* context, const NDArray* x, const NDArray* hLast, const NDArray* W, const NDArray* Wc, const NDArray* b, const NDArray* bc, const NDArray* dLdr, const NDArray* dLdu, const NDArray* dLdc, const NDArray* dLdh, NDArray* dLdx, NDArray* dLdhLast, NDArray* dLdW, NDArray* dLdWc, NDArray* dLdb, NDArray* dLdbc);
|
||||||
const NDArray* dLdWh0, const NDArray* dLdb0, NDArray* dLdx, NDArray* dLdh0, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb);
|
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
|
||||||
FORCEINLINE NDArray sigmoid(const NDArray& arr) {
|
|
||||||
return (const_cast<NDArray&>(arr)).transform(transform::Sigmoid);
|
|
||||||
}
|
|
||||||
|
|
||||||
FORCEINLINE void sigmoidInplace(const NDArray& arr) {
|
|
||||||
(const_cast<NDArray&>(arr)).applyTransform(transform::Sigmoid);
|
|
||||||
}
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
|
||||||
FORCEINLINE NDArray tanh(const NDArray& arr) {
|
|
||||||
return (const_cast<NDArray&>(arr)).transform(transform::Tanh);
|
|
||||||
}
|
|
||||||
|
|
||||||
FORCEINLINE void tanhInplace(const NDArray& arr) {
|
|
||||||
(const_cast<NDArray&>(arr)).applyTransform(transform::Tanh);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,37 +27,37 @@ namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
FORCEINLINE Nd4jLong longBytes(T value);
|
FORCEINLINE _CUDA_HD Nd4jLong longBytes(T value);
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
FORCEINLINE Nd4jLong longBytes(float value) {
|
FORCEINLINE _CUDA_HD Nd4jLong longBytes(float value) {
|
||||||
int intie = *(int *)&value;
|
int intie = *(int *)&value;
|
||||||
return static_cast<Nd4jLong>(intie);
|
return static_cast<Nd4jLong>(intie);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
FORCEINLINE Nd4jLong longBytes(double value) {
|
FORCEINLINE _CUDA_HD Nd4jLong longBytes(double value) {
|
||||||
Nd4jLong longie = *(Nd4jLong *)&value;
|
Nd4jLong longie = *(Nd4jLong *)&value;
|
||||||
return longie;
|
return longie;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
FORCEINLINE Nd4jLong longBytes(float16 value) {
|
FORCEINLINE _CUDA_HD Nd4jLong longBytes(float16 value) {
|
||||||
return longBytes<float>((float) value);
|
return longBytes<float>((float) value);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
FORCEINLINE Nd4jLong longBytes(Nd4jLong value) {
|
FORCEINLINE _CUDA_HD Nd4jLong longBytes(Nd4jLong value) {
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
FORCEINLINE Nd4jLong longBytes(bfloat16 value) {
|
FORCEINLINE _CUDA_HD Nd4jLong longBytes(bfloat16 value) {
|
||||||
return longBytes<float>((float) value);
|
return longBytes<float>((float) value);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
FORCEINLINE Nd4jLong longBytes(T value) {
|
FORCEINLINE _CUDA_HD Nd4jLong longBytes(T value) {
|
||||||
return longBytes<Nd4jLong>((Nd4jLong) value);
|
return longBytes<Nd4jLong>((Nd4jLong) value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -30,41 +30,34 @@ namespace helpers {
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
static FORCEINLINE NDArray activation(const NDArray& arr) {
|
void rnnCell(nd4j::LaunchContext * context, const NDArray* xt, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* hPrev, NDArray* ht) {
|
||||||
|
|
||||||
return (const_cast<NDArray&>(arr)).transform(transform::Tanh);
|
// xt input [bS x iS]
|
||||||
|
// Wx input-to-hidden weights, [iS x nU]
|
||||||
|
// Wh hidden-to-hidden weights, [nU x nU]
|
||||||
|
// b biases, [2*nU]: {0, nU} are input-to-hidden biases and {nU, 2*nU} are hidden-to-hidden biases
|
||||||
|
// hPrev previous cell output [bS x nU], that is at previous time step t-1, in case of projection=false -> nU=nU!!!
|
||||||
|
|
||||||
|
const int nU = hPrev->sizeAt(1);
|
||||||
|
|
||||||
|
// ht is current cell output [bS x nU], that is at current time step t
|
||||||
|
ht->assign(mmul(*xt, *Wx) + (*b)({{0, nU}}) + mmul(*hPrev, *Wh) + (*b)({{nU, 2*nU}})); // [bS x nU] + [nU] + [bS x nU] + [nU] = [bS x nU]
|
||||||
|
ht->applyTransform(transform::Tanh);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
|
||||||
void rnnCell(nd4j::LaunchContext * context, const NDArray* xt, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* ht_1, NDArray* ht) {
|
|
||||||
|
|
||||||
// xt input [bS x inSize]
|
|
||||||
// Wx input-to-hidden weights, [inSize x numUnits]
|
|
||||||
// Wh hidden-to-hidden weights, [numUnits x numUnits]
|
|
||||||
// b biases, [2*numUnits]: {0, numUnits} are input-to-hidden biases and {numUnits, 2*numUnits} are hidden-to-hidden biases
|
|
||||||
// ht_1 previous cell output [bS x numUnits], that is at previous time step t-1, in case of projection=false -> numUnits=numUnits!!!
|
|
||||||
|
|
||||||
const int numUnits = ht_1->sizeAt(1);
|
|
||||||
|
|
||||||
// ht is current cell output [bS x numUnits], that is at current time step t
|
|
||||||
ht->assign(activation(mmul(*xt, *Wx) + (*b)({{0, numUnits}}) + mmul(*ht_1, *Wh) + (*b)({{numUnits, 2*numUnits}}))); // [bS x numUnits] + [numUnits] + [bS x numUnits] + [numUnits] = [bS x numUnits]
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void rnnTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* h0, const NDArray* maxTimeStep, NDArray* h, NDArray* hFinal) {
|
void rnnTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* h0, const NDArray* maxTimeStep, NDArray* h, NDArray* hFinal) {
|
||||||
|
|
||||||
// x input [time x bS x inSize]
|
// x input [time x bS x iS]
|
||||||
// Wx input-to-hidden weights, [inSize x numUnits]
|
// Wx input-to-hidden weights, [iS x nU]
|
||||||
// Wh hidden-to-hidden weights, [numUnits x numUnits]
|
// Wh hidden-to-hidden weights, [nU x nU]
|
||||||
// b biases for, [2*numUnits]
|
// b biases for, [2*nU]
|
||||||
|
|
||||||
// h0 initial cell output (at time step = 0) [bS x numUnits]
|
// h0 initial cell output (at time step = 0) [bS x nU]
|
||||||
// maxTimeStep vector [bS] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this means there are no calculations for time >= maxTimeStep
|
// maxTimeStep vector [bS] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this means there are no calculations for time >= maxTimeStep
|
||||||
|
|
||||||
const int time = x->sizeAt(0);
|
const int time = x->sizeAt(0);
|
||||||
const int bS = x->sizeAt(1);
|
const int bS = x->sizeAt(1);
|
||||||
|
|
||||||
// at first time step
|
// at first time step
|
||||||
if(h0)
|
if(h0)
|
||||||
|
@ -82,16 +75,16 @@ void rnnTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray*
|
||||||
|
|
||||||
auto xt = (*x)({t,t+1, e,e+1, 0,0}, true);
|
auto xt = (*x)({t,t+1, e,e+1, 0,0}, true);
|
||||||
auto ht = (*h)({t,t+1, e,e+1, 0,0}, true);
|
auto ht = (*h)({t,t+1, e,e+1, 0,0}, true);
|
||||||
auto ht_1 = (*hFinal)({e,e+1, 0,0}, true); // previous state
|
auto hPrev = (*hFinal)({e,e+1, 0,0}, true); // previous state
|
||||||
|
|
||||||
if(t >= maxStep) {
|
if(t >= maxStep) {
|
||||||
ht = 0.;
|
ht = 0.;
|
||||||
if(maxStep != 0)
|
if(maxStep != 0)
|
||||||
ht_1.assign((*h)({maxStep-1,maxStep, e,e+1, 0,0}));
|
hPrev.assign((*h)({maxStep-1,maxStep, e,e+1, 0,0}));
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
helpers::rnnCell(context, &xt, Wx, Wh, b, &ht_1, &ht);
|
helpers::rnnCell(context, &xt, Wx, Wh, b, &hPrev, &ht);
|
||||||
ht_1.assign(ht);
|
hPrev.assign(ht);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,36 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author sgazeos@gmail.com
|
||||||
|
// @brief helpers common fuctions for segment_* ops (segment_max, segment_min, etc.)
|
||||||
|
// @brief helpers common fuctions for unsorted_segment_* ops (unsorted_segment_max, etc.)
|
||||||
|
//
|
||||||
|
#ifndef __SEGMENT_COMMON_HELPERS__
|
||||||
|
#define __SEGMENT_COMMON_HELPERS__
|
||||||
|
#include <op_boilerplate.h>
|
||||||
|
#include <NDArray.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
void fillUpSegments(NDArray* indices, Nd4jLong numClasses, NDArray& classesRangesBegs, NDArray& classesRangesLens);
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
|
@ -23,6 +23,7 @@
|
||||||
|
|
||||||
#include <ops/declarable/helpers/helpers.h>
|
#include <ops/declarable/helpers/helpers.h>
|
||||||
#include <helpers/helper_random.h>
|
#include <helpers/helper_random.h>
|
||||||
|
#include <graph/RandomGenerator.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -32,7 +33,7 @@ namespace helpers {
|
||||||
|
|
||||||
void trace(nd4j::LaunchContext * context, const NDArray& input, NDArray& output);
|
void trace(nd4j::LaunchContext * context, const NDArray& input, NDArray& output);
|
||||||
|
|
||||||
void randomShuffle(nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::random::RandomBuffer& rng, const bool isInplace);
|
void randomShuffle(nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::graph::RandomGenerator& rng, const bool isInplace);
|
||||||
|
|
||||||
// auxiliary function which serves for recursion purpose and is used in pad operation
|
// auxiliary function which serves for recursion purpose and is used in pad operation
|
||||||
// void recursiveLoopForPad(const int mode, NDArray& input, const NDArray& paddings, NDArray& output, std::vector<int> dimensions, int dim, int inIdx, int outIdx, NDArray& padValue);
|
// void recursiveLoopForPad(const int mode, NDArray& input, const NDArray& paddings, NDArray& output, std::vector<int> dimensions, int dim, int inIdx, int outIdx, NDArray& padValue);
|
||||||
|
|
|
@ -1126,15 +1126,7 @@ inline __device__ bool nd4j_atomicAdd<bool>(bool* address, bool val) {
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline __device__ double nd4j_atomicSub<double>(double* address, double val) {
|
inline __device__ double nd4j_atomicSub<double>(double* address, double val) {
|
||||||
unsigned long long int* address_as_ull =
|
return nd4j_atomicAdd<double>(address, -val);
|
||||||
(unsigned long long int *) address;
|
|
||||||
unsigned long long int old = *address_as_ull, assumed;
|
|
||||||
do {
|
|
||||||
assumed = old;
|
|
||||||
old = atomicCAS(address_as_ull, assumed,__double_as_longlong(val -
|
|
||||||
__longlong_as_double(assumed)));
|
|
||||||
} while (assumed != old);
|
|
||||||
return __longlong_as_double(old);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
|
@ -1152,15 +1144,7 @@ inline __device__ double nd4j_atomicMul<double>(double* address, double val) {
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline __device__ double nd4j_atomicDiv<double>(double* address, double val) {
|
inline __device__ double nd4j_atomicDiv<double>(double* address, double val) {
|
||||||
unsigned long long int* address_as_ull =
|
return nd4j_atomicMul<double>(address, 1./val);
|
||||||
(unsigned long long int*) address;
|
|
||||||
unsigned long long int old = *address_as_ull, assumed;
|
|
||||||
do {
|
|
||||||
assumed = old;
|
|
||||||
old = atomicCAS(address_as_ull, assumed,__double_as_longlong(val /
|
|
||||||
__longlong_as_double(assumed)));
|
|
||||||
} while (assumed != old);
|
|
||||||
return __longlong_as_double(old);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
|
@ -1179,14 +1163,16 @@ inline __device__ int32_t nd4j_atomicAdd<int32_t>(int32_t* address, int32_t val)
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline __device__ float nd4j_atomicSub<float>(float* address, float val) {
|
inline __device__ float nd4j_atomicSub<float>(float* address, float val) {
|
||||||
int* address_as_ull = (int*) address;
|
return nd4j_atomicAdd<float>(address, -val);
|
||||||
int old = *address_as_ull, assumed;
|
}
|
||||||
do {
|
|
||||||
assumed = old;
|
template <>
|
||||||
old = atomicCAS(address_as_ull, assumed, __float_as_int(val -
|
inline __device__ float16 nd4j_atomicSub<float16>(float16* address, float16 val) {
|
||||||
__float_as_int(assumed)));
|
return nd4j_atomicAdd<float16>(address, -val);
|
||||||
} while (assumed != old);
|
}
|
||||||
return __int_as_float(old);
|
template <>
|
||||||
|
inline __device__ bfloat16 nd4j_atomicSub<bfloat16>(bfloat16* address, bfloat16 val) {
|
||||||
|
return nd4j_atomicAdd<bfloat16>(address, -val);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
|
@ -1415,6 +1401,30 @@ inline __device__ float16 nd4j_atomicMul<float16>(float16* address, float16 val)
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline __device__ float nd4j_atomicDiv<float>(float* address, float val) {
|
inline __device__ float nd4j_atomicDiv<float>(float* address, float val) {
|
||||||
|
int* address_as_ull =
|
||||||
|
(int*)address;
|
||||||
|
int old = *address_as_ull, assumed;
|
||||||
|
do {
|
||||||
|
assumed = old;
|
||||||
|
old = atomicCAS(address_as_ull, assumed, __float_as_int(__int_as_float(assumed) / val ));
|
||||||
|
} while (assumed != old);
|
||||||
|
return __int_as_float(old);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline __device__ float16 nd4j_atomicDiv<float16>(float16* address, float16 val) {
|
||||||
|
int* address_as_ull =
|
||||||
|
(int*)address;
|
||||||
|
int old = *address_as_ull, assumed;
|
||||||
|
do {
|
||||||
|
assumed = old;
|
||||||
|
old = atomicCAS(address_as_ull, assumed, __float_as_int(val *
|
||||||
|
__float_as_int(assumed)));
|
||||||
|
} while (assumed != old);
|
||||||
|
return __int_as_float(old);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
inline __device__ bfloat16 nd4j_atomicDiv<bfloat16>(bfloat16* address, bfloat16 val) {
|
||||||
int* address_as_ull =
|
int* address_as_ull =
|
||||||
(int*)address;
|
(int*)address;
|
||||||
int old = *address_as_ull, assumed;
|
int old = *address_as_ull, assumed;
|
||||||
|
|
|
@ -76,6 +76,9 @@
|
||||||
(nd4j::DataType::FLOAT32, float), \
|
(nd4j::DataType::FLOAT32, float), \
|
||||||
(nd4j::DataType::DOUBLE, double)
|
(nd4j::DataType::DOUBLE, double)
|
||||||
|
|
||||||
|
#define FLOAT_NATIVE \
|
||||||
|
(nd4j::DataType::FLOAT32, float), \
|
||||||
|
(nd4j::DataType::DOUBLE, double)
|
||||||
|
|
||||||
#define FLOAT_TYPES_0 \
|
#define FLOAT_TYPES_0 \
|
||||||
(nd4j::DataType::HALF, float16)
|
(nd4j::DataType::HALF, float16)
|
||||||
|
|
|
@ -1870,7 +1870,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_1) {
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
NDArray* result = results->at(0);
|
NDArray* result = results->at(0);
|
||||||
result->printIndexedBuffer("OOOOUUUUTTT");
|
// result->printIndexedBuffer("OOOOUUUUTTT");
|
||||||
|
|
||||||
ASSERT_TRUE(expected.isSameShapeStrict(result));
|
ASSERT_TRUE(expected.isSameShapeStrict(result));
|
||||||
ASSERT_TRUE(expected.equalsTo(result));
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
@ -1881,9 +1881,9 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_1) {
|
||||||
////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_2) {
|
TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_2) {
|
||||||
|
|
||||||
NDArray boxes = NDArrayFactory::create<float>('c', {6,4}, {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1.f, 0.9f,
|
NDArray boxes = NDArrayFactory::create<double>('c', {6,4}, {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1.f, 0.9f,
|
||||||
0, 10, 1, 11, 0, 10.1f, 1.f, 11.1f, 0, 100, 1, 101});
|
0, 10, 1, 11, 0, 10.1f, 1.f, 11.1f, 0, 100, 1, 101});
|
||||||
NDArray scales = NDArrayFactory::create<float>('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); //3, 0, 1, 2, 4, 5
|
NDArray scales = NDArrayFactory::create<double>('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); //3, 0, 1, 2, 4, 5
|
||||||
NDArray expected = NDArrayFactory::create<int>('c', {3}, {3,0,5});
|
NDArray expected = NDArrayFactory::create<int>('c', {3}, {3,0,5});
|
||||||
|
|
||||||
nd4j::ops::non_max_suppression op;
|
nd4j::ops::non_max_suppression op;
|
||||||
|
@ -1892,7 +1892,7 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_2) {
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
NDArray* result = results->at(0);
|
NDArray* result = results->at(0);
|
||||||
result->printBuffer("NonMaxSuppression OUtput2");
|
// result->printBuffer("NonMaxSuppression OUtput2");
|
||||||
ASSERT_TRUE(expected.isSameShapeStrict(result));
|
ASSERT_TRUE(expected.isSameShapeStrict(result));
|
||||||
ASSERT_TRUE(expected.equalsTo(result));
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
|
@ -1970,6 +1970,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_3) {
|
||||||
|
|
||||||
delete results;
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests10, Image_CropAndResize_4) {
|
TEST_F(DeclarableOpsTests10, Image_CropAndResize_4) {
|
||||||
|
|
||||||
|
|
|
@ -402,22 +402,219 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_4) {
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests13, CellContains_test_1) {
|
TEST_F(DeclarableOpsTests13, CellContains_test_1) {
|
||||||
|
|
||||||
auto corners = NDArrayFactory::create<double>( {0.5384, 0.5640, 0.3449, 0.5257, 0.5505});
|
auto corners = NDArrayFactory::create<double>( {0.5384, 0.5640, 0.3449, 0.5257, 0.5505});
|
||||||
auto width = NDArrayFactory::create<double>({0.4306, 0.3960, 0.4639, 0.5040, 0.4904});
|
auto width = NDArrayFactory::create<double>({0.4306, 0.3960, 0.4639, 0.5040, 0.4904});
|
||||||
auto point = NDArrayFactory::create<double>({0.3000, 0.2625, 0.2674, 0.8604, 0.4803});
|
auto point = NDArrayFactory::create<double>({0.3000, 0.2625, 0.2674, 0.8604, 0.4803});
|
||||||
//auto exp = NDArrayFactory::create<double>('c', {1, 39}, {15.000000, 0.000000, 0.000000, 65.000000, 60.000000, 145.000000, 20.000000, 25.000000, 65.000000, 145.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000});
|
//auto exp = NDArrayFactory::create<double>('c', {1, 39}, {15.000000, 0.000000, 0.000000, 65.000000, 60.000000, 145.000000, 20.000000, 25.000000, 65.000000, 145.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000});
|
||||||
// data.linspace(1);
|
// data.linspace(1);
|
||||||
|
|
||||||
// auto y = NDArrayFactory::create<double>('c', {2,3}, {-0.1,-2,3, -4, -0.5, -6});
|
// auto y = NDArrayFactory::create<double>('c', {2,3}, {-0.1,-2,3, -4, -0.5, -6});
|
||||||
// auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6});
|
// auto eps = NDArrayFactory::create<double>('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6});
|
||||||
// auto exp = NDArrayFactory::create<double>('c', {2,3}, {1, 2, 1, 2, 2, 2});
|
// auto exp = NDArrayFactory::create<double>('c', {2,3}, {1, 2, 1, 2, 2, 2});
|
||||||
nd4j::ops::cell_contains op;
|
nd4j::ops::cell_contains op;
|
||||||
auto result = op.execute({&corners, &width, &point}, {}, {5});
|
auto result = op.execute({&corners, &width, &point}, {}, {5});
|
||||||
ASSERT_EQ(result->status(), Status::OK());
|
ASSERT_EQ(result->status(), Status::OK());
|
||||||
ASSERT_TRUE(result->at(0)->e<bool>(0));
|
ASSERT_TRUE(result->at(0)->e<bool>(0));
|
||||||
//result->at(2)->printBuffer("Symmetrized3");
|
//result->at(2)->printBuffer("Symmetrized3");
|
||||||
//exp.printBuffer("EXPect symm3");
|
//exp.printBuffer("EXPect symm3");
|
||||||
// ASSERT_TRUE(exp[i]->equalsTo(result->at(i)));
|
// ASSERT_TRUE(exp[i]->equalsTo(result->at(i)));
|
||||||
//ASSERT_TRUE(exp.equalsTo(result->at(0)));
|
//ASSERT_TRUE(exp.equalsTo(result->at(0)));
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests13, adjustHue_1) {
|
||||||
|
|
||||||
|
NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray exp ('c', {2,2,3}, {100,0,44, 208,5,220, 177,230,97, 2,255,244}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
nd4j::ops::adjust_hue op;
|
||||||
|
auto results = op.execute({&input}, {0.5}, {2});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto result = results->at(0);
|
||||||
|
// result->printIndexedBuffer();
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(result));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(result));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests13, adjustHue_2) {
|
||||||
|
|
||||||
|
NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray exp ('c', {2,2,3}, {4,100,0, 146,220,5, 97,123.8,230, 255,2,164.8}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
nd4j::ops::adjust_hue op;
|
||||||
|
auto results = op.execute({&input}, {0.9}, {2});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto result = results->at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(result));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(result));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests13, adjustHue_3) {
|
||||||
|
|
||||||
|
NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray exp ('c', {2,2,3}, {0.,84.,100., 5.,220.,122.0001, 229.8,97.,230., 255.,142.8002,2.}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
nd4j::ops::adjust_hue op;
|
||||||
|
auto results = op.execute({&input}, {-0.9}, {2});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto result = results->at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(result));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(result));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests13, adjustHue_4) {
|
||||||
|
|
||||||
|
NDArray input('c', {2,3,2}, {0,17, 100,220, 56,5, 150,255, 97,2, 230,13}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray exp ('c', {2,3,2}, {100,208, 0,5, 44,220, 177,2, 230,255, 97,244}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
nd4j::ops::adjust_hue op;
|
||||||
|
auto results = op.execute({&input}, {0.5}, {1});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto result = results->at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(result));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(result));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests13, adjustHue_5) {
|
||||||
|
|
||||||
|
NDArray input('c', {3,2,2}, {0,17, 150,255, 100,220, 97,2, 56,5, 230,13}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray exp ('c', {3,2,2}, {100,208, 177,2, 0,5, 230,255, 44,220, 97,244}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
nd4j::ops::adjust_hue op;
|
||||||
|
auto results = op.execute({&input}, {0.5}, {0});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto result = results->at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(result));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(result));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests13, adjustSaturation_1) {
|
||||||
|
|
||||||
|
NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray exp ('c', {2,2,3}, {50,100,78, 118.5,220,112.5, 190,163.5,230, 255,128.5,134}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
nd4j::ops::adjust_saturation op;
|
||||||
|
auto results = op.execute({&input}, {0.5}, {2});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto result = results->at(0);
|
||||||
|
// result->printIndexedBuffer();
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(result));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(result));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests13, adjustSaturation_2) {
|
||||||
|
|
||||||
|
NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray exp ('c', {2,2,3}, {0.,100.,56., 12.279087,220.,0., 91.654228,0.,230., 255.,0.,11.087015}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
nd4j::ops::adjust_saturation op;
|
||||||
|
auto results = op.execute({&input}, {10}, {2});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto result = results->at(0);
|
||||||
|
// result->printIndexedBuffer();
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(result));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(result));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests13, adjustSaturation_3) {
|
||||||
|
|
||||||
|
NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray exp ('c', {2,2,3}, {100.,100.,100., 220.,220.,220., 230.,230.,230., 255., 255., 255.}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
nd4j::ops::adjust_saturation op;
|
||||||
|
auto results = op.execute({&input}, {-10}, {2});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto result = results->at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(result));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(result));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests13, adjustSaturation_4) {
|
||||||
|
|
||||||
|
NDArray input('c', {2,3,2}, {0,17, 100,220, 56,5, 150,255, 97,2, 230,13}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray exp ('c', {2,3,2}, {50,118.5, 100,220, 78,112.5, 190,255, 163.5,128.5, 230,134}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
nd4j::ops::adjust_saturation op;
|
||||||
|
auto results = op.execute({&input}, {0.5}, {1});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto result = results->at(0);
|
||||||
|
// result->printIndexedBuffer();
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(result));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(result));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests13, adjustSaturation_5) {
|
||||||
|
|
||||||
|
NDArray input('c', {3,2,2}, {0,17, 150,255, 100,220, 97,2, 56,5, 230,13}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray exp ('c', {3,2,2}, {50,118.5, 190,255, 100,220, 163.5,128.5, 78,112.5, 230,134}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
nd4j::ops::adjust_saturation op;
|
||||||
|
auto results = op.execute({&input}, {0.5}, {0});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto result = results->at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(result));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(result));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
|
@ -1479,6 +1479,27 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test3) {
|
||||||
|
|
||||||
delete results;
|
delete results;
|
||||||
}
|
}
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests5, random_shuffle_test04) {
|
||||||
|
auto input = NDArrayFactory::create<double>('c', {4});
|
||||||
|
input.linspace(1);
|
||||||
|
|
||||||
|
nd4j::ops::random_shuffle op;
|
||||||
|
//NDArray* output;
|
||||||
|
auto results = op.execute({&input}, {}, {}, {}, true, nd4j::DataType::DOUBLE);
|
||||||
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
auto output = &input; //results->at(0);
|
||||||
|
bool haveZeros = false;
|
||||||
|
for(int i = 0; i < output->lengthOf(); ++i)
|
||||||
|
if(output->e<float>(i) == (float)0.)
|
||||||
|
haveZeros = true;
|
||||||
|
|
||||||
|
ASSERT_TRUE(input.isSameShape(output));
|
||||||
|
//ASSERT_TRUE(!input.equalsTo(output));
|
||||||
|
ASSERT_TRUE(!haveZeros);
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests5, random_shuffle_test4) {
|
TEST_F(DeclarableOpsTests5, random_shuffle_test4) {
|
||||||
|
@ -1486,17 +1507,17 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test4) {
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::random_shuffle op;
|
nd4j::ops::random_shuffle op;
|
||||||
auto results = op.execute({&input}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
|
//NDArray* output;
|
||||||
|
auto results = op.execute({&input}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
|
||||||
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
auto output = results->at(0);
|
auto output = results->at(0);
|
||||||
|
|
||||||
bool haveZeros = false;
|
bool haveZeros = false;
|
||||||
for(int i = 0; i < output->lengthOf(); ++i)
|
for(int i = 0; i < output->lengthOf(); ++i)
|
||||||
if(output->e<float>(i) == (float)0.)
|
if(output->e<float>(i) == (float)0.)
|
||||||
haveZeros = true;
|
haveZeros = true;
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
|
||||||
ASSERT_TRUE(input.isSameShape(output));
|
ASSERT_TRUE(input.isSameShape(output));
|
||||||
ASSERT_TRUE(!input.equalsTo(output));
|
//ASSERT_TRUE(!input.equalsTo(output));
|
||||||
ASSERT_TRUE(!haveZeros);
|
ASSERT_TRUE(!haveZeros);
|
||||||
|
|
||||||
delete results;
|
delete results;
|
||||||
|
|
|
@ -1601,8 +1601,8 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_1) {
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
// z->printIndexedBuffer("Output ");
|
z->printIndexedBuffer("Output ");
|
||||||
// exp.printIndexedBuffer("Expected ");
|
exp.printIndexedBuffer("Expected ");
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
@ -1610,6 +1610,75 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_1) {
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests6, MatrixInverse_01) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {1, 5, 5}, {
|
||||||
|
2., 4., 60., 8., 10.,
|
||||||
|
0., 1., 2., 3., 4.,
|
||||||
|
0., 0., 2., 4., 6.,
|
||||||
|
0., 0., 0., 1., 2.,
|
||||||
|
0., 0., 0., 0., 4.
|
||||||
|
|
||||||
|
});
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {1, 5, 5}, {
|
||||||
|
0.5, -2.0, -13.0, 54.0, -6.75,
|
||||||
|
0.0, 1.0, -1.0, 1.0, 0.0,
|
||||||
|
0, 0, 0.5, -2.0, 0.25,
|
||||||
|
0, 0, 0, 1.0, -0.5,
|
||||||
|
0, 0, 0, 0, 0.25
|
||||||
|
|
||||||
|
});
|
||||||
|
|
||||||
|
nd4j::ops::matrix_inverse op;
|
||||||
|
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
z->printIndexedBuffer("Output ");
|
||||||
|
exp.printIndexedBuffer("Expected ");
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests6, MatrixInverse_02) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {1, 5, 5}, {
|
||||||
|
1., 0., 0., 0., 0.,
|
||||||
|
2., 1., 0., 0., 0.,
|
||||||
|
30., 2., 1., 0., 0.,
|
||||||
|
4., 3., 2., 1., 0.,
|
||||||
|
5., 4., 3., 2., 1.
|
||||||
|
});
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {1, 5, 5}, {
|
||||||
|
1.0, 0.0, 0.0, 0.0, 0.,
|
||||||
|
-2.0, 1.0, 0., 0., 0.,
|
||||||
|
-26.0, -2.0, 1, 0, 0.,
|
||||||
|
54.0, 1.0, -2.0, 1, 0.,
|
||||||
|
-27.0, 0.0, 1.0, -2.0, 1.
|
||||||
|
});
|
||||||
|
|
||||||
|
nd4j::ops::matrix_inverse op;
|
||||||
|
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
z->printIndexedBuffer("Output ");
|
||||||
|
exp.printIndexedBuffer("Expected ");
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
/*
|
/*
|
||||||
|
@ -1658,6 +1727,39 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_2) {
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
TEST_F(DeclarableOpsTests6, MatrixInverse_03) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {5, 5}, {
|
||||||
|
4., 0., 0., 0., 0.,
|
||||||
|
4., 2., 0., 0., 0.,
|
||||||
|
30., 2., 1., 0., 0.,
|
||||||
|
8., 6., 4., 2., 0.,
|
||||||
|
15., 12., 9., 6., 3.,
|
||||||
|
});
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {5, 5}, {
|
||||||
|
0.25, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
-0.50, 0.5, 0.0, 0.0, 0.0,
|
||||||
|
-6.50, -1.0, 1.0, 0.0, 0.0,
|
||||||
|
13.50, 0.5, -2.0, 0.5, 0.0,
|
||||||
|
-6.75, 0.0, 1.0, -1.0, 0.33333333
|
||||||
|
});
|
||||||
|
|
||||||
|
nd4j::ops::matrix_inverse op;
|
||||||
|
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
z->printIndexedBuffer("Output ");
|
||||||
|
exp.printIndexedBuffer("Expected ");
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests6, MatrixInverse_3) {
|
TEST_F(DeclarableOpsTests6, MatrixInverse_3) {
|
||||||
|
|
||||||
|
@ -1695,7 +1797,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_3) {
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests6, MatrixInverse_4) {
|
TEST_F(DeclarableOpsTests6, MatrixInverse_4) {
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {5, 5}, {
|
auto x = NDArrayFactory::create<float>('c', {5, 5}, {
|
||||||
1., 2., 30., 4., 5.,
|
1., 2., 30., 4., 5.,
|
||||||
0., 1., 2., 3., 4.,
|
0., 1., 2., 3., 4.,
|
||||||
0., 0., 1., 2., 3.,
|
0., 0., 1., 2., 3.,
|
||||||
|
@ -1703,7 +1805,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_4) {
|
||||||
0., 0., 0., 0., 1.
|
0., 0., 0., 0., 1.
|
||||||
});
|
});
|
||||||
|
|
||||||
auto exp = NDArrayFactory::create<double>('c', {5, 5}, {
|
auto exp = NDArrayFactory::create<float>('c', {5, 5}, {
|
||||||
1.0, -2.0, -26.0, 54.0, -27.0,
|
1.0, -2.0, -26.0, 54.0, -27.0,
|
||||||
0.0, 1.0, -2.0, 1.0, 0.0,
|
0.0, 1.0, -2.0, 1.0, 0.0,
|
||||||
0.0, 0.0, 1.0, -2.0, 1.0,
|
0.0, 0.0, 1.0, -2.0, 1.0,
|
||||||
|
@ -1712,13 +1814,13 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_4) {
|
||||||
});
|
});
|
||||||
|
|
||||||
nd4j::ops::matrix_inverse op;
|
nd4j::ops::matrix_inverse op;
|
||||||
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
|
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
//z->printIndexedBuffer("Output ");
|
z->printIndexedBuffer("Output ");
|
||||||
//exp.printIndexedBuffer("Expected ");
|
exp.printIndexedBuffer("Expected ");
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
|
@ -763,15 +763,15 @@ TEST_F(DeclarableOpsTests7, Test_Dynamic_Partition_119_2) {
|
||||||
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests7, Test_SequenceMask_1) {
|
TEST_F(DeclarableOpsTests7, Test_SequenceMask_1) {
|
||||||
auto input = NDArrayFactory::create<double>('c', {4, 4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f});
|
auto input = NDArrayFactory::create<int>('c', {4, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||||
auto exp = NDArrayFactory::create<double>('c', {4, 4, 16}, {1.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,1.f, 1.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
|
auto exp = NDArrayFactory::create<bool>('c', {4, 4, 16}, {1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
1.f, 1.f, 1.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,1.f, 1.f, 1.f, 1.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
|
1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
1.f, 1.f, 1.f, 1.f, 1.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
|
1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
|
1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0,
|
||||||
1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.f, 0.f, 0.f, 0.f, 0.f,1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.f, 0.f, 0.f, 0.f,
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0,
|
||||||
1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.f, 0.f, 0.f,1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.f, 0.f,
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0,
|
||||||
1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.f,1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f });
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 });
|
||||||
|
|
||||||
nd4j::ops::sequence_mask op;
|
nd4j::ops::sequence_mask op;
|
||||||
auto result = op.execute({&input}, {}, {});
|
auto result = op.execute({&input}, {}, {});
|
||||||
|
@ -788,19 +788,19 @@ TEST_F(DeclarableOpsTests7, Test_SequenceMask_1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests7, Test_SequenceMask_2) {
|
TEST_F(DeclarableOpsTests7, Test_SequenceMask_2) {
|
||||||
auto input = NDArrayFactory::create<double>('c', {2, 2, 2}, {10., 20., 30., 4., 0., 6., 7., 8.});
|
auto input = NDArrayFactory::create<int>('c', {2, 2, 2}, {10, 20, 30, 4, 0, 6, 7, 8});
|
||||||
auto exp = NDArrayFactory::create<double>('c', {2, 2, 2, 30}, { 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
|
auto exp = NDArrayFactory::create<bool>('c', {2, 2, 2, 30}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
|
1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
|
1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.});
|
1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
|
||||||
|
|
||||||
nd4j::ops::sequence_mask op;
|
nd4j::ops::sequence_mask op;
|
||||||
auto result = op.execute({&input}, {}, {});
|
auto result = op.execute({&input}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
// z->printIndexedBuffer("Output");
|
// z->printBuffer("Output");
|
||||||
// z->printShapeInfo("Shape");
|
// z->printShapeInfo("Shape");
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
|
@ -2770,9 +2770,8 @@ TEST_F(DeclarableOpsTests9, batchnorm_bp_test3) {
|
||||||
ASSERT_TRUE(isGradCorrect);
|
ASSERT_TRUE(isGradCorrect);
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
|
||||||
/*
|
/*
|
||||||
//2019/02/23 AB - GRU backprop tests disabled pending update of GRU backprop op after rewriting forward pass
|
////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests9, gru_cell_bp_test1) {
|
TEST_F(DeclarableOpsTests9, gru_cell_bp_test1) {
|
||||||
|
|
||||||
const int bS = 2;
|
const int bS = 2;
|
||||||
|
@ -2780,160 +2779,58 @@ TEST_F(DeclarableOpsTests9, gru_cell_bp_test1) {
|
||||||
const int nU = 4;
|
const int nU = 4;
|
||||||
|
|
||||||
NDArray x('c', {bS, iS}, nd4j::DataType::DOUBLE);
|
NDArray x('c', {bS, iS}, nd4j::DataType::DOUBLE);
|
||||||
NDArray h0('c', {bS, nU}, nd4j::DataType::DOUBLE);
|
NDArray hi('c', {bS, nU}, nd4j::DataType::DOUBLE);
|
||||||
NDArray Wx('c', {iS, 3*nU}, nd4j::DataType::DOUBLE);
|
NDArray W('c', {iS+nU, 2*nU}, nd4j::DataType::DOUBLE);
|
||||||
NDArray Wh('c', {nU, 3*nU}, nd4j::DataType::DOUBLE);
|
NDArray Wc('c', {iS+nU, nU}, nd4j::DataType::DOUBLE);
|
||||||
NDArray b('c', {3*nU}, nd4j::DataType::DOUBLE);
|
NDArray b('c', {2*nU}, nd4j::DataType::DOUBLE);
|
||||||
|
NDArray bc('c', {nU}, nd4j::DataType::DOUBLE);
|
||||||
|
NDArray dLdr('c', {bS, nU}, nd4j::DataType::DOUBLE);
|
||||||
|
NDArray dLdu('c', {bS, nU}, nd4j::DataType::DOUBLE);
|
||||||
|
NDArray dLdc('c', {bS, nU}, nd4j::DataType::DOUBLE);
|
||||||
NDArray dLdh('c', {bS, nU}, nd4j::DataType::DOUBLE);
|
NDArray dLdh('c', {bS, nU}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
x.linspace(0.5, 0.5);
|
x.linspace(-5, 0.5);
|
||||||
h0 = 1.;
|
hi = 1.;
|
||||||
Wx = 0.003;
|
W = 0.003;
|
||||||
Wh = 0.006;
|
Wc = 0.006;
|
||||||
b = 0.5;
|
b = 0.5;
|
||||||
|
bc = 0.35;
|
||||||
|
|
||||||
const OpArgsHolder argsHolderFF({&x, &h0, &Wx, &Wh, &b}, {}, {});
|
|
||||||
const OpArgsHolder argsHolderBP({&x, &h0, &Wx, &Wh, &b, &dLdh}, {}, {});
|
const OpArgsHolder argsHolderFF({&x, &hi, &W, &Wc, &b, &bc}, {}, {});
|
||||||
|
nd4j::ops::gruCell op;
|
||||||
|
auto results = op.execute(argsHolderFF);
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto u = results->at(1); // [bS, nU]
|
||||||
|
auto c = results->at(2); // [bS, nU]
|
||||||
|
auto h = results->at(3); // [bS, nU]
|
||||||
|
|
||||||
|
dLdh = 1.; // SUM loss
|
||||||
|
|
||||||
|
NDArray Wch = Wc({iS,iS+nU, 0,0}); // [nU, nU]
|
||||||
|
NDArray dhdc = 1. - *u;
|
||||||
|
NDArray dhdu = hi - *c;
|
||||||
|
NDArray dcdZc = 1. - *c * *c;
|
||||||
|
dLdc.assign(dLdh * dhdc);
|
||||||
|
dLdu.assign(dLdh * dhdu);
|
||||||
|
dLdr.assign(mmul(dLdc * dcdZc * hi, Wch.transpose()));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
|
||||||
|
|
||||||
|
const OpArgsHolder argsHolderBP({&x, &hi, &W, &Wc, &b, &bc, &dLdr, &dLdu, &dLdc, &dLdh}, {}, {});
|
||||||
|
|
||||||
nd4j::ops::gruCell opFF;
|
nd4j::ops::gruCell opFF;
|
||||||
nd4j::ops::gruCell_bp opBP;
|
nd4j::ops::gruCell_bp opBP;
|
||||||
|
|
||||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 1, 1, 1 , 1, 1}, {0., 1.}, nd4j::GradCheck::LossFunc::SUM, true);
|
||||||
|
|
||||||
ASSERT_TRUE(isGradCorrect);
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests9, gru_cell_bp_test2) {
|
|
||||||
|
|
||||||
const int bS = 2;
|
|
||||||
const int iS = 3;
|
|
||||||
const int nU = 4;
|
|
||||||
|
|
||||||
NDArray x('c', {bS, iS}, nd4j::DataType::DOUBLE);
|
|
||||||
NDArray h0('c', {bS, nU}, nd4j::DataType::DOUBLE);
|
|
||||||
NDArray Wx('c', {iS, 3*nU}, nd4j::DataType::DOUBLE);
|
|
||||||
NDArray Wh('c', {nU, 3*nU}, nd4j::DataType::DOUBLE);
|
|
||||||
NDArray b('c', {3*nU}, nd4j::DataType::DOUBLE);
|
|
||||||
NDArray dLdh('c', {bS, nU}, nd4j::DataType::DOUBLE);
|
|
||||||
|
|
||||||
x.linspace(0.5, 0.5);
|
|
||||||
h0 = 1.;
|
|
||||||
Wx = 0.003;
|
|
||||||
Wh = 0.006;
|
|
||||||
b = 0.;
|
|
||||||
|
|
||||||
const OpArgsHolder argsHolderFF({&x, &h0, &Wx, &Wh, &b}, {}, {});
|
|
||||||
const OpArgsHolder argsHolderBP({&x, &h0, &Wx, &Wh, &b, &dLdh}, {}, {});
|
|
||||||
|
|
||||||
nd4j::ops::gruCell opFF;
|
|
||||||
nd4j::ops::gruCell_bp opBP;
|
|
||||||
|
|
||||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
|
||||||
|
|
||||||
ASSERT_TRUE(isGradCorrect);
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests9, gru_cell_bp_test3) {
|
|
||||||
|
|
||||||
const int bS = 2;
|
|
||||||
const int iS = 3;
|
|
||||||
const int nU = 4;
|
|
||||||
|
|
||||||
NDArray x('c', {bS, iS}, nd4j::DataType::DOUBLE);
|
|
||||||
NDArray h0('c', {bS, nU}, nd4j::DataType::DOUBLE);
|
|
||||||
NDArray Wx('c', {iS, 3*nU}, nd4j::DataType::DOUBLE);
|
|
||||||
NDArray Wh('c', {nU, 3*nU}, nd4j::DataType::DOUBLE);
|
|
||||||
NDArray b('c', {3*nU}, nd4j::DataType::DOUBLE);
|
|
||||||
NDArray dLdh('c', {bS, nU}, nd4j::DataType::DOUBLE);
|
|
||||||
// NDArray<double> dLdWx0('c', {iS, 3*nU});
|
|
||||||
// NDArray<double> dLdWh0('c', {nU, 3*nU});
|
|
||||||
// NDArray<double> dLdb0 ('c', {3*nU});
|
|
||||||
|
|
||||||
x = 1.;
|
|
||||||
h0 = 0.0;
|
|
||||||
Wx = 0.0;
|
|
||||||
Wh = 0.0;
|
|
||||||
b = 0.5;
|
|
||||||
|
|
||||||
const OpArgsHolder argsHolderFF({&x, &h0, &Wx, &Wh, &b}, {}, {});
|
|
||||||
const OpArgsHolder argsHolderBP({&x, &h0, &Wx, &Wh, &b, &dLdh}, {}, {});
|
|
||||||
|
|
||||||
nd4j::ops::gruCell opFF;
|
|
||||||
nd4j::ops::gruCell_bp opBP;
|
|
||||||
|
|
||||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
|
||||||
|
|
||||||
ASSERT_TRUE(isGradCorrect);
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
|
||||||
// TEST_F(DeclarableOpsTests9, gru_bp_test1) {
|
|
||||||
|
|
||||||
// const int time = 5;
|
|
||||||
// const int bS = 2;
|
|
||||||
// const int iS = 3;
|
|
||||||
// const int nU = 4;
|
|
||||||
|
|
||||||
// NDArray<double> x ('c', {time, bS, iS});
|
|
||||||
// NDArray<double> h0 ('c', {bS, nU});
|
|
||||||
// NDArray<double> Wx ('c', {iS, 3*nU});
|
|
||||||
// NDArray<double> Wh ('c', {nU, 3*nU});
|
|
||||||
// NDArray<double> b ('c', {3*nU});
|
|
||||||
// NDArray<double> dLdh ('c', {time, bS, nU});
|
|
||||||
|
|
||||||
// x.linspace(0.5, 0.5);
|
|
||||||
// h0 = 1.;
|
|
||||||
// Wx = 0.003;
|
|
||||||
// Wh = 0.006;
|
|
||||||
// b = 0.5;
|
|
||||||
|
|
||||||
// const OpArgsHolder<double> argsHolderFF({&x, &h0, &Wx, &Wh, &b}, {}, {});
|
|
||||||
// const OpArgsHolder<double> argsHolderBP({&x, &h0, &Wx, &Wh, &b, &dLdh}, {}, {});
|
|
||||||
|
|
||||||
// nd4j::ops::gru<double> opFF;
|
|
||||||
// nd4j::ops::gru_bp<double> opBP;
|
|
||||||
|
|
||||||
// const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
|
||||||
|
|
||||||
// ASSERT_TRUE(isGradCorrect);
|
|
||||||
// }
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests9, gru_cell_bp_test3_1) {
|
|
||||||
|
|
||||||
const int bS = 2;
|
|
||||||
const int iS = 3;
|
|
||||||
const int nU = 4;
|
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {bS, iS});
|
|
||||||
auto h0 = NDArrayFactory::create<double>('c', {bS, nU});
|
|
||||||
auto Wx = NDArrayFactory::create<double>('c', {iS, 3*nU});
|
|
||||||
auto Wh = NDArrayFactory::create<double>('c', {nU, 3*nU});
|
|
||||||
auto b = NDArrayFactory::create<double>('c', {3*nU});
|
|
||||||
auto dLdh = NDArrayFactory::create<double>('c', {bS, nU});
|
|
||||||
// NDArray<double> dLdWx0('c', {iS, 3*nU});
|
|
||||||
// NDArray<double> dLdWh0('c', {nU, 3*nU});
|
|
||||||
// NDArray<double> dLdb0 ('c', {3*nU});
|
|
||||||
|
|
||||||
x = 1.;
|
|
||||||
h0 = 0.0;
|
|
||||||
Wx = 0.0;
|
|
||||||
Wh = 0.0;
|
|
||||||
b = 0.5;
|
|
||||||
|
|
||||||
const OpArgsHolder argsHolderFF({&x, &h0, &Wx, &Wh, &b}, {}, {});
|
|
||||||
const OpArgsHolder argsHolderBP({&x, &h0, &Wx, &Wh, &b, &dLdh}, {}, {});
|
|
||||||
|
|
||||||
nd4j::ops::gruCell opFF;
|
|
||||||
nd4j::ops::gruCell_bp opBP;
|
|
||||||
|
|
||||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
|
||||||
|
|
||||||
ASSERT_TRUE(isGradCorrect);
|
ASSERT_TRUE(isGradCorrect);
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests9, Cholesky_Test_1) {
|
TEST_F(DeclarableOpsTests9, Cholesky_Test_1) {
|
||||||
|
|
||||||
|
|
|
@ -719,6 +719,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ParityOpsTests, Test_Scatter_Add_2) {
|
TEST_F(ParityOpsTests, Test_Scatter_Add_2) {
|
||||||
|
|
||||||
auto vec = NDArrayFactory::create<float>('c', {4}, {1, 2, 3, 4});
|
auto vec = NDArrayFactory::create<float>('c', {4}, {1, 2, 3, 4});
|
||||||
NDArray idc('c', {1, 4}, {0, 1, 2, 3}, nd4j::DataType::INT64);
|
NDArray idc('c', {1, 4}, {0, 1, 2, 3}, nd4j::DataType::INT64);
|
||||||
auto updates = NDArrayFactory::create<float>('c', {1, 4}, {1, 1, 1, 1});
|
auto updates = NDArrayFactory::create<float>('c', {1, 4}, {1, 1, 1, 1});
|
||||||
|
@ -1588,36 +1589,79 @@ TEST_F(ParityOpsTests, scatterND_update_test5) {
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(ParityOpsTests, scatter_update_1) {
|
TEST_F(ParityOpsTests, scatter_update_1) {
|
||||||
auto matrix = NDArrayFactory::create_<float>('c', {3, 2});
|
|
||||||
auto updates = NDArrayFactory::create_<float>('c', {2, 2});
|
|
||||||
updates->assign(1.0);
|
|
||||||
|
|
||||||
//updates.printBuffer("Updates");
|
NDArray x('c', {2,2}, {1,2,3,4}, nd4j::DataType::INT32);
|
||||||
|
NDArray updates('c', {2,2}, {10,20,30,40}, nd4j::DataType::INT32);
|
||||||
|
|
||||||
auto variableSpace = new VariableSpace();
|
NDArray exp('c', {2,2}, {30,40,10,20}, nd4j::DataType::INT32);
|
||||||
variableSpace->putVariable(-1, matrix);
|
|
||||||
variableSpace->putVariable(-2, updates);
|
|
||||||
variableSpace->putVariable(1, new Variable(&matrix));
|
|
||||||
|
|
||||||
auto block = new Context(1, variableSpace, false);
|
|
||||||
block->fillInputs({-1, -2});
|
|
||||||
|
|
||||||
std::vector<int>* arguments = block->getIArguments();
|
|
||||||
arguments->push_back(0);
|
|
||||||
arguments->push_back(1);
|
|
||||||
arguments->push_back(1);
|
|
||||||
arguments->push_back(2);
|
|
||||||
arguments->push_back(1);
|
|
||||||
arguments->push_back(2);
|
|
||||||
|
|
||||||
nd4j::ops::scatter_update op;
|
nd4j::ops::scatter_update op;
|
||||||
|
auto results = op.execute({&x, &updates}, {}, {6, 1,1, 2,1,0});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
// x.printBuffer();
|
||||||
|
|
||||||
Nd4jStatus result = op.execute(block);
|
ASSERT_TRUE(exp.isSameShape(x));
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result);
|
ASSERT_TRUE(exp.equalsTo(x));
|
||||||
|
|
||||||
delete block;
|
delete results;
|
||||||
delete variableSpace;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(ParityOpsTests, scatter_update_2) {
|
||||||
|
|
||||||
|
NDArray x('c', {2,2}, {1,2,3,4}, nd4j::DataType::INT32);
|
||||||
|
NDArray updates('c', {2,2}, {10,20,30,40}, nd4j::DataType::INT32);
|
||||||
|
|
||||||
|
NDArray exp('c', {2,2}, {20,10,40,30}, nd4j::DataType::INT32);
|
||||||
|
|
||||||
|
nd4j::ops::scatter_update op;
|
||||||
|
auto results = op.execute({&x, &updates}, {}, {6, 1,0, 2,1,0});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(x));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(x));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(ParityOpsTests, scatter_update_3) {
|
||||||
|
|
||||||
|
NDArray x('c', {2,2,2}, {1,2,3,4,5,6,7,8}, nd4j::DataType::INT32);
|
||||||
|
NDArray updates('c', {2,2,2}, {10,20,30,40,50,60,70,80}, nd4j::DataType::INT32);
|
||||||
|
|
||||||
|
NDArray exp('c', {2,2,2}, {50,60,70,80,10,20,30,40}, nd4j::DataType::INT32);
|
||||||
|
|
||||||
|
nd4j::ops::scatter_update op;
|
||||||
|
auto results = op.execute({&x, &updates}, {}, {6, 2,1,2, 2,1,0});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(x));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(x));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(ParityOpsTests, scatter_update_4) {
|
||||||
|
|
||||||
|
NDArray x('c', {2,2,2}, {1,2,3,4,5,6,7,8}, nd4j::DataType::INT32);
|
||||||
|
NDArray updates('c', {2,2,2}, {10,20,30,40,50,60,70,80}, nd4j::DataType::INT32);
|
||||||
|
|
||||||
|
NDArray exp('c', {2,2,2}, {20,2,3,10,60,6,7,50}, nd4j::DataType::INT32);
|
||||||
|
|
||||||
|
nd4j::ops::scatter_update op;
|
||||||
|
auto results = op.execute({&x, &updates}, {}, {6, 1,0, 2,3,0});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(x));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(x));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
|
@ -278,8 +278,8 @@ TEST_F(RNGTests, Test_Gaussian_22) {
|
||||||
auto x0 = NDArrayFactory::create<float>('c', {10000, 1000});
|
auto x0 = NDArrayFactory::create<float>('c', {10000, 1000});
|
||||||
auto x1 = NDArrayFactory::create<float>('c', {10000, 1000});
|
auto x1 = NDArrayFactory::create<float>('c', {10000, 1000});
|
||||||
|
|
||||||
RandomLauncher::fillGaussian(_rngA, &x0, 0.0f, 1.0f);
|
RandomLauncher::fillGaussian(nd4j::LaunchContext::defaultContext(), _rngA, &x0, 0.0f, 1.0f);
|
||||||
RandomLauncher::fillGaussian(_rngB, &x1, 0.0f, 1.0f);
|
RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngB, &x1, 0.0f, 1.0f);
|
||||||
|
|
||||||
//x0.printIndexedBuffer("x0");
|
//x0.printIndexedBuffer("x0");
|
||||||
//x1.printIndexedBuffer("x1");
|
//x1.printIndexedBuffer("x1");
|
||||||
|
@ -306,7 +306,7 @@ TEST_F(RNGTests, Test_Gaussian_22) {
|
||||||
TEST_F(RNGTests, Test_Gaussian_3) {
|
TEST_F(RNGTests, Test_Gaussian_3) {
|
||||||
auto x0 = NDArrayFactory::create<double>('c', {10000000});
|
auto x0 = NDArrayFactory::create<double>('c', {10000000});
|
||||||
|
|
||||||
RandomLauncher::fillGaussian(_rngA, &x0, 0.0, 1.0);
|
RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngA, &x0, 0.0, 1.0);
|
||||||
|
|
||||||
auto mean = x0.meanNumber().e<double>(0);
|
auto mean = x0.meanNumber().e<double>(0);
|
||||||
auto stdev = x0.varianceNumber(nd4j::variance::SummaryStatsStandardDeviation, false).e<double>(0);
|
auto stdev = x0.varianceNumber(nd4j::variance::SummaryStatsStandardDeviation, false).e<double>(0);
|
||||||
|
@ -319,8 +319,8 @@ TEST_F(RNGTests, Test_LogNormal_1) {
|
||||||
auto x0 = NDArrayFactory::create<float>('c', {10, 10});
|
auto x0 = NDArrayFactory::create<float>('c', {10, 10});
|
||||||
auto x1 = NDArrayFactory::create<float>('c', {10, 10});
|
auto x1 = NDArrayFactory::create<float>('c', {10, 10});
|
||||||
|
|
||||||
RandomLauncher::fillLogNormal(_rngA, &x0, 1.0f, 2.0f);
|
RandomLauncher::fillLogNormal(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f);
|
||||||
RandomLauncher::fillLogNormal(_rngB, &x1, 1.0f, 2.0f);
|
RandomLauncher::fillLogNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f);
|
||||||
|
|
||||||
ASSERT_TRUE(x0.equalsTo(&x1));
|
ASSERT_TRUE(x0.equalsTo(&x1));
|
||||||
|
|
||||||
|
@ -333,8 +333,8 @@ TEST_F(RNGTests, Test_Truncated_1) {
|
||||||
auto x0 = NDArrayFactory::create<float>('c', {10, 10});
|
auto x0 = NDArrayFactory::create<float>('c', {10, 10});
|
||||||
auto x1 = NDArrayFactory::create<float>('c', {10, 10});
|
auto x1 = NDArrayFactory::create<float>('c', {10, 10});
|
||||||
|
|
||||||
RandomLauncher::fillTruncatedNormal(_rngA, &x0, 1.0f, 2.0f);
|
RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f);
|
||||||
RandomLauncher::fillTruncatedNormal(_rngB, &x1, 1.0f, 2.0f);
|
RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f);
|
||||||
|
|
||||||
ASSERT_TRUE(x0.equalsTo(&x1));
|
ASSERT_TRUE(x0.equalsTo(&x1));
|
||||||
|
|
||||||
|
@ -357,8 +357,8 @@ TEST_F(RNGTests, Test_Truncated_2) {
|
||||||
auto x0 = NDArrayFactory::create<float>('c', {1000, 1000});
|
auto x0 = NDArrayFactory::create<float>('c', {1000, 1000});
|
||||||
auto x1 = NDArrayFactory::create<float>('c', {1000, 1000});
|
auto x1 = NDArrayFactory::create<float>('c', {1000, 1000});
|
||||||
|
|
||||||
RandomLauncher::fillTruncatedNormal(_rngA, &x0, 1.0f, 2.0f);
|
RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f);
|
||||||
RandomLauncher::fillTruncatedNormal(_rngB, &x1, 1.0f, 2.0f);
|
RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f);
|
||||||
|
|
||||||
ASSERT_TRUE(x0.equalsTo(&x1));
|
ASSERT_TRUE(x0.equalsTo(&x1));
|
||||||
|
|
||||||
|
@ -383,8 +383,8 @@ TEST_F(RNGTests, Test_Truncated_21) {
|
||||||
auto x0 = NDArrayFactory::create<float>('c', {1000, 1000});
|
auto x0 = NDArrayFactory::create<float>('c', {1000, 1000});
|
||||||
auto x1 = NDArrayFactory::create<float>('c', {1000, 1000});
|
auto x1 = NDArrayFactory::create<float>('c', {1000, 1000});
|
||||||
|
|
||||||
RandomLauncher::fillTruncatedNormal(_rngA, &x0, 1.0f, 2.0f);
|
RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f);
|
||||||
RandomLauncher::fillTruncatedNormal(_rngB, &x1, 1.0f, 2.0f);
|
RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f);
|
||||||
|
|
||||||
ASSERT_TRUE(x0.equalsTo(&x1));
|
ASSERT_TRUE(x0.equalsTo(&x1));
|
||||||
|
|
||||||
|
@ -430,8 +430,8 @@ TEST_F(RNGTests, Test_Truncated_22) {
|
||||||
auto x0 = NDArrayFactory::create<float>('c', {1000, 1000});
|
auto x0 = NDArrayFactory::create<float>('c', {1000, 1000});
|
||||||
auto x1 = NDArrayFactory::create<float>('c', {1000, 1000});
|
auto x1 = NDArrayFactory::create<float>('c', {1000, 1000});
|
||||||
|
|
||||||
RandomLauncher::fillTruncatedNormal(_rngA, &x0, 2.0f, 4.0f);
|
RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 2.0f, 4.0f);
|
||||||
RandomLauncher::fillTruncatedNormal(_rngB, &x1, 2.0f, 4.0f);
|
RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 2.0f, 4.0f);
|
||||||
|
|
||||||
ASSERT_TRUE(x0.equalsTo(&x1));
|
ASSERT_TRUE(x0.equalsTo(&x1));
|
||||||
|
|
||||||
|
@ -477,8 +477,8 @@ TEST_F(RNGTests, Test_Truncated_23) {
|
||||||
auto x0 = NDArrayFactory::create<float>('c', {1000, 1000});
|
auto x0 = NDArrayFactory::create<float>('c', {1000, 1000});
|
||||||
auto x1 = NDArrayFactory::create<float>('c', {1000, 1000});
|
auto x1 = NDArrayFactory::create<float>('c', {1000, 1000});
|
||||||
|
|
||||||
RandomLauncher::fillTruncatedNormal(_rngA, &x0, 0.0f, 1.0f);
|
RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 0.0f, 1.0f);
|
||||||
RandomLauncher::fillTruncatedNormal(_rngB, &x1, 0.0f, 1.0f);
|
RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 0.0f, 1.0f);
|
||||||
|
|
||||||
ASSERT_TRUE(x0.equalsTo(&x1));
|
ASSERT_TRUE(x0.equalsTo(&x1));
|
||||||
|
|
||||||
|
@ -524,8 +524,8 @@ TEST_F(RNGTests, Test_Truncated_3) {
|
||||||
auto x0 = NDArrayFactory::create<float>('c', {10000, 1000});
|
auto x0 = NDArrayFactory::create<float>('c', {10000, 1000});
|
||||||
auto x1 = NDArrayFactory::create<float>('c', {10000, 1000});
|
auto x1 = NDArrayFactory::create<float>('c', {10000, 1000});
|
||||||
|
|
||||||
RandomLauncher::fillTruncatedNormal(_rngA, &x0, 1.0f, 2.0f);
|
RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f);
|
||||||
RandomLauncher::fillTruncatedNormal(_rngB, &x1, 1.0f, 2.0f);
|
RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f);
|
||||||
|
|
||||||
ASSERT_TRUE(x0.equalsTo(&x1));
|
ASSERT_TRUE(x0.equalsTo(&x1));
|
||||||
|
|
||||||
|
@ -964,7 +964,7 @@ TEST_F(RNGTests, Test_Reproducibility_2) {
|
||||||
TEST_F(RNGTests, Test_Uniform_4) {
|
TEST_F(RNGTests, Test_Uniform_4) {
|
||||||
auto x1 = NDArrayFactory::create<double>('c', {1000000});
|
auto x1 = NDArrayFactory::create<double>('c', {1000000});
|
||||||
|
|
||||||
RandomLauncher::fillUniform(_rngB, &x1, 1.0, 2.0);
|
RandomLauncher::fillUniform(LaunchContext::defaultContext(), _rngB, &x1, 1.0, 2.0);
|
||||||
|
|
||||||
/* Check up distribution */
|
/* Check up distribution */
|
||||||
auto mean = x1.reduceNumber(reduce::Mean);
|
auto mean = x1.reduceNumber(reduce::Mean);
|
||||||
|
|
|
@ -69,6 +69,24 @@ TEST_F(SortCudaTests, test_linear_sort_by_val_1) {
|
||||||
ASSERT_EQ(ev, v);
|
ASSERT_EQ(ev, v);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(SortCudaTests, test_linear_sort_by_val_2) {
|
||||||
|
auto k = NDArrayFactory::create<int>('c', {6}, {0, 1, 2, 3, 4, 5});
|
||||||
|
// auto v = NDArrayFactory::create<double>('c', {6}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5});
|
||||||
|
NDArray v = NDArrayFactory::create<double>('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f});
|
||||||
|
auto ek = NDArrayFactory::create<int>('c', {6}, {3, 0, 1, 2, 4, 5});
|
||||||
|
auto ev = NDArrayFactory::create<double>('c', {6}, {0.95, 0.9, 0.75, 0.6, 0.5, 0.3});
|
||||||
|
|
||||||
|
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
|
||||||
|
|
||||||
|
NativeOps nativeOps;
|
||||||
|
nativeOps.sortByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), true);
|
||||||
|
k.tickWriteDevice();
|
||||||
|
v.tickWriteDevice();
|
||||||
|
k.printIndexedBuffer("KEYS");
|
||||||
|
ASSERT_EQ(ek, k);
|
||||||
|
ASSERT_EQ(ev, v);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(SortCudaTests, test_tad_sort_by_key_1) {
|
TEST_F(SortCudaTests, test_tad_sort_by_key_1) {
|
||||||
auto k = NDArrayFactory::create<Nd4jLong>('c', {2, 10}, {1, 3, 5, 9, 0, 2, 4, 6, 7, 8, 1, 3, 5, 9, 0, 2, 4, 6, 7, 8});
|
auto k = NDArrayFactory::create<Nd4jLong>('c', {2, 10}, {1, 3, 5, 9, 0, 2, 4, 6, 7, 8, 1, 3, 5, 9, 0, 2, 4, 6, 7, 8});
|
||||||
auto v = NDArrayFactory::create<double>('c', {2, 10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5, 1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5});
|
auto v = NDArrayFactory::create<double>('c', {2, 10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5, 1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5});
|
||||||
|
|
Loading…
Reference in New Issue