[WIP] more CUDA stuff (#57)

* initial commit

Signed-off-by: raver119 <raver119@gmail.com>

* Added gradcheck test for dynamic_partition_bp op.

* - implementation of dilation op (cpu and cuda)

Signed-off-by: Yurii <yurii@skymind.io>

* Fixed broadcast_dynamic_shape 1D case and tests.

* Fixed usage of default integer arguments.

* Fixed dynamic_partition_bp op and tests.

* Eliminated test with grad check for dynamic_partition_bp op.

* start working on cuda svd - porting available corresponding api from cuSOLVER library

Signed-off-by: Yurii <yurii@skymind.io>

* provide prelu_bp

Signed-off-by: Yurii <yurii@skymind.io>

* - provide gruCell_bp (old version ??)

Signed-off-by: Yurii <yurii@skymind.io>

* - polishing cumsum_bp and cumprod_bp tests

Signed-off-by: Yurii <yurii@skymind.io>

* provide sparseSoftmaxCrossEntropyWithLogits and sparseSoftmaxCrossEntropyWithLogits_grad

Signed-off-by: Yurii <yurii@skymind.io>

* Fixed atomicMul with float input/output

* implementation of cuda kernel for triu_bp operation

Signed-off-by: Yurii <yurii@skymind.io>

* Refactored lup helper to add parrallel computing.

* cusolver libraries

Signed-off-by: raver119 <raver119@gmail.com>

* uncomment cuSolver APIs in svd.cu

Signed-off-by: Yurii <yurii@skymind.io>

* cusolver var

Signed-off-by: raver119 <raver119@gmail.com>

* - further work on cuSolver svd

Signed-off-by: Yurii <yurii@skymind.io>

* Implement usage of cuda solver to LUP decomposition.

* - correct naames in lup functions

Signed-off-by: Yurii <yurii@skymind.io>

* correct svdQR cuda

Signed-off-by: Yurii <yurii@skymind.io>

* - provide transpositions of input matrices in case of c order in svdCudaQR

Signed-off-by: Yurii <yurii@skymind.io>

* Fixed implementation issues with LUP usign cuda solver.

* Implementation of matrix_determinant helper with cuda kernels. Working revision.

* Implemented log_matrix_determinant helper with cuda kernels.

* - implementation of batched cuda svd

Signed-off-by: Yurii <yurii@skymind.io>

* Refactored cholesky helper and implementation of cuda solver cholesky batch.

* - implementation of cuda kernel for tile bp

Signed-off-by: Yurii <yurii@skymind.io>

* Implementation of cholesky and logdet with cuda kernels.

* - implementation of cuda kernel for sru_bidirectional

Signed-off-by: Yurii <yurii@skymind.io>

* Fixed cholesky helper.

* Cholesky op helper implementation. Working double-based cublas implementation.

* bad import excluded

Signed-off-by: raver119 <raver119@gmail.com>

* Finished with cuda implementation of cholesky helper and tests.

* - implementation of cuda kernel for sru_bidirectional_backprop operation

Signed-off-by: Yurii <yurii@skymind.io>

* Implementation of matrix_inverse op helper with cuda kernels. The first revision.

* - start working on gruCell_bp

Signed-off-by: Yurii <yurii@skymind.io>

* Implementation of matrix_inverse helper.

* - further work on new gruCell_bp

Signed-off-by: Yurii <yurii@skymind.io>

* cuBLAS related fixes

Signed-off-by: raver119 <raver119@gmail.com>

* calculateOutputShapes() now passes device buffers as well

Signed-off-by: raver119 <raver119@gmail.com>

* special concat/average/accumulate init host pointers now

Signed-off-by: raver119 <raver119@gmail.com>

* few more tweaks

Signed-off-by: raver119 <raver119@gmail.com>

* additional CudaDataBufferFactory signatures certain for data types

Signed-off-by: raver119 <raver119@gmail.com>

* cuSolver host buffer

Signed-off-by: raver119 <raver119@gmail.com>

* buffer to buffer memcpy host ptr allocation

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-07-12 11:51:51 +03:00 committed by AlexDBlack
parent cb6654bebb
commit c969b724bb
75 changed files with 3716 additions and 1508 deletions

View File

@ -288,7 +288,7 @@ if(CUDA_BLAS)
endif() endif()
target_link_libraries(${LIBND4J_NAME} ${CUDA_LIBRARIES} ${CUDA_CUBLAS_LIBRARIES}) target_link_libraries(${LIBND4J_NAME} ${CUDA_LIBRARIES} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_cusolver_LIBRARY})
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/cuda) set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/cuda)
install(TARGETS ${LIBND4J_NAME} DESTINATION .) install(TARGETS ${LIBND4J_NAME} DESTINATION .)

View File

@ -1143,9 +1143,9 @@ namespace nd4j {
* - down from main diagonal starting at subdiagonal number "lower" if direction = 'd' (down) or 'b' (both) * - down from main diagonal starting at subdiagonal number "lower" if direction = 'd' (down) or 'b' (both)
* - up from main diagonal starting at superdiagonal number "upper"if direction = 'u' (up) or 'b' (both) * - up from main diagonal starting at superdiagonal number "upper"if direction = 'u' (up) or 'b' (both)
* direction - in what direction to fill matrix. There are 3 possible directions: * direction - in what direction to fill matrix. There are 3 possible directions:
* 'u' - fill up, mathematically this corresponds to lower triangular matrix, parameter "lower" is not taken into account * 'u' - fill up, mathematically this corresponds to lower triangular matrix, subdiagonal "lower" unaffected
* 'l' - fill down, mathematically this corresponds to upper triangular matrix, parameter "upper" is not taken into account * 'l' - fill down, mathematically this corresponds to upper triangular matrix, superdiagonal "upper" remains unaffected
* 'b' - fill in both directions, both parameters "lower" and "upper" are taken into account * 'b' - fill in both directions, both "lower" and "upper" are taken into account
* rest of target elements are equal to this array elements * rest of target elements are equal to this array elements
* target and this array should have same shapes, except when this_rank = 1 (in that case should be target_rank = 2) * target and this array should have same shapes, except when this_rank = 1 (in that case should be target_rank = 2)
*/ */

View File

@ -2371,7 +2371,7 @@ void NDArray::tileToShape(const std::vector<Nd4jLong>& shape, NDArray* target) {
if(i > rankOf()) if(i > rankOf())
repeats[newRank-i] = shape[newRank - i]; repeats[newRank-i] = shape[newRank - i];
else else
repeats[newRank-i] = shape[newRank - i] / thisShape[rankOf() - i]; repeats[newRank-i] = shape[newRank - i] / thisShape[rankOf() - i];
} }
tilei(repeats); tilei(repeats);

View File

@ -228,7 +228,7 @@ void* NDArray::getSpecialBuffer() const {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
// change an array by repeating it the number of times given by reps. // change an array by repeating it the number of times given by reps.
NDArray NDArray::tile(const std::vector<Nd4jLong>& reps) const { NDArray NDArray::tile(const std::vector<Nd4jLong>& reps) const {
int dim = reps.size(); const int repsSize = reps.size();
int product = 1; int product = 1;
for(const auto& item : reps) for(const auto& item : reps)
product *= item; product *= item;
@ -236,11 +236,11 @@ NDArray NDArray::tile(const std::vector<Nd4jLong>& reps) const {
throw std::runtime_error("NDArray::tile method: one of the elements in reps array is zero !"); throw std::runtime_error("NDArray::tile method: one of the elements in reps array is zero !");
int rankOld = rankOf(); int rankOld = rankOf();
int diff = rankOld - dim; int diff = rankOld - repsSize;
if(product==1) { // in this case 2 possibilities are present: just reshape or nothing to do if(product==1) { // in this case 2 possibilities are present: just reshape or nothing to do
NDArray result(*this); NDArray result(*this);
if(diff < 0) { // reshape to higher dimension if(diff < 0) { // reshape to higher dimension
std::vector<Nd4jLong> shapeNew = reps; // need to have unities at first "diff" positions of new shape std::vector<Nd4jLong> shapeNew = reps; // there is requirement to have unities at first "diff" positions of new shape
memcpy(&shapeNew[-diff], result.getShapeInfo()+1, rankOld * sizeof(Nd4jLong)); // put old shape numbers at rest of positions memcpy(&shapeNew[-diff], result.getShapeInfo()+1, rankOld * sizeof(Nd4jLong)); // put old shape numbers at rest of positions
result.reshapei(ordering(), shapeNew); result.reshapei(ordering(), shapeNew);
} }

View File

@ -1489,14 +1489,8 @@ void NativeOps::specialConcat(
Nd4jPointer *inputShapeInfo, Nd4jPointer *inputShapeInfo,
void *dZ, void *dZ,
Nd4jLong *dZShapeInfo, Nd4jPointer *tadPointers, Nd4jPointer *offsetPointers) { Nd4jLong *dZShapeInfo, Nd4jPointer *tadPointers, Nd4jPointer *offsetPointers) {
nd4j::SpecialMethods<float>::concatCpuGeneric(
dimension,
numArrays,
data,
inputShapeInfo,
dZ,
dZShapeInfo);
BUILD_SINGLE_SELECTOR(ArrayOptions::dataType(dZShapeInfo), nd4j::SpecialMethods ,::concatCpuGeneric(dimension, numArrays, data, inputShapeInfo, dZ, dZShapeInfo), LIBND4J_TYPES);
} }
@ -2578,8 +2572,9 @@ nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::D
// we shouldn't copy buffer if that's empty array // we shouldn't copy buffer if that's empty array
void *buffer_ = nd4j::ArrayOptions::arrayType(shape_) == ArrayType::EMPTY ? nullptr : inputBuffers[e]; void *buffer_ = nd4j::ArrayOptions::arrayType(shape_) == ArrayType::EMPTY ? nullptr : inputBuffers[e];
void *bufferD_ = nd4j::ArrayOptions::arrayType(shape_) == ArrayType::EMPTY ? nullptr : inputBuffers[e + numInputShapes];
auto array = new nd4j::NDArray(buffer_, shape_); auto array = new nd4j::NDArray(buffer_, bufferD_, shape_);
// block should contain references to proper variable // block should contain references to proper variable
varSpace.putVariable(1, e, array); varSpace.putVariable(1, e, array);

View File

@ -86,6 +86,7 @@ class ND4J_EXPORT LaunchContext {
FORCEINLINE void setCudaStream(cudaStream_t* cudaStream) {_cudaStream = cudaStream;}; FORCEINLINE void setCudaStream(cudaStream_t* cudaStream) {_cudaStream = cudaStream;};
FORCEINLINE void setCudaSpecialStream(cudaStream_t* cudaStream) {_cudaSpecialStream = cudaStream;}; FORCEINLINE void setCudaSpecialStream(cudaStream_t* cudaStream) {_cudaSpecialStream = cudaStream;};
FORCEINLINE void setCublasHandle(void *handle) {_cublasHandle = handle; };
#endif // JCPP #endif // JCPP

View File

@ -454,6 +454,9 @@ namespace nd4j {
#ifdef __CUDABLAS__ #ifdef __CUDABLAS__
_context = new LaunchContext(cudaStream, reductionPointer, allocationPointer); _context = new LaunchContext(cudaStream, reductionPointer, allocationPointer);
// FIXME: either pass handle from outside, or make sure outside we use the same handle
_context->setCublasHandle(LaunchContext::defaultContext()->getCublasHandle());
for (auto v: _fastpath_out) for (auto v: _fastpath_out)
v->setContext(_context); v->setContext(_context);

View File

@ -22,7 +22,6 @@
#include <cublas_v2.h> #include <cublas_v2.h>
#include "../MmulHelper.h" #include "../MmulHelper.h"
#include <specials_cuda.h> #include <specials_cuda.h>
#include <helpers/PointersManager.h>
namespace nd4j { namespace nd4j {

View File

@ -552,7 +552,7 @@ std::vector<int> ShapeUtils::getDimsWithSameShape(const NDArray& max, const NDAr
// evaluate shapeInfo for resulting array from tile operation // evaluate shapeInfo for resulting array from tile operation
Nd4jLong* ShapeUtils::evalTileShapeInfo(const NDArray& arr, const std::vector<Nd4jLong>& reps, nd4j::memory::Workspace* workspace) { Nd4jLong* ShapeUtils::evalTileShapeInfo(const NDArray& arr, const std::vector<Nd4jLong>& reps, nd4j::memory::Workspace* workspace) {
// check whether reps contains at least one zero (then throw exception) or whether all elements in reps are unities (then simply reshape or do nothing) // check whether reps contains at least one zero (then throw exception) or whether all elements in reps are unities (then simply reshape or do nothing)
int dim = reps.size(); int repsSize = reps.size();
int product = 1; int product = 1;
for(const auto& item : reps) for(const auto& item : reps)
product *= item; product *= item;
@ -560,24 +560,24 @@ Nd4jLong* ShapeUtils::evalTileShapeInfo(const NDArray& arr, const std::vector<Nd
throw std::runtime_error("NDArray::tile method: one of the elements in reps array is zero !"); throw std::runtime_error("NDArray::tile method: one of the elements in reps array is zero !");
int rankOld = arr.rankOf(); int rankOld = arr.rankOf();
int diff = rankOld - dim; int diff = rankOld - repsSize;
// evaluate new shapeInfo // evaluate new shapeInfo
Nd4jLong* newShapeInfo = nullptr; Nd4jLong* newShapeInfo = nullptr;
if(diff < 0) { if(diff < 0) {
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(dim), Nd4jLong); ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(repsSize), Nd4jLong);
newShapeInfo[0] = dim; // set new rank newShapeInfo[0] = repsSize; // set new rank
for(int i=1; i <= -diff; ++i) for(int i=1; i <= -diff; ++i)
newShapeInfo[i] = 1; // set unities to be new dimensions at left-hand side of newShapeInfo shape place newShapeInfo[i] = 1; // set unities to be new dimensions at left-hand side of newShapeInfo shape place
memcpy(newShapeInfo + 1 - diff, arr.getShapeInfo() + 1, rankOld*sizeof(Nd4jLong)); // copy old dimensions to the right-hand side of newShapeInfo shape place memcpy(newShapeInfo + 1 - diff, arr.getShapeInfo() + 1, rankOld*sizeof(Nd4jLong)); // copy old dimensions to the right-hand side of newShapeInfo shape place
for(int i=1; i <= dim; ++i) for(int i=1; i <= repsSize; ++i)
newShapeInfo[i] *= reps[i - 1]; // set new shape by multiplying old dimensions by corresponding numbers from reps newShapeInfo[i] *= reps[i - 1]; // set new shape by multiplying old dimensions by corresponding numbers from reps
} }
else { else {
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(rankOld), Nd4jLong); ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(rankOld), Nd4jLong);
memcpy(newShapeInfo, arr.getShapeInfo(), shape::shapeInfoByteLength(rankOld)); // copy all elements of _shapeInfo to newShapeInfo memcpy(newShapeInfo, arr.getShapeInfo(), shape::shapeInfoByteLength(rankOld)); // copy all elements of _shapeInfo to newShapeInfo
for(int i=1; i <= dim; ++i) for(int i=1; i <= repsSize; ++i)
newShapeInfo[rankOld + 1 - i] *= reps[dim - i]; // set new shape by multiplying old dimensions by corresponding numbers from reps newShapeInfo[rankOld + 1 - i] *= reps[repsSize - i]; // set new shape by multiplying old dimensions by corresponding numbers from reps
} }
shape::updateStrides(newShapeInfo, arr.ordering()); shape::updateStrides(newShapeInfo, arr.ordering());
ArrayOptions::setDataType(newShapeInfo, arr.dataType()); ArrayOptions::setDataType(newShapeInfo, arr.dataType());

View File

@ -889,7 +889,9 @@ namespace shape {
* @param indices the indices to iterate over * @param indices the indices to iterate over
* @return the double at the specified index * @return the double at the specified index
*/ */
ND4J_EXPORT _CUDA_HD Nd4jLong getOffset(Nd4jLong baseOffset, const Nd4jLong *shape, const Nd4jLong *stride, const Nd4jLong *indices,int rank); ND4J_EXPORT _CUDA_HD Nd4jLong getOffset(Nd4jLong baseOffset, const Nd4jLong *shape, const Nd4jLong *stride, const Nd4jLong *indices, const int rank);
ND4J_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const Nd4jLong *indices, Nd4jLong baseOffset = 0);
ND4J_EXPORT Nd4jLong getOffset(const Nd4jLong *shapeInfo, const std::vector<uint>& indices);
ND4J_EXPORT _CUDA_HD Nd4jLong* createShapeInfo(Nd4jLong *shape, Nd4jLong *stride, int rank); ND4J_EXPORT _CUDA_HD Nd4jLong* createShapeInfo(Nd4jLong *shape, Nd4jLong *stride, int rank);
@ -987,7 +989,8 @@ namespace shape {
// calculate offsets of max-array, these output offsets correspond to one minIdx index of min-array which is sub-array of max-array // calculate offsets of max-array, these output offsets correspond to one minIdx index of min-array which is sub-array of max-array
// dimsToExclude - should be sorted in increasing order // dimsToExclude - should be sorted in increasing order
ND4J_EXPORT _CUDA_HD int outerArrayOffsets(Nd4jLong* maxOffsets, const Nd4jLong minIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, const int* dimsToExclude = nullptr); // memBuff - auxiliary memory buffer (size = 2 * max_rank) for coordinates and increments storing, should be passed from outside
ND4J_EXPORT _CUDA_HD int outerArrayOffsets(Nd4jLong* maxOffsets, const Nd4jLong minIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, Nd4jLong* memBuff, const int* dimsToExclude = nullptr);
// calculates offsets for entities (elements or sub-arrays), shape in context of sub-array means dimensions excluded from outer array // calculates offsets for entities (elements or sub-arrays), shape in context of sub-array means dimensions excluded from outer array
// rank is equal to size of shape // rank is equal to size of shape
@ -3200,7 +3203,7 @@ INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, cons
* @param indices the indices to iterate over * @param indices the indices to iterate over
* @return the double at the specified index * @return the double at the specified index
*/ */
INLINEDEF _CUDA_HD Nd4jLong getOffset(Nd4jLong baseOffset, const Nd4jLong *shape, const Nd4jLong *stride, const Nd4jLong *indices, int rank) { INLINEDEF _CUDA_HD Nd4jLong getOffset(Nd4jLong baseOffset, const Nd4jLong *shape, const Nd4jLong *stride, const Nd4jLong *indices, const int rank) {
Nd4jLong offset = baseOffset; Nd4jLong offset = baseOffset;
for(int i = 0; i < rank; i++) { for(int i = 0; i < rank; i++) {
if(shape[i] != 1) if(shape[i] != 1)
@ -3210,6 +3213,21 @@ INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, cons
return offset; return offset;
} }
INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const Nd4jLong *indices, Nd4jLong baseOffset) {
return shape::getOffset(baseOffset, shape::shapeOf(const_cast<Nd4jLong*>(shapeInfo)), shape::stride(const_cast<Nd4jLong*>(shapeInfo)), indices, shapeInfo[0]);
}
INLINEDEF Nd4jLong getOffset(const Nd4jLong *shapeInfo, const std::vector<uint>& indices) {
Nd4jLong offset = 0;
for(uint i = 0; i < shapeInfo[0]; ++i)
if(shapeInfo[i + 1] != 1)
offset += indices[i] * shapeInfo[shapeInfo[0] + i + 1];
return offset;
}
@ -4226,21 +4244,18 @@ INLINEDEF _CUDA_HD void maxIndToMinInd(Nd4jLong* maxIdxs, Nd4jLong* minIdxs, con
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
INLINEDEF _CUDA_HD int outerArrayOffsets(Nd4jLong* maxOffsets, const Nd4jLong minIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, const int* dimsToExclude) { INLINEDEF _CUDA_HD int outerArrayOffsets(Nd4jLong* maxOffsets, const Nd4jLong minIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, Nd4jLong* memBuff, const int* dimsToExclude) {
const auto rankMin = shape::rank(minShapeInfo); const auto rankMin = shape::rank(minShapeInfo);
const auto rankMax = shape::rank(maxShapeInfo); const auto rankMax = shape::rank(maxShapeInfo);
// if(rankMin >= rankMax) // if(rankMin >= rankMax)
// throw std::runtime_error("shape::subArrayIndex method: rank of min array should be smaller then rank of max array!"); // throw std::runtime_error("shape::subArrayIndex method: rank of min array should be smaller then rank of max array!");
// if(rankMax > MAX_RANK/2)
// throw std::runtime_error("shape::subArrayIndex method: rank of max array should be <= MAX_RANK/2 !");
const auto diff = rankMax - rankMin; // the size of dimsToExclude is equal to diff const auto diff = rankMax - rankMin; // the size of dimsToExclude is equal to diff
Nd4jLong buffer[MAX_RANK]; Nd4jLong* indices = memBuff;
Nd4jLong* indices = buffer; Nd4jLong* increment = memBuff + rankMax;
Nd4jLong* increment = buffer + MAX_RANK/2;
int N, minI, maxI; int N, minI, maxI;

View File

@ -92,7 +92,6 @@ __global__ void execOesTadKernelKey(void *vx, Nd4jLong *xShapeInfo,
} }
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template<typename T> template<typename T>
__global__ void execOesTadKernel(void *vx, Nd4jLong *xShapeInfo, __global__ void execOesTadKernel(void *vx, Nd4jLong *xShapeInfo,

View File

@ -39,7 +39,6 @@ CONFIGURABLE_OP_IMPL(prelu, 2, 1, true, 0, 0) {
std::vector<int> sharedAxes = *block.getIArguments(); std::vector<int> sharedAxes = *block.getIArguments();
const int inputRank = input->rankOf(); const int inputRank = input->rankOf();
const int alphaRank = alpha->rankOf();
const int numSharedAxes = sharedAxes.size(); // can be zero as well const int numSharedAxes = sharedAxes.size(); // can be zero as well
const Nd4jLong inputLen = input->lengthOf(); const Nd4jLong inputLen = input->lengthOf();
const Nd4jLong alphaLen = alpha->lengthOf(); const Nd4jLong alphaLen = alpha->lengthOf();
@ -91,7 +90,6 @@ CONFIGURABLE_OP_IMPL(prelu_bp, 3, 2, true, 0, 0) {
std::vector<int> sharedAxes = *block.getIArguments(); std::vector<int> sharedAxes = *block.getIArguments();
const int inputRank = input->rankOf(); const int inputRank = input->rankOf();
const int alphaRank = alpha->rankOf();
const int numSharedAxes = sharedAxes.size(); // can be zero as well const int numSharedAxes = sharedAxes.size(); // can be zero as well
const Nd4jLong inputLen = input->lengthOf(); const Nd4jLong inputLen = input->lengthOf();
const Nd4jLong alphaLen = alpha->lengthOf(); const Nd4jLong alphaLen = alpha->lengthOf();

View File

@ -29,18 +29,18 @@ namespace ops {
CUSTOM_OP_IMPL(svd, 1, 1, false, 0, 3) { CUSTOM_OP_IMPL(svd, 1, 1, false, 0, 3) {
auto x = INPUT_VARIABLE(0); auto x = INPUT_VARIABLE(0);
const int rank = x->rankOf(); const int rank = x->rankOf();
REQUIRE_TRUE(rank >= 2 , 0, "SVD OP: the rank of input array must be >=2, but got %i instead!", rank); REQUIRE_TRUE(rank >= 2 , 0, "SVD OP: the rank of input array must be >=2, but got %i instead!", rank);
const bool fullUV = (bool)INT_ARG(0); const bool fullUV = (bool)INT_ARG(0);
const bool calcUV = (bool)INT_ARG(1); const bool calcUV = (bool)INT_ARG(1);
const int switchNum = INT_ARG(2); const int switchNum = INT_ARG(2);
#ifndef __CUDABLAS__ // #ifndef __CUDABLAS__
helpers::svd(block.launchContext(), x, {OUTPUT_VARIABLE(0), calcUV ? OUTPUT_VARIABLE(1) : nullptr, calcUV ? OUTPUT_VARIABLE(2) : nullptr}, fullUV, calcUV, switchNum); helpers::svd(block.launchContext(), x, {OUTPUT_VARIABLE(0), calcUV ? OUTPUT_VARIABLE(1) : nullptr, calcUV ? OUTPUT_VARIABLE(2) : nullptr}, fullUV, calcUV, switchNum);
#endif // #endif
return Status::OK();; return Status::OK();;
} }
@ -56,28 +56,28 @@ DECLARE_SHAPE_FN(svd) {
auto inShapeInfo = inputShape->at(0); auto inShapeInfo = inputShape->at(0);
bool fullUV = (bool)INT_ARG(0); bool fullUV = (bool)INT_ARG(0);
bool calcUV = (bool)INT_ARG(1); bool calcUV = (bool)INT_ARG(1);
const int rank = inShapeInfo[0]; const int rank = inShapeInfo[0];
REQUIRE_TRUE(rank >= 2 , 0, "SVD OP: the rank of input array must be >=2, but got %i instead!", rank); REQUIRE_TRUE(rank >= 2 , 0, "SVD OP: the rank of input array must be >=2, but got %i instead!", rank);
const int diagSize = inShapeInfo[rank] < inShapeInfo[rank-1] ? inShapeInfo[rank] : inShapeInfo[rank-1]; const int diagSize = inShapeInfo[rank] < inShapeInfo[rank-1] ? inShapeInfo[rank] : inShapeInfo[rank-1];
Nd4jLong* sShapeInfo(nullptr); Nd4jLong* sShapeInfo(nullptr);
if(rank == 2) { if(rank == 2) {
ALLOCATE(sShapeInfo, block.getWorkspace(), shape::shapeInfoLength(1), Nd4jLong); ALLOCATE(sShapeInfo, block.getWorkspace(), shape::shapeInfoLength(1), Nd4jLong);
sShapeInfo[0] = 1; sShapeInfo[0] = 1;
sShapeInfo[1] = diagSize; sShapeInfo[1] = diagSize;
} }
else { else {
ALLOCATE(sShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank-1), Nd4jLong); ALLOCATE(sShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank-1), Nd4jLong);
sShapeInfo[0] = rank - 1; sShapeInfo[0] = rank - 1;
for(int i=1; i <= rank-2; ++i) for(int i=1; i <= rank-2; ++i)
sShapeInfo[i] = inShapeInfo[i]; sShapeInfo[i] = inShapeInfo[i];
sShapeInfo[rank-1] = diagSize; sShapeInfo[rank-1] = diagSize;
} }
ShapeUtils::updateStridesAndType(sShapeInfo, inShapeInfo, shape::order(inShapeInfo)); ShapeUtils::updateStridesAndType(sShapeInfo, inShapeInfo, shape::order(inShapeInfo));
if(calcUV){ if(calcUV){
Nd4jLong *uShapeInfo(nullptr), *vShapeInfo(nullptr); Nd4jLong *uShapeInfo(nullptr), *vShapeInfo(nullptr);
@ -93,10 +93,10 @@ DECLARE_SHAPE_FN(svd) {
vShapeInfo[rank-1] = vShapeInfo[rank]; vShapeInfo[rank-1] = vShapeInfo[rank];
vShapeInfo[rank] = diagSize; vShapeInfo[rank] = diagSize;
} }
shape::updateStrides(uShapeInfo, shape::order(inShapeInfo)); shape::updateStrides(uShapeInfo, shape::order(inShapeInfo));
shape::updateStrides(vShapeInfo, shape::order(inShapeInfo)); shape::updateStrides(vShapeInfo, shape::order(inShapeInfo));
auto result = SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(sShapeInfo)), ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(uShapeInfo)), ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(vShapeInfo))); auto result = SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(sShapeInfo)), ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(uShapeInfo)), ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(vShapeInfo)));
RELEASE(sShapeInfo, block.workspace()); RELEASE(sShapeInfo, block.workspace());
RELEASE(uShapeInfo, block.workspace()); RELEASE(uShapeInfo, block.workspace());

View File

@ -35,8 +35,8 @@ namespace ops {
REQUIRE_TRUE(input->rankOf() == 4, 0, "Dilation2D: input should be 4D"); REQUIRE_TRUE(input->rankOf() == 4, 0, "Dilation2D: input should be 4D");
REQUIRE_TRUE(weights->rankOf() == 3, 0, "Dilation2D: weights should be 3D"); REQUIRE_TRUE(weights->rankOf() == 3, 0, "Dilation2D: weights should be 3D");
const int batch_size = input->sizeAt(0); const int bS = input->sizeAt(0);
const int depth = input->sizeAt(3); const int iC = input->sizeAt(3);
const bool isSameShape = INT_ARG(0) == 1; const bool isSameShape = INT_ARG(0) == 1;
REQUIRE_TRUE(input->sizeAt(3) == weights->sizeAt(2), 0, "Dilation2D: number of input channels doesn't match number of channels in weights: %i vs %i", input->sizeAt(3), weights->sizeAt(2)); REQUIRE_TRUE(input->sizeAt(3) == weights->sizeAt(2), 0, "Dilation2D: number of input channels doesn't match number of channels in weights: %i vs %i", input->sizeAt(3), weights->sizeAt(2));
@ -66,17 +66,17 @@ namespace ops {
} }
int stride_rows = 0, stride_cols = 0; int sH = 0, sW = 0;
int rate_rows = 0, rate_cols = 0; int dH = 0, dW = 0;
int pad_top = 0, pad_left = 0; int pH = 0, pW = 0;
int out_rows = 0, out_cols = 0; int oH = 0, oW = 0;
helpers::dilation_hw(block.launchContext(), input->shapeInfo(), weights->shapeInfo(), strides, rates, isSameShape, &stride_rows, &stride_cols, &rate_rows, &rate_cols, &pad_top, &pad_left, &out_rows, &out_cols); helpers::dilation_hw(block.launchContext(), input->shapeInfo(), weights->shapeInfo(), strides, rates, isSameShape, &sH, &sW, &pH, &pW, &dH, &dW, &oH, &oW);
REQUIRE_TRUE(out_rows > 0 && out_cols > 0, 0, "Dilation2D: outY and outX should have positive values, but got [%i, %i] instead", out_rows, out_cols); REQUIRE_TRUE(oH > 0 && oW > 0, 0, "Dilation2D: outY and outX should have positive values, but got [%i, %i] instead", oH, oW);
helpers::dilation2d(block.launchContext(), input, weights, output, stride_rows, stride_cols, rate_rows, rate_cols, pad_top, pad_left); helpers::dilation2d(block.launchContext(), input, weights, output, sH, sW, pH, pW, dH, dW);
return Status::OK(); return Status::OK();
} }
@ -91,8 +91,8 @@ namespace ops {
auto input = inputShape->at(0); auto input = inputShape->at(0);
auto weights = inputShape->at(1); auto weights = inputShape->at(1);
const int batch_size = shape::sizeAt(input, 0); const int bS = shape::sizeAt(input, 0);
const int depth = shape::sizeAt(input, 3); const int iC = shape::sizeAt(input, 3);
const bool isSameShape = INT_ARG(0) == 1; const bool isSameShape = INT_ARG(0) == 1;
std::vector<int> strides(4); std::vector<int> strides(4);
@ -121,14 +121,14 @@ namespace ops {
strides[cnt] = INT_ARG(e++); strides[cnt] = INT_ARG(e++);
} }
int stride_rows = 0, stride_cols = 0; int sH = 0, sW = 0;
int rate_rows = 0, rate_cols = 0; int dH = 0, dW = 0;
int pad_top = 0, pad_left = 0; int pH = 0, pW = 0;
int out_rows = 0, out_cols = 0; int oH = 0, oW = 0;
helpers::dilation_hw(block.launchContext(), input, weights, strides, rates, isSameShape, &stride_rows, &stride_cols, &rate_rows, &rate_cols, &pad_top, &pad_left, &out_rows, &out_cols); helpers::dilation_hw(block.launchContext(), input, weights, strides, rates, isSameShape, &sH, &sW, &pH, &pW, &dH, &dW, &oH, &oW);
std::array<Nd4jLong, 4> shape = {{batch_size, out_rows, out_cols, depth}}; std::array<Nd4jLong, 4> shape = {{bS, oH, oW, iC}};
newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(weights), 'c', 4, shape.data()); newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(weights), 'c', 4, shape.data());
return SHAPELIST(newShape); return SHAPELIST(newShape);
} }

View File

@ -37,32 +37,32 @@ CUSTOM_OP_IMPL(sparse_softmax_cross_entropy_loss_with_logits, 2, 1, false, 0, 0)
const int labelsRank = labels->rankOf(); const int labelsRank = labels->rankOf();
const int logitsRank = logits->rankOf(); const int logitsRank = logits->rankOf();
// input validation // input validation
REQUIRE_TRUE(labelsRank == logitsRank - 1, 0, "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: input arrays should satisfy relation (labels_rank = logits_rank - 1), but got labels_rank = %i and logits_rank = %i instead !", labelsRank, logitsRank); REQUIRE_TRUE(labelsRank == logitsRank - 1, 0, "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: input arrays should satisfy relation (labels_rank = logits_rank - 1), but got labels_rank = %i and logits_rank = %i instead !", labelsRank, logitsRank);
std::vector<Nd4jLong> labelsShape = labels->getShapeAsVector(); // this is correct std::vector<Nd4jLong> labelsShape = labels->getShapeAsVector(); // this is correct
std::vector<Nd4jLong> logitsShape = logits->getShapeAsVector(); std::vector<Nd4jLong> logitsShape = logits->getShapeAsVector();
logitsShape.pop_back(); logitsShape.pop_back();
bool equalSoft = logitsShape == labelsShape; bool equalSoft = logitsShape == labelsShape;
REQUIRE_TRUE(equalSoft, 0, "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: wrong shape of labels array, its shape should be the same as logits shape with last dimension excluded, however got labels_shape = %s and logits_shape = %s instead !", ShapeUtils::shapeAsString(labelsShape).c_str(), ShapeUtils::shapeAsString(logitsShape).c_str()); REQUIRE_TRUE(equalSoft, 0, "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: wrong shape of labels array, its shape should be the same as logits shape with last dimension excluded, however got labels_shape = %s and logits_shape = %s instead !", ShapeUtils::shapeAsString(labelsShape).c_str(), ShapeUtils::shapeAsString(logitsShape).c_str());
std::vector<int> dimension = {-1}; std::vector<int> dimension = {-1};
auto maxAlongDim = logits->reduceAlongDims(reduce::Max, dimension, true); auto maxAlongDim = logits->reduceAlongDims(reduce::Max, dimension, true);
auto logitsExp = (*logits - maxAlongDim).transform(transform::Exp, nullptr); auto logitsExp = (*logits - maxAlongDim).transform(transform::Exp, nullptr);
auto logSoftMax = ( logitsExp / logitsExp.reduceAlongDims(reduce::Sum, dimension, true) ).transform(transform::Log); auto logSoftMax = -(( logitsExp / logitsExp.reduceAlongDims(reduce::Sum, dimension, true) ).transform(transform::Log));
helpers::scatterForLoss(block.launchContext(), *labels, -logSoftMax, *output, false); helpers::scatterForLoss(block.launchContext(), *labels, logSoftMax, *output, false);
return Status::OK(); return Status::OK();
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
DECLARE_TYPES(sparse_softmax_cross_entropy_loss_with_logits) { DECLARE_TYPES(sparse_softmax_cross_entropy_loss_with_logits) {
getOpDescriptor()->setAllowedInputTypes(0, {ALL_INTS})->setAllowedInputTypes(1, {ALL_FLOATS})->setAllowedOutputTypes({ALL_FLOATS}); getOpDescriptor()->setAllowedInputTypes(0, {ALL_INTS})->setAllowedInputTypes(1, {ALL_FLOATS})->setAllowedOutputTypes({ALL_FLOATS});
} }
@ -79,8 +79,8 @@ DECLARE_SHAPE_FN(sparse_softmax_cross_entropy_loss_with_logits) {
if (labelsShapeInfo[i] != logitsShapeInfo[i]) { if (labelsShapeInfo[i] != logitsShapeInfo[i]) {
equalSoft = false; equalSoft = false;
break; break;
} }
REQUIRE_TRUE(equalSoft, 0, "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: wrong shape of labels array, its shape should be the same as logits shape with last dimension excluded, however got labels_shape = %s and logits_shape = %s instead !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); REQUIRE_TRUE(equalSoft, 0, "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: wrong shape of labels array, its shape should be the same as logits shape with last dimension excluded, however got labels_shape = %s and logits_shape = %s instead !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str());
auto outShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, logitsShapeInfo, false, block.getWorkspace()); auto outShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, logitsShapeInfo, false, block.getWorkspace());
@ -96,7 +96,7 @@ DECLARE_SHAPE_FN(sparse_softmax_cross_entropy_loss_with_logits) {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(sparse_softmax_cross_entropy_loss_with_logits_grad, 2, 1, false, 0, 0) { CUSTOM_OP_IMPL(sparse_softmax_cross_entropy_loss_with_logits_grad, 2, 1, false, 0, 0) {
auto labels = INPUT_VARIABLE(0); auto labels = INPUT_VARIABLE(0);
auto logits = INPUT_VARIABLE(1); auto logits = INPUT_VARIABLE(1);
@ -104,15 +104,15 @@ CUSTOM_OP_IMPL(sparse_softmax_cross_entropy_loss_with_logits_grad, 2, 1, false,
const int labelsRank = labels->rankOf(); const int labelsRank = labels->rankOf();
const int logitsRank = logits->rankOf(); const int logitsRank = logits->rankOf();
// input validation // input validation
REQUIRE_TRUE(labelsRank == logitsRank - 1, 0, "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: input arrays should satisfy relation (labels_rank = logits_rank - 1), but got labels_rank = %i and logits_rank = %i instead !", labelsRank, logitsRank); REQUIRE_TRUE(labelsRank == logitsRank - 1, 0, "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: input arrays should satisfy relation (labels_rank = logits_rank - 1), but got labels_rank = %i and logits_rank = %i instead !", labelsRank, logitsRank);
std::vector<Nd4jLong> labelsShape = labels->getShapeAsVector(); // this is correct std::vector<Nd4jLong> labelsShape = labels->getShapeAsVector(); // this is correct
std::vector<Nd4jLong> logitsShape = logits->getShapeAsVector(); std::vector<Nd4jLong> logitsShape = logits->getShapeAsVector();
logitsShape.pop_back(); logitsShape.pop_back();
bool equalSoft = logitsShape == labelsShape; bool equalSoft = logitsShape == labelsShape;
REQUIRE_TRUE(equalSoft, 0, "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: wrong shape of labels array, its shape should be the same as logits shape with last dimension excluded, however got labels_shape = %s and logits_shape = %s instead !", ShapeUtils::shapeAsString(labelsShape).c_str(), ShapeUtils::shapeAsString(logitsShape).c_str()); REQUIRE_TRUE(equalSoft, 0, "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: wrong shape of labels array, its shape should be the same as logits shape with last dimension excluded, however got labels_shape = %s and logits_shape = %s instead !", ShapeUtils::shapeAsString(labelsShape).c_str(), ShapeUtils::shapeAsString(logitsShape).c_str());
std::vector<int> dimension = {-1}; std::vector<int> dimension = {-1};
@ -123,7 +123,7 @@ CUSTOM_OP_IMPL(sparse_softmax_cross_entropy_loss_with_logits_grad, 2, 1, false,
// dEdp = softmax - 1 (or 0) // dEdp = softmax - 1 (or 0)
dLdp->assign(softmax); dLdp->assign(softmax);
// subtract unities at appropriate indexes of dLdp array // subtract unities at appropriate indexes of dLdp array
helpers::scatterForLoss(block.launchContext(), *labels, *dLdp, *labels /*actually third array is unnecessary for gradient calculation*/, true); helpers::scatterForLoss(block.launchContext(), *labels, *dLdp, *labels /*actually third array is unnecessary for gradient calculation*/, true);
return Status::OK(); return Status::OK();
@ -131,8 +131,8 @@ CUSTOM_OP_IMPL(sparse_softmax_cross_entropy_loss_with_logits_grad, 2, 1, false,
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
DECLARE_TYPES(sparse_softmax_cross_entropy_loss_with_logits_grad) { DECLARE_TYPES(sparse_softmax_cross_entropy_loss_with_logits_grad) {
getOpDescriptor()->setAllowedInputTypes(0, {ALL_INTS})->setAllowedInputTypes(1, {ALL_FLOATS})->setAllowedOutputTypes({ALL_FLOATS}); getOpDescriptor()->setAllowedInputTypes(0, {ALL_INTS})->setAllowedInputTypes(1, {ALL_FLOATS})->setAllowedOutputTypes({ALL_FLOATS});
} }
@ -149,14 +149,14 @@ DECLARE_SHAPE_FN(sparse_softmax_cross_entropy_loss_with_logits_grad) {
if (labelsShapeInfo[i] != logitsShapeInfo[i]) { if (labelsShapeInfo[i] != logitsShapeInfo[i]) {
equalSoft = false; equalSoft = false;
break; break;
} }
REQUIRE_TRUE(equalSoft, 0, "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: wrong shape of labels array, its shape should be the same as logits shape with last dimension excluded, however got labels_shape = %s and logits_shape = %s instead !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); REQUIRE_TRUE(equalSoft, 0, "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: wrong shape of labels array, its shape should be the same as logits shape with last dimension excluded, however got labels_shape = %s and logits_shape = %s instead !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str());
DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo));
Nd4jLong *dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(logitsShapeInfo, outType, false, block.getWorkspace());
Nd4jLong *dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(logitsShapeInfo, outType, false, block.getWorkspace());
return SHAPELIST(CONSTANT(dLdpShapeInfo)); return SHAPELIST(CONSTANT(dLdpShapeInfo));
} }

View File

@ -113,7 +113,7 @@ namespace ops {
originalIndices.linspace(0); originalIndices.linspace(0);
ops::dynamic_partition op; ops::dynamic_partition op;
auto res = op.execute({&originalIndices, indices}, {}, {numPartition}); auto res = op.execute({&originalIndices, indices}, {}, {numPartition});
REQUIRE_OK(res->status()); REQUIRE_TRUE(res->status() == ND4J_STATUS_OK, 0, "dynamic_partition_bp: Error with dynamic partitioning.");
ops::dynamic_stitch stichOp; ops::dynamic_stitch stichOp;
std::vector<NDArray*> partitions(numPartition * 2); std::vector<NDArray*> partitions(numPartition * 2);
for (size_t i = 0; i < res->size(); i++) { for (size_t i = 0; i < res->size(); i++) {
@ -122,7 +122,8 @@ namespace ops {
} }
auto result = stichOp.execute(partitions, {}, {numPartition}, {}, false); auto result = stichOp.execute(partitions, {}, {numPartition}, {}, false);
REQUIRE_OK(result->status()); REQUIRE_TRUE(result->status() == ND4J_STATUS_OK, 0, "dynamic_partition_bp: Error with dynamic partitioning.");
result->at(0)->reshapei(outputList[0]->getShapeAsVector());
outputList[1]->assign(indices); outputList[1]->assign(indices);
outputList[0]->assign(result->at(0)); outputList[0]->assign(result->at(0));

View File

@ -46,8 +46,11 @@ namespace nd4j {
max = &m2; max = &m2;
} }
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
int numBits = INT_ARG(0); int numBits = 8;
bool narrowed = INT_ARG(1); if (block.getIArguments() && block.getIArguments()->size())
numBits = INT_ARG(0);
bool narrowed = false;
//INT_ARG(1);
if (block.getIArguments()->size() == 2) { if (block.getIArguments()->size() == 2) {
numBits = INT_ARG(0); numBits = INT_ARG(0);
narrowed = INT_ARG(1); narrowed = INT_ARG(1);

View File

@ -44,10 +44,10 @@ namespace nd4j {
if (targetRank == 0) { // scalar only if (targetRank == 0) { // scalar only
determinantShape = ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inShape)); determinantShape = ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inShape));
} }
else if (targetRank == 1) { // vector else if (targetRank == 1) { // vector
determinantShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(shape::sizeAt(inShape, 0), ArrayOptions::dataType(inShape)); determinantShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(shape::sizeAt(inShape, 0), ArrayOptions::dataType(inShape));
} }
else { // only two last dimensions are excluded else { // only two last dimensions are excluded
determinantShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), shape::order(inShape), targetRank, shape::shapeOf(inShape)); determinantShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), shape::order(inShape), targetRank, shape::shapeOf(inShape));
} }
return SHAPELIST(determinantShape); return SHAPELIST(determinantShape);
@ -79,7 +79,7 @@ namespace nd4j {
REQUIRE_TRUE(input->rankOf() >=2, 0, "log_matrix_determinant: The rank of input array should not less than 2, but %i is given", input->rankOf()); REQUIRE_TRUE(input->rankOf() >=2, 0, "log_matrix_determinant: The rank of input array should not less than 2, but %i is given", input->rankOf());
REQUIRE_TRUE(input->sizeAt(-1) == input->sizeAt(-2), 0, "log_matrix_determinant: The last two dimmensions should be equal, but %i and %i are given", input->sizeAt(-1), input->sizeAt(-2)); REQUIRE_TRUE(input->sizeAt(-1) == input->sizeAt(-2), 0, "log_matrix_determinant: The last two dimmensions should be equal, but %i and %i are given", input->sizeAt(-1), input->sizeAt(-2));
return helpers::log_abs_determinant(block.launchContext(), input, output); return helpers::logAbsDeterminant(block.launchContext(), input, output);
} }
DECLARE_SHAPE_FN(log_matrix_determinant) { DECLARE_SHAPE_FN(log_matrix_determinant) {
@ -91,7 +91,7 @@ namespace nd4j {
if (targetRank == 0) { // scalar only if (targetRank == 0) { // scalar only
determinantShape = ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inShape)); determinantShape = ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inShape));
} }
else if (targetRank == 1) { // vector else if (targetRank == 1) { // vector
determinantShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(shape::sizeAt(inShape, 0), ArrayOptions::dataType(inShape)); determinantShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(shape::sizeAt(inShape, 0), ArrayOptions::dataType(inShape));
} }
else { // only two last dimensions are excluded else { // only two last dimensions are excluded
@ -132,7 +132,7 @@ namespace nd4j {
if (targetRank == 0) { // scalar only if (targetRank == 0) { // scalar only
determinantShape = ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inShape)); determinantShape = ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inShape));
} }
else if (targetRank == 1) { // vector else if (targetRank == 1) { // vector
determinantShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(shape::sizeAt(inShape, 0), ArrayOptions::dataType(inShape)); determinantShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(shape::sizeAt(inShape, 0), ArrayOptions::dataType(inShape));
} }
else { // only two last dimensions are excluded else { // only two last dimensions are excluded

View File

@ -31,36 +31,36 @@ namespace ops {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(gruCell, 6, 4, false, 0, 0) { CUSTOM_OP_IMPL(gruCell, 6, 4, false, 0, 0) {
auto x = INPUT_VARIABLE(0); // input [bS x inSize] auto x = INPUT_VARIABLE(0); // input [bS, nIn], nIn - input size
auto hLast = INPUT_VARIABLE(1); // previous cell output [bS x numUnits], that is at previous time step t-1 auto hLast = INPUT_VARIABLE(1); // previous cell output [bS, nU], that is at previous time step t-1, nU - number of units
auto Wru = INPUT_VARIABLE(2); // RU weights - [(nIn+nOut), 2*numUnits] - reset and update gates (input/recurrent weights) auto Wru = INPUT_VARIABLE(2); // RU weights - [nIn+nU, 2*nU] - reset and update gates (input/recurrent weights)
auto Wc = INPUT_VARIABLE(3); // C weights - [(nIn+nOut), numUnits] - cell gate (input/recurrent weights) auto Wc = INPUT_VARIABLE(3); // C weights - [nIn+nU, nU] - cell gate (input/recurrent weights)
auto bru = INPUT_VARIABLE(4); // reset and update biases, [2*numUnits] - reset and update gates auto bru = INPUT_VARIABLE(4); // reset and update biases, [2*nU] - reset and update gates
auto bc = INPUT_VARIABLE(5); // cell biases, [numUnits] auto bc = INPUT_VARIABLE(5); // cell biases, [nU]
auto r = OUTPUT_VARIABLE(0); // Reset gate output [bS, numUnits] auto r = OUTPUT_VARIABLE(0); // Reset gate output [bS, nU]
auto u = OUTPUT_VARIABLE(1); // Update gate output [bS, numUnits] auto u = OUTPUT_VARIABLE(1); // Update gate output [bS, nU]
auto c = OUTPUT_VARIABLE(2); // Cell gate output [bS, numUnits] auto c = OUTPUT_VARIABLE(2); // Cell gate output [bS, nU]
auto h = OUTPUT_VARIABLE(3); // current cell output [bS, numUnits] auto h = OUTPUT_VARIABLE(3); // current cell output [bS, nU]
REQUIRE_TRUE(x->rankOf()==2 && hLast->rankOf()==2, 0, "gruCell: Input ranks must be 2 for inputs 0 and 1 (x, hLast) - got %i, %i", x->rankOf(), hLast->rankOf()); REQUIRE_TRUE(x->rankOf()==2 && hLast->rankOf()==2, 0, "gruCell: Input ranks must be 2 for inputs 0 and 1 (x, hLast) - got %i, %i", x->rankOf(), hLast->rankOf());
const int rank = x->rankOf(); const int rank = x->rankOf();
const auto bS = x->sizeAt(0); const auto bS = x->sizeAt(0);
const auto nIn = x->sizeAt(1); const auto nIn = x->sizeAt(1);
const auto nU = hLast->sizeAt(1); const auto nU = hLast->sizeAt(1);
REQUIRE_TRUE(x->sizeAt(0) == hLast->sizeAt(0), 0, "gruCell: Input minibatch sizes (dimension 0) must be same for x and hLast"); REQUIRE_TRUE(x->sizeAt(0) == hLast->sizeAt(0), 0, "gruCell: Input minibatch sizes (dimension 0) must be same for x and hLast");
REQUIRE_TRUE(Wru->rankOf()==2 && Wc->rankOf()==2, 0, "gruCell: weight arrays (Wru, Wc) arrays must be 2, got %i and %i", Wru->rankOf(), Wc->rankOf()); REQUIRE_TRUE(Wru->rankOf()==2 && Wc->rankOf()==2, 0, "gruCell: weight arrays (Wru, Wc) arrays must be 2, got %i and %i", Wru->rankOf(), Wc->rankOf());
REQUIRE_TRUE(Wru->sizeAt(0)==(nIn+nU) && Wc->sizeAt(0)==(nIn+nU), 0, "gruCell: Weights size(0) must be equal to inSize + numUnits, got %i", Wru->sizeAt(0)); REQUIRE_TRUE(Wru->sizeAt(0)==(nIn+nU) && Wc->sizeAt(0)==(nIn+nU), 0, "gruCell: Weights size(0) must be equal to nIn + nU, got %i", Wru->sizeAt(0));
REQUIRE_TRUE(Wru->sizeAt(1)==(2*nU), 0, "gruCell: Weights (reset and update) size(1) must be equal to 2*numUnits, got %i", Wru->sizeAt(1)); REQUIRE_TRUE(Wru->sizeAt(1)==(2*nU), 0, "gruCell: Weights (reset and update) size(1) must be equal to 2*nU, got %i", Wru->sizeAt(1));
REQUIRE_TRUE(Wc->sizeAt(1)==nU, 0, "gruCell: Weights (cell) size(1) must be equal to numUnits, got %i", Wc->sizeAt(1)); REQUIRE_TRUE(Wc->sizeAt(1)==nU, 0, "gruCell: Weights (cell) size(1) must be equal to nU, got %i", Wc->sizeAt(1));
REQUIRE_TRUE(bru->rankOf()==1 && bru->sizeAt(0)==(2*nU), 0, "gruCell: reset/update biases must be rank 1, size 2*numUnits"); REQUIRE_TRUE(bru->rankOf()==1 && bru->sizeAt(0)==(2*nU), 0, "gruCell: reset/update biases must be rank 1, size 2*nU");
REQUIRE_TRUE(bc->rankOf()==1 && bc->sizeAt(0)==nU, 0, "gruCell: cell biases must be rank 1, size numUnits"); REQUIRE_TRUE(bc->rankOf()==1 && bc->sizeAt(0)==nU, 0, "gruCell: cell biases must be rank 1, size nU");
REQUIRE_TRUE(r->rankOf()==2 && u->rankOf()==2 && c->rankOf()==2 && h->rankOf()==2 && REQUIRE_TRUE(r->rankOf()==2 && u->rankOf()==2 && c->rankOf()==2 && h->rankOf()==2 &&
r->sizeAt(0)==bS && u->sizeAt(0)==bS && c->sizeAt(0)==bS && h->sizeAt(0)==bS && r->sizeAt(0)==bS && u->sizeAt(0)==bS && c->sizeAt(0)==bS && h->sizeAt(0)==bS &&
r->sizeAt(1)==nU && u->sizeAt(1)==nU && c->sizeAt(1)==nU && h->sizeAt(1)==nU, r->sizeAt(1)==nU && u->sizeAt(1)==nU && c->sizeAt(1)==nU && h->sizeAt(1)==nU,
0, "gruCell: Output arrays must all be rank 2 with size(0) == batchSize and size(1) == numUnits"); 0, "gruCell: Output arrays must all be rank 2 with size(0) == batchSize and size(1) == nU");
helpers::gruCell(block.launchContext(), x, hLast, Wru, Wc, bru, bc, r, u, c, h); helpers::gruCell(block.launchContext(), x, hLast, Wru, Wc, bru, bc, r, u, c, h);
@ -80,39 +80,39 @@ DECLARE_TYPES(gruCell) {
DECLARE_SHAPE_FN(gruCell) { DECLARE_SHAPE_FN(gruCell) {
auto x = inputShape->at(0); // input [bS x inSize] auto x = inputShape->at(0); // input [bS x nIn]
auto hLast = inputShape->at(1); // previous cell output [bS x numUnits], that is at previous time step t-1 auto hLast = inputShape->at(1); // previous cell output [bS x nU], that is at previous time step t-1
auto Wru = inputShape->at(2); // RU weights - [(nIn+nOut), 2*numUnits] - reset and update gates (input/recurrent weights) auto Wru = inputShape->at(2); // RU weights - [(nIn+nU), 2*nU] - reset and update gates (input/recurrent weights)
auto Wc = inputShape->at(3); // C weights - [(nIn+nOut), numUnits] - cell gate (input/recurrent weights) auto Wc = inputShape->at(3); // C weights - [(nIn+nU), nU] - cell gate (input/recurrent weights)
auto bru = inputShape->at(4); // reset and update biases, [2*numUnits] - reset and update gates auto bru = inputShape->at(4); // reset and update biases, [2*nU] - reset and update gates
auto bc = inputShape->at(5); // cell biases, [numUnits] auto bc = inputShape->at(5); // cell biases, [nU]
REQUIRE_TRUE(shape::rank(x)==2 && shape::rank(hLast)==2, 0, "gruCell: Input ranks must be 2 for inputs 0 and 1 (x, hLast) - got %i, %i", shape::rank(x), shape::rank(hLast)); REQUIRE_TRUE(shape::rank(x)==2 && shape::rank(hLast)==2, 0, "gruCell: Input ranks must be 2 for inputs 0 and 1 (x, hLast) - got %i, %i", shape::rank(x), shape::rank(hLast));
const int rank = x[0]; const int rank = x[0];
const auto bS = x[1]; const auto bS = x[1];
const auto inSize = x[2]; const auto nIn = x[2];
const auto numUnits = hLast[2]; const auto nU = hLast[2];
REQUIRE_TRUE(x[1] == hLast[1], 0, "gruCell: Input minibatch sizes (dimension 0) must be same for x and hLast"); REQUIRE_TRUE(x[1] == hLast[1], 0, "gruCell: Input minibatch sizes (dimension 0) must be same for x and hLast");
REQUIRE_TRUE(shape::rank(Wru)==2 && shape::rank(Wc)==2, 0, "gruCell: weight arrays (Wru, Wc) arrays must be 2, got %i and %i", shape::rank(Wru), shape::rank(Wc)); REQUIRE_TRUE(shape::rank(Wru)==2 && shape::rank(Wc)==2, 0, "gruCell: weight arrays (Wru, Wc) arrays must be 2, got %i and %i", shape::rank(Wru), shape::rank(Wc));
REQUIRE_TRUE(Wru[1]==(inSize+numUnits) && Wc[1]==(inSize+numUnits), 0, "gruCell: Weights size(0) must be equal to inSize + numUnits, got %i and %i", Wru[1], Wc[1]); REQUIRE_TRUE(Wru[1]==(nIn+nU) && Wc[1]==(nIn+nU), 0, "gruCell: Weights size(0) must be equal to nIn + nU, got %i and %i", Wru[1], Wc[1]);
REQUIRE_TRUE(Wru[2]==(2*numUnits), 0, "gruCell: Weights (reset and update) size(1) must be equal to 2*numUnits, got %i", Wru[2]); REQUIRE_TRUE(Wru[2]==(2*nU), 0, "gruCell: Weights (reset and update) size(1) must be equal to 2*nU, got %i", Wru[2]);
REQUIRE_TRUE(Wc[2]==numUnits, 0, "gruCell: Weights (cell) size(1) must be equal to numUnits, got %i", Wc[2]); REQUIRE_TRUE(Wc[2]==nU, 0, "gruCell: Weights (cell) size(1) must be equal to nU, got %i", Wc[2]);
REQUIRE_TRUE(shape::rank(bru)==1 && bru[1]==(2*numUnits), 0, "gruCell: reset/update biases must be rank 1, size 2*numUnits"); REQUIRE_TRUE(shape::rank(bru)==1 && bru[1]==(2*nU), 0, "gruCell: reset/update biases must be rank 1, size 2*nU");
REQUIRE_TRUE(shape::rank(bc)==1 && bc[1]==numUnits, 0, "gruCell: cell biases must be rank 1, size numUnits"); REQUIRE_TRUE(shape::rank(bc)==1 && bc[1]==nU, 0, "gruCell: cell biases must be rank 1, size nU");
Nd4jLong *s0(nullptr); Nd4jLong *s0(nullptr);
ALLOCATE(s0, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong);// [bS x numUnits] ALLOCATE(s0, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong);// [bS x nU]
s0[0] = rank; s0[0] = rank;
s0[1] = bS; s0[1] = bS;
s0[2] = numUnits; s0[2] = nU;
ShapeUtils::updateStridesAndType(s0, x, shape::order(hLast)); ShapeUtils::updateStridesAndType(s0, x, shape::order(hLast));
auto ts0 = ConstantShapeHelper::getInstance()->createFromExisting(s0, block.workspace()); auto ts0 = ConstantShapeHelper::getInstance()->createFromExisting(s0, block.workspace());
//4 output shapes, all [bs, numUnits] //4 output shapes, all [bs, nU]
return SHAPELIST(ts0, ts0, ts0, ts0); return SHAPELIST(ts0, ts0, ts0, ts0);
} }

View File

@ -152,9 +152,9 @@ CUSTOM_OP_IMPL(sru_bp, 8, 4, true, 0, 0) {
auto inGradH = INPUT_VARIABLE(6); // [bS x K x N] auto inGradH = INPUT_VARIABLE(6); // [bS x K x N]
NDArray* mask = nullptr; // optional, 2d tensor of dropout mask [bS x K] NDArray* mask = nullptr; // optional, 2d tensor of dropout mask [bS x K]
bool applyMask = false; bool applyMask = false;
if (block.width() > 7) { if (block.width() > 7) {
mask = INPUT_VARIABLE(7); mask = INPUT_VARIABLE(7);
applyMask = true; applyMask = true;
} }
@ -166,7 +166,7 @@ CUSTOM_OP_IMPL(sru_bp, 8, 4, true, 0, 0) {
const int bS = x->shapeOf()[0]; const int bS = x->shapeOf()[0];
const int K = x->shapeOf()[1]; const int K = x->shapeOf()[1];
const int N = x->shapeOf()[2]; // N - number of time steps const int N = x->shapeOf()[2]; // N - number of time steps
auto gradBias = NDArrayFactory::create_(x->ordering(), {bS, 2*K, N}, gradX->dataType(), block.launchContext()); auto gradBias = NDArrayFactory::create_(x->ordering(), {bS, 2*K, N}, gradX->dataType(), block.launchContext());
auto gradU = NDArrayFactory::create_(x->ordering(), {bS, 3*K, N}, gradX->dataType(), block.launchContext()); auto gradU = NDArrayFactory::create_(x->ordering(), {bS, 3*K, N}, gradX->dataType(), block.launchContext());
auto gradHX = NDArrayFactory::create_(x->ordering(), {bS, K, N}, gradX->dataType(), block.launchContext()); auto gradHX = NDArrayFactory::create_(x->ordering(), {bS, K, N}, gradX->dataType(), block.launchContext());
@ -199,7 +199,7 @@ CUSTOM_OP_IMPL(sru_bp, 8, 4, true, 0, 0) {
std::vector<Nd4jLong> idx = {0,0, 0,0, 0,0}; std::vector<Nd4jLong> idx = {0,0, 0,0, 0,0};
for (int t = N-1; t >=0 ; --t) { for (int t = N-1; t >=0 ; --t) {
// initialization // initialization
idx[4] = t; idx[4] = t;
idx[5] = t + 1; idx[5] = t + 1;
@ -242,7 +242,7 @@ CUSTOM_OP_IMPL(sru_bp, 8, 4, true, 0, 0) {
gct->applyPairwiseTransform(pairwise::Subtract, &xt, temp1, nullptr); // temp1 = (g_ct - xt) gct->applyPairwiseTransform(pairwise::Subtract, &xt, temp1, nullptr); // temp1 = (g_ct - xt)
rtMinus->applyPairwiseTransform(pairwise::Multiply, &rt, temp2, nullptr); // temp2 = (1.0f - rt) * rt; rtMinus->applyPairwiseTransform(pairwise::Multiply, &rt, temp2, nullptr); // temp2 = (1.0f - rt) * rt;
temp1->applyPairwiseTransform(pairwise::Multiply, *temp2, nullptr); // temp1 = (g_ct - xt) * (1.0f - rt) * rt; temp1->applyPairwiseTransform(pairwise::Multiply, *temp2, nullptr); // temp1 = (g_ct - xt) * (1.0f - rt) * rt;
inGradHt.applyPairwiseTransform(pairwise::Multiply, temp1, &gradBRt, nullptr); // = inGradHt * (g_ct - xt) * (1.0f - rt) * rt; inGradHt.applyPairwiseTransform(pairwise::Multiply, temp1, &gradBRt, nullptr); // = inGradHt * (g_ct - xt) * (1.0f - rt) * rt;
// bF, TODO - tanh // bF, TODO - tanh
// gradTanh = (1.0f - g_ct * g_ct); // gradTanh = (1.0f - g_ct * g_ct);
@ -259,7 +259,7 @@ CUSTOM_OP_IMPL(sru_bp, 8, 4, true, 0, 0) {
temp1->applyPairwiseTransform(pairwise::Multiply, temp2, &gradBFt, nullptr); // gradBFt = (gradCt + inGradCt) * (ct_1 - zt) * (1 - ft) * ft; temp1->applyPairwiseTransform(pairwise::Multiply, temp2, &gradBFt, nullptr); // gradBFt = (gradCt + inGradCt) * (ct_1 - zt) * (1 - ft) * ft;
// x_t (highway connection), gradHXt = inGradHt * (1.0f - rt); // x_t (highway connection), gradHXt = inGradHt * (1.0f - rt);
inGradHt.applyPairwiseTransform(pairwise::Multiply, rtMinus, &gradHXt, nullptr); inGradHt.applyPairwiseTransform(pairwise::Multiply, rtMinus, &gradHXt, nullptr);
// U_t, gradUZt = (inGradHt * rt * grad_tanh + inGradCt) * (1.0f - ft); // U_t, gradUZt = (inGradHt * rt * grad_tanh + inGradCt) * (1.0f - ft);
rt.applyPairwiseTransform(pairwise::Multiply, gradTanh, temp1, nullptr); // temp1 = rt * grad_tanh rt.applyPairwiseTransform(pairwise::Multiply, gradTanh, temp1, nullptr); // temp1 = rt * grad_tanh
@ -280,14 +280,14 @@ CUSTOM_OP_IMPL(sru_bp, 8, 4, true, 0, 0) {
// gradInit // gradInit
gradInit->assign(inGradCt); gradInit->assign(inGradCt);
// gradX // gradX
auto weightsT = w->transpose(); // [K x 3K] auto weightsT = w->transpose(); // [K x 3K]
MmulHelper::mmul(&weightsT, gradU, gradX, 1., 0.); // [bS x K x N] MmulHelper::mmul(&weightsT, gradU, gradX, 1., 0.); // [bS x K x N]
gradX->applyPairwiseTransform(pairwise::Add, gradHX, gradX, nullptr); // + grad_highway_x gradX->applyPairwiseTransform(pairwise::Add, gradHX, gradX, nullptr); // + grad_highway_x
if(applyMask) if(applyMask)
gradX->applyBroadcast(broadcast::Multiply, {0,1}, mask, gradX, nullptr); // apply mask gradX->applyBroadcast(broadcast::Multiply, {0,1}, mask, gradX, nullptr); // apply mask
// gradB // gradB
auto temp3 = gradBias->reduceAlongDimension(reduce::Sum, {0,2}, false, true); // [1 x 2K] auto temp3 = gradBias->reduceAlongDimension(reduce::Sum, {0,2}, false, true); // [1 x 2K]
gradB->assign(temp3); gradB->assign(temp3);
@ -298,7 +298,7 @@ CUSTOM_OP_IMPL(sru_bp, 8, 4, true, 0, 0) {
delete gct; delete gradU; delete gradHX; delete gct; delete gradU; delete gradHX;
delete temp1; delete temp2; delete temp3; delete gradCt; delete wi; delete temp1; delete temp2; delete temp3; delete gradCt; delete wi;
delete gradTanh; delete ftMinus; delete rtMinus; delete gradBias; delete gradTanh; delete ftMinus; delete rtMinus; delete gradBias;
return Status::OK(); return Status::OK();
} }
@ -322,21 +322,21 @@ DECLARE_SHAPE_FN(sru_bp) {
ShapeDescriptor descriptor4(ArrayOptions::dataType(inShape), order, {bS, inSize}); ShapeDescriptor descriptor4(ArrayOptions::dataType(inShape), order, {bS, inSize});
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(descriptor1), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor2), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor3), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor4)); return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(descriptor1), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor2), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor3), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor4));
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(sru_bi, 5, 2, true, 0, 0) { CUSTOM_OP_IMPL(sru_bi, 5, 2, true, 0, 0) {
auto x = INPUT_VARIABLE(0); // X, input 3d tensor [time x bS x 2*inSize], time - number of time steps, bS - batch size, inSize - number of features auto x = INPUT_VARIABLE(0); // X, input 3d tensor [time x bS x 2*inSize], time - number of time steps, bS - batch size, inSize - number of features
auto w = INPUT_VARIABLE(1); // W, 2d tensor of weights [2*inSize x 6*inSize] auto w = INPUT_VARIABLE(1); // W, 2d tensor of weights [2*inSize x 6*inSize]
auto b = INPUT_VARIABLE(2); // B, row of biases with twice length [1 x 4*inSize] auto b = INPUT_VARIABLE(2); // B, row of biases with twice length [1 x 4*inSize]
auto c0 = INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state [bS x 2*inSize] at time t=0 auto c0 = INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state [bS x 2*inSize] at time t=0
NDArray* mask = block.width() > 4 ? INPUT_VARIABLE(4) : nullptr; // optional, 2d tensor of dropout mask [bS x 2*inSize] NDArray* mask = block.width() > 4 ? INPUT_VARIABLE(4) : nullptr; // optional, 2d tensor of dropout mask [bS x 2*inSize]
auto ht = OUTPUT_VARIABLE(0); // h_t, [time x bS x 2*inSize] auto ht = OUTPUT_VARIABLE(0); // h_t, [time x bS x 2*inSize]
auto ct = OUTPUT_VARIABLE(1); // c_t, [time x bS x 2*inSize] auto ct = OUTPUT_VARIABLE(1); // c_t, [time x bS x 2*inSize]
// input shapes validation // input shapes validation
const int rank = x->rankOf(); const int rank = x->rankOf();
@ -345,20 +345,20 @@ CUSTOM_OP_IMPL(sru_bi, 5, 2, true, 0, 0) {
REQUIRE_TRUE(x->rankOf() == rank, 0, "SRU_BI operation: wrong rank of input array, expected is %i, but got %i instead !", rank, x->rankOf()); REQUIRE_TRUE(x->rankOf() == rank, 0, "SRU_BI operation: wrong rank of input array, expected is %i, but got %i instead !", rank, x->rankOf());
REQUIRE_TRUE(w->rankOf() == rank-1, 0, "SRU_BI operation: wrong rank of weights array, expected is %i, but got %i instead !", rank-1, w->rankOf()); REQUIRE_TRUE(w->rankOf() == rank-1, 0, "SRU_BI operation: wrong rank of weights array, expected is %i, but got %i instead !", rank-1, w->rankOf());
REQUIRE_TRUE(b->rankOf() <= rank-1, 0, "SRU_BI operation: wrong rank of biases array, expected is <=2, but got %i instead !", b->rankOf()); REQUIRE_TRUE(b->rankOf() == 1, 0, "SRU_BI operation: wrong rank of biases array, expected is 1, but got %i instead !", b->rankOf());
REQUIRE_TRUE(c0->rankOf() == rank-1, 0, "SRU_BI operation: wrong rank of initial state array, expected is %i, but got %i instead !", rank-1, c0->rankOf()); REQUIRE_TRUE(c0->rankOf() == rank-1, 0, "SRU_BI operation: wrong rank of initial state array, expected is %i, but got %i instead !", rank-1, c0->rankOf());
if(mask) if(mask)
REQUIRE_TRUE(mask->rankOf() == rank-1, 0, "SRU_BI operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, mask->rankOf()); REQUIRE_TRUE(mask->rankOf() == rank-1, 0, "SRU_BI operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, mask->rankOf());
const std::string wShape = ShapeUtils::shapeAsString(w); const std::string wShape = ShapeUtils::shapeAsString(w);
const std::string wCorrectShape = ShapeUtils::shapeAsString({2*inSize, 6*inSize}); const std::string wCorrectShape = ShapeUtils::shapeAsString({2*inSize, 6*inSize});
const std::string bShape = ShapeUtils::shapeAsString(b); const std::string bShape = ShapeUtils::shapeAsString(b);
const std::string bCorrectShape = ShapeUtils::shapeAsString({1, 4*inSize}); const std::string bCorrectShape = ShapeUtils::shapeAsString({4*inSize});
const std::string c0Shape = ShapeUtils::shapeAsString(c0); const std::string c0Shape = ShapeUtils::shapeAsString(c0);
const std::string c0CorrectShape = ShapeUtils::shapeAsString({bS, 2*inSize}); const std::string c0CorrectShape = ShapeUtils::shapeAsString({bS, 2*inSize});
REQUIRE_TRUE(wShape == wCorrectShape, 0, "SRU_BI operation: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str()); REQUIRE_TRUE(wShape == wCorrectShape, 0, "SRU_BI operation: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
REQUIRE_TRUE(bShape == bCorrectShape, 0, "SRU_BI operation: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str()); REQUIRE_TRUE(bShape == bCorrectShape, 0, "SRU_BI operation: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
REQUIRE_TRUE(c0Shape == c0CorrectShape, 0, "SRU_BI operation: wrong shape of initial state array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), c0Shape.c_str()); REQUIRE_TRUE(c0Shape == c0CorrectShape, 0, "SRU_BI operation: wrong shape of initial state array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), c0Shape.c_str());
if(mask) { if(mask) {
const std::string maskShape = ShapeUtils::shapeAsString(mask); const std::string maskShape = ShapeUtils::shapeAsString(mask);
@ -366,7 +366,7 @@ CUSTOM_OP_IMPL(sru_bi, 5, 2, true, 0, 0) {
} }
helpers::sruBI(block.launchContext(), x, w, b, c0, mask, ht, ct); helpers::sruBI(block.launchContext(), x, w, b, c0, mask, ht, ct);
return Status::OK(); return Status::OK();
} }
@ -379,20 +379,20 @@ CUSTOM_OP_IMPL(sru_bi, 5, 2, true, 0, 0) {
DECLARE_SHAPE_FN(sru_bi) { DECLARE_SHAPE_FN(sru_bi) {
auto xShapeInfo = inputShape->at(0); // [time x bS x 2K ] auto xShapeInfo = inputShape->at(0); // [time x bS x 2K ]
auto wShapeInfo = inputShape->at(1); auto wShapeInfo = inputShape->at(1);
auto bShapeInfo = inputShape->at(2); auto bShapeInfo = inputShape->at(2);
auto c0ShapeInfo = inputShape->at(3); auto c0ShapeInfo = inputShape->at(3);
Nd4jLong* maskShapeInfo = block.width() > 4 ? inputShape->at(4) : nullptr; // optional, 2d tensor of dropout mask [bS x inSize] Nd4jLong* maskShapeInfo = block.width() > 4 ? inputShape->at(4) : nullptr; // optional, 2d tensor of dropout mask [bS x inSize]
const int rank = xShapeInfo[0]; // = 3 const int rank = xShapeInfo[0]; // = 3
const Nd4jLong time = xShapeInfo[1]; const Nd4jLong time = xShapeInfo[1];
const Nd4jLong bS = xShapeInfo[2]; const Nd4jLong bS = xShapeInfo[2];
const Nd4jLong inSize = xShapeInfo[3] / 2; const Nd4jLong inSize = xShapeInfo[3] / 2;
// input shapes validation // input shapes validation
REQUIRE_TRUE(wShapeInfo[0] == rank-1, 0, "SRU_BI operation: wrong rank of weights array, expected is %i, but got %i instead !", rank-1, wShapeInfo[0]); REQUIRE_TRUE(wShapeInfo[0] == rank-1, 0, "SRU_BI operation: wrong rank of weights array, expected is %i, but got %i instead !", rank-1, wShapeInfo[0]);
REQUIRE_TRUE(bShapeInfo[0] == rank-1, 0, "SRU_BI operation: wrong rank of biases array, expected is <=2, but got %i instead !", rank-1, bShapeInfo[0]); REQUIRE_TRUE(bShapeInfo[0] == 1, 0, "SRU_BI operation: wrong rank of biases array, expected is 1, but got %i instead !", bShapeInfo[0]);
REQUIRE_TRUE(c0ShapeInfo[0] == rank-1, 0, "SRU_BI operation: wrong rank of initial state array, expected is %i, but got %i instead !", rank-1, c0ShapeInfo[0]); REQUIRE_TRUE(c0ShapeInfo[0] == rank-1, 0, "SRU_BI operation: wrong rank of initial state array, expected is %i, but got %i instead !", rank-1, c0ShapeInfo[0]);
if(maskShapeInfo) if(maskShapeInfo)
REQUIRE_TRUE(maskShapeInfo[0] == rank-1, 0, "SRU_BI operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, maskShapeInfo[0]); REQUIRE_TRUE(maskShapeInfo[0] == rank-1, 0, "SRU_BI operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, maskShapeInfo[0]);
@ -400,12 +400,12 @@ DECLARE_SHAPE_FN(sru_bi) {
const std::string wShape = ShapeUtils::shapeAsString(wShapeInfo); const std::string wShape = ShapeUtils::shapeAsString(wShapeInfo);
const std::string wCorrectShape = ShapeUtils::shapeAsString({2*inSize, 6*inSize}); const std::string wCorrectShape = ShapeUtils::shapeAsString({2*inSize, 6*inSize});
const std::string bShape = ShapeUtils::shapeAsString(bShapeInfo); const std::string bShape = ShapeUtils::shapeAsString(bShapeInfo);
const std::string bCorrectShape = ShapeUtils::shapeAsString({1, 4*inSize}); const std::string bCorrectShape = ShapeUtils::shapeAsString({4*inSize});
const std::string c0Shape = ShapeUtils::shapeAsString(c0ShapeInfo); const std::string c0Shape = ShapeUtils::shapeAsString(c0ShapeInfo);
const std::string c0CorrectShape = ShapeUtils::shapeAsString({bS, 2*inSize}); const std::string c0CorrectShape = ShapeUtils::shapeAsString({bS, 2*inSize});
REQUIRE_TRUE(wShape == wCorrectShape, 0, "SRU_BI operation: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str()); REQUIRE_TRUE(wShape == wCorrectShape, 0, "SRU_BI operation: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
REQUIRE_TRUE(bShape == bCorrectShape, 0, "SRU_BI operation: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str()); REQUIRE_TRUE(bShape == bCorrectShape, 0, "SRU_BI operation: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
REQUIRE_TRUE(c0Shape == c0CorrectShape, 0, "SRU_BI operation: wrong shape of initial state array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), c0Shape.c_str()); REQUIRE_TRUE(c0Shape == c0CorrectShape, 0, "SRU_BI operation: wrong shape of initial state array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), c0Shape.c_str());
if(maskShapeInfo) { if(maskShapeInfo) {
const std::string maskShape = ShapeUtils::shapeAsString(maskShapeInfo); const std::string maskShape = ShapeUtils::shapeAsString(maskShapeInfo);
@ -428,15 +428,15 @@ DECLARE_SHAPE_FN(sru_bi) {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(sru_bi_bp, 8, 4, true, 0, 0) { CUSTOM_OP_IMPL(sru_bi_bp, 8, 4, true, 0, 0) {
auto x = INPUT_VARIABLE(0); // X, input 3d tensor [time x bS x 2*inSize], time - number of time steps, bS - batch size, inSize - number of features auto x = INPUT_VARIABLE(0); // X, input 3d tensor [time x bS x 2*inSize], time - number of time steps, bS - batch size, inSize - number of features
auto w = INPUT_VARIABLE(1); // W, 2d tensor of weights [2*inSize x 6*inSize] auto w = INPUT_VARIABLE(1); // W, 2d tensor of weights [2*inSize x 6*inSize]
auto b = INPUT_VARIABLE(2); // B, row of biases with twice length [1 x 4*inSize] auto b = INPUT_VARIABLE(2); // B, row of biases with twice length [4*inSize]
auto c0 = INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state [bS x 2*inSize] at time t=0 auto c0 = INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state [bS x 2*inSize] at time t=0
auto ct = INPUT_VARIABLE(4); // C, [time x bS x 2*inSize] auto ct = INPUT_VARIABLE(4); // C, [time x bS x 2*inSize]
auto inGradC0 = INPUT_VARIABLE(5); // [bS x 2*inSize] auto inGradC0 = INPUT_VARIABLE(5); // [bS x 2*inSize]
auto inGradHt = INPUT_VARIABLE(6); // [time x bS x 2*inSize] auto inGradHt = INPUT_VARIABLE(6); // [time x bS x 2*inSize]
NDArray* mask = block.width() > 7 ? INPUT_VARIABLE(7) : nullptr; // optional, 2d tensor of dropout mask [bS x 2*inSize] NDArray* mask = block.width() > 7 ? INPUT_VARIABLE(7) : nullptr; // optional, 2d tensor of dropout mask [bS x 2*inSize]
// input shapes validation // input shapes validation
const int rank = x->rankOf(); const int rank = x->rankOf();
@ -445,7 +445,7 @@ CUSTOM_OP_IMPL(sru_bi_bp, 8, 4, true, 0, 0) {
const Nd4jLong inSize = x->sizeAt(2) / 2; const Nd4jLong inSize = x->sizeAt(2) / 2;
REQUIRE_TRUE(w->rankOf() == rank-1, 0, "SRU_BI_BP operation: wrong rank of weights array, expected is %i, but got %i instead !", rank-1, w->rankOf()); REQUIRE_TRUE(w->rankOf() == rank-1, 0, "SRU_BI_BP operation: wrong rank of weights array, expected is %i, but got %i instead !", rank-1, w->rankOf());
REQUIRE_TRUE(b->rankOf() <= rank-1, 0, "SRU_BI_BP operation: wrong rank of biases array, expected is <=2, but got %i instead !", b->rankOf()); REQUIRE_TRUE(b->rankOf() == 1, 0, "SRU_BI_BP operation: wrong rank of biases array, expected is 1, but got %i instead !", b->rankOf());
REQUIRE_TRUE(c0->rankOf() == rank-1, 0, "SRU_BI_BP operation: wrong rank of initial state array, expected is %i, but got %i instead !", rank-1, c0->rankOf()); REQUIRE_TRUE(c0->rankOf() == rank-1, 0, "SRU_BI_BP operation: wrong rank of initial state array, expected is %i, but got %i instead !", rank-1, c0->rankOf());
REQUIRE_TRUE(ct->rankOf() == rank, 0, "SRU_BI_BP operation: wrong rank of state array, expected is %i, but got %i instead !", rank, ct->rankOf()); REQUIRE_TRUE(ct->rankOf() == rank, 0, "SRU_BI_BP operation: wrong rank of state array, expected is %i, but got %i instead !", rank, ct->rankOf());
REQUIRE_TRUE(inGradC0->rankOf() == rank-1, 0, "SRU_BI_BP operation: wrong rank of gradient c0, expected is %i, but got %i instead !", rank-1, inGradC0->rankOf()); REQUIRE_TRUE(inGradC0->rankOf() == rank-1, 0, "SRU_BI_BP operation: wrong rank of gradient c0, expected is %i, but got %i instead !", rank-1, inGradC0->rankOf());
@ -456,7 +456,7 @@ CUSTOM_OP_IMPL(sru_bi_bp, 8, 4, true, 0, 0) {
const std::string wShape = ShapeUtils::shapeAsString(w); const std::string wShape = ShapeUtils::shapeAsString(w);
const std::string wCorrectShape = ShapeUtils::shapeAsString({2*inSize, 6*inSize}); const std::string wCorrectShape = ShapeUtils::shapeAsString({2*inSize, 6*inSize});
const std::string bShape = ShapeUtils::shapeAsString(b); const std::string bShape = ShapeUtils::shapeAsString(b);
const std::string bCorrectShape = ShapeUtils::shapeAsString({1, 4*inSize}); const std::string bCorrectShape = ShapeUtils::shapeAsString({4*inSize});
const std::string c0Shape = ShapeUtils::shapeAsString(c0); const std::string c0Shape = ShapeUtils::shapeAsString(c0);
const std::string c0CorrectShape = ShapeUtils::shapeAsString({bS, 2*inSize}); const std::string c0CorrectShape = ShapeUtils::shapeAsString({bS, 2*inSize});
const std::string ctShape = ShapeUtils::shapeAsString(ct); const std::string ctShape = ShapeUtils::shapeAsString(ct);
@ -470,7 +470,7 @@ CUSTOM_OP_IMPL(sru_bi_bp, 8, 4, true, 0, 0) {
const std::string maskShape = ShapeUtils::shapeAsString(mask); const std::string maskShape = ShapeUtils::shapeAsString(mask);
REQUIRE_TRUE(maskShape == c0CorrectShape, 0, "SRU_BI operation: wrong shape of mask array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), maskShape.c_str()); REQUIRE_TRUE(maskShape == c0CorrectShape, 0, "SRU_BI operation: wrong shape of mask array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), maskShape.c_str());
} }
auto gradI = OUTPUT_VARIABLE(0); // [time x bS x 2*inSize] auto gradI = OUTPUT_VARIABLE(0); // [time x bS x 2*inSize]
auto gradW = OUTPUT_VARIABLE(1); // [time x 2*inSize x 6*inSize] auto gradW = OUTPUT_VARIABLE(1); // [time x 2*inSize x 6*inSize]
auto gradB = OUTPUT_VARIABLE(2); // [1 x 4*inSize] auto gradB = OUTPUT_VARIABLE(2); // [1 x 4*inSize]
@ -484,9 +484,9 @@ CUSTOM_OP_IMPL(sru_bi_bp, 8, 4, true, 0, 0) {
DECLARE_SHAPE_FN(sru_bi_bp) { DECLARE_SHAPE_FN(sru_bi_bp) {
auto xShapeInfo = inputShape->at(0); // [time x bS x 2K ] auto xShapeInfo = inputShape->at(0); // [time x bS x 2K ]
auto wShapeInfo = inputShape->at(1); auto wShapeInfo = inputShape->at(1);
auto bShapeInfo = inputShape->at(2); auto bShapeInfo = inputShape->at(2);
auto c0ShapeInfo = inputShape->at(3); auto c0ShapeInfo = inputShape->at(3);
auto ctShapeInfo = inputShape->at(4); auto ctShapeInfo = inputShape->at(4);
auto inGradC0ShapeInfo = inputShape->at(5); auto inGradC0ShapeInfo = inputShape->at(5);
auto inGradHtShapeInfo = inputShape->at(6); auto inGradHtShapeInfo = inputShape->at(6);
@ -499,7 +499,7 @@ DECLARE_SHAPE_FN(sru_bi_bp) {
const Nd4jLong inSize = xShapeInfo[3] / 2; const Nd4jLong inSize = xShapeInfo[3] / 2;
REQUIRE_TRUE(wShapeInfo[0] == rank-1, 0, "SRU_BI_BP operation: wrong rank of weights array, expected is %i, but got %i instead !", rank-1, wShapeInfo[0]); REQUIRE_TRUE(wShapeInfo[0] == rank-1, 0, "SRU_BI_BP operation: wrong rank of weights array, expected is %i, but got %i instead !", rank-1, wShapeInfo[0]);
REQUIRE_TRUE(bShapeInfo[0] <= rank-1, 0, "SRU_BI_BP operation: wrong rank of biases array, expected is <=2, but got %i instead !", bShapeInfo); REQUIRE_TRUE(bShapeInfo[0] == 1, 0, "SRU_BI_BP operation: wrong rank of biases array, expected is 1, but got %i instead !", bShapeInfo[0]);
REQUIRE_TRUE(c0ShapeInfo[0] == rank-1, 0, "SRU_BI_BP operation: wrong rank of initial state array, expected is %i, but got %i instead !", rank-1, c0ShapeInfo); REQUIRE_TRUE(c0ShapeInfo[0] == rank-1, 0, "SRU_BI_BP operation: wrong rank of initial state array, expected is %i, but got %i instead !", rank-1, c0ShapeInfo);
REQUIRE_TRUE(ctShapeInfo[0] == rank, 0, "SRU_BI_BP operation: wrong rank of state array, expected is %i, but got %i instead !", rank, ctShapeInfo); REQUIRE_TRUE(ctShapeInfo[0] == rank, 0, "SRU_BI_BP operation: wrong rank of state array, expected is %i, but got %i instead !", rank, ctShapeInfo);
REQUIRE_TRUE(inGradC0ShapeInfo[0] == rank-1, 0, "SRU_BI_BP operation: wrong rank of gradient c0, expected is %i, but got %i instead !", rank-1, inGradC0ShapeInfo[0]); REQUIRE_TRUE(inGradC0ShapeInfo[0] == rank-1, 0, "SRU_BI_BP operation: wrong rank of gradient c0, expected is %i, but got %i instead !", rank-1, inGradC0ShapeInfo[0]);
@ -510,7 +510,7 @@ DECLARE_SHAPE_FN(sru_bi_bp) {
const std::string wShape = ShapeUtils::shapeAsString(wShapeInfo); const std::string wShape = ShapeUtils::shapeAsString(wShapeInfo);
const std::string wCorrectShape = ShapeUtils::shapeAsString({2*inSize, 6*inSize}); const std::string wCorrectShape = ShapeUtils::shapeAsString({2*inSize, 6*inSize});
const std::string bShape = ShapeUtils::shapeAsString(bShapeInfo); const std::string bShape = ShapeUtils::shapeAsString(bShapeInfo);
const std::string bCorrectShape = ShapeUtils::shapeAsString({1, 4*inSize}); const std::string bCorrectShape = ShapeUtils::shapeAsString({4*inSize});
const std::string c0Shape = ShapeUtils::shapeAsString(c0ShapeInfo); const std::string c0Shape = ShapeUtils::shapeAsString(c0ShapeInfo);
const std::string c0CorrectShape = ShapeUtils::shapeAsString({bS, 2*inSize}); const std::string c0CorrectShape = ShapeUtils::shapeAsString({bS, 2*inSize});
const std::string ctShape = ShapeUtils::shapeAsString(ctShapeInfo); const std::string ctShape = ShapeUtils::shapeAsString(ctShapeInfo);
@ -535,11 +535,11 @@ DECLARE_SHAPE_FN(sru_bi_bp) {
ShapeDescriptor descriptor1(ArrayOptions::dataType(xShapeInfo), order, {time, bS, 2 * inSize}); ShapeDescriptor descriptor1(ArrayOptions::dataType(xShapeInfo), order, {time, bS, 2 * inSize});
ShapeDescriptor descriptor2(ArrayOptions::dataType(xShapeInfo), order, {time, 2 * inSize, 6 * inSize}); ShapeDescriptor descriptor2(ArrayOptions::dataType(xShapeInfo), order, {time, 2 * inSize, 6 * inSize});
ShapeDescriptor descriptor3(ArrayOptions::dataType(xShapeInfo), order, {1, 4 * inSize}); ShapeDescriptor descriptor3(ArrayOptions::dataType(xShapeInfo), order, {4 * inSize});
ShapeDescriptor descriptor4(ArrayOptions::dataType(xShapeInfo), order, {bS, 2 * inSize}); ShapeDescriptor descriptor4(ArrayOptions::dataType(xShapeInfo), order, {bS, 2 * inSize});
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(descriptor1), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor2), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor3), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor4)); return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(descriptor1), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor2), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor3), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor4));
} }
} }
} }
@ -549,15 +549,15 @@ DECLARE_SHAPE_FN(sru_bi_bp) {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
/** /**
* Implementation of operations for Simple Recurrent Unit: "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi * Implementation of operations for Simple Recurrent Unit: "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi
* *
* Input arrays: * Input arrays:
* 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - batch size, K - number of features * 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - batch size, K - number of features
* 1: 2d tensor of weights [3K x K] * 1: 2d tensor of weights [3K x K]
* 2: row of biases with twice length [1 x 2K] * 2: row of biases with twice length [1 x 2K]
* 3: 2d tensor of previous cell state [bS x K] * 3: 2d tensor of previous cell state [bS x K]
* 4: optional, 2d tensor of dropout mask [bS x K] * 4: optional, 2d tensor of dropout mask [bS x K]
* *
* Output arrays: * Output arrays:
* 0: 3d tensor of cell output [bS x K x N] * 0: 3d tensor of cell output [bS x K x N]
* 1: 3d tensor of cell state [bS x K x N] * 1: 3d tensor of cell state [bS x K x N]
*/ */
@ -568,15 +568,15 @@ DECLARE_SHAPE_FN(sru_bi_bp) {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
/** /**
* Implementation of operation for Simple Recurrent Unit: "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi * Implementation of operation for Simple Recurrent Unit: "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi
* *
* Input arrays: * Input arrays:
* 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - batch size, K - number of features * 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - batch size, K - number of features
* 1: 2d tensor of weights [3K x K] * 1: 2d tensor of weights [3K x K]
* 2: row of biases with twice length [1 x 2K] * 2: row of biases with twice length [1 x 2K]
* 3: 2d tensor of previous cell state [bS x K] * 3: 2d tensor of previous cell state [bS x K]
* 4: optional, 2d tensor of dropout mask [bS x K] * 4: optional, 2d tensor of dropout mask [bS x K]
* *
* Output arrays: * Output arrays:
* 0: 3d tensor of cell output [bS x K x N] * 0: 3d tensor of cell output [bS x K x N]
* 1: 3d tensor of cell state [bS x K x N] * 1: 3d tensor of cell state [bS x K x N]
*/ */
@ -588,8 +588,8 @@ DECLARE_SHAPE_FN(sru_bi_bp) {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
/** /**
* Implementation of operation for back propagation in Simple Recurrent Unit: "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi * Implementation of operation for back propagation in Simple Recurrent Unit: "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi
* *
* Input arrays: * Input arrays:
* 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - batch size, K - number of features * 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - batch size, K - number of features
* 1: 2d tensor of weights [3K x K] * 1: 2d tensor of weights [3K x K]
* 2: row of biases with twice length [1 x 2K] * 2: row of biases with twice length [1 x 2K]
@ -598,13 +598,13 @@ DECLARE_SHAPE_FN(sru_bi_bp) {
* 5: 2d tensor of cell state gradients [bS x K] * 5: 2d tensor of cell state gradients [bS x K]
* 6: 3d tensor of state output gradients [bS x K x N] * 6: 3d tensor of state output gradients [bS x K x N]
* 7: optional, 2d tensor of dropout mask [bS x K] * 7: optional, 2d tensor of dropout mask [bS x K]
* *
* Output arrays: * Output arrays:
* 0: 3d tensor of input gradients [bS x K x N] * 0: 3d tensor of input gradients [bS x K x N]
* 1: 3d tensor of weights gradients [bS x 3K x K] * 1: 3d tensor of weights gradients [bS x 3K x K]
* 2: 2d, row of biases gradients [1 x 2K] * 2: 2d, row of biases gradients [1 x 2K]
* 3: 2d, tensor of state gradients [bS x K] * 3: 2d, tensor of state gradients [bS x K]
*/ */
// #if NOT_EXCLUDED(OP_sru_logic) // #if NOT_EXCLUDED(OP_sru_logic)
// DECLARE_CUSTOM_OP(sru_bp_logic,8, 4, true, 0, 0); // DECLARE_CUSTOM_OP(sru_bp_logic,8, 4, true, 0, 0);
// #endif // #endif
@ -618,27 +618,27 @@ DECLARE_SHAPE_FN(sru_bi_bp) {
// } // }
///////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////
// CUSTOM_OP_IMPL(sru_logic, 5, 2, false, 0, 0) { // CUSTOM_OP_IMPL(sru_logic, 5, 2, false, 0, 0) {
// auto input = INPUT_VARIABLE(0); // X, input 3d tensor [bS x K x N], N - number of time steps, bS - batch size, K - number of features // auto input = INPUT_VARIABLE(0); // X, input 3d tensor [bS x K x N], N - number of time steps, bS - batch size, K - number of features
// auto weights = INPUT_VARIABLE(1); // W, 2d tensor of weights [3K x K] // auto weights = INPUT_VARIABLE(1); // W, 2d tensor of weights [3K x K]
// auto bias = INPUT_VARIABLE(2); // B, row of biases with twice length [1 x 2*K] // auto bias = INPUT_VARIABLE(2); // B, row of biases with twice length [1 x 2*K]
// auto init = INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state [bS x K] at time t=0 // auto init = INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state [bS x K] at time t=0
// NDArray* mask = nullptr; // optional, 2d tensor of dropout mask [bS x K] // NDArray* mask = nullptr; // optional, 2d tensor of dropout mask [bS x K]
// bool applyMask = false; // bool applyMask = false;
// if (block.width() > 4) { // if (block.width() > 4) {
// mask = INPUT_VARIABLE(4); // mask = INPUT_VARIABLE(4);
// applyMask = true; // applyMask = true;
// } // }
// auto output = OUTPUT_VARIABLE(0); // h_t, [bS x K x N] // auto output = OUTPUT_VARIABLE(0); // h_t, [bS x K x N]
// auto state = OUTPUT_VARIABLE(1); // c_t, [bS x K x N] // auto state = OUTPUT_VARIABLE(1); // c_t, [bS x K x N]
// const int bS = input->shapeOf()[0]; // bS - batch size // const int bS = input->shapeOf()[0]; // bS - batch size
// const int K = input->shapeOf()[1]; // K - number of features // const int K = input->shapeOf()[1]; // K - number of features
// const int N = input->shapeOf()[2]; // N - number of time steps // const int N = input->shapeOf()[2]; // N - number of time steps
// const auto wi = mmul(*weights, *input); // U [bS x 3K x N] // const auto wi = mmul(*weights, *input); // U [bS x 3K x N]
// const auto bF = (*bias)({0,0, 0, K}); // biases for forget gate [1 x K] // const auto bF = (*bias)({0,0, 0, K}); // biases for forget gate [1 x K]
// const auto bR = (*bias)({0,0, K,2*K}); // biases for reset gate [1 x K] // const auto bR = (*bias)({0,0, K,2*K}); // biases for reset gate [1 x K]
@ -664,7 +664,7 @@ DECLARE_SHAPE_FN(sru_bi_bp) {
// ft = sigmoid_(ft + bF); // ft = sigmoid_(ft + bF);
// rt = sigmoid_(rt + bR); // rt = sigmoid_(rt + bR);
// ct = ft * (ct - zt) + zt; // ct = ft * (ct - zt) + zt;
// // TODO T val = (activation_type == 1) ? tanh(cur) : ((activation_type == 2) ? reluf(cur) : cur ); // // TODO T val = (activation_type == 1) ? tanh(cur) : ((activation_type == 2) ? reluf(cur) : cur );
// ct.applyTransform(transform::Tanh, &gct); // ct.applyTransform(transform::Tanh, &gct);
// ht = rt * (gct - xt) + xt; // ht = rt * (gct - xt) + xt;
@ -694,16 +694,16 @@ DECLARE_SHAPE_FN(sru_bi_bp) {
// Nd4jLong* newShapeInfo1 = nullptr; // Nd4jLong* newShapeInfo1 = nullptr;
// ALLOCATE(newShapeInfo1, block.getWorkspace(), size, Nd4jLong); // ALLOCATE(newShapeInfo1, block.getWorkspace(), size, Nd4jLong);
// newShapeInfo1[0] = rank; // newShapeInfo1[0] = rank;
// newShapeInfo1[1] = bS; // newShapeInfo1[1] = bS;
// newShapeInfo1[2] = K; // newShapeInfo1[2] = K;
// newShapeInfo1[3] = N; // newShapeInfo1[3] = N;
// ShapeUtils::updateStridesAndType(newShapeInfo1, inShape, order); // ShapeUtils::updateStridesAndType(newShapeInfo1, inShape, order);
// auto result = CONSTANT(newShapeInfo1); // auto result = CONSTANT(newShapeInfo1);
// return SHAPELIST(result, result); // return SHAPELIST(result, result);
// } // }
// ////////////////////////////////////////////////////////////////////////// // //////////////////////////////////////////////////////////////////////////
@ -860,7 +860,7 @@ DECLARE_SHAPE_FN(sru_bi_bp) {
// const std::string inGradCtCorrectShape = ShapeUtils::shapeAsString({bS, inSize}); // const std::string inGradCtCorrectShape = ShapeUtils::shapeAsString({bS, inSize});
// const std::string inGradHShape = ShapeUtils::shapeAsString(inGradH); // const std::string inGradHShape = ShapeUtils::shapeAsString(inGradH);
// const std::string inGradHCorrectShape = ShapeUtils::shapeAsString({bS, inSize, time}); // const std::string inGradHCorrectShape = ShapeUtils::shapeAsString({bS, inSize, time});
// REQUIRE_TRUE(wShape == wCorrectShape, 0, "SRU_BP operation: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str()); // REQUIRE_TRUE(wShape == wCorrectShape, 0, "SRU_BP operation: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
// // REQUIRE_TRUE(bShape == bCorrectShape, 0, "SRU_BP operation: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str()); // // REQUIRE_TRUE(bShape == bCorrectShape, 0, "SRU_BP operation: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
// REQUIRE_TRUE(c0Shape == c0CorrectShape, 0, "SRU_BP operation: wrong shape of initial state array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), c0Shape.c_str()); // REQUIRE_TRUE(c0Shape == c0CorrectShape, 0, "SRU_BP operation: wrong shape of initial state array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), c0Shape.c_str());
@ -896,11 +896,11 @@ DECLARE_SHAPE_FN(sru_bi_bp) {
// auto inGradHt = (*inGradH)({ 0,0, 0,0, t,t+1}); // [bS x inSize x time] -> [bS x inSize] // auto inGradHt = (*inGradH)({ 0,0, 0,0, t,t+1}); // [bS x inSize x time] -> [bS x inSize]
// auto ct_1 = t ? (*c)({ 0,0, 0,0, t-1,t}) : *c0; // previous c_{t-1} // auto ct_1 = t ? (*c)({ 0,0, 0,0, t-1,t}) : *c0; // previous c_{t-1}
// ///////////////// forward // ///////////////// forward
// // ft = sigmoid(ft + bf), rt = sigmoid(rt + bR) // // ft = sigmoid(ft + bf), rt = sigmoid(rt + bR)
// ft = sigmoid_(ft + bF); // ft = sigmoid_(ft + bF);
// rt = sigmoid_(rt + bR); // rt = sigmoid_(rt + bR);
// // TODO T val = (activation_type == 1) ? tanh(cur) : ((activation_type == 2) ? reluf(cur) : cur ); // // TODO T val = (activation_type == 1) ? tanh(cur) : ((activation_type == 2) ? reluf(cur) : cur );
// ct.applyTransform(transform::Tanh, &gct); // ct.applyTransform(transform::Tanh, &gct);
@ -910,7 +910,7 @@ DECLARE_SHAPE_FN(sru_bi_bp) {
// NDArray ftMinus = 1. - ft; // NDArray ftMinus = 1. - ft;
// NDArray rtMinus = 1. - rt; // NDArray rtMinus = 1. - rt;
// NDArray gradBRt = inGradHt * (gct - xt) * rtMinus * rt; // NDArray gradBRt = inGradHt * (gct - xt) * rtMinus * rt;
// // bF, TODO - tanh // // bF, TODO - tanh
// NDArray gradTanh = 1. - gct * gct; // NDArray gradTanh = 1. - gct * gct;
// NDArray gradCt = inGradHt * rt * gradTanh; // NDArray gradCt = inGradHt * rt * gradTanh;
// NDArray gradBFt = (gradCt + *inGradCt) * (ct_1 - zt) * ftMinus * ft; // NDArray gradBFt = (gradCt + *inGradCt) * (ct_1 - zt) * ftMinus * ft;
@ -923,7 +923,7 @@ DECLARE_SHAPE_FN(sru_bi_bp) {
// // c_{t-1}, inGradCt = (gradCt + inGradCt) * ft; // // c_{t-1}, inGradCt = (gradCt + inGradCt) * ft;
// *inGradCt = (gradCt + *inGradCt) * ft; // *inGradCt = (gradCt + *inGradCt) * ft;
// // save results // // save results
// gradBias({0,0, 0,inSize, t,t+1}, true).assign(gradBFt); // gradBias({0,0, 0,inSize, t,t+1}, true).assign(gradBFt);
// gradBias({0,0, inSize,2*inSize, t,t+1}, true).assign(gradBRt); // gradBias({0,0, inSize,2*inSize, t,t+1}, true).assign(gradBRt);
// gradU({0,0, 0,inSize, t,t+1}, true).assign(gradUZt); // gradU({0,0, 0,inSize, t,t+1}, true).assign(gradUZt);
@ -934,19 +934,19 @@ DECLARE_SHAPE_FN(sru_bi_bp) {
// // gradInit // // gradInit
// gradInit->assign(inGradCt); // gradInit->assign(inGradCt);
// // gradX // // gradX
// w->transposei(); // [inSize x 3K] // w->transposei(); // [inSize x 3K]
// gradX->assign( mmul(*w, gradU) + gradHX); // gradX->assign( mmul(*w, gradU) + gradHX);
// if(mask) // if(mask)
// gradX->applyBroadcast(broadcast::Multiply, {0,1}, mask, gradX, nullptr); // apply mask // gradX->applyBroadcast(broadcast::Multiply, {0,1}, mask, gradX, nullptr); // apply mask
// // gradB // // gradB
// gradBias.reduceAlongDimension(reduce::Sum, gradB, {0,2}, false, true); // [1 x 2K] // gradBias.reduceAlongDimension(reduce::Sum, gradB, {0,2}, false, true); // [1 x 2K]
// // gradW [bS x 3K x inSize] // // gradW [bS x 3K x inSize]
// x->permutei({0, 2, 1}); // [bS x time x inSize] // x->permutei({0, 2, 1}); // [bS x time x inSize]
// gradW->assign( mmul(gradU, *x) ); // gradW->assign( mmul(gradU, *x) );
// return Status::OK(); // return Status::OK();
// } // }
@ -969,4 +969,4 @@ DECLARE_SHAPE_FN(sru_bi_bp) {
// ShapeDescriptor descriptor4(ArrayOptions::dataType(inShape), order, {bS, inSize}); // ShapeDescriptor descriptor4(ArrayOptions::dataType(inShape), order, {bS, inSize});
// return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(descriptor1), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor2), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor3), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor4)); // return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(descriptor1), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor2), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor3), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor4));
// } // }

View File

@ -293,46 +293,46 @@ void softmax(nd4j::LaunchContext * context, const NDArray& input, NDArray& outpu
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void prelu(nd4j::LaunchContext * context, const NDArray& input, const NDArray& alpha, NDArray& output) { void prelu(nd4j::LaunchContext * context, const NDArray& input, const NDArray& alpha, NDArray& output) {
const Nd4jLong inputLen = input.lengthOf(); const Nd4jLong inputLen = input.lengthOf();
const Nd4jLong* inputShapeInfo = input.getShapeInfo(); const Nd4jLong* inputShapeInfo = input.getShapeInfo();
const Nd4jLong* alphaShapeInfo = alpha.getShapeInfo(); const Nd4jLong* alphaShapeInfo = alpha.getShapeInfo();
PRAGMA_OMP_PARALLEL_FOR_IF(inputLen > Environment::getInstance()->elementwiseThreshold()) PRAGMA_OMP_PARALLEL_FOR_IF(inputLen > Environment::getInstance()->elementwiseThreshold())
for(Nd4jLong i = 0; i < inputLen; ++i) { for(Nd4jLong i = 0; i < inputLen; ++i) {
// FIXME: double! // FIXME: double!
double x = input.e<double>(i); double x = input.e<double>(i);
if(x < 0.0) { if(x < 0.0) {
// FIXME: double
output.p(i, (x * alpha.e<double>(shape::subArrayIndex(i, inputShapeInfo, alphaShapeInfo))));
} else
output.p(i, x);
}
}
//////////////////////////////////////////////////////////////////////////
void preluBP(nd4j::LaunchContext * context, const NDArray& input, const NDArray& alpha, const NDArray& dLdO, NDArray& dLdI, NDArray& dLdA) {
const Nd4jLong inputLen = input.lengthOf();
const Nd4jLong* inputShapeInfo = input.getShapeInfo();
const Nd4jLong* alphaShapeInfo = alpha.getShapeInfo();
dLdA.assign(0.0f);
for(Nd4jLong i = 0; i < inputLen; ++i) {
// FIXME: double // FIXME: double
double x = input.e<double>(i); output.p(i, (x * alpha.e<double>(shape::subArrayIndex(i, inputShapeInfo, alphaShapeInfo))));
double grO = dLdO.e<double>(i); } else
if(x < 0.0) { output.p(i, x);
Nd4jLong alphaInd = shape::subArrayIndex(i, inputShapeInfo, alphaShapeInfo); }
dLdI.p(i, grO * alpha.e<double>(alphaInd)); }
double prevVal = dLdA.e<double>(alphaInd);
prevVal += (grO * x); //////////////////////////////////////////////////////////////////////////
dLdA.p(alphaInd, prevVal ); void preluBP(nd4j::LaunchContext * context, const NDArray& input, const NDArray& alpha, const NDArray& dLdO, NDArray& dLdI, NDArray& dLdA) {
}
else const Nd4jLong inputLen = input.lengthOf();
dLdI.p(i, grO); const Nd4jLong* inputShapeInfo = input.getShapeInfo();
const Nd4jLong* alphaShapeInfo = alpha.getShapeInfo();
dLdA.assign(0.0f);
for(Nd4jLong i = 0; i < inputLen; ++i) {
// FIXME: double
double x = input.e<double>(i);
double grO = dLdO.e<double>(i);
if(x < 0.0) {
Nd4jLong alphaInd = shape::subArrayIndex(i, inputShapeInfo, alphaShapeInfo);
dLdI.p(i, grO * alpha.e<double>(alphaInd));
double prevVal = dLdA.e<double>(alphaInd);
prevVal += (grO * x);
dLdA.p(alphaInd, prevVal);
}
else
dLdI.p(i, grO);
} }
} }

View File

@ -45,7 +45,7 @@ static void batchnorm_(const NDArray* input, const NDArray* mean, const NDArray*
} }
// auto sigmaInvGam = (*variance + epsilon).transform(transform::RSqrt); // sigmaInvGam = 1 / sqrt(variance + epsilon) // auto sigmaInvGam = (*variance + epsilon).transform(transform::RSqrt); // sigmaInvGam = 1 / sqrt(variance + epsilon)
// if(gamma != nullptr) sigmaInvGam *= *gamma; // if(gamma != nullptr) sigmaInvGam *= *gamma;
const T* sigmaBuff = sigmaInvGam.bufferAsT<T>(); const T* sigmaBuff = sigmaInvGam.bufferAsT<T>();
const T* meanBuff = mean->bufferAsT<T>(); const T* meanBuff = mean->bufferAsT<T>();
@ -60,8 +60,8 @@ static void batchnorm_(const NDArray* input, const NDArray* mean, const NDArray*
uint inShapeInfoCast[MAX_RANK]; uint inShapeInfoCast[MAX_RANK];
uint meanShapeInfoCast[MAX_RANK]; uint meanShapeInfoCast[MAX_RANK];
bool canCastIn = nd4j::DataTypeUtils::castShapeInfo(inShapeInfo, inShapeInfoCast); bool canCastIn = nd4j::DataTypeUtils::castShapeInfo(inShapeInfo, inShapeInfoCast);
bool canCastMean = nd4j::DataTypeUtils::castShapeInfo(meanShapeInfo, meanShapeInfoCast); bool canCastMean = nd4j::DataTypeUtils::castShapeInfo(meanShapeInfo, meanShapeInfoCast);
const Nd4jLong step = lenBig / lenSmall; const Nd4jLong step = lenBig / lenSmall;
std::vector<int> dimsToExclude = ShapeUtils::evalDimsToExclude(input->rankOf(), axes); std::vector<int> dimsToExclude = ShapeUtils::evalDimsToExclude(input->rankOf(), axes);
@ -70,58 +70,62 @@ static void batchnorm_(const NDArray* input, const NDArray* mean, const NDArray*
if(beta != nullptr) { if(beta != nullptr) {
const T* betaBuff = beta->bufferAsT<T>(); const T* betaBuff = beta->bufferAsT<T>();
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{ {
const auto threadNum = omp_get_thread_num(); const auto threadNum = omp_get_thread_num();
Nd4jLong* inOffsets = new Nd4jLong[step]; Nd4jLong* inOffsets = new Nd4jLong[step];
Nd4jLong* memBuff = new Nd4jLong[2 * inShapeInfo[0]];
for (int j = 0; j < lenSmall; ++j) {
for (int j = 0; j < lenSmall; ++j) {
const bool isOwner = j < info._numThreads ? threadNum == j : threadNum == j % info._numThreads; const bool isOwner = j < info._numThreads ? threadNum == j : threadNum == j % info._numThreads;
if (!isOwner) continue; if (!isOwner) continue;
const Nd4jLong start = j * step; const Nd4jLong start = j * step;
const Nd4jLong end = start + step; const Nd4jLong end = start + step;
// calculate offset for mean, variance, gamma, beta (all of them have the same shape) // calculate offset for mean, variance, gamma, beta (all of them have the same shape)
auto offsetSmall = shape::indexOffset(j, meanShapeInfo, meanShapeInfoCast, lenSmall, canCastMean); auto offsetSmall = shape::indexOffset(j, meanShapeInfo, meanShapeInfoCast, lenSmall, canCastMean);
// calculate offset for input and output (all of them have the same shape) // calculate offset for input and output (all of them have the same shape)
shape::outerArrayOffsets(inOffsets, j, inShapeInfo, meanShapeInfo, dimsToExclude.data()); shape::outerArrayOffsets(inOffsets, j, inShapeInfo, meanShapeInfo, memBuff, dimsToExclude.data());
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < step; ++i) { for (Nd4jLong i = 0; i < step; ++i) {
auto offsetBig = inOffsets[i]; auto offsetBig = inOffsets[i];
outBuff[offsetBig] = (inBuff[offsetBig] - meanBuff[offsetSmall]) * sigmaBuff[offsetSmall] + betaBuff[offsetSmall]; outBuff[offsetBig] = (inBuff[offsetBig] - meanBuff[offsetSmall]) * sigmaBuff[offsetSmall] + betaBuff[offsetSmall];
} }
} }
delete []inOffsets; delete []inOffsets;
} delete []memBuff;
}
} }
else { else {
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{ {
const auto threadNum = omp_get_thread_num(); const auto threadNum = omp_get_thread_num();
Nd4jLong* inOffsets = new Nd4jLong[step]; Nd4jLong* inOffsets = new Nd4jLong[step];
Nd4jLong* memBuff = new Nd4jLong[2 * inShapeInfo[0]];
for (int j = 0; j < lenSmall; ++j) {
for (int j = 0; j < lenSmall; ++j) {
const bool isOwner = j < info._numThreads ? threadNum == j : threadNum == j % info._numThreads; const bool isOwner = j < info._numThreads ? threadNum == j : threadNum == j % info._numThreads;
if (!isOwner) continue; if (!isOwner) continue;
const Nd4jLong start = j * step; const Nd4jLong start = j * step;
const Nd4jLong end = start + step; const Nd4jLong end = start + step;
// calculate offset for mean, variance, gamma, beta (all of them have the same shape) // calculate offset for mean, variance, gamma, beta (all of them have the same shape)
auto offsetSmall = shape::indexOffset(j, meanShapeInfo, meanShapeInfoCast, lenSmall, canCastMean); auto offsetSmall = shape::indexOffset(j, meanShapeInfo, meanShapeInfoCast, lenSmall, canCastMean);
// calculate offset for input and output (all of them have the same shape) // calculate offset for input and output (all of them have the same shape)
shape::outerArrayOffsets(inOffsets, j, inShapeInfo, meanShapeInfo, dimsToExclude.data()); shape::outerArrayOffsets(inOffsets, j, inShapeInfo, meanShapeInfo, memBuff, dimsToExclude.data());
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < step; ++i) { for (Nd4jLong i = 0; i < step; ++i) {
auto offsetBig = inOffsets[i]; auto offsetBig = inOffsets[i];
outBuff[offsetBig] = (inBuff[offsetBig] - meanBuff[offsetSmall]) * sigmaBuff[offsetSmall]; outBuff[offsetBig] = (inBuff[offsetBig] - meanBuff[offsetSmall]) * sigmaBuff[offsetSmall];
} }
} }
delete []inOffsets; delete []inOffsets;
delete []memBuff;
} }
} }
} }

View File

@ -30,11 +30,20 @@ namespace helpers {
if (x_shape->lengthOf() == 1 || y_shape->lengthOf() == 1) {// except case if (x_shape->lengthOf() == 1 || y_shape->lengthOf() == 1) {// except case
auto lesser = (x_shape->lengthOf() == 1 ? x_shape: y_shape); // lenght are equals
auto greater = (x_shape->lengthOf() == 1 ? y_shape: x_shape); if (x_shape->lengthOf() == y_shape->lengthOf()) {
output->assign(greater); auto greater = (x_shape->e<Nd4jLong>(0) < y_shape->e<Nd4jLong>(0) ? y_shape : x_shape);
output->assign(greater);
output->p(greater->lengthOf() - 1, lesser->e(0L)); }
else {
auto lesser = (x_shape->lengthOf() == 1 ? x_shape : y_shape);
auto greater = (x_shape->lengthOf() == 1 ? y_shape : x_shape);
output->assign(greater);
auto lastG = greater->lengthOf() - 1;
auto lastL = lesser->lengthOf() - 1;
if (greater->e<Nd4jLong>(lastG) < lesser->e<Nd4jLong>(lastL))
output->p(lastG, lesser->e(lastL));
}
} }
else { else {
//int e = 0, x = 0, y = 0; //int e = 0, x = 0, y = 0;

View File

@ -1418,18 +1418,15 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d(
for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) for (Nd4jLong kw = wstart; kw < wend; kw += iStep3)
sum += pIn[kh + kw]; sum += pIn[kh + kw];
if (extraParam0 == 0) { //Exclude padding
auto oi = b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3; int a = (hend-hstart)/iStep2 + ((hend-hstart) % iStep2 == 0 ? 0 : 1);
int b = (wend-wstart)/iStep3 + ((wend-wstart) % iStep3 == 0 ? 0 : 1);
if (extraParam0 == 0) { //Exclude padding sum /= static_cast<T>(a * b); // Accounts for dilation
int _a = (hend-hstart)/iStep2 + ((hend-hstart) % iStep2 == 0 ? 0 : 1); }
int _b = (wend-wstart)/iStep3 + ((wend-wstart) % iStep3 == 0 ? 0 : 1); else if (extraParam0 == 1) //Include padding
sum /= _a * _b; //Accounts for dilation
} else if (extraParam0 == 1) //Include padding
sum /= kProd; sum /= kProd;
out[oi] = sum; out[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3] = sum;
} }
} }
} }

View File

@ -1,76 +1,92 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyrigkht (c) 2015-2018 Skymind, Inc.
* *
* This program and the accompanying materials are made available under the * Tkhis program and tkhe accompanying materials are made available under tkhe
* terms of the Apache License, Version 2.0 which is available at * terms of tkhe Apackhe License, Version 2.0 wkhickh is available at
* https://www.apache.org/licenses/LICENSE-2.0. * khttps://www.apackhe.org/licenses/LICENSE-2.0.
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * distributed under tkhe License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * WARRANTIES OR CONDITIONS OF ANY KIND, eitkher express or implied. See tkhe
* License for the specific language governing permissions and limitations * License for tkhe specific language governing permissions and limitations
* under the License. * under tkhe License.
* *
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apackhe-2.0
******************************************************************************/ ******************************************************************************/
// //
// @author raver119@gmail.com // @autkhor raver119@gmail.com
// //
#include <ops/declarable/helpers/dilation2d.h> #include <ops/declarable/helpers/dilation2d.h>
#include <array/DataTypeUtils.h> #include <array/DataTypeUtils.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
template <typename X, typename Y>
static void dilation2d_(NDArray *input, NDArray *weights, NDArray *output, int stride_rows, int stride_cols, int rate_rows, int rate_cols, int pad_top, int pad_left) {
const int batch = input->sizeAt(0);
const int input_rows = input->sizeAt(1);
const int input_cols = input->sizeAt(2);
const int depth = input->sizeAt(3);
const int filter_rows = weights->sizeAt(0); //////////////////////////////////////////////////////////////////////
const int filter_cols = weights->sizeAt(1); template <typename X, typename Z>
static void dilation2d_(NDArray *input, NDArray *weights, NDArray *output, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW) {
const int output_rows = output->sizeAt(1); // input [bS, iH, iW, iC]
const int output_cols = output->sizeAt(2); // weights [kH, kW, iC]
// output [bS, oH, oW, iC]
PRAGMA_OMP_PARALLEL_FOR_SIMD const X* x = input->bufferAsT<X>();
for (int b = 0; b < batch; ++b) { const X* y = weights->bufferAsT<X>();
for (int h_out = 0; h_out < output_rows; ++h_out) { Z* z = output->bufferAsT<Z>();
int h_beg = h_out * stride_rows - pad_top;
for (int w_out = 0; w_out < output_cols; ++w_out) { const Nd4jLong* xShapeInfo = input->getShapeInfo();
int w_beg = w_out * stride_cols - pad_left; const Nd4jLong* yShapeInfo = weights->getShapeInfo();
for (int d = 0; d < depth; ++d) { const Nd4jLong* zShapeInfo = output->getShapeInfo();
Y cur_val = -DataTypeUtils::max<Y>();
for (int h = 0; h < filter_rows; ++h) { const uint bS = input->sizeAt(0);
const int h_in = h_beg + h * rate_rows; const uint iH = input->sizeAt(1);
if (h_in >= 0 && h_in < input_rows) { const uint iW = input->sizeAt(2);
for (int w = 0; w < filter_cols; ++w) { const uint iC = input->sizeAt(3);
const int w_in = w_beg + w * rate_cols;
if (w_in >= 0 && w_in < input_cols) { const uint kH = weights->sizeAt(0);
const Y val = input->e<Y>(b, h_in, w_in, d) + weights->e<Y>(h, w, d); const uint kW = weights->sizeAt(1);
if (val > cur_val) {
cur_val = val; const uint oH = output->sizeAt(1);
} const uint oW = output->sizeAt(2);
}
} PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(4))
} for (uint b = 0; b < bS; ++b) {
for (uint oh = 0; oh < oH; ++oh) {
for (uint ow = 0; ow < oW; ++ow) {
for (uint c = 0; c < iC; ++c) {
X max = -DataTypeUtils::max<X>();
for (uint kh = 0; kh < kH; ++kh) {
const int ih = oh * sH - pH + kh * dH;
if (ih < 0 || ih >= iH) continue;
for (uint kw = 0; kw < kW; ++kw) {
const int iw = ow * sW - pW + kw * dW;
if(iw < 0 || iw >= iW) continue;
const X val = x[shape::getOffset(xShapeInfo, {b,(uint)ih,(uint)iw,c})] + y[shape::getOffset(yShapeInfo, {kh,kw,c})];
if (val > max)
max = val;
} }
(*output).p<Y>(b, h_out, w_out, d, cur_val);
} }
z[shape::getOffset(zShapeInfo, {b,oh,ow,c})] = static_cast<Z>(max);
} }
} }
} }
};
void dilation2d(nd4j::LaunchContext * context, NDArray *input, NDArray *weights, NDArray *output, int stride_rows, int stride_cols, int rate_rows, int rate_cols, int pad_top, int pad_left) {
BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), dilation2d_, (input, weights, output, stride_rows, stride_cols, rate_rows, rate_cols, pad_top, pad_left), LIBND4J_TYPES, FLOAT_TYPES);
} }
}
BUILD_DOUBLE_TEMPLATE(template void dilation2d_, (NDArray *input, NDArray *weights, NDArray *output, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW), LIBND4J_TYPES, FLOAT_TYPES);
void dilation2d(nd4j::LaunchContext* context, NDArray *input, NDArray *weights, NDArray *output, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW) {
BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), dilation2d_, (input, weights, output, sH, sW, pH, pW, dH, dW), LIBND4J_TYPES, FLOAT_TYPES);
}
BUILD_DOUBLE_TEMPLATE(template void dilation2d_, (NDArray *input, NDArray *weights, NDArray *output, int stride_rows, int stride_cols, int rate_rows, int rate_cols, int pad_top, int pad_left), LIBND4J_TYPES, FLOAT_TYPES);
} }
} }

View File

@ -44,9 +44,9 @@ namespace helpers {
BUILD_SINGLE_TEMPLATE(template void dropoutSimple, (NDArray const* input, NDArray* output, double probValue, int seed), FLOAT_TYPES); BUILD_SINGLE_TEMPLATE(template void dropoutSimple, (NDArray const* input, NDArray* output, double probValue, int seed), FLOAT_TYPES);
template <typename T> template <typename T>
int _dropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue) { int dropOutFunctor_(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue) {
//NativeOps native; //NativeOps native;
//nd4j::graph::RandomGenerator nodeRng(seed); //static int _dropOutFunctor(nd4j::random::RandomBuffer* rng, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue) { //nd4j::graph::RandomGenerator nodeRng(seed); //static int dropOutFunctor_(nd4j::random::RandomBuffer* rng, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue) {
//NativeOps native; //NativeOps native;
//native.reSeedBuffer(nullptr, (long)seed, rng); //native.reSeedBuffer(nullptr, (long)seed, rng);
//if (newRng ) //if (newRng )
@ -78,9 +78,9 @@ namespace helpers {
// broadcast chunk to full matrix // broadcast chunk to full matrix
std::unique_ptr<NDArray> dropOutMultiplier(new NDArray(*input)); std::unique_ptr<NDArray> dropOutMultiplier(new NDArray(*input));
dropOutMultiplier->assign(1.f); dropOutMultiplier->assign(1.f);
*dropOutMultiplier += *chunk; *dropOutMultiplier += *chunk;
output->assign(*input * *dropOutMultiplier); //input->applyPairwiseTransform(pairwise::Multiply, dropOutMultiplier.get(), output, nullptr); output->assign(*input * *dropOutMultiplier); //input->applyPairwiseTransform(pairwise::Multiply, dropOutMultiplier.get(), output, nullptr);
} }
@ -90,10 +90,10 @@ namespace helpers {
int dropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue) { int dropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue) {
auto xType = input->dataType(); auto xType = input->dataType();
BUILD_SINGLE_SELECTOR(xType, return _dropOutFunctor, (context, input, output, reduceShape, seed, probValue), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(xType, return dropOutFunctor_, (context, input, output, reduceShape, seed, probValue), FLOAT_TYPES);
} }
BUILD_SINGLE_TEMPLATE(template int _dropOutFunctor, (graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue);, FLOAT_TYPES); BUILD_SINGLE_TEMPLATE(template int dropOutFunctor_, (graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue);, FLOAT_TYPES);
/////////////////////////////////// backrpopagations /////////////////////////////////////////////// /////////////////////////////////// backrpopagations ///////////////////////////////////////////////
template <typename T> template <typename T>

View File

@ -40,51 +40,75 @@ void gruCell(nd4j::LaunchContext * context, const NDArray* x, const NDArray* hLa
NDArray* r, NDArray* u, NDArray* c, NDArray* h) { NDArray* r, NDArray* u, NDArray* c, NDArray* h) {
//Inputs: //Inputs:
// x input [bS x inSize] // x input [bS, nIn], nIn - input size
// hLast previous cell output [bS x numUnits], that is at previous time step t-1 // hLast previous cell output [bS, nUn], that is at previous time step t-1, nUn - number of units
// Wru RU weights - [bS, 2*numUnits] - reset and update gates // Wru RU weights - [nIn+nUn, 2*nUn] - reset and update gates
// Wc C weights - [bS, numUnits] - cell gate // Wc C weights - [nIn+nUn, nUn] - cell gate
// bru r and u biases, [2*numUnits] - reset and update gates // bru r and u biases, [2*nUn] - reset and update gates
// bc c biases, [numUnits] - cell gate // bc c biases, [nUn] - cell gate
//Outputs: //Outputs:
// r Reset gate output [bS, numUnits] // r Reset gate output [bS, nUn]
// u Update gate output [bS, numUnits] // u Update gate output [bS, nUn]
// c Cell gate output [bS, numUnits] // c Cell gate output [bS, nUn]
// h current cell output [bS, numUnits] // h current cell output [bS, nUn]
/***************************************************************************************/
/************************ THIS IS NOT OPTIMAZED CODE ***********************************/
/** however it is more math-friendly and convenient for backprop formulas derivation) **/
const int bS = x->sizeAt(0);
const int nIn = x->sizeAt(1); const int nIn = x->sizeAt(1);
const int nU = hLast->sizeAt(1); // number of units const int nUn = hLast->sizeAt(1);
//Concat inputs: [x, yt-1]: concat([bs,nIn],[bs,nOut]) -> [bs, (nIn+nOut)] NDArray Wr = (*Wru)({0,nIn, 0,0}); // reset gates weights [nIn, 2*nUn]
nd4j::ops::concat concatOp; NDArray Wu = (*Wru)({nIn,nIn+nUn, 0,0}); // updates gates weights [nUn, 2*nUn]
std::vector<NDArray*> inputs;
std::vector<double> targs;
std::vector<Nd4jLong> iargs({1}); //Axis = 1
std::vector<bool> bargs;
inputs.emplace_back(const_cast<NDArray*>(x));
inputs.emplace_back(const_cast<NDArray*>(hLast));
auto result = concatOp.execute(inputs, targs, iargs, bargs); NDArray Wcr = (*Wc)({0,nIn, 0,0}); // reset cell weights [nIn, nUn]
auto concatOut = result->at(0); NDArray Wcu = (*Wc)({nIn,nIn+nUn, 0,0}); // updates cell weights [nUn, nUn]
//mmul/z for reset and update gates: (x * weight_ux + hLast * weight_xr + b_u) // gates = sigmoid(x*Wr + hLast*Wu + br + bu)
auto m = mmul(*concatOut, *Wru); //mmul: [bs, (nIn+numUnits)]* [(inSize+numUnits), 2*numUnits] = [bs, 4*numUnits] NDArray gates = mmul(*x, Wr) + mmul(*hLast, Wu) + *bru; // [bS, nIn] * [nIn, 2*nUn] + [bS, nUn] * [nUn, 2*nUn] + [2*nUn] = [bS, 2*nUn]
m += (*bru); gates.applyTransform(transform::Sigmoid);
// reset gate
r->assign(gates({0,0, 0,nUn})); // [bS, nUn]
// update gate
u->assign(gates({0,0, nUn,2*nUn})); // [bS, nUn]
// cell gate c = activation(x*Wcr + (r◦hlast)*Wcu + bc)
c->assign(mmul(*x, Wcr) + mmul(*r * *hLast, Wcu) + *bc); // [bS, nIn] * [nIn, nUn] + [bS, nUn] * [nUn, nUn] + [nUn] = [bS, nUn]
c->applyTransform(transform::Tanh);
// cell output
h->assign(*u * *hLast + (1.f - *u) * *c);
/***************************************************************************************/
/********************** THIS MORE OPTIMAZED CODE (except concat ) **********************/
/***************************************************************************************/
/*
//Concat inputs: x + hLast : [bs, nIn + nUn]
NDArray xhConcat(x->ordering(), {bS, nIn + nUn}, x->dataType(), context); // concat([bs, nIn], [bs, nUn]) -> [bs, nIn + nUn]
helpers::concat(context, {const_cast<NDArray*>(x), const_cast<NDArray*>(hLast)}, xhConcat, {1});
//mmul for reset and update gates: (x * weight_ux + hLast * weight_xr + b_u)
auto m = mmul(xhConcat, *Wru) + *bru ; // [bs, nIn+nUn] * [nIn+nUn, 2*nUn] = [bs, 2*nUn]
// m += *bru;
sigmoidInplace(m); //sigmoid(rz) and sigmoid(uz) sigmoidInplace(m); //sigmoid(rz) and sigmoid(uz)
auto mr = m({0,0, 0, nU});
auto mu = m({0,0, nU, 2*nU});
r->assign(&mr); r->assign(m({0,0, 0, nUn}));
u->assign(&mu); u->assign(m({0,0, nUn, 2*nUn}));
//Concatenated inputs: [x, yt-1 .* r] // hLast = hLast * r
auto yr = (*concatOut)({0,0, nIn, nIn+nU}); xhConcat({0,0, nIn, nIn+nUn}) *= *r;
yr *= (*r);
//c = tanh(x * weight_cx + (hLast .* r) * weight_cr + b_c) //c = tanh(x * weight_cx + (hLast .* r) * weight_cr + b_c)
MmulHelper::mmul(concatOut, const_cast<NDArray*>(Wc), c, 1.0, 0.0); //c = 1.0 * concatOut * Wc + 0.0 * c MmulHelper::mmul(&xhConcat, Wc, c, 1.0, 0.0); //c = 1.0 * xhConcat * Wc + 0.0 * c
*c += *bc; *c += *bc;
tanhInplace(*c); tanhInplace(*c);
@ -94,135 +118,134 @@ void gruCell(nd4j::LaunchContext * context, const NDArray* x, const NDArray* hLa
auto temp = (1.0f - *u); auto temp = (1.0f - *u);
temp *= (*c); temp *= (*c);
(*h) += temp; (*h) += temp;
*/
delete result;
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void gruTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h) { void gruTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h) {
// x input [time, bS, iS] // x input [time, bS, iS]
// h0 initial cell output (at time step = 0) [bS, nU] // h0 initial cell output (at time step = 0) [bS, nUn]
// Wx input-to-hidden weights, [iS, 3*nU] // Wx input-to-hidden weights, [iS, 3*nUn]
// Wh hidden-to-hidden weights, [nU, 3*nU] // Wh hidden-to-hidden weights, [nUn, 3*nUn]
// b biases, [3*nU] // b biases, [3*nUn]
// h is cell outputs at each time step [time, bS, nU] // h is cell outputs at each time step [time, bS, nUn]
const int time = x->sizeAt(0); const int time = x->sizeAt(0);
NDArray ht_1(*h0); NDArray ht_1(*h0);
// loop through time steps // loop through time steps
for (int t = 0; t < time; ++t) { for (int t = 0; t < time; ++t) {
auto xt = (*x)({t,t+1, 0,0, 0,0}); auto xt = (*x)({t,t+1, 0,0, 0,0});
auto ht = (*h)({t,t+1, 0,0, 0,0}); auto ht = (*h)({t,t+1, 0,0, 0,0});
//helpers::gruCell(&xt, &ht_1, Wx, Wh, b, &ht); //helpers::gruCell(&xt, &ht_1, Wx, Wh, b, &ht);
//ht_1.assign(ht); //ht_1.assign(ht);
} }
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void gruCellBP(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh, const NDArray* dLdWx0, void gruCellBP(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh, const NDArray* dLdWx0,
const NDArray* dLdWh0, const NDArray* dLdb0, NDArray* dLdx, NDArray* dLdh0, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb) { const NDArray* dLdWh0, const NDArray* dLdb0, NDArray* dLdx, NDArray* dLdh0, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb) {
// x input [bS, iS] // x input [bS, iS]
// h0 previous cell output [bS, nU], that is at previous time step t-1 // h0 previous cell output [bS, nUn], that is at previous time step t-1
// Wx input-to-hidden weights, [iS, 3*nU] // Wx input-to-hidden weights, [iS, 3*nUn]
// Wh hidden-to-hidden weights, [nU, 3*nU] // Wh hidden-to-hidden weights, [nUn, 3*nUn]
// b biases, [3*nU] // b biases, [3*nUn]
// dLdh gradient wrt output, [bS,nU], that is epsilon_next // dLdh gradient wrt output, [bS,nUn], that is epsilon_next
// dLdWx0 gradient wrt Wx at previous time step, [iS, 3*nU] // dLdWx0 gradient wrt Wx at previous time step, [iS, 3*nUn]
// dLdWh0 gradient wrt Wh at previous time step, [nU, 3*nU] // dLdWh0 gradient wrt Wh at previous time step, [nUn, 3*nUn]
// dLdb0 gradient wrt b at previous time step, [3*nU] // dLdb0 gradient wrt b at previous time step, [3*nUn]
// dLdx gradient wrt x, [bS, iS], that is epsilon // dLdx gradient wrt x, [bS, iS], that is epsilon
// dLdh0 gradient wrt h0, [bS, nU] // dLdh0 gradient wrt h0, [bS, nUn]
// dLdWx gradient wrt Wx, [iS, 3*nU] // dLdWx gradient wrt Wx, [iS, 3*nUn]
// dLdWh gradient wrt Wh, [nU, 3*nU] // dLdWh gradient wrt Wh, [nUn, 3*nUn]
// dLdb gradient wrt b at previous time step, [3*nU] // dLdb gradient wrt b at previous time step, [3*nUn]
// h is current cell output [bS, nU], that is at current time step t // h is current cell output [bS, nUn], that is at current time step t
const int nU = h0->sizeAt(1); const int nUn = h0->sizeAt(1);
// ***** feed forward step ***** // // ***** feed forward step ***** //
// gates = sigmoid(x*Wx + h0*Wh + b) // gates = sigmoid(x*Wx + h0*Wh + b)
auto gates = sigmoid(mmul(*x, (*Wx)({0,0, 0,2*nU})) + mmul(*h0, (*Wh)({0,0, 0,2*nU})) + (*b)({0,2*nU})); // [bS, 2*nU] + [bS, 2*nU] + [1, 2*nU] = [bS, 2*nU] auto gates = sigmoid(mmul(*x, (*Wx)({0,0, 0,2*nUn})) + mmul(*h0, (*Wh)({0,0, 0,2*nUn})) + (*b)({0,2*nUn})); // [bS, 2*nUn] + [bS, 2*nUn] + [1, 2*nUn] = [bS, 2*nUn]
// reset gate // reset gate
auto r = gates({0,0, 0, nU}); // [bS, nU] auto r = gates({0,0, 0, nUn}); // [bS, nUn]
// update gate // update gate
auto u = gates({0,0, nU, 2*nU}); // [bS, nU] auto u = gates({0,0, nUn, 2*nUn}); // [bS, nUn]
// ◦ means element-wise product or so called Hadamard product // ◦ means element-wise product or so called Hadamard product
// n = tanh(x*Wx + (r◦h0)*Wh + b) // n = tanh(x*Wx + (r◦h0)*Wh + b)
auto n = tanh(mmul(*x, (*Wx)({0,0, 2*nU,3*nU})) + mmul((*h0)*r, (*Wh)({0,0, 2*nU,3*nU})) + (*b)({2*nU,3*nU})); // [bS, nU] auto n = tanh(mmul(*x, (*Wx)({0,0, 2*nUn,3*nUn})) + mmul((*h0)*r, (*Wh)({0,0, 2*nUn,3*nUn})) + (*b)({2*nUn,3*nUn})); // [bS, nUn]
// ***** back prop step ***** // // ***** back prop step ***** //
auto Wxr = (*Wx)({0,0, 0, nU}); auto Wxr = (*Wx)({0,0, 0, nUn});
auto Wxu = (*Wx)({0,0, nU, 2*nU}); auto Wxu = (*Wx)({0,0, nUn, 2*nUn});
auto Wxn = (*Wx)({0,0, 2*nU,3*nU}); auto Wxn = (*Wx)({0,0, 2*nUn,3*nUn});
auto Whr = (*Wh)({0,0, 0, nU}); auto Whr = (*Wh)({0,0, 0, nUn});
auto Whu = (*Wh)({0,0, nU, 2*nU}); auto Whu = (*Wh)({0,0, nUn, 2*nUn});
auto Whn = (*Wh)({0,0, 2*nU,3*nU}); auto Whn = (*Wh)({0,0, 2*nUn,3*nUn});
auto WxrT = Wxr.transpose(); auto WxrT = Wxr.transpose();
auto WxuT = Wxu.transpose(); auto WxuT = Wxu.transpose();
auto WxnT = Wxn.transpose(); auto WxnT = Wxn.transpose();
auto WhrT = Whr.transpose(); auto WhrT = Whr.transpose();
auto WhuT = Whu.transpose(); auto WhuT = Whu.transpose();
auto WhnT = Whn.transpose(); auto WhnT = Whn.transpose();
auto xT = x->transpose(); auto xT = x->transpose();
auto h0T = h0->transpose(); auto h0T = h0->transpose();
auto dLdWxr = (*dLdWx)({0,0, 0, nU}); auto dLdWxr = (*dLdWx)({0,0, 0, nUn});
auto dLdWxu = (*dLdWx)({0,0, nU, 2*nU}); auto dLdWxu = (*dLdWx)({0,0, nUn, 2*nUn});
auto dLdWxn = (*dLdWx)({0,0, 2*nU,3*nU}); auto dLdWxn = (*dLdWx)({0,0, 2*nUn,3*nUn});
auto dLdWhr = (*dLdWh)({0,0, 0, nU}); auto dLdWhr = (*dLdWh)({0,0, 0, nUn});
auto dLdWhu = (*dLdWh)({0,0, nU, 2*nU}); auto dLdWhu = (*dLdWh)({0,0, nUn, 2*nUn});
auto dLdWhn = (*dLdWh)({0,0, 2*nU,3*nU}); auto dLdWhn = (*dLdWh)({0,0, 2*nUn,3*nUn});
auto dLdbr = (*dLdb)({0, nU}); auto dLdbr = (*dLdb)({0, nUn});
auto dLdbu = (*dLdb)({nU, 2*nU}); auto dLdbu = (*dLdb)({nUn, 2*nUn});
auto dLdbn = (*dLdb)({2*nU,3*nU}); auto dLdbn = (*dLdb)({2*nUn,3*nUn});
auto dhdu = *h0 - n; // [bS, nU] auto dhdu = *h0 - n; // [bS, nUn]
auto dhdn = 1.f - u; // [bS, nU] auto dhdn = 1.f - u; // [bS, nUn]
auto dSigdu = u * (1.f - u); // [bS, nU] auto dSigdu = u * (1.f - u); // [bS, nUn]
auto dSigdr = r * (1.f - r); // [bS, nU] auto dSigdr = r * (1.f - r); // [bS, nUn]
auto dActdn = 1.f - n * n; // [bS, nU] auto dActdn = 1.f - n * n; // [bS, nUn]
auto dndr = mmul(dActdn * (*h0), WhnT); auto dndr = mmul(dActdn * (*h0), WhnT);
auto drdh0 = mmul(dSigdr, WhrT); auto drdh0 = mmul(dSigdr, WhrT);
auto dLdn = (*dLdh) * dhdn; auto dLdn = (*dLdh) * dhdn;
auto dLdu = (*dLdh) * dhdu; auto dLdu = (*dLdh) * dhdu;
auto dLdr = dLdn * dndr; auto dLdr = dLdn * dndr;
dLdx->assign( mmul(dLdu * dSigdu, WxuT) + mmul(dLdr * dSigdr, WxrT) + mmul(dLdn * dActdn, WxnT) ); // [bS,iS] dLdx->assign( mmul(dLdu * dSigdu, WxuT) + mmul(dLdr * dSigdr, WxrT) + mmul(dLdn * dActdn, WxnT) ); // [bS,iS]
dLdh0->assign( mmul(dLdu * dSigdu, WhuT) + mmul(dLdn * dActdn * (r + drdh0), WhnT) + (*dLdh)*u ); // [bS,nU] dLdh0->assign( mmul(dLdu * dSigdu, WhuT) + mmul(dLdn * dActdn * (r + drdh0), WhnT) + (*dLdh)*u ); // [bS,nUn]
dLdWxr.assign( mmul(xT, dSigdr * dLdr) ); // [iS,nU] dLdWxr.assign( mmul(xT, dSigdr * dLdr) ); // [iS,nUn]
dLdWhr.assign( mmul(h0T, dSigdr * dLdr) ); // [nU,nU] dLdWhr.assign( mmul(h0T, dSigdr * dLdr) ); // [nUn,nUn]
dLdWxu.assign( mmul(xT, dSigdu * dLdu) ); // [iS,nU] dLdWxu.assign( mmul(xT, dSigdu * dLdu) ); // [iS,nUn]
dLdWhu.assign( mmul(h0T, dSigdu * dLdu) ); // [nU,nU] dLdWhu.assign( mmul(h0T, dSigdu * dLdu) ); // [nUn,nUn]
dLdWxn.assign( mmul(xT, dActdn * dLdn) ); // [iS,nU] dLdWxn.assign( mmul(xT, dActdn * dLdn) ); // [iS,nUn]
dLdWhn.assign( mmul((r*(*h0)).transpose(), dActdn * dLdn) ); // [nU,nU] dLdWhn.assign( mmul((r*(*h0)).transpose(), dActdn * dLdn) ); // [nUn,nUn]
dLdbr.assign( (dSigdr * dLdr).reduceAlongDims(reduce::Sum, {0})); // [nU] dLdbr.assign( (dSigdr * dLdr).reduceAlongDims(reduce::Sum, {0})); // [nUn]
dLdbu.assign( (dSigdu * dLdu).reduceAlongDims(reduce::Sum, {0})); // [nU] dLdbu.assign( (dSigdu * dLdu).reduceAlongDims(reduce::Sum, {0})); // [nUn]
dLdbn.assign( (dActdn * dLdn).reduceAlongDims(reduce::Sum, {0})); // [nU] dLdbn.assign( (dActdn * dLdn).reduceAlongDims(reduce::Sum, {0})); // [nUn]
if(dLdWx0 != nullptr) if(dLdWx0 != nullptr)
*dLdWx += *dLdWx0; *dLdWx += *dLdWx0;
if(dLdWh0 != nullptr) if(dLdWh0 != nullptr)
*dLdWh += *dLdWh0; *dLdWh += *dLdWh0;
if(dLdb0 != nullptr) if(dLdb0 != nullptr)
*dLdb += *dLdb0; *dLdb += *dLdb0;
} }
@ -232,24 +255,24 @@ if(dLdb0 != nullptr)
// void gruTimeLoopBP(const std::vector<NDArray<T>*>& inArrs, const std::vector<NDArray<T>*>& outArrs) { // void gruTimeLoopBP(const std::vector<NDArray<T>*>& inArrs, const std::vector<NDArray<T>*>& outArrs) {
// NDArray<T>* x = inArrs[0]; // input [time, bS, iS] // NDArray<T>* x = inArrs[0]; // input [time, bS, iS]
// NDArray<T>* hi = inArrs[1]; // previous/initial cell output [bS, nU], that is at previous time step t-1 // NDArray<T>* hi = inArrs[1]; // previous/initial cell output [bS, nUn], that is at previous time step t-1
// NDArray<T>* Wx = inArrs[2]; // input-to-hidden weights, [iS, 3*nU] // NDArray<T>* Wx = inArrs[2]; // input-to-hidden weights, [iS, 3*nUn]
// NDArray<T>* Wh = inArrs[3]; // hidden-to-hidden weights, [nU, 3*nU] // NDArray<T>* Wh = inArrs[3]; // hidden-to-hidden weights, [nUn, 3*nUn]
// NDArray<T>* b = inArrs[4]; // biases, [3*nU] // NDArray<T>* b = inArrs[4]; // biases, [3*nUn]
// NDArray<T>* dLdh = inArrs[5]; // gradient wrt output, [time, bS, nU], that is epsilon_next // NDArray<T>* dLdh = inArrs[5]; // gradient wrt output, [time, bS, nUn], that is epsilon_next
// NDArray<T>* dLdx = outArrs[0]; // gradient wrt x, [time, bS, iS], that is epsilon // NDArray<T>* dLdx = outArrs[0]; // gradient wrt x, [time, bS, iS], that is epsilon
// NDArray<T>* dLdhi = outArrs[1]; // gradient wrt hi, [bS, nU] // NDArray<T>* dLdhi = outArrs[1]; // gradient wrt hi, [bS, nUn]
// NDArray<T>* dLdWx = outArrs[2]; // gradient wrt Wx, [iS, 3*nU] // NDArray<T>* dLdWx = outArrs[2]; // gradient wrt Wx, [iS, 3*nUn]
// NDArray<T>* dLdWh = outArrs[3]; // gradient wrt Wh, [nU, 3*nU] // NDArray<T>* dLdWh = outArrs[3]; // gradient wrt Wh, [nUn, 3*nUn]
// NDArray<T>* dLdb = outArrs[4]; // gradient wrt b, [3*nU] // NDArray<T>* dLdb = outArrs[4]; // gradient wrt b, [3*nUn]
// const Nd4jLong time = x->sizeAt(0); // const Nd4jLong time = x->sizeAt(0);
// const Nd4jLong bS = x->sizeAt(1); // const Nd4jLong bS = x->sizeAt(1);
// const Nd4jLong iS = x->sizeAt(2); // const Nd4jLong iS = x->sizeAt(2);
// const Nd4jLong nU = hi->sizeAt(1); // const Nd4jLong nUn = hi->sizeAt(1);
// NDArray<T> h(hi->ordering(), {time, bS, nU}); // feed forward output // NDArray<T> h(hi->ordering(), {time, bS, nUn}); // feed forward output
// // first step, time = 0, feed forward // // first step, time = 0, feed forward
// NDArray<T> x0 = (*x)({{0,1}, {}, {}}); // NDArray<T> x0 = (*x)({{0,1}, {}, {}});

View File

@ -85,11 +85,11 @@ namespace helpers {
PRAGMA_OMP_PARALLEL_FOR_IF(n > Environment::getInstance()->elementwiseThreshold()) PRAGMA_OMP_PARALLEL_FOR_IF(n > Environment::getInstance()->elementwiseThreshold())
for (int i = 0; i < n; i++) for (int i = 0; i < n; i++)
invertedMatrix->p(i, i, invertedMatrix->e<T>(i, i) / inputMatrix->e<T>(i, i)); invertedMatrix->t<T>(i, i) /= inputMatrix->t<T>(i, i);
PRAGMA_OMP_PARALLEL_FOR_IF(n > Environment::getInstance()->elementwiseThreshold()) PRAGMA_OMP_PARALLEL_FOR_IF(n > Environment::getInstance()->elementwiseThreshold())
for (int i = 0; i < n - 1; i++) for (int i = 0; i < n - 1; i++)
invertedMatrix->t<T>(i, i + 1) = invertedMatrix->t<T>(i, i+1) - (inputMatrix->t<T>(i, i + 1) * invertedMatrix->t<T>(i + 1, i + 1) / inputMatrix->t<T>(i, i)); invertedMatrix->t<T>(i, i + 1) -= (inputMatrix->t<T>(i, i + 1) * invertedMatrix->t<T>(i + 1, i + 1) / inputMatrix->t<T>(i, i));
// PRAGMA_OMP_PARALLEL_FOR_SIMD // PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int i = n - 2; i > - 1; i--) { for (int i = n - 2; i > - 1; i--) {
@ -124,25 +124,25 @@ namespace helpers {
for(int i = 0; i < rowNum; i++ ) { for(int i = 0; i < rowNum; i++ ) {
pivotValue = T(0.0); pivotValue = T(0.0);
pivot = -1; pivot = -1;
PRAGMA_OMP_PARALLEL_FOR //_ARGS(firstprivate(pivot,pivotValue))
for(int rowCounter = i; rowCounter < rowNum; rowCounter++ ) { for(int rowCounter = i; rowCounter < rowNum; rowCounter++ ) {
if(nd4j::math::nd4j_abs(compoundMatrix.e<T>(rowCounter, i)) > pivotValue ) { if (nd4j::math::nd4j_abs(compoundMatrix.t<T>(rowCounter, i)) > pivotValue) {
pivotValue = nd4j::math::nd4j_abs(compoundMatrix.e<T>(rowCounter, i)); pivotValue = nd4j::math::nd4j_abs(compoundMatrix.t<T>(rowCounter, i));
pivot = rowCounter; pivot = rowCounter;
} }
} }
if( pivotValue != T(0.0) ) { if( pivotValue > T(0.00001)) {
swapRows(&compoundMatrix, pivot, i); swapRows(&compoundMatrix, pivot, i);
swapRows(&permutationMatrix, pivot, i); swapRows(&permutationMatrix, pivot, i);
if (pivot != i) if (pivot != i)
swapCount++; swapCount++;
for( int j = i + 1; j < rowNum; j++ ) { for( int j = i + 1; j < rowNum; j++ ) {
compoundMatrix.p(j, i, compoundMatrix.e<T>(j, i) / compoundMatrix.e<T>(i, i)); compoundMatrix.t<T>(j, i) /= compoundMatrix.t<T>(i, i);
PRAGMA_OMP_PARALLEL_FOR
for( int k = i + 1; k < rowNum; k++ ) { for( int k = i + 1; k < rowNum; k++ ) {
T arg = compoundMatrix.e<T>(j, i) * compoundMatrix.e<T>(i, k); compoundMatrix.t<T>(j, k) -= compoundMatrix.t<T>(j, i) * compoundMatrix.t<T>(i, k);
compoundMatrix.p(j, k, compoundMatrix.e<T>(j, k) - arg);
} }
} }
} }
@ -188,7 +188,7 @@ namespace helpers {
} }
template <typename T> template <typename T>
int log_abs_determinant_(NDArray* input, NDArray* output) { int logAbsDeterminant_(NDArray* input, NDArray* output) {
Nd4jLong n = input->sizeAt(-1); Nd4jLong n = input->sizeAt(-1);
Nd4jLong n2 = n * n; Nd4jLong n2 = n * n;
@ -206,14 +206,14 @@ template <typename T>
return ND4J_STATUS_OK; return ND4J_STATUS_OK;
} }
BUILD_SINGLE_TEMPLATE(template int log_abs_determinant_, (NDArray* input, NDArray* output), FLOAT_TYPES); BUILD_SINGLE_TEMPLATE(template int logAbsDeterminant_, (NDArray* input, NDArray* output), FLOAT_TYPES);
int log_abs_determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { int logAbsDeterminant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
BUILD_SINGLE_SELECTOR(input->dataType(), return log_abs_determinant_, (input, output), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (input, output), FLOAT_TYPES);
} }
template <typename T> template <typename T>
static int _inverse(NDArray* input, NDArray* output) { static int inverse_(NDArray* input, NDArray* output) {
auto n = input->sizeAt(-1); auto n = input->sizeAt(-1);
auto n2 = n * n; auto n2 = n * n;
@ -236,7 +236,7 @@ template <typename T>
T det = lup_<T>(&matrix, &compound, &permutation).template e<T>(0); T det = lup_<T>(&matrix, &compound, &permutation).template e<T>(0);
// FIXME: and how this is going to work on float16? // FIXME: and how this is going to work on float16?
if (nd4j::math::nd4j_abs<T>(det) < T(0.0000001)) { if (nd4j::math::nd4j_abs<T>(det) < T(0.000001)) {
nd4j_printf("matrix_inverse: The matrix %i has no inverse due determinant is %lf. Quiting...\n", e, det); nd4j_printf("matrix_inverse: The matrix %i has no inverse due determinant is %lf. Quiting...\n", e, det);
matrix.printIndexedBuffer("Wrong matrix"); matrix.printIndexedBuffer("Wrong matrix");
return ND4J_STATUS_VALIDATION; return ND4J_STATUS_VALIDATION;
@ -244,12 +244,12 @@ template <typename T>
lowerMatrix.setIdentity(); // set up U to identity matrix lowerMatrix.setIdentity(); // set up U to identity matrix
for (int k = 1; k < n; k++) { // and then put all values under main diagonal on to it for (int k = 1; k < n; k++) { // and then put all values under main diagonal on to it
for (int j = 0; j < k; j++) for (int j = 0; j < k; j++)
lowerMatrix.p(k, j, compound.template e<T>(k, j)); lowerMatrix.template t<T>(k, j) = compound.template t<T>(k, j);
} }
upperMatrix.setIdentity(); // set up U to identity matrix upperMatrix.setIdentity(); // set up U to identity matrix
for (int k = 0; k < n; k++) { // and then put all values under main diagonal on to it for (int k = 0; k < n; k++) { // and then put all values under main diagonal on to it
for (int j = k; j < n; j++) for (int j = k; j < n; j++)
upperMatrix.p(k, j, compound.template e<T>(k, j)); upperMatrix.template t<T>(k, j) = compound.template e<T>(k, j);
} }
invertUpperMatrix(&upperMatrix, &matrix); invertUpperMatrix(&upperMatrix, &matrix);
@ -258,7 +258,7 @@ template <typename T>
nd4j::MmulHelper::mmul(&matrix, &upperMatrix, &compound, 1.0, 0.0); nd4j::MmulHelper::mmul(&matrix, &upperMatrix, &compound, 1.0, 0.0);
nd4j::MmulHelper::mmul(&compound, &permutation, &matrix, 1.0, 0.0); nd4j::MmulHelper::mmul(&compound, &permutation, &matrix, 1.0, 0.0);
for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) { for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) {
output->p(k, matrix.template e<T>(row++)); output->t<T>(k) = matrix.template t<T>(row++);
} }
} }
@ -266,7 +266,7 @@ template <typename T>
} }
int inverse(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { int inverse(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
BUILD_SINGLE_SELECTOR(input->dataType(), return _inverse, (input, output), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (input, output), FLOAT_TYPES);
} }
template <typename T> template <typename T>
@ -346,7 +346,7 @@ template <typename T>
BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (input, output, inplace), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (input, output, inplace), FLOAT_TYPES);
} }
BUILD_SINGLE_TEMPLATE(template int cholesky_, (NDArray* input, NDArray* output, bool inplace), FLOAT_TYPES); BUILD_SINGLE_TEMPLATE(template int cholesky_, (NDArray* input, NDArray* output, bool inplace), FLOAT_TYPES);
BUILD_SINGLE_TEMPLATE(template int _inverse, (NDArray* input, NDArray* output), FLOAT_TYPES); BUILD_SINGLE_TEMPLATE(template int inverse_, (NDArray* input, NDArray* output), FLOAT_TYPES);
template <typename T> template <typename T>
int logdetFunctor_(NDArray* input, NDArray* output) { int logdetFunctor_(NDArray* input, NDArray* output) {

View File

@ -22,125 +22,122 @@
#include <numeric> #include <numeric>
#include <helpers/ShapeUtils.h> #include <helpers/ShapeUtils.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
void scatter(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock) {
const int outRank = output.rankOf(); void scatter(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock) {
const int indRank = indices.rankOf();
const int updRank = updates.rankOf();
const Nd4jLong indLen = indices.lengthOf();
if(outRank == 1) { const int outRank = output.rankOf();
const int indRank = indices.rankOf();
const int updRank = updates.rankOf();
const Nd4jLong indLen = indices.lengthOf();
if(outRank == 1) {
<<<<<<< HEAD
// PRAGMA_OMP_PARALLEL_FOR_ARGS(if(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) // PRAGMA_OMP_PARALLEL_FOR_ARGS(if(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided)) PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided))
for(Nd4jLong i = 0; i < indLen; ++i) { for(Nd4jLong i = 0; i < indLen; ++i) {
for(Nd4jLong i = 0; i < indLen; ++i) {
Nd4jLong idx = indices.e<Nd4jLong>(i); Nd4jLong idx = indices.e<Nd4jLong>(i);
NDArray out = output({idx, idx+1}); NDArray out = output({idx, idx+1});
out.applyPairwiseTransform(op, updates.e(i), nullptr); out.applyPairwiseTransform(op, updates.e(i), nullptr);
} }
} }
else { // outRank > 1 else { // outRank > 1
int sizeOfDims = indRank; int sizeOfDims = indRank;
if(outRank == updRank && indices.isVector()) if(outRank == updRank && indices.isVector())
sizeOfDims = 1; sizeOfDims = 1;
std::vector<int> dimsToExcludeUpd(sizeOfDims); std::vector<int> dimsToExcludeUpd(sizeOfDims);
std::iota(dimsToExcludeUpd.begin(), dimsToExcludeUpd.end(), 0); std::iota(dimsToExcludeUpd.begin(), dimsToExcludeUpd.end(), 0);
// PRAGMA_OMP_PARALLEL_FOR_ARGS(if(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) // causes known openMP asan bug ! // PRAGMA_OMP_PARALLEL_FOR_ARGS(if(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) // causes known openMP asan bug !
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided)) PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided))
for(Nd4jLong i = 0; i < indLen; ++i) { for(Nd4jLong i = 0; i < indLen; ++i) {
NDArray outSubArr = output(indices.e<Nd4jLong>(i), std::vector<int>({0})); NDArray outSubArr = output(indices.e<Nd4jLong>(i), std::vector<int>({0}));
NDArray updSubArr = updates(i, dimsToExcludeUpd); NDArray updSubArr = updates(i, dimsToExcludeUpd);
outSubArr.applyPairwiseTransform(op, updSubArr, nullptr); outSubArr.applyPairwiseTransform(op, updSubArr, nullptr);
} }
} }
} }
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
void scatterND(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock) { void scatterND(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock) {
const Nd4jLong indLen = indices.lengthOf(); const Nd4jLong indLen = indices.lengthOf();
const int outRank = output.rankOf(); const int outRank = output.rankOf();
const int indRank = indices.rankOf(); const int indRank = indices.rankOf();
const Nd4jLong indLastDim = indices.sizeAt(-1); const Nd4jLong indLastDim = indices.sizeAt(-1);
if(outRank == 1) { if(outRank == 1) {
// PRAGMA_OMP_PARALLEL_FOR_ARGS(if(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) // PRAGMA_OMP_PARALLEL_FOR_ARGS(if(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided)) PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided))
for(Nd4jLong i = 0; i < indLen; ++i) { for(Nd4jLong i = 0; i < indLen; ++i) {
Nd4jLong idx = indices.e<Nd4jLong>(i); Nd4jLong idx = indices.e<Nd4jLong>(i);
NDArray out = output({idx, idx+1}); NDArray out = output({idx, idx+1});
out.applyPairwiseTransform(op, updates.e(i), nullptr); out.applyPairwiseTransform(op, updates.e(i), nullptr);
} }
} }
else { else {
std::vector<int> dimsToExcludeInd = ShapeUtils::evalDimsToExclude(indRank, {indRank-1}); std::vector<int> dimsToExcludeInd = ShapeUtils::evalDimsToExclude(indRank, {indRank-1});
std::vector<int> dimsToExcludeUpd(indRank - 1); std::vector<int> dimsToExcludeUpd(indRank - 1);
std::iota(dimsToExcludeUpd.begin(), dimsToExcludeUpd.end(), 0); std::iota(dimsToExcludeUpd.begin(), dimsToExcludeUpd.end(), 0);
std::vector<Nd4jLong> idxRangeOut(2*outRank, 0); std::vector<Nd4jLong> idxRangeOut(2*outRank, 0);
// PRAGMA_OMP_PARALLEL_FOR_ARGS(if(indLen/indLastDim > Environment::getInstance()->elementwiseThreshold()) schedule(guided) firstprivate(idxRangeOut)) // PRAGMA_OMP_PARALLEL_FOR_ARGS(if(indLen/indLastDim > Environment::getInstance()->elementwiseThreshold()) schedule(guided) firstprivate(idxRangeOut))
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided) firstprivate(idxRangeOut)) PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided) firstprivate(idxRangeOut))
for(Nd4jLong i = 0; i < indLen/indLastDim; ++i) { for(Nd4jLong i = 0; i < indLen/indLastDim; ++i) {
NDArray indSubArr = indices(i, dimsToExcludeInd); NDArray indSubArr = indices(i, dimsToExcludeInd);
for(Nd4jLong j = 0; j < indLastDim; ++j) { for(Nd4jLong j = 0; j < indLastDim; ++j) {
idxRangeOut[2*j] = indSubArr.e<Nd4jLong>(j); idxRangeOut[2*j] = indSubArr.e<Nd4jLong>(j);
idxRangeOut[2*j + 1] = idxRangeOut[2*j] + 1; idxRangeOut[2*j + 1] = idxRangeOut[2*j] + 1;
} }
NDArray outSubArr = output(idxRangeOut); NDArray outSubArr = output(idxRangeOut);
NDArray updSubArr = updates(i, dimsToExcludeUpd); NDArray updSubArr = updates(i, dimsToExcludeUpd);
outSubArr.applyPairwiseTransform(op, updSubArr, nullptr); outSubArr.applyPairwiseTransform(op, updSubArr, nullptr);
} }
} }
} }
void scatterForLoss(nd4j::LaunchContext *context, const NDArray& indices, NDArray& updates, NDArray& output, const bool calcGrad) {
void scatterForLoss(nd4j::LaunchContext *context, const NDArray& indices, const NDArray& updates, NDArray& output, const bool calcGrad) { // shapes of indices and output must be the same
// requirements for arrays // shape of indices should be the same as updates shape with last dimension excluded
// shapes of updates and output must be the same // for example if updates is {a,b,c} then indices should be {a,b}
// shape of indices should be the same as updates shape with last dimension excluded
// for example if updates is {a,b,c} then indices should be {a,b} const Nd4jLong indicesLen = indices.lengthOf();
const Nd4jLong indicesLen = indices.lengthOf(); std::vector<int> dimsToExclude = ShapeUtils::evalDimsToExclude(updates.rankOf(), {-1});
std::vector<int> dimsToExclude = ShapeUtils::evalDimsToExclude(updates.rankOf(), {-1}); if(!calcGrad) {
PRAGMA_OMP_PARALLEL_FOR_ARGS(schedule(guided))
if(!calcGrad) { for(Nd4jLong i = 0; i < indicesLen; ++i) {
PRAGMA_OMP_PARALLEL_FOR_ARGS(schedule(guided))
for(Nd4jLong i = 0; i < indicesLen; ++i) { auto subArr = updates(i, dimsToExclude);
output.p(i, subArr.e(indices.e<Nd4jLong>(i)));
auto subArr = updates(i, dimsToExclude); }
output.p(i, subArr.e(indices.e<Nd4jLong>(i))); } else {
} PRAGMA_OMP_PARALLEL_FOR_ARGS(schedule(guided))
} for(Nd4jLong i = 0; i < indicesLen; ++i) {
else {
PRAGMA_OMP_PARALLEL_FOR_ARGS(schedule(guided)) auto subArr = updates(i, dimsToExclude);
for(Nd4jLong i = 0; i < indicesLen; ++i) { auto ind = indices.e<Nd4jLong>(i);
subArr.p(ind, subArr.e(ind) - 1.);
auto subArr = updates(i, dimsToExclude);
auto ind = indices.e<Nd4jLong>(i);
subArr.p(ind, subArr.e(ind) - 1.);
}
}
}
} }
} }
} }

View File

@ -22,6 +22,7 @@
#include<ops/declarable/helpers/sru.h> #include<ops/declarable/helpers/sru.h>
#include <NDArrayFactory.h> #include <NDArrayFactory.h>
#include <MmulHelper.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
@ -108,27 +109,27 @@ void sruTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray*
template <typename T> template <typename T>
static void sruBI_(NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* mask, NDArray* ht, NDArray* ct) { static void sruBI_(NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* mask, NDArray* ht, NDArray* ct) {
// x input 3d tensor [time x bS x 2*inSize], time - number of time steps, bS - batch size, inSize - number of features // x input 3d tensor [time x bS x 2*K], time - number of time steps, bS - batch size, K - number of features
// w 2d tensor of weights [2*inSize x 6*inSize] // w 2d tensor of weights [2*K x 6*K]
// b row of biases with twice length [1 x 4*inSize] // b row of biases with twice length [4*K]
// c0 2d tensor of initial state [bS x 2*inSize] at time t=0 // c0 2d tensor of initial state [bS x 2*K] at time t=0
// mask optional, 2d tensor of dropout mask [bS x 2*inSize] // mask optional, 2d tensor of dropout mask [bS x 2*K]
// ht [time x bS x 2*inSize] // ht [time x bS x 2*K]
// ct [time x bS x 2*inSize] // ct [time x bS x 2*K]
const Nd4jLong time = x->sizeAt(0); // time - number of time steps const Nd4jLong time = x->sizeAt(0); // time - number of time steps
const Nd4jLong bS = x->sizeAt(1); // bS - batch size const Nd4jLong bS = x->sizeAt(1); // bS - batch size
const Nd4jLong inSize = x->sizeAt(2) / 2; // inSize - number of features const Nd4jLong K = x->sizeAt(2) / 2; // K - number of features
// x = x * mask // x = x * mask
if(mask) if(mask)
x->applyBroadcast(broadcast::Multiply, {1, 2}, mask, x, nullptr); // apply mask x->applyBroadcast(broadcast::Multiply, {1, 2}, mask, x, nullptr); // apply mask
// U = x * w // U = x * w
NDArray wi = mmul(*x, *w); // U [time x bS x 6*inSize] NDArray wi = mmul(*x, *w); // U [time x bS x 6*K]
const Nd4jLong d2 = 2*inSize; const Nd4jLong d2 = 2*K;
const Nd4jLong ncols = bS*d2; const Nd4jLong ncols = bS*d2;
const Nd4jLong ncolsWi = 3*ncols; const Nd4jLong ncolsWi = 3*ncols;
@ -140,42 +141,39 @@ static void sruBI_(NDArray* x, const NDArray* w, const NDArray* b, const NDArray
T* pHt = ht->bufferAsT<T>(); T* pHt = ht->bufferAsT<T>();
T* pCt = ct->bufferAsT<T>(); T* pCt = ct->bufferAsT<T>();
Nd4jLong ncolsRev, ncolsWiRev; // for reverse direction PRAGMA_OMP_PARALLEL_FOR
T maskVal, cur, bF, bR, ft, rt, val;
T *pIVal(nullptr), *pWiVal(nullptr), *pHtVal(nullptr), *pCtVal(nullptr);
bool flip = false;
for (Nd4jLong col = 0; col < ncols; ++col) { for (Nd4jLong col = 0; col < ncols; ++col) {
const auto colNum = col % d2; const auto colNum = col % d2;
flip = colNum >= inSize; bool flip = colNum >= K;
maskVal = mask ? *(pMask + col) : T(1); T maskVal = mask ? *(pMask + col) : T(1);
cur = *(pInit + col); T cur = *(pInit + col);
bF = *(pBias + colNum); T bF = *(pBias + colNum);
bR = *(pBias + colNum + d2); T bR = *(pBias + colNum + d2);
pWiVal = pWi + 3*col; T* pWiVal = pWi + 3*col;
pIVal = pI + col; T* pIVal = pI + col;
pHtVal = pHt + col; T* pHtVal = pHt + col;
pCtVal = pCt + col; T* pCtVal = pCt + col;
if (flip) { if (flip) {
pIVal += (time-1)*ncols; const auto step = (time - 1) * ncols;
pWiVal += (time-1)*ncolsWi; pIVal += step;
pHtVal += (time-1)*ncols; pHtVal += step;
pCtVal += (time-1)*ncols; pCtVal += step;
pWiVal += (time - 1) * ncolsWi;
} }
ncolsRev = flip ? -ncols : ncols; auto ncolsRev = flip ? -ncols : ncols;
ncolsWiRev = flip ? -ncolsWi : ncolsWi; auto ncolsWiRev = flip ? -ncolsWi : ncolsWi;
for (Nd4jLong t = 0; t < time; ++t) { for (Nd4jLong t = 0; t < time; ++t) {
// evaluate sigmoids // evaluate sigmoids
ft = (1.)/(1. + nd4j::math::nd4j_exp<T, T>(-(*(pWiVal + 1) + bF))); T ft = (1.)/(1. + nd4j::math::nd4j_exp<T, T>(-(pWiVal[1] + bF)));
rt = (1.)/(1. + nd4j::math::nd4j_exp<T, T>(-(*(pWiVal + 2) + bR))); T rt = (1.)/(1. + nd4j::math::nd4j_exp<T, T>(-(pWiVal[2] + bR)));
cur = (cur - *pWiVal)*ft + *pWiVal; cur = (cur - *pWiVal)*ft + *pWiVal;
*pCtVal = cur; *pCtVal = cur;
val = nd4j::math::nd4j_tanh<T, T>(cur); T val = nd4j::math::nd4j_tanh<T, T>(cur);
*pHtVal = (val*maskVal - *pIVal)*rt + *pIVal; *pHtVal = (val*maskVal - *pIVal)*rt + *pIVal;
pIVal += ncolsRev; pIVal += ncolsRev;
@ -191,34 +189,34 @@ template <typename T>
static void sruBIBP_(NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* ct, const NDArray* inGradC0, const NDArray* inGradHt, const NDArray* mask, static void sruBIBP_(NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* ct, const NDArray* inGradC0, const NDArray* inGradHt, const NDArray* mask,
NDArray* gradI, NDArray* gradW, NDArray* gradB, NDArray* gradC0) { NDArray* gradI, NDArray* gradW, NDArray* gradB, NDArray* gradC0) {
// x input 3d tensor [time x bS x 2*inSize], time - number of time steps, bS - batch size, inSize - number of features // x input 3d tensor [time x bS x 2*K], time - number of time steps, bS - batch size, K - number of features
// w 2d tensor of weights [2*inSize x 6*inSize] // w 2d tensor of weights [2*K x 6*K]
// b row of biases with twice length [1 x 4*inSize] // b row of biases with twice length 4*K]
// c0 2d tensor of initial state [bS x 2*inSize] at time t=0 // c0 2d tensor of initial state [bS x 2*K] at time t=0
// ct [time x bS x 2*inSize] // ct [time x bS x 2*K]
// inGradC0 [bS x 2*inSize] // inGradC0 [bS x 2*K]
// inGradHt [time x bS x 2*inSize] // inGradHt [time x bS x 2*K]
// mask optional, 2d tensor of dropout mask [bS x 2*inSize] // mask optional, 2d tensor of dropout mask [bS x 2*K]
// gradI [time x bS x 2*inSize] // gradI [time x bS x 2*K]
// gradW [time x 2*inSize x 6*inSize] // gradW [time x 2*K x 6*K]
// gradB [1 x 4*inSize] // gradB [4*K]
// gradC0 [bS x 2*inSize] // gradC0 [bS x 2*K]
const Nd4jLong time = x->sizeAt(0); // time - number of time steps const Nd4jLong time = x->sizeAt(0); // time - number of time steps
const Nd4jLong bS = x->sizeAt(1); const Nd4jLong bS = x->sizeAt(1);
const Nd4jLong inSize = x->sizeAt(2) / 2; const Nd4jLong K = x->sizeAt(2) / 2;
// x = x * mask // x = x * mask
if(mask) if(mask)
x->applyBroadcast(broadcast::Multiply, {1, 2}, mask, x, nullptr); // apply mask x->applyBroadcast(broadcast::Multiply, {1, 2}, mask, x, nullptr); // apply mask
// U = x * w // U = x * w
NDArray wi = mmul(*x, *w); // [time x bS x 2*inSize] * [2*inSize x 6*inSize] = [time x bS x 6*inSize] NDArray wi = mmul(*x, *w); // [time x bS x 2*K] * [2*K x 6*K] = [time x bS x 6*K]
NDArray gradBias(x->ordering(), {bS, 4*inSize}, x->dataType(), x->getContext()); NDArray gradBias(x->ordering(), {bS, 4*K}, x->dataType(), x->getContext());
NDArray gradWi (x->ordering(), {time, bS, 6*inSize}, x->dataType(), x->getContext()); NDArray gradWi (x->ordering(), {time, bS, 6*K}, x->dataType(), x->getContext());
const Nd4jLong d2 = 2*inSize; const Nd4jLong d2 = 2*K;
const Nd4jLong ncols = bS*d2; const Nd4jLong ncols = bS*d2;
const Nd4jLong ncolsWi = 3*ncols; const Nd4jLong ncolsWi = 3*ncols;
T* pInput = x->bufferAsT<T>(); T* pInput = x->bufferAsT<T>();
@ -234,59 +232,61 @@ static void sruBIBP_(NDArray* x, const NDArray* w, const NDArray* b, const NDArr
T* pGradBias = gradBias.bufferAsT<T>(); T* pGradBias = gradBias.bufferAsT<T>();
T* pGradInit = gradC0->bufferAsT<T>(); T* pGradInit = gradC0->bufferAsT<T>();
Nd4jLong ncolsRev, ncolsWiRev; // for reverse direction PRAGMA_OMP_PARALLEL_FOR
T gbF, gbR, cur, maskVal, bF, bR, ft, rt, val, prevVal, gft, grt, gradSateVal;
bool flip = false;
T *pInputVal(nullptr), *pWiVal(nullptr), *pStateVal(nullptr), *pInGradHtVal(nullptr), *pGradWiVal(nullptr), *pGradInputVal(nullptr);
for (Nd4jLong col = 0; col < ncols; ++col) { for (Nd4jLong col = 0; col < ncols; ++col) {
gbF = gbR = (T)0.; T gbF = 0.f;
T gbR = 0.f;
const auto colNum = col % d2; const auto colNum = col % d2;
flip = colNum >= inSize; const bool flip = colNum >= K;
maskVal = mask ? *(pMask + col) : T(1.); T maskVal = mask ? *(pMask + col) : T(1.);
cur = *(pInGradCt + col); T cur = *(pInGradCt + col);
bF = *(pBias + colNum); T bF = *(pBias + colNum);
bR = *(pBias + colNum + d2); T bR = *(pBias + colNum + d2);
pWiVal = pWi + 3*col; T* pWiVal = pWi + 3*col;
pInputVal = pInput + col; T* pInputVal = pInput + col;
pStateVal = pState + col; T* pStateVal = pState + col;
pInGradHtVal = pInGradHt + col; T* pInGradHtVal = pInGradHt + col;
pGradWiVal = pGradWi + 3*col; T* pGradWiVal = pGradWi + 3*col;
pGradInputVal = pGradInput + col; T* pGradInputVal = pGradInput + col;
if (!flip) { if (!flip) {
pInputVal += (time-1)*ncols; const auto stepI = (time - 1) * ncols;
pWiVal += (time-1)*ncolsWi; const auto stepW = (time - 1) * ncolsWi;
pStateVal += (time-1)*ncols; pInputVal += stepI;
pInGradHtVal += (time-1)*ncols; pStateVal += stepI;
pGradWiVal += (time-1)*ncolsWi; pInGradHtVal += stepI;
pGradInputVal += (time-1)*ncols; pGradInputVal += stepI;
pWiVal += stepW;
pGradWiVal += stepW;
} }
ncolsRev = flip ? -ncols : ncols;
ncolsWiRev = flip ? -ncolsWi : ncolsWi; Nd4jLong ncolsRev = flip ? -ncols : ncols;
Nd4jLong ncolsWiRev = flip ? -ncolsWi : ncolsWi;
for (Nd4jLong t = 0; t < time; ++t) { for (Nd4jLong t = 0; t < time; ++t) {
// evaluate sigmoids // evaluate sigmoids
ft = ((T)1.)/((T)1. + nd4j::math::nd4j_exp<T,T>(-(*(pWiVal + 1) + bF))); T ft = ((T)1.)/((T)1. + nd4j::math::nd4j_exp<T,T>(-(*(pWiVal + 1) + bF)));
rt = ((T)1.)/((T)1. + nd4j::math::nd4j_exp<T,T>(-(*(pWiVal + 2) + bR))); T rt = ((T)1.)/((T)1. + nd4j::math::nd4j_exp<T,T>(-(*(pWiVal + 2) + bR)));
val = nd4j::math::nd4j_tanh<T,T>(*pStateVal); T val = nd4j::math::nd4j_tanh<T,T>(*pStateVal);
prevVal = (t < time-1) ? (*(pStateVal - ncolsRev)) : (*(pInit + col)); T prevVal = (t < time-1) ? (*(pStateVal - ncolsRev)) : (*(pInit + col));
// grad wrt input // grad wrt input
*pGradInputVal = *pInGradHtVal - (*pInGradHtVal)*rt ; *pGradInputVal = *pInGradHtVal - (*pInGradHtVal)*rt ;
// grad wrt rt, wiR and bR // grad wrt rt, wiR and bR
grt = (*pInGradHtVal) * (val*maskVal - *pInputVal) * (rt - rt*rt); T grt = (*pInGradHtVal) * (val*maskVal - *pInputVal) * (rt - rt*rt);
*(pGradWiVal + 2) = grt; *(pGradWiVal + 2) = grt;
gbR += grt; gbR += grt;
// grad wrt state // grad wrt state
gradSateVal = (*pInGradHtVal) * maskVal * (rt - rt*val*val) + cur; T gradSateVal = (*pInGradHtVal) * maskVal * (rt - rt*val*val) + cur;
// grad wrt wi0 // grad wrt wi0
*pGradWiVal = gradSateVal - gradSateVal*ft; *pGradWiVal = gradSateVal - gradSateVal*ft;
// grad wrt ft, wi1, and bF // grad wrt ft, wi1, and bF
gft = gradSateVal * (prevVal - *pWiVal) * (ft - ft*ft); T gft = gradSateVal * (prevVal - *pWiVal) * (ft - ft*ft);
*(pGradWiVal + 1) = gft; *(pGradWiVal + 1) = gft;
gbF += gft; gbF += gft;
// grad wrt c_previous // grad wrt c_previous
cur = gradSateVal * ft; cur = gradSateVal * ft;
pInputVal -= ncolsRev; pInputVal -= ncolsRev;
pWiVal -= ncolsWiRev; pWiVal -= ncolsWiRev;
pStateVal -= ncolsRev; pStateVal -= ncolsRev;
@ -300,11 +300,11 @@ static void sruBIBP_(NDArray* x, const NDArray* w, const NDArray* b, const NDArr
} }
// gradB // gradB
gradBias.reduceAlongDimension(reduce::Sum, gradB, {0}, false, true); // [1 x 4*inSize] gradBias.reduceAlongDimension(reduce::Sum, gradB, {0}); // [4*K]
// gradW // gradW
x->permutei({0, 2, 1}); // [time x bS x 2*inSize] -> [time x 2*inSize x bS] x->permutei({0, 2, 1}); // [time x bS x 2*K] -> [time x 2*K x bS]
*gradW = mmul(*x, gradWi); // [time x 2*inSize x bS ] * [time x bS x 6*inSize] = [time x 2*inSize x 6*inSize] MmulHelper::mmul(x, &gradWi, gradW, 1., 0.); // [time x 2*K x bS ] * [time x bS x 6*K] = [time x 2*K x 6*K]
} }

View File

@ -43,8 +43,8 @@ static void triuBP_(nd4j::LaunchContext * context, const NDArray& input, const N
PRAGMA_OMP_PARALLEL_FOR_IF(dLen > Environment::getInstance()->elementwiseThreshold()) PRAGMA_OMP_PARALLEL_FOR_IF(dLen > Environment::getInstance()->elementwiseThreshold())
for(int i = 0; i < dLen; ++i) { for(int i = 0; i < dLen; ++i) {
if(dOdI.e<T>(i) != (T)0.f) if(dOdI.t<T>(i) != static_cast<T>(0.f))
dOdI.p(i, T(1.f)); dOdI.t<T>(i) = static_cast<T>(1.f);
} }
// FIXME: !!! // FIXME: !!!

View File

@ -32,30 +32,48 @@ namespace helpers {
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
template<typename X, typename Y> template<typename X, typename Y>
__global__ void preluCuda(const void *vx, const Nd4jLong *xShapeInfo, __global__ void preluCuda(const void *vx, const Nd4jLong *xShapeInfo,
const void *vy, const Nd4jLong *yShapeInfo, const void *vy, const Nd4jLong *yShapeInfo,
void *vz) { void *vz) {
const auto x = reinterpret_cast<const X*>(vx); const auto x = reinterpret_cast<const X*>(vx);
const auto y = reinterpret_cast<const Y*>(vy); const auto y = reinterpret_cast<const Y*>(vy);
auto z = reinterpret_cast<X*>(vz); auto z = reinterpret_cast<X*>(vz);
__shared__ Nd4jLong len; __shared__ Nd4jLong xzLen, totalThreads, *sharedMem;
__shared__ int xzRank, yRank;
if (threadIdx.x == 0) if (threadIdx.x == 0) {
len = shape::length(xShapeInfo); extern __shared__ unsigned char shmem[];
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
xzLen = shape::length(xShapeInfo);
totalThreads = gridDim.x * blockDim.x;
xzRank = shape::rank(xShapeInfo);
yRank = shape::rank(yShapeInfo);
}
__syncthreads(); __syncthreads();
const auto tid = blockIdx.x * blockDim.x + threadIdx.x; const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
const auto totalThreads = gridDim.x * blockDim.x; Nd4jLong* coords = sharedMem + threadIdx.x * xzRank;
for (int i = tid; i < len; i += totalThreads) { for (int i = tid; i < xzLen; i += totalThreads) {
const auto xzOffset = shape::getIndexOffset(i, xShapeInfo, len); shape::index2coords(xzRank, xShapeInfo + 1, i, xzLen, coords);
const auto xVal = x[xzOffset];
if(xVal < 0) const auto xzOffset = shape::getOffset(0, xShapeInfo + 1, xShapeInfo + xzRank + 1, coords, xzRank);
z[xzOffset] = xVal * y[shape::subArrayOffset(i, xShapeInfo, yShapeInfo)];
const auto xVal = x[xzOffset];
if(xVal < 0) {
for (uint j = 0; j < yRank; ++j)
if(yShapeInfo[j + 1] == 1)
coords[j + 1] = 0;
z[xzOffset] = xVal * y[shape::getOffset(0, yShapeInfo + 1, yShapeInfo + yRank + 1, coords + 1, yRank)];
}
else else
z[xzOffset] = xVal; z[xzOffset] = xVal;
} }
@ -63,28 +81,121 @@ __global__ void preluCuda(const void *vx, const Nd4jLong *xShapeInfo,
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
template<typename X, typename Y> template<typename X, typename Y>
linkage void preluCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz) { linkage void preluCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz) {
preluCuda<X, Y><<<blocksPerGrid, threadsPerBlock, 1024, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz); preluCuda<X, Y><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz);
} }
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
void prelu(nd4j::LaunchContext * context, const NDArray& input, const NDArray& alpha, NDArray& output) { void prelu(nd4j::LaunchContext * context, const NDArray& input, const NDArray& alpha, NDArray& output) {
if(!input.isActualOnDeviceSide()) input.syncToDevice();
if(!alpha.isActualOnDeviceSide()) alpha.syncToDevice(); PointersManager manager(context, "prelu");
const int threadsPerBlock = MAX_NUM_THREADS / 2;
const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
const int sharedMem = input.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128;
const auto xType = input.dataType(); const auto xType = input.dataType();
const auto yType = alpha.dataType(); const auto yType = alpha.dataType();
int threadsPerBlock = MAX_NUM_THREADS;
int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
BUILD_DOUBLE_SELECTOR(xType, yType, preluCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), alpha.getSpecialBuffer(), alpha.getSpecialShapeInfo(), output.getSpecialBuffer()), LIBND4J_TYPES, FLOAT_TYPES); NDArray::prepareSpecialUse({&output}, {&input, &alpha});
BUILD_DOUBLE_SELECTOR(xType, yType, preluCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), alpha.getSpecialBuffer(), alpha.getSpecialShapeInfo(), output.getSpecialBuffer()), LIBND4J_TYPES, FLOAT_TYPES);
NDArray::registerSpecialUse({&output}, {&input, &alpha});
input.tickReadHost(); manager.synchronize();
alpha.tickReadHost();
output.tickWriteDevice();
} }
///////////////////////////////////////////////////////////////////
template<typename X, typename Y>
__global__ linkage void preluBPCuda(const void *vIn, const Nd4jLong *inShapeInfo,
const void *vAlpha, const Nd4jLong *alphaShapeInfo,
const void *vdLdO, const Nd4jLong *dLdOShapeInfo,
void *vdLdI, const Nd4jLong *dLdIShapeInfo,
void *vdLdA, const Nd4jLong *dLdAShapeInfo) {
const auto in = reinterpret_cast<const X*>(vIn);
const auto alpha = reinterpret_cast<const Y*>(vAlpha);
const auto dLdO = reinterpret_cast<const Y*>(vdLdO);
auto dLdI = reinterpret_cast<Y*>(vdLdI);
auto dLdA = reinterpret_cast<Y*>(vdLdA);
__shared__ Nd4jLong inLen, totalThreads, *sharedMem;
__shared__ int inRank, alphaRank;
if (threadIdx.x == 0) {
extern __shared__ unsigned char shmem[];
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
inLen = shape::length(inShapeInfo);
totalThreads = gridDim.x * blockDim.x;
inRank = shape::rank(inShapeInfo);
alphaRank = shape::rank(alphaShapeInfo);
}
__syncthreads();
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
Nd4jLong* coords = sharedMem + threadIdx.x * inRank;
for (int i = tid; i < inLen; i += totalThreads) {
shape::index2coords(inRank, inShapeInfo + 1, i, inLen, coords);
const auto inOffset = shape::getOffset(0, inShapeInfo + 1, inShapeInfo + inRank + 1, coords, inRank);
const auto dLdOOffset = shape::getOffset(0, dLdOShapeInfo + 1, dLdOShapeInfo + inRank + 1, coords, inRank);
const auto dLdIOffset = shape::getOffset(0, dLdIShapeInfo + 1, dLdIShapeInfo + inRank + 1, coords, inRank);
const auto xVal = in[inOffset];
const auto grO = dLdO[dLdOOffset];
if(xVal < 0) {
for (uint j = 0; j < alphaRank; ++j)
if(alphaShapeInfo[j + 1] == 1)
coords[j + 1] = 0;
const auto alphaOffset = shape::getOffset(0, alphaShapeInfo + 1, alphaShapeInfo + alphaRank + 1, coords + 1, alphaRank);
const auto dLdAOffset = shape::getOffset(0, dLdAShapeInfo + 1, dLdAShapeInfo + alphaRank + 1, coords + 1, alphaRank);
dLdI[dLdIOffset] = grO * alpha[alphaOffset];
nd4j::math::atomics::nd4j_atomicAdd<Y>(&dLdA[dLdAOffset], static_cast<Y>(grO * xVal));
}
else
dLdI[dLdIOffset] = grO;
}
}
//////////////////////////////////////////////////////////////////////////
template<typename X, typename Y>
__host__ linkage void preluBPCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vIn, const Nd4jLong *inShapeInfo, const void *vAlpha, const Nd4jLong *alphaShapeInfo, const void *vdLdO, const Nd4jLong *dLdOShapeInfo, void *vdLdI, const Nd4jLong *dLdIShapeInfo, void *vdLdA, const Nd4jLong *dLdAShapeInfo) {
preluBPCuda<X, Y><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vIn, inShapeInfo, vAlpha, alphaShapeInfo, vdLdO, dLdOShapeInfo, vdLdI, dLdIShapeInfo, vdLdA, dLdAShapeInfo);
}
//////////////////////////////////////////////////////////////////////////
void preluBP(nd4j::LaunchContext* context, const NDArray& input, const NDArray& alpha, const NDArray& dLdO, NDArray& dLdI, NDArray& dLdA) {
dLdA.nullify();
PointersManager manager(context, "preluBP");
const int threadsPerBlock = MAX_NUM_THREADS / 2;
const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
const int sharedMem = input.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128;
const auto xType = input.dataType();
const auto zType = alpha.dataType();
NDArray::prepareSpecialUse({&dLdI, &dLdA}, {&input, &alpha, &dLdO});
BUILD_DOUBLE_SELECTOR(xType, zType, preluBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), alpha.getSpecialBuffer(), alpha.getSpecialShapeInfo(), dLdO.getSpecialBuffer(), dLdO.getSpecialShapeInfo(), dLdI.getSpecialBuffer(), dLdI.getSpecialShapeInfo(), dLdA.getSpecialBuffer(), dLdA.getSpecialShapeInfo()), LIBND4J_TYPES, FLOAT_TYPES);
NDArray::registerSpecialUse({&dLdI, &dLdA}, {&input, &alpha, &dLdO});
manager.synchronize();
}
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
template<typename T> template<typename T>
__global__ void softMaxForVectorCuda(const void *vx, const Nd4jLong *xzShapeInfo, void *vz) { __global__ void softMaxForVectorCuda(const void *vx, const Nd4jLong *xzShapeInfo, void *vz) {
@ -439,88 +550,6 @@ void softmaxDerivative(nd4j::LaunchContext * context, const NDArray& input, NDAr
output.tickWriteDevice(); output.tickWriteDevice();
} }
///////////////////////////////////////////////////////////////////
template<typename X, typename Y>
__global__ linkage void preluBPCuda(const void *vIn, const Nd4jLong *inShapeInfo,
const void *vAlpha, const Nd4jLong *alphaShapeInfo,
const void *vdLdO, const Nd4jLong *dLdOShapeInfo,
void *vdLdI, const Nd4jLong *dLdIShapeInfo,
void *vdLdA, const Nd4jLong *dLdAShapeInfo) {
const auto in = reinterpret_cast<const X*>(vIn);
const auto alpha = reinterpret_cast<const Y*>(vAlpha);
const auto dLdO = reinterpret_cast<const Y*>(vdLdO);
auto dLdI = reinterpret_cast<Y*>(vdLdI);
auto dLdA = reinterpret_cast<Y*>(vdLdA);
__shared__ Nd4jLong alphaLen;
if (threadIdx.x == 0)
alphaLen = shape::length(alphaShapeInfo);
__syncthreads();
const auto i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= alphaLen) return;
Nd4jLong inputIdxs[MAX_RANK*2];
int numIdxs = shape::outerArrayOffsets(inputIdxs, i, inShapeInfo, alphaShapeInfo);
Nd4jLong dLdOIdxs[MAX_RANK*2];
shape::outerArrayOffsets(dLdOIdxs, i, dLdOShapeInfo, alphaShapeInfo);
Nd4jLong dLdIIdxs[MAX_RANK*2];
shape::outerArrayOffsets(dLdIIdxs, i, dLdIShapeInfo, alphaShapeInfo);
const auto alphaOffset = shape::getIndexOffset(i, alphaShapeInfo, alphaLen);
const auto dLdAOffset = shape::getIndexOffset(i, dLdAShapeInfo, alphaLen);
for(Nd4jLong j = 0; j < numIdxs; ++j) {
const auto inInd = inputIdxs[j];
const auto dLdOInd = dLdOIdxs[j];
const auto dLdIInd = dLdIIdxs[j];
if(in[inInd] < 0) {
dLdI[dLdIInd] = dLdO[dLdOInd] * alpha[alphaOffset];
auto prevVal = dLdA[dLdAOffset];
prevVal = prevVal + dLdO[dLdOInd] * in[inInd];
dLdA[dLdAOffset] = prevVal;
}
else
dLdI[dLdIInd] = dLdO[dLdOInd];
}
}
template<typename X, typename Y>
__host__ linkage void preluBPCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void *vIn, const Nd4jLong *inShapeInfo, const void *vAlpha, const Nd4jLong *alphaShapeInfo, const void *vdLdO, const Nd4jLong *dLdOShapeInfo, void *vdLdI, const Nd4jLong *dLdIShapeInfo, void *vdLdA, const Nd4jLong *dLdAShapeInfo) {
preluBPCuda<X, Y><<<blocksPerGrid, threadsPerBlock, 1024, *stream>>>(vIn, inShapeInfo, vAlpha, alphaShapeInfo, vdLdO, dLdOShapeInfo, vdLdI, dLdIShapeInfo, vdLdA, dLdAShapeInfo);
}
//////////////////////////////////////////////////////////////////////////
void preluBP(nd4j::LaunchContext * context, const NDArray& input, const NDArray& alpha, const NDArray& dLdO, NDArray& dLdI, NDArray& dLdA) {
if(!input.isActualOnDeviceSide()) input.syncToDevice();
if(!alpha.isActualOnDeviceSide()) alpha.syncToDevice();
if(!dLdO.isActualOnDeviceSide()) dLdO.syncToDevice();
const auto xType = input.dataType();
const auto zType = dLdO.dataType();
int threadsPerBlock = MAX_NUM_THREADS;
int blocksPerGrid = (alpha.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
BUILD_DOUBLE_SELECTOR(xType, zType, preluBPCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), alpha.getSpecialBuffer(), alpha.getSpecialShapeInfo(), dLdO.getSpecialBuffer(), dLdO.getSpecialShapeInfo(), dLdI.getSpecialBuffer(), dLdI.getSpecialShapeInfo(), dLdA.getSpecialBuffer(), dLdA.getSpecialShapeInfo()), LIBND4J_TYPES, FLOAT_TYPES);
input.tickReadHost();
alpha.tickReadHost();
dLdO.tickReadHost();
dLdI.tickWriteDevice();
dLdA.tickWriteDevice();
}
template <typename T> template <typename T>
linkage void thresholdRelu_(NDArray const& input, double threshold, NDArray& output) { linkage void thresholdRelu_(NDArray const& input, double threshold, NDArray& output) {
@ -545,8 +574,8 @@ __host__ linkage void preluBPCudaLauncher(const int blocksPerGrid, const int thr
BUILD_SINGLE_TEMPLATE(template void thresholdReluDerivative_, (NDArray* input, double threshold, NDArray* dLdO, NDArray* output), FLOAT_TYPES); BUILD_SINGLE_TEMPLATE(template void thresholdReluDerivative_, (NDArray* input, double threshold, NDArray* dLdO, NDArray* output), FLOAT_TYPES);
BUILD_DOUBLE_TEMPLATE(template void preluCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz), LIBND4J_TYPES, FLOAT_TYPES); BUILD_DOUBLE_TEMPLATE(template void preluCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz), LIBND4J_TYPES, FLOAT_TYPES);
BUILD_DOUBLE_TEMPLATE(template void preluBPCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void *vIn, const Nd4jLong *inShapeInfo, const void *vAlpha, const Nd4jLong *alphaShapeInfo, const void *vdLdO, const Nd4jLong *dLdOShapeInfo, void *vdLdI, const Nd4jLong *dLdIShapeInfo, void *vdLdA, const Nd4jLong *dLdAShapeInfo), LIBND4J_TYPES, FLOAT_TYPES); BUILD_DOUBLE_TEMPLATE(template void preluBPCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vIn, const Nd4jLong *inShapeInfo, const void *vAlpha, const Nd4jLong *alphaShapeInfo, const void *vdLdO, const Nd4jLong *dLdOShapeInfo, void *vdLdI, const Nd4jLong *dLdIShapeInfo, void *vdLdA, const Nd4jLong *dLdAShapeInfo), LIBND4J_TYPES, FLOAT_TYPES);
BUILD_SINGLE_TEMPLATE(template void softMaxForVectorCudaLauncher, (const cudaStream_t* stream, const void *vx, const Nd4jLong *xzShapeInfo, void *vz), FLOAT_TYPES); BUILD_SINGLE_TEMPLATE(template void softMaxForVectorCudaLauncher, (const cudaStream_t* stream, const void *vx, const Nd4jLong *xzShapeInfo, void *vz), FLOAT_TYPES);
BUILD_SINGLE_TEMPLATE(template void softMaxDerivForVectorCudaLauncher, (const cudaStream_t* stream, const void *vx, const Nd4jLong *xzShapeInfo, void *vz), FLOAT_TYPES); BUILD_SINGLE_TEMPLATE(template void softMaxDerivForVectorCudaLauncher, (const cudaStream_t* stream, const void *vx, const Nd4jLong *xzShapeInfo, void *vz), FLOAT_TYPES);

View File

@ -803,8 +803,12 @@ __global__ static void pooling3dCuda(const void* vx, const Nd4jLong* xShapeInfo,
for (coords[4] = wstart; coords[4] < wend; coords[4] += dW) for (coords[4] = wstart; coords[4] < wend; coords[4] += dW)
sum += x[shape::getOffset(0, xShapeInfo + 1, xShapeInfo + rank + 1, coords, rank)]; sum += x[shape::getOffset(0, xShapeInfo + 1, xShapeInfo + rank + 1, coords, rank)];
if (extraParam0 == 0) //Exclude padding if (extraParam0 == 0) { //Exclude padding
sum /= nd4j::math::nd4j_ceil<double,T>(static_cast<double>(dend - dstart) / static_cast<double>(dD)) * nd4j::math::nd4j_ceil<double,T>(static_cast<double>(hend - hstart) / static_cast<double>(dH)) * nd4j::math::nd4j_ceil<double,T>(static_cast<double>(wend - wstart) / static_cast<double>(dW)); //Accounts for dilation uint a = (dend - dstart) / dD + ((dend - dstart) % dD == 0 ? 0 : 1);
uint b = (hend - hstart) / dH + ((hend - hstart) % dH == 0 ? 0 : 1);
uint c = (wend - wstart) / dW + ((wend - wstart) % dW == 0 ? 0 : 1);
sum /= static_cast<T>(a * b * c); // /= nd4j::math::nd4j_ceil<double,T>(static_cast<double>(dend - dstart) / static_cast<double>(dD)) * nd4j::math::nd4j_ceil<double,T>(static_cast<double>(hend - hstart) / static_cast<double>(dH)) * nd4j::math::nd4j_ceil<double,T>(static_cast<double>(wend - wstart) / static_cast<double>(dW)); //Accounts for dilation
}
else if (extraParam0 == 1) //Include padding else if (extraParam0 == 1) //Include padding
sum /= kProd; sum /= kProd;

View File

@ -15,26 +15,123 @@
******************************************************************************/ ******************************************************************************/
// //
// @author raver119@gmail.com // @author Yurii Shyrma (iuriish@yahoo.com)
// //
#include <ops/declarable/helpers/dilation2d.h> #include <ops/declarable/helpers/dilation2d.h>
#include <array/DataTypeUtils.h> #include <array/DataTypeUtils.h>
#include <PointersManager.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
template <typename X, typename Y>
static void dilation2d_(NDArray *input, NDArray *weights, NDArray *output, int stride_rows, int stride_cols, int rate_rows, int rate_cols, int pad_top, int pad_left) {
}; //////////////////////////////////////////////////////////////////////////
template <typename X, typename Z>
__global__ static void dilation2dCuda(const void* vx, const Nd4jLong* xShapeInfo,
const void* vy, const Nd4jLong* yShapeInfo,
void* vz, const Nd4jLong* zShapeInfo,
const int sH, const int sW,
const int pH, const int pW,
const int dH, const int dW) {
void dilation2d(nd4j::LaunchContext * context, NDArray *input, NDArray *weights, NDArray *output, int stride_rows, int stride_cols, int rate_rows, int rate_cols, int pad_top, int pad_left) { // x [bS, iH, iW, iC]
BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), dilation2d_, (input, weights, output, stride_rows, stride_cols, rate_rows, rate_cols, pad_top, pad_left), LIBND4J_TYPES, FLOAT_TYPES); // y [kH, kW, iC]
// z [bS, oH, oW, iC]
const X* x = reinterpret_cast<const X*>(vx);
const X* y = reinterpret_cast<const X*>(vy);
Z* z = reinterpret_cast<Z*>(vz);
__shared__ int xzRank, yRank;
__shared__ uint iH, iW, kH, kW;
__shared__ Nd4jLong *sharedMem, zLen;
if (threadIdx.x == 0) {
extern __shared__ unsigned char shmem[];
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
zLen = shape::length(zShapeInfo);
xzRank = shape::rank(xShapeInfo);
yRank = shape::rank(yShapeInfo);
iH = xShapeInfo[2];
iW = xShapeInfo[3];
kH = yShapeInfo[1];
kW = yShapeInfo[2];
} }
BUILD_DOUBLE_TEMPLATE(template void dilation2d_, (NDArray *input, NDArray *weights, NDArray *output, int stride_rows, int stride_cols, int rate_rows, int rate_cols, int pad_top, int pad_left), LIBND4J_TYPES, FLOAT_TYPES); __syncthreads();
const auto zInd = threadIdx.x + blockIdx.x * blockDim.x;
if(zInd >= zLen)
return;
auto xzCoords = sharedMem + threadIdx.x * (xzRank + yRank);
auto yCoords = xzCoords + xzRank;
shape::index2coords(xzRank, zShapeInfo + 1, zInd, zLen, xzCoords);
const auto zOffset = shape::getOffset(zShapeInfo, xzCoords);
yCoords[2] = xzCoords[3]; // iC coordinate is same for x, y and z
const auto oh = xzCoords[1];
const auto ow = xzCoords[2];
X max = -DataTypeUtils::max<X>();
for (yCoords[0] = 0; yCoords[0] < kH; ++yCoords[0]) {
xzCoords[1] = oh * sH - pH + yCoords[0] * dH;
if (xzCoords[1] < 0 || xzCoords[1] >= iH) continue;
for (yCoords[1] = 0; yCoords[1] < kW; ++yCoords[1]) {
xzCoords[2] = ow * sW - pW + yCoords[1] * dW;
if(xzCoords[2] < 0 || xzCoords[2] >= iW) continue;
const X val = x[shape::getOffset(xShapeInfo, xzCoords)] + y[shape::getOffset(yShapeInfo, yCoords)];
if (val > max)
max = val;
}
}
z[zOffset] = static_cast<Z>(max);
}
//////////////////////////////////////////////////////////////////////////
template <typename X, typename Z>
static void dilation2dCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream,
const void* vx, const Nd4jLong* xShapeInfo,
const void* vy, const Nd4jLong* yShapeInfo,
void* vz, const Nd4jLong* zShapeInfo,
const int sH, const int sW,
const int pH, const int pW,
const int dH, const int dW) {
dilation2dCuda<X,Z><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, sH, sW, pH, pW, dH, dW);
}
BUILD_DOUBLE_TEMPLATE(template void dilation2dCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW), LIBND4J_TYPES, FLOAT_TYPES);
void dilation2d(nd4j::LaunchContext* context, NDArray *input, NDArray *weights, NDArray *output, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW) {
PointersManager manager(context, "dilation2d");
const int threadsPerBlock = MAX_NUM_THREADS / 2;
const int blocksPerGrid = (output->lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
const int sharedMem = (weights->rankOf() + output->rankOf()) * sizeof(Nd4jLong) * threadsPerBlock + 128;
NDArray::prepareSpecialUse({output}, {input, weights});
BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), dilation2dCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), weights->getSpecialBuffer(), weights->getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), sH, sW, pH, pW, dH, dW), LIBND4J_TYPES, FLOAT_TYPES);
NDArray::registerSpecialUse({output}, {input, weights});
manager.synchronize();
}
} }
} }
} }

View File

@ -107,6 +107,102 @@ void gruTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray*
void gruCellBP(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh, const NDArray* dLdWx0, void gruCellBP(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh, const NDArray* dLdWx0,
const NDArray* dLdWh0, const NDArray* dLdb0, NDArray* dLdx, NDArray* dLdh0, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb) { const NDArray* dLdWh0, const NDArray* dLdb0, NDArray* dLdx, NDArray* dLdh0, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb) {
// x input [bS, iS]
// h0 previous cell output [bS, nU], that is at previous time step t-1
// Wx input-to-hidden weights, [iS, 3*nU]
// Wh hidden-to-hidden weights, [nU, 3*nU]
// b biases, [3*nU]
// dLdh gradient wrt output, [bS,nU], that is epsilon_next
// dLdWx0 gradient wrt Wx at previous time step, [iS, 3*nU]
// dLdWh0 gradient wrt Wh at previous time step, [nU, 3*nU]
// dLdb0 gradient wrt b at previous time step, [3*nU]
// dLdx gradient wrt x, [bS, iS], that is epsilon
// dLdh0 gradient wrt h0, [bS, nU]
// dLdWx gradient wrt Wx, [iS, 3*nU]
// dLdWh gradient wrt Wh, [nU, 3*nU]
// dLdb gradient wrt b at previous time step, [3*nU]
// h is current cell output [bS, nU], that is at current time step t
const int nU = h0->sizeAt(1);
// ***** feed forward step ***** //
// gates = sigmoid(x*Wx + h0*Wh + b)
auto gates = sigmoid(mmul(*x, (*Wx)({0,0, 0,2*nU})) + mmul(*h0, (*Wh)({0,0, 0,2*nU})) + (*b)({0,2*nU})); // [bS, 2*nU] + [bS, 2*nU] + [1, 2*nU] = [bS, 2*nU]
// reset gate
auto r = gates({0,0, 0, nU}); // [bS, nU]
// update gate
auto u = gates({0,0, nU, 2*nU}); // [bS, nU]
// ◦ means element-wise product or so called Hadamard product
// n = tanh(x*Wx + (r◦h0)*Wh + b)
auto n = tanh(mmul(*x, (*Wx)({0,0, 2*nU,3*nU})) + mmul((*h0)*r, (*Wh)({0,0, 2*nU,3*nU})) + (*b)({2*nU,3*nU})); // [bS, nU]
// ***** back prop step ***** //
auto Wxr = (*Wx)({0,0, 0, nU});
auto Wxu = (*Wx)({0,0, nU, 2*nU});
auto Wxn = (*Wx)({0,0, 2*nU,3*nU});
auto Whr = (*Wh)({0,0, 0, nU});
auto Whu = (*Wh)({0,0, nU, 2*nU});
auto Whn = (*Wh)({0,0, 2*nU,3*nU});
auto WxrT = Wxr.transpose();
auto WxuT = Wxu.transpose();
auto WxnT = Wxn.transpose();
auto WhrT = Whr.transpose();
auto WhuT = Whu.transpose();
auto WhnT = Whn.transpose();
auto xT = x->transpose();
auto h0T = h0->transpose();
auto dLdWxr = (*dLdWx)({0,0, 0, nU});
auto dLdWxu = (*dLdWx)({0,0, nU, 2*nU});
auto dLdWxn = (*dLdWx)({0,0, 2*nU,3*nU});
auto dLdWhr = (*dLdWh)({0,0, 0, nU});
auto dLdWhu = (*dLdWh)({0,0, nU, 2*nU});
auto dLdWhn = (*dLdWh)({0,0, 2*nU,3*nU});
auto dLdbr = (*dLdb)({0, nU});
auto dLdbu = (*dLdb)({nU, 2*nU});
auto dLdbn = (*dLdb)({2*nU,3*nU});
auto dhdu = *h0 - n; // [bS, nU]
auto dhdn = 1.f - u; // [bS, nU]
auto dSigdu = u * (1.f - u); // [bS, nU]
auto dSigdr = r * (1.f - r); // [bS, nU]
auto dActdn = 1.f - n * n; // [bS, nU]
auto dndr = mmul(dActdn * (*h0), WhnT);
auto drdh0 = mmul(dSigdr, WhrT);
auto dLdn = (*dLdh) * dhdn;
auto dLdu = (*dLdh) * dhdu;
auto dLdr = dLdn * dndr;
dLdx->assign( mmul(dLdu * dSigdu, WxuT) + mmul(dLdr * dSigdr, WxrT) + mmul(dLdn * dActdn, WxnT) ); // [bS,iS]
dLdh0->assign( mmul(dLdu * dSigdu, WhuT) + mmul(dLdn * dActdn * (r + drdh0), WhnT) + (*dLdh)*u ); // [bS,nU]
dLdWxr.assign( mmul(xT, dSigdr * dLdr) ); // [iS,nU]
dLdWhr.assign( mmul(h0T, dSigdr * dLdr) ); // [nU,nU]
dLdWxu.assign( mmul(xT, dSigdu * dLdu) ); // [iS,nU]
dLdWhu.assign( mmul(h0T, dSigdu * dLdu) ); // [nU,nU]
dLdWxn.assign( mmul(xT, dActdn * dLdn) ); // [iS,nU]
dLdWhn.assign( mmul((r*(*h0)).transpose(), dActdn * dLdn) ); // [nU,nU]
dLdbr.assign( (dSigdr * dLdr).reduceAlongDims(reduce::Sum, {0})); // [nU]
dLdbu.assign( (dSigdu * dLdu).reduceAlongDims(reduce::Sum, {0})); // [nU]
dLdbn.assign( (dActdn * dLdn).reduceAlongDims(reduce::Sum, {0})); // [nU]
if(dLdWx0 != nullptr)
*dLdWx += *dLdWx0;
if(dLdWh0 != nullptr)
*dLdWh += *dLdWh0;
if(dLdb0 != nullptr)
*dLdb += *dLdb0;
} }

View File

@ -23,13 +23,17 @@
#include <NDArrayFactory.h> #include <NDArrayFactory.h>
#include <Status.h> #include <Status.h>
#include <ConstantTadHelper.h> #include <ConstantTadHelper.h>
#include <ShapeUtils.h>
#include <cusolverDn.h>
#include <cuda_exception.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
template <typename T> template <typename T>
static __device__ void _swapRows(T* matrix, Nd4jLong* shape, int theFirst, int theSecond, Nd4jLong N) { static __device__ void swapRows_(T* matrix, Nd4jLong* shape, int theFirst, int theSecond, Nd4jLong N) {
if (theFirst != theSecond) { if (theFirst != theSecond) {
auto start = threadIdx.x + blockIdx.x * blockDim.x; auto start = threadIdx.x + blockIdx.x * blockDim.x;
auto step = blockDim.x * gridDim.x; auto step = blockDim.x * gridDim.x;
@ -46,32 +50,180 @@ namespace helpers {
} }
} }
} }
// BUILD_SINGLE_TEMPLATE(template void _swapRows, (NDArray* matrix, int theFirst, int theSecond), FLOAT_TYPES); // BUILD_SINGLE_TEMPLATE(template void swapRows_, (NDArray* matrix, int theFirst, int theSecond), FLOAT_TYPES);
// //
// void swapRows(NDArray* matrix, int theFirst, int theSecond) { // void swapRows(NDArray* matrix, int theFirst, int theSecond) {
// BUILD_SINGLE_SELECTOR(matrix->dataType(), _swapRows, (matrix, theFirst, theSecond), FLOAT_TYPES); // BUILD_SINGLE_SELECTOR(matrix->dataType(), swapRows_, (matrix, theFirst, theSecond), FLOAT_TYPES);
// } // }
template <typename T> template <typename T>
static void _invertLowerMatrix(NDArray* inputMatrix, NDArray* invertedMatrix) { static __global__ void invertKernelLow(void* invertedBuf, Nd4jLong* invertedShape, void* inputBuf, Nd4jLong* inputShape, Nd4jLong n) {
__shared__ T* inverted;
__shared__ T* input;
if (threadIdx.x == 0) {
inverted = reinterpret_cast<T*>(invertedBuf);
input = reinterpret_cast<T*>(inputBuf);
}
__syncthreads();
auto start = threadIdx.x + blockIdx.x * blockDim.x;
auto step = blockDim.x * gridDim.x;
for (int i = start + 1; i < n; i += step) {
Nd4jLong pos[] = {i, i - 1};
auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2);
auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), pos, 2);
inverted[zIndex] = -input[xIndex];
}
} }
BUILD_SINGLE_TEMPLATE(template void _invertLowerMatrix, (NDArray* inputMatrix, NDArray* invertedMatrix);, FLOAT_TYPES); template <typename T>
static __global__ void upvertKernel(void* invertedBuf, Nd4jLong* invertedShape, void* inputBuf, Nd4jLong* inputShape, Nd4jLong n) {
__shared__ T* inverted;
__shared__ T* input;
if (threadIdx.x == 0) {
inverted = reinterpret_cast<T*>(invertedBuf);
input = reinterpret_cast<T*>(inputBuf);
}
__syncthreads();
auto start = threadIdx.x + blockIdx.x * blockDim.x;
auto step = blockDim.x * gridDim.x;
for (int i = start + 1; i < n; i += step) {
Nd4jLong pos[] = {i, i};
auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2);
auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), pos, 2);
inverted[zIndex] /= input[xIndex];
}
}
template <typename T>
static __global__ void upvertKernelUp(void* invertedBuf, Nd4jLong* invertedShape, void* inputBuf, Nd4jLong* inputShape, Nd4jLong n) {
__shared__ T* inverted;
__shared__ T* input;
if (threadIdx.x == 0) {
inverted = reinterpret_cast<T*>(invertedBuf);
input = reinterpret_cast<T*>(inputBuf);
}
__syncthreads();
auto start = threadIdx.x + blockIdx.x * blockDim.x;
auto step = blockDim.x * gridDim.x;
for (int i = start + 1; i < n - 1; i += step) {
Nd4jLong pos[] = {i, i + 1};
Nd4jLong posY[] = {i, i};
Nd4jLong posX[] = {i + 1, i + 1};
auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2);
auto yIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2);
// auto yIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2);
auto iIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posX, 2);
auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), pos, 2);
inverted[zIndex] -= input[xIndex] * inverted[iIndex] / input[yIndex];
//inputMatrix->t<T>(i, i + 1) * invertedMatrix->t<T>(i + 1, i + 1) / inputMatrix->t<T>(i, i)
}
}
template <typename T>
static __global__ void invertLowKernel(void* invertedBuf, Nd4jLong* invertedShape, void* inputBuf, Nd4jLong* inputShape, Nd4jLong n) {
__shared__ T* inverted;
__shared__ T* input;
if (threadIdx.x == 0) {
inverted = reinterpret_cast<T*>(invertedBuf);
input = reinterpret_cast<T*>(inputBuf);
}
__syncthreads();
// auto start = threadIdx.x + blockIdx.x * blockDim.x;
// auto step = blockDim.x * gridDim.x;
for (int i = blockIdx.x + 2; i < n; i += gridDim.x) {
for (int j = i - 2; j > -1; --j)
for (int k = threadIdx.x; k < i; k+= blockDim.x) {
Nd4jLong posZ[] = {i, j};
Nd4jLong posX[] = {k, j};
Nd4jLong posY[] = {i, k};
auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posX, 2);
auto yIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posY, 2);
auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posZ, 2);
inverted[zIndex] -= inverted[yIndex] * input[xIndex];
}
}
}
template <typename T>
static __global__ void invertUpKernel(void* invertedBuf, Nd4jLong* invertedShape, void* inputBuf, Nd4jLong* inputShape, Nd4jLong n) {
__shared__ T* inverted;
__shared__ T* input;
if (threadIdx.x == 0) {
inverted = reinterpret_cast<T*>(invertedBuf);
input = reinterpret_cast<T*>(inputBuf);
}
__syncthreads();
// auto start = threadIdx.x + blockIdx.x * blockDim.x;
// auto step = blockDim.x * gridDim.x;
for (int i = n - blockIdx.x - 2; i >= 0; i -= gridDim.x) {
for (int j = i + 2; j < n; j++)
for (int k = i + threadIdx.x; k < n; k+= blockDim.x) {
Nd4jLong posZ[] = {i, j};
Nd4jLong posY[] = {k, j};
Nd4jLong posX[] = {i, k};
Nd4jLong posD[] = {i, i};
auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posX, 2);
auto yIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posY, 2);
auto dIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posD, 2);
auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posZ, 2);
inverted[zIndex] -= inverted[yIndex] * input[xIndex] / input[dIndex];
}
}
}
template <typename T>
static void invertLowerMatrix_(NDArray* inputMatrix, NDArray* invertedMatrix) {
int n = inputMatrix->rows();
invertedMatrix->setIdentity();
if (inputMatrix->isIdentityMatrix()) return;
LaunchContext* context = inputMatrix->getContext();
auto stream = context->getCudaStream();
invertKernelLow<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
invertLowKernel<T><<<n, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
}
BUILD_SINGLE_TEMPLATE(template void invertLowerMatrix_, (NDArray* inputMatrix, NDArray* invertedMatrix);, FLOAT_TYPES);
void invertLowerMatrix(NDArray* inputMatrix, NDArray* invertedMatrix) { void invertLowerMatrix(NDArray* inputMatrix, NDArray* invertedMatrix) {
BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), _invertLowerMatrix, (inputMatrix, invertedMatrix), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertLowerMatrix_, (inputMatrix, invertedMatrix), FLOAT_TYPES);
} }
template <typename T> template <typename T>
static void _invertUpperMatrix(NDArray* inputMatrix, NDArray* invertedMatrix) { static void invertUpperMatrix_(NDArray* inputMatrix, NDArray* invertedMatrix) {
int n = inputMatrix->rows();
invertedMatrix->setIdentity();
auto stream = inputMatrix->getContext()->getCudaStream();
if (inputMatrix->isIdentityMatrix()) { // the inverse for I is I
return;
}
upvertKernel<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
upvertKernelUp<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
invertUpKernel<T><<<n, n, 256, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
} }
BUILD_SINGLE_TEMPLATE(template void _invertUpperMatrix, (NDArray* inputMatrix, NDArray* invertedMatrix);, FLOAT_TYPES); BUILD_SINGLE_TEMPLATE(template void invertUpperMatrix_, (NDArray* inputMatrix, NDArray* invertedMatrix);, FLOAT_TYPES);
void invertUpperMatrix(NDArray* inputMatrix, NDArray* invertedMatrix) { void invertUpperMatrix(NDArray* inputMatrix, NDArray* invertedMatrix) {
BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), _invertUpperMatrix, (inputMatrix, invertedMatrix), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertUpperMatrix_, (inputMatrix, invertedMatrix), FLOAT_TYPES);
} }
template <typename T> template <typename T>
@ -91,8 +243,8 @@ namespace helpers {
} }
if( pivotValue != T(0.0) ) { if( pivotValue != T(0.0) ) {
_swapRows<T>(compound, compoundShape, pivot, i, rowNum); swapRows_<T>(compound, compoundShape, pivot, i, rowNum);
_swapRows<T>(permutation, permutationShape, pivot, i, rowNum); swapRows_<T>(permutation, permutationShape, pivot, i, rowNum);
if (pivot != i) if (pivot != i)
swapCount++; swapCount++;
@ -115,124 +267,582 @@ namespace helpers {
} }
} }
} }
template <typename T>
static __global__ void determinantKernel(T* compound, Nd4jLong* shape, T* result) {
__shared__ Nd4jLong len;
if (threadIdx.x == 0) { template <typename T, typename F>
len = shape::length(shape); static __global__ void determinantKernel(T* compound, T* result, Nd4jLong len) {
__shared__ F tempRes;
if (blockIdx.x == 0) {
tempRes = (F)result[0];
} }
__syncthreads();
auto start = blockIdx.x * blockDim.x + threadIdx.x; auto start = blockIdx.x * blockDim.x + threadIdx.x;
auto step = blockDim.x * gridDim.x; auto step = blockDim.x * gridDim.x;
for (auto i = start; i < len; i += step) { for (auto i = start; i < len; i += step) {
Nd4jLong di[] = {i, i}; auto pos = i * len + i; //shape::getOffset(0, shape::shapeOf(shape), shape::stride(shape), di, 2);
auto pos = shape::getOffset(0, shape::shapeOf(shape), shape::stride(shape), di, 2); math::atomics::nd4j_atomicMul<F>(&tempRes, (F)compound[pos]);
math::atomics::nd4j_atomicMul(result, compound[pos]); }
__syncthreads();
if (blockIdx.x == 0) {
result[0] = (T)tempRes;
} }
} }
template <typename T>
static __global__ void determinantFullKernel(T* input, Nd4jLong* inputShape, T* output, Nd4jLong* outputShape, Nd4jLong* tadShape, Nd4jLong* tadOffsets) {
template <typename T, typename F>
static __global__ void determinantLogKernel(T* compound, T* result, Nd4jLong len) {
__shared__ F tempRes;
if (blockIdx.x == 0) {
tempRes = (F)result[0];
}
__syncthreads();
auto start = blockIdx.x * blockDim.x + threadIdx.x;
auto step = blockDim.x * gridDim.x;
for (auto i = start; i < len; i += step) {
auto pos = i * len + i; //shape::getOffset(0, shape::shapeOf(shape), shape::stride(shape), di, 2);
math::atomics::nd4j_atomicMul<F>(&tempRes, (F)compound[pos]);
}
__syncthreads();
if (blockIdx.x == 0) {
result[0] = (T)math::nd4j_log<F,F>(math::nd4j_abs(tempRes));
}
}
template <typename T, typename F>
static __global__ void fillMatrix(void* output, Nd4jLong* outShape, void* input, Nd4jLong* inputShape, Nd4jLong pos, Nd4jLong rowLen) {
__shared__ F* matrix;
__shared__ T* inputBuf;
__shared__ Nd4jLong inputLen;
__shared__ Nd4jLong n2;
if (threadIdx.x == 0) {
matrix = reinterpret_cast<F*>(output);
inputBuf = reinterpret_cast<T*>(input);
inputLen = shape::length(inputShape);
n2 = rowLen * rowLen;
}
__syncthreads();
auto start = blockIdx.x * blockDim.x + threadIdx.x;
auto step = blockDim.x * gridDim.x;
for (int k = pos + start, j = start; j < n2; k += step, j += step) {
auto xIndex = shape::getIndexOffset(k, inputShape, inputLen);
matrix[j] = (F)inputBuf[xIndex];
}
}
template <typename F>
static __global__ void fillUpPermutation(void* output, Nd4jLong* shape, int* source, int rowNum) {
__shared__ F* permutation;
if (threadIdx.x == 0) {
permutation = reinterpret_cast<F*>(output);
}
auto start = blockIdx.x * blockDim.x + threadIdx.x;
auto step = blockDim.x * gridDim.x;
for (auto i = start; i < rowNum; i += step) {
int val = source[i] - 1;
Nd4jLong posF[] = {i, val};
auto pos = shape::getOffset(0, shape::shapeOf(shape), shape::stride(shape), posF, 2);
permutation[pos] = F(1.f);
}
} }
template <typename T> template <typename T>
static NDArray _lup(LaunchContext* context, NDArray* input, NDArray* compound, NDArray* permutation) { static void lup_(LaunchContext* context, NDArray* input, NDArray* compound, NDArray* permutation) {
NDArray determinant = NDArrayFactory::create<T>(1.f);
auto rowNum = input->rows();
auto columnNum = input->columns();
NDArray compoundMatrix = *input; // copy
NDArray permutationMatrix(input, false, input->getContext()); // has same shape as input and contiguous strides
permutationMatrix.setIdentity();
T pivotValue; // = T(0.0);
int pivot; // = -1;
int swapCount = 0;
T* compoundBuf = reinterpret_cast<T*>(compoundMatrix.specialBuffer());
T* permutationBuf = reinterpret_cast<T*>(permutationMatrix.specialBuffer());
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
lupKernel<T><<<256, 256, 1024, *stream>>>(compoundBuf, compoundMatrix.specialShapeInfo(), permutationBuf, permutationMatrix.specialShapeInfo(), rowNum); auto n = input->rows();
determinantKernel<T><<<256, 256, 1024, *stream>>>(compoundBuf, compoundMatrix.specialShapeInfo(), reinterpret_cast<T*>(determinant.specialBuffer())); cusolverDnHandle_t cusolverH = nullptr;
// for (int e = 0; e < rowNum; e++) { cusolverStatus_t status = cusolverDnCreate(&cusolverH);
// // nd4j_printf("Compound matrix diag %i %f.\n", e, (*compoundMatrix)(e, e)); if (CUSOLVER_STATUS_SUCCESS != status) {
// determinant *= compoundMatrix.e<T>(e, e); throw cuda_exception::build("Cannot create cuSolver handle", status);
// } }
if (swapCount % 2) determinant = -determinant; status = cusolverDnSetStream(cusolverH, *stream);
if (compound != nullptr) if (CUSOLVER_STATUS_SUCCESS != status) {
compound->assign(compoundMatrix); throw cuda_exception::build("Cannot set up stream for cuda solver", status);
if (permutation != nullptr) }
permutation->assign(permutationMatrix); int lwork = 0;
return determinant; int *d_info = nullptr;
auto err = cudaMalloc((void **) &d_info, sizeof(int));
if (err) {
throw cuda_exception::build("helpers::lup_: Cannot allocate memory for solver info buffer", err);
}
DataType dtype = input->dataType();
switch(dtype) {
case DataType::DOUBLE: {
double *d_work = nullptr;
err = cudaMalloc((void **) &d_work, sizeof(float) * lwork);
if (err) {
throw cuda_exception::build("helpers::lup_: Cannot allocate memory for solver data buffer", err);
}
double *matrix = reinterpret_cast<double*>(input->specialBuffer());
status = cusolverDnDgetrf_bufferSize(
cusolverH,
n,
n,
matrix,
n,
&lwork);
if (CUSOLVER_STATUS_SUCCESS != status) {
throw cuda_exception::build("helpers::lup_: Cannot create cuSolver handle", status);
}
if (permutation == nullptr)
status = cusolverDnDgetrf(
cusolverH,
n,
n,
matrix,
n,
d_work,
nullptr,
d_info);
else {
NDArray permutVector('c', {n}, nd4j::DataType::INT32, context);
int *permutationBuf = reinterpret_cast<int *>(permutVector.specialBuffer());
status = cusolverDnDgetrf(
cusolverH,
n,
n,
matrix,
n,
d_work,
permutationBuf,
d_info);
fillUpPermutation<double><<<n, n, 128, *stream>>>(permutation->specialBuffer(), permutation->specialShapeInfo(), permutationBuf, n);
permutation->tickWriteDevice();
}
err = cudaFree(d_work);
if (err) {
throw cuda_exception::build("helpers::lup_: Cannot deallocate memory for solver data buffer", err);
}
}
break;
case DataType::FLOAT32: {
float *matrix = reinterpret_cast<float*>(input->specialBuffer());
float *d_work = nullptr;
err = cudaMalloc((void **) &d_work, sizeof(float) * lwork);
if (err) {
throw cuda_exception::build("helpers::lup_: Cannot allocate memory for solver data buffer", err);
}
status = cusolverDnSgetrf_bufferSize(
cusolverH,
n,
n,
matrix,
n,
&lwork);
if (CUSOLVER_STATUS_SUCCESS != status) {
throw cuda_exception::build("helpers::lup_: Cannot create cuSolver handle", status);
}
if (permutation == nullptr)
status = cusolverDnSgetrf(
cusolverH,
n,
n,
matrix,
n,
d_work,
nullptr,
d_info);
else {
NDArray permutVector('c', {n}, nd4j::DataType::INT32, context);
int *permutationBuf = reinterpret_cast<int *>(permutVector.specialBuffer());
status = cusolverDnSgetrf(
cusolverH,
n,
n,
matrix,
n,
d_work,
permutationBuf,
d_info);
fillUpPermutation<float><<<n, n, 128, *stream>>>(permutation->specialBuffer(), permutation->specialShapeInfo(), permutationBuf, n);
permutation->tickWriteDevice();
}
err = cudaFree(d_work);
if (err) {
throw cuda_exception::build("helpers::lup_: Cannot deallocate memory for solver data buffer", err);
}
}
}
if (CUSOLVER_STATUS_SUCCESS != status) {
throw cuda_exception::build("helpers::lup_: Cannot make LU decomposition", status);
}
err = cudaFree(d_info);
if (err) {
throw cuda_exception::build("helpers::lup_: Cannot deallocate memory for solver info buffer", err);
}
cusolverDnDestroy(cusolverH);
// NDArray::registerSpecialUse({input}, {input});
input->tickWriteDevice();
} }
BUILD_SINGLE_TEMPLATE(template NDArray _lup, (LaunchContext* context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES); BUILD_SINGLE_TEMPLATE(template void lup_, (LaunchContext* context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES);
template <typename T> template <typename T>
static int _determinant(nd4j::LaunchContext* context, NDArray* input, NDArray* output) { static int determinant_(nd4j::LaunchContext* context, NDArray* input, NDArray* output) {
Nd4jLong n = input->sizeAt(-1); Nd4jLong n = input->sizeAt(-1);
Nd4jLong n2 = n * n; Nd4jLong n2 = n * n;
std::vector<int> dims(); std::vector<int> dims();
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {input->rankOf() - 2, input->rankOf() - 1}); auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {input->rankOf() - 2, input->rankOf() - 1});
//auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1}); //auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1});
DataType dtype = input->dataType();
if (dtype != DataType::DOUBLE)
dtype = DataType::FLOAT32;
//auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), input->getContext()); //, block.getWorkspace()); auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype, input->getContext()); //, block.getWorkspace());
auto det = NDArrayFactory::create<T>(1);
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
auto inputBuf = reinterpret_cast<T*>(input->specialBuffer()); NDArray::prepareSpecialUse({output}, {input});
auto outputBuf = reinterpret_cast<T*>(output->specialBuffer());
dim3 launchDims(256, 256, 1024); dim3 launchDims(256, 256, 1024);
determinantFullKernel<T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(inputBuf, input->specialShapeInfo(), outputBuf, output->specialShapeInfo(), packX.specialShapeInfo(), packX.specialOffsets()); output->assign(1.f);
// for (int e = 0; e < output->lengthOf(); e++) { for (int e = 0; e < output->lengthOf(); e++) {
// for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row) Nd4jLong pos = e * n2;
// matrix.p(row, input->e<T>(k)); if (matrix.dataType() == input->dataType())
//// output->p(e, lup_<T>(&matrix, (NDArray*)nullptr, (NDArray*)nullptr)); fillMatrix<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
// } else
fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
if (matrix.dataType() == input->dataType())
lup_<T>(context, &matrix, nullptr, nullptr);
else
lup_<float>(context, &matrix, nullptr, nullptr);
auto offset = shape::getIndexOffset(e, output->shapeInfo(), output->lengthOf());
auto inputBuf = reinterpret_cast<T*>(matrix.specialBuffer());
auto outputBuf = reinterpret_cast<T*>(output->specialBuffer()) + offset;
if (matrix.dataType() == input->dataType())
determinantKernel<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n);
else
determinantKernel<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n);
}
NDArray::registerSpecialUse({output}, {input});
return Status::OK(); return Status::OK();
} }
BUILD_SINGLE_TEMPLATE(template int _determinant, (nd4j::LaunchContext* context, NDArray* input, NDArray* output), FLOAT_TYPES); BUILD_SINGLE_TEMPLATE(template int determinant_, (nd4j::LaunchContext* context, NDArray* input, NDArray* output), FLOAT_TYPES);
int determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { int determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
BUILD_SINGLE_SELECTOR(input->dataType(), return _determinant, (context, input, output), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (context, input, output), FLOAT_TYPES);
} }
template <typename T> template <typename T>
int log_abs_determinant_(NDArray* input, NDArray* output) { int logAbsDeterminant_(LaunchContext* context, NDArray* input, NDArray* output) {
Nd4jLong n = input->sizeAt(-1);
Nd4jLong n2 = n * n;
std::vector<int> dims();
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {input->rankOf() - 2, input->rankOf() - 1});
//auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1});
DataType dtype = input->dataType();
if (dtype != DataType::DOUBLE)
dtype = DataType::FLOAT32;
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype, input->getContext()); //, block.getWorkspace());
auto det = NDArrayFactory::create<T>(1);
auto stream = context->getCudaStream();
NDArray::prepareSpecialUse({output}, {input});
dim3 launchDims(256, 256, 1024);
output->assign(1.f);
for (int e = 0; e < output->lengthOf(); e++) {
Nd4jLong pos = e * n2;
if (matrix.dataType() == input->dataType())
fillMatrix<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
else
fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
if (matrix.dataType() == input->dataType())
lup_<T>(context, &matrix, nullptr, nullptr);
else
lup_<float>(context, &matrix, nullptr, nullptr);
auto offset = shape::getIndexOffset(e, output->shapeInfo(), output->lengthOf());
auto inputBuf = reinterpret_cast<T*>(matrix.specialBuffer());
auto outputBuf = reinterpret_cast<T*>(output->specialBuffer()) + offset;
if (matrix.dataType() == input->dataType())
determinantLogKernel<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n);
else
determinantLogKernel<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n);
}
NDArray::registerSpecialUse({output}, {input});
return Status::OK();
return ND4J_STATUS_OK; return ND4J_STATUS_OK;
} }
BUILD_SINGLE_TEMPLATE(template int log_abs_determinant_, (NDArray* input, NDArray* output), FLOAT_TYPES); BUILD_SINGLE_TEMPLATE(template int logAbsDeterminant_, (LaunchContext* context, NDArray* input, NDArray* output), FLOAT_TYPES);
int log_abs_determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { int logAbsDeterminant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
BUILD_SINGLE_SELECTOR(input->dataType(), return log_abs_determinant_, (input, output), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (context, input, output), FLOAT_TYPES);
} }
template <typename T> template <typename T>
static int _inverse(NDArray* input, NDArray* output) { static __global__ void fillLowerUpperKernel(void* lowerBuf, Nd4jLong* lowerShape, void* upperBuf, Nd4jLong* upperShape, void* matrixBuf, Nd4jLong* matrixShape, Nd4jLong n) {
__shared__ Nd4jLong* xShapeOf;
__shared__ Nd4jLong* yShapeOf;
__shared__ Nd4jLong* zShapeOf;
__shared__ Nd4jLong* xStrideOf;
__shared__ Nd4jLong* yStrideOf;
__shared__ Nd4jLong* zStrideOf;
__shared__ T* lowerMatrix;
__shared__ T* upperMatrix;
__shared__ T* matrix;
if (threadIdx.x == 0) {
xShapeOf = shape::shapeOf(lowerShape);
yShapeOf = shape::shapeOf(upperShape);
zShapeOf = shape::shapeOf(matrixShape);
xStrideOf = shape::stride(lowerShape);
yStrideOf = shape::stride(upperShape);
zStrideOf = shape::stride(matrixShape);
lowerMatrix = reinterpret_cast<T*>(lowerBuf);
upperMatrix = reinterpret_cast<T*>(upperBuf);
matrix = reinterpret_cast<T*>(matrixBuf);
}
__syncthreads();
for (int k = blockIdx.x; k < n; k += gridDim.x) { // and then put all values under main diagonal on to it
for (int j = threadIdx.x; j < n; j += blockDim.x) {
Nd4jLong posX[] = {j, k};
auto xPos = shape::getOffset(0, xShapeOf, xStrideOf, posX, 2);
auto yPos = shape::getOffset(0, yShapeOf, yStrideOf, posX, 2);
auto pos = shape::getOffset(0, zShapeOf, zStrideOf, posX, 2);
if (k <= j)
lowerMatrix[xPos] = matrix[pos];//(k, j);
else
upperMatrix[yPos] = matrix[pos]; //k, j);
}
}
}
template <typename T>
static int inverse_(nd4j::LaunchContext* context, NDArray* input, NDArray* output) {
auto n = input->sizeAt(-1);
auto n2 = n * n;
auto dtype = input->dataType();
if (dtype != DataType::DOUBLE)
dtype = DataType::FLOAT32;
NDArray matrix = NDArrayFactory::create('c', {n, n}, dtype, input->getContext());
NDArray upper = NDArrayFactory::create('c', {n, n}, dtype, input->getContext());
NDArray lower = NDArrayFactory::create('c', {n, n}, dtype, input->getContext());
NDArray compound = NDArrayFactory::create('c', {n, n}, dtype, input->getContext());
NDArray permutation = NDArrayFactory::create('c', {n, n}, dtype, input->getContext());
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {input->rankOf() - 2, input->rankOf() - 1});
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {output->rankOf() - 2, output->rankOf() - 1});
auto stream = context->getCudaStream();
for (auto i = 0LL; i < packX.numberOfTads(); i++) {
fillMatrix<T, float><<<1, n2, 128, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), i * n2, n);
permutation.assign(0.f);
lup_<float>(context, &matrix, &compound, &permutation);
matrix.tickWriteDevice();
permutation.tickWriteDevice();
permutation.printIndexedBuffer("PERMUTE");
lower.setIdentity(); // set up U to identity matrix
upper.setIdentity();
fillLowerUpperKernel<float><<<1, n2, 128>>>(lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), matrix.specialBuffer(), matrix.specialShapeInfo(), n);
lower.tickWriteDevice();
upper.tickWriteDevice();
invertUpperMatrix(&upper, &matrix);
invertLowerMatrix(&lower, &upper);
lower.tickWriteDevice();
upper.tickWriteDevice();
lower.printIndexedBuffer("LOWER");
upper.printIndexedBuffer("UPPER");
nd4j::MmulHelper::mmul(&matrix, &upper, &compound, 1.0, 0.0);
nd4j::MmulHelper::mmul(&compound, &permutation, &matrix, 1.0, 0.0);
// for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) {
// output->t<T>(k) = matrix.template t<T>(row++);
// }
}
return Status::OK(); return Status::OK();
} }
int inverse(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { int inverse(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
BUILD_SINGLE_SELECTOR(input->dataType(), return _inverse, (input, output), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (context, input, output), FLOAT_TYPES);
} }
bool checkCholeskyInput(nd4j::LaunchContext * context, NDArray const* input) { bool checkCholeskyInput(nd4j::LaunchContext * context, NDArray const* input) {
return false; return true;
} }
template <typename T> template <typename F>
int cholesky_(NDArray* input, NDArray* output, bool inplace) { __global__ void fillBatchKernel(F** dArrayBatch, F* buf, Nd4jLong* offsets, Nd4jLong batchSize) {
auto start = blockIdx.x * blockDim.x + threadIdx.x;
auto step = blockDim.x * gridDim.x;
for (auto i = start; i < batchSize; i += step) {
dArrayBatch[i] = buf + offsets[i];
}
}
template <typename F>
__global__ void adjustResultsKernel(F* dArray, Nd4jLong* shape, Nd4jLong* offsets, Nd4jLong batchSize, Nd4jLong n) {
//auto i = blockIdx.x * blockDim.x + threadIdx.x;
__shared__ Nd4jLong* shapeOf;
__shared__ Nd4jLong* strideOf;
if (blockIdx.x == 0 && threadIdx.x == 0) {
shapeOf = shape::shapeOf(shape);
strideOf = shape::stride(shape);
}
__syncthreads();
for (auto i = blockIdx.x; i < batchSize; i+= gridDim.x) {
auto current = dArray + offsets[i];
for (auto r = threadIdx.x; r < n; r += blockDim.x) {
for (auto c = r + 1; c < n; c++) {
Nd4jLong posRC[] = {r, c};
auto pos = r * n + c; //shape::getOffset(0, shapeOf, strideOf, posRC, 2);
current[pos] = 0.;
}
}
}
}
template <typename F>
int cholesky__(LaunchContext* context, NDArray* input, NDArray* output, bool inplace) {
if (!inplace)
output->assign(input);
std::unique_ptr<NDArray> tempOutput(output->dup());
cusolverDnHandle_t handle = nullptr;
auto n = input->sizeAt(-1);
auto n2 = n * n;
NDArray::prepareSpecialUse({output}, {input});
auto status = cusolverDnCreate(&handle);
if (CUSOLVER_STATUS_SUCCESS != status) {
throw cuda_exception::build("helpers::cholesky_: Cannot create solver handle", status);
}
F** dArrayBatch = nullptr;
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempOutput->getShapeInfo(), {tempOutput->rankOf() - 2, tempOutput->rankOf() - 1});
const Nd4jLong batchSize = packX.numberOfTads();
int* dInfoArray = nullptr;
auto err = cudaMalloc((void**)&dArrayBatch, sizeof(F*) * batchSize);
if (err) {
throw cuda_exception::build("helpers::cholesky_: Cannot allocate memory for solver batch data buffer", err);
}
err = cudaMalloc ((void**)&dInfoArray, sizeof(int) * batchSize);
if (err) {
throw cuda_exception::build("helpers::cholesky_: Cannot allocate memory for solver errors buffer", err);
}
auto stream = context->getCudaStream();
fillBatchKernel<F><<<1, batchSize, 128, *stream>>>(dArrayBatch, reinterpret_cast<F*>(tempOutput->specialBuffer()), packX.specialOffsets(), batchSize);
status = cusolverDnSetStream(handle, *stream);
if (CUSOLVER_STATUS_SUCCESS != status) {
throw cuda_exception::build("helpers::cholesky_: Cannot set stream to solver handle", status);
}
const cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER;
if (input->dataType() == DataType::DOUBLE)
status = cusolverDnDpotrfBatched(
handle,
uplo,
n,
(double**)dArrayBatch,
n,
dInfoArray,
batchSize);
else
status = cusolverDnSpotrfBatched(
handle,
uplo,
n,
(float**)dArrayBatch,
n,
dInfoArray,
batchSize);
if (CUSOLVER_STATUS_SUCCESS != status) {
throw cuda_exception::build("helpers::cholesky_: Cholesky factorization failed for batch", status);
}
adjustResultsKernel<F><<<batchSize, n2, 128, *stream>>>(reinterpret_cast<F*>(tempOutput->specialBuffer()), packX.specialShapeInfo(), packX.specialOffsets(), batchSize, n);
err = cudaFree(dArrayBatch);
if (err) {
throw cuda_exception::build("helpers::cholesky_: Cannot deallocate memory for solver batch data buffer", err);
}
err = cudaFree(dInfoArray);
if (err) {
throw cuda_exception::build("helpers::cholesky_: Cannot allocate memory for solver errors buffer", err);
}
if(!inplace)
output->assign(tempOutput.get());
NDArray::registerSpecialUse({output}, {input});
return Status::OK(); return Status::OK();
} }
int cholesky(nd4j::LaunchContext * context, NDArray* input, NDArray* output, bool inplace) { // template <typename T>
BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (input, output, inplace), FLOAT_TYPES); int cholesky_(LaunchContext* context, NDArray* input, NDArray* output, bool inplace) {
} if (input->dataType() == DataType::DOUBLE)
BUILD_SINGLE_TEMPLATE(template int cholesky_, (NDArray* input, NDArray* output, bool inplace), FLOAT_TYPES); cholesky__<double>(context, input, output, inplace);
BUILD_SINGLE_TEMPLATE(template int _inverse, (NDArray* input, NDArray* output), FLOAT_TYPES); else if (input->dataType() == DataType::FLOAT32)
cholesky__<float>(context, input, output, inplace);
else {
std::unique_ptr<NDArray> tempOutput(NDArrayFactory::create_('c', input->getShapeAsVector(), DataType::FLOAT32, input->getContext()));
tempOutput->assign(input);
cholesky__<float>(context, tempOutput.get(), tempOutput.get(), true);
output->assign(tempOutput.get());
}
return Status::OK();
}
int cholesky(nd4j::LaunchContext* context, NDArray* input, NDArray* output, bool inplace) {
// BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (context, input, output, inplace), FLOAT_TYPES);
return cholesky_(context, input, output, inplace);
}
// BUILD_SINGLE_TEMPLATE(template int cholesky_, (LaunchContext* context, NDArray* input, NDArray* output, bool inplace), FLOAT_TYPES);
BUILD_SINGLE_TEMPLATE(template int inverse_, (nd4j::LaunchContext* context, NDArray* input, NDArray* output), FLOAT_TYPES);
int logdetFunctor(nd4j::LaunchContext * context, NDArray* input, NDArray* output) { __global__ void logDetKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong batchNum, Nd4jLong* tadShape, Nd4jLong* tadOffsets, void* outputBuf, Nd4jLong* outputShape) {
return 119; __shared__ double* output;
__shared__ double* input;
__shared__ int n2;
if (threadIdx.x == 0) {
output = reinterpret_cast<double*>(outputBuf);
input = reinterpret_cast<double*>(inputBuf);
n2 = shape::sizeAt(inputShape, -1) * shape::sizeAt(inputShape, -1);
}
__syncthreads();
for (Nd4jLong i = blockIdx.x; i < batchNum; i += gridDim.x) {
double* current = input + tadOffsets[i];
Nd4jLong* shapeOf = shape::shapeOf(tadShape);
Nd4jLong* strideOf = shape::stride(tadShape);
auto zIndex = shape::getIndexOffset(i, outputShape, batchNum);
for (Nd4jLong e = threadIdx.x; e < n2; e += blockDim.x) {
Nd4jLong diag[] = {e, e};
auto xIndex = shape::getOffset(0, shapeOf, strideOf, diag, 2);
math::atomics::nd4j_atomicAdd(&output[zIndex], math::nd4j_log<double,double>(current[xIndex] * current[xIndex]));
}
}
}
int logdetFunctor(nd4j::LaunchContext* context, NDArray* input, NDArray* output) {
NDArray::prepareSpecialUse({output}, {input});
auto tempOutput = input->dup('c');
auto n2 = input->sizeAt(-1) * input->sizeAt(-2);
auto stream = context->getCudaStream();
cholesky(context, tempOutput, tempOutput, true);
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempOutput->getShapeInfo(), {tempOutput->rankOf() - 2, tempOutput->rankOf() - 1});
//for (Nd4jLong e = 0; e < output->lengthOf(); e++) {
auto outputBuf = reinterpret_cast<double*>(output->specialBuffer()); // + e * n2;
logDetKernel<<<packX.numberOfTads(), n2, 128, *stream>>>(tempOutput->specialBuffer(), tempOutput->specialShapeInfo(), packX.numberOfTads(), packX.specialShapeInfo(), packX.specialOffsets(), outputBuf, output->specialShapeInfo());
//}
NDArray::registerSpecialUse({output}, {input});
delete tempOutput;
return Status::OK();
} }
} }
} }

View File

@ -732,10 +732,81 @@ void scatterND(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& i
manager.synchronize(); manager.synchronize();
} }
///////////////////////////////////////////////////////////////////
template<typename X, typename Z>
__global__ void scatterForLossCuda(const void *vx, const Nd4jLong *xShapeInfo,
void *vy, const Nd4jLong *yShapeInfo,
void *vz, const Nd4jLong *zShapeInfo) {
const auto x = reinterpret_cast<const X*>(vx);
auto y = reinterpret_cast<Z*>(vy);
auto z = reinterpret_cast<Z*>(vz);
void scatterForLoss(nd4j::LaunchContext *context, const NDArray& indices, const NDArray& updates, NDArray& output, const bool calcGrad) { __shared__ Nd4jLong xLen, *sharedMem;
__shared__ int xRank; // xRank = zRank, yRank = xRank + 1
if (threadIdx.x == 0) {
extern __shared__ unsigned char shmem[];
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
xLen = shape::length(xShapeInfo);
xRank = shape::rank(xShapeInfo);
}
__syncthreads();
const auto xInd = threadIdx.x + blockIdx.x * blockDim.x;
if(xInd >= xLen)
return;
auto coords = sharedMem + threadIdx.x * (xRank + 1);
shape::index2coords(xRank, xShapeInfo + 1, xInd, xLen, coords);
// y last coordinate
coords[xRank] = x[shape::getOffset(0, xShapeInfo + 1, xShapeInfo + xRank + 1, coords, xRank)];
const auto yOffset = shape::getOffset(0, yShapeInfo + 1, yShapeInfo + xRank + 2, coords, xRank + 1);
if(z == nullptr) { // gradient calculation
y[yOffset] -= 1.f;
}
else {
z[shape::getOffset(0, zShapeInfo + 1, zShapeInfo + xRank + 1, coords, xRank)] = y[yOffset];
}
}
///////////////////////////////////////////////////////////////////
template<typename X, typename Z>
static void scatterForLossCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vx, const Nd4jLong* xShapeInfo, void *vy, const Nd4jLong* yShapeInfo, void *vz, const Nd4jLong* zShapeInfo) {
scatterForLossCuda<X, Z><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo);
}
///////////////////////////////////////////////////////////////////
void scatterForLoss(nd4j::LaunchContext* context, const NDArray& indices, NDArray& updates, NDArray& output, const bool calcGrad) {
// shapes of indices and output must be the same
// shape of indices should be the same as updates shape with last dimension excluded, for example if updates is {a,b,c} then indices should be {a,b}
PointersManager manager(context, "scatterForLoss");
const int threadsPerBlock = MAX_NUM_THREADS / 2;
const int blocksPerGrid = (indices.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
const int sharedMem = updates.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128;
if(calcGrad) {
NDArray::prepareSpecialUse({&updates}, {&indices});
BUILD_DOUBLE_SELECTOR(indices.dataType(), updates.dataType(), scatterForLossCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), updates.specialBuffer(), updates.specialShapeInfo(), nullptr, nullptr), INTEGER_TYPES, FLOAT_TYPES);
NDArray::registerSpecialUse({&updates}, {&indices});
}
else {
NDArray::prepareSpecialUse({&output}, {&indices, &updates});
BUILD_DOUBLE_SELECTOR(indices.dataType(), updates.dataType(), scatterForLossCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), updates.getSpecialBuffer(), updates.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo()), INTEGER_TYPES, FLOAT_TYPES);
NDArray::registerSpecialUse({&output}, {&indices, &updates});
}
manager.synchronize();
} }

View File

@ -22,6 +22,8 @@
#include<ops/declarable/helpers/sru.h> #include<ops/declarable/helpers/sru.h>
#include <NDArrayFactory.h> #include <NDArrayFactory.h>
#include <PointersManager.h>
#include <MmulHelper.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
@ -103,30 +105,432 @@ void sruTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray*
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template <typename T> template <typename T>
static void sruBI_(NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* mask, NDArray* ht, NDArray* ct) { __global__ static void sruBICuda(const void* vx, const Nd4jLong* xShapeInfo,
const void* vwi, const Nd4jLong* wiShapeInfo,
const void* vb, const Nd4jLong* bShapeInfo,
const void* vc0, const Nd4jLong* c0ShapeInfo,
const void* vmask, const Nd4jLong* maskShapeInfo,
void* vht, const Nd4jLong* htShapeInfo,
void* vct, const Nd4jLong* ctShapeInfo) {
// inputs:
// x [time, bS, 2*K]
// wi [time, bS, 6*K], wi = mmul(x, weights);
// b [4*K]
// c0 [bS, 2*K]
// mask [bS, 2*K], optional
// outputs
// ht [time, bS, 2*K]
// ct [time, bS, 2*K]
const auto x = reinterpret_cast<const T*>(vx);
const auto wi = reinterpret_cast<const T*>(vwi);
const auto b = reinterpret_cast<const T*>(vb);
const auto c0 = reinterpret_cast<const T*>(vc0);
const auto mask = reinterpret_cast<const T*>(vmask);
auto ht = reinterpret_cast<T*>(vht);
auto ct = reinterpret_cast<T*>(vct);
const int rank = 3;
__shared__ int time, K;
__shared__ Nd4jLong len, totalThreads, *sharedMem;
if (threadIdx.x == 0) {
extern __shared__ unsigned char shmem[];
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
time = xShapeInfo[1];
K = xShapeInfo[3] / 2;
len = xShapeInfo[2] * xShapeInfo[3]; // 2*K*bS
totalThreads = gridDim.x * blockDim.x;
} }
////////////////////////////////////////////////////////////////////////// __syncthreads();
template <typename T>
static void sruBIBP_(NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* ct, const NDArray* inGradC0, const NDArray* inGradHt, const NDArray* mask, const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
NDArray* gradI, NDArray* gradW, NDArray* gradB, NDArray* gradC0) { Nd4jLong* coords = sharedMem + threadIdx.x * rank;
if(tid >= len)
return;
shape::index2coords(rank, xShapeInfo + 2, tid, len, coords + 1); // loop through last two dimensions of x : {bS, 2*K}
const auto maskOffst = mask ? shape::getOffset(0, maskShapeInfo + 1, maskShapeInfo + rank, coords + 1, rank - 1) : 0;
const auto c0Offset = shape::getOffset(0, c0ShapeInfo + 1, c0ShapeInfo + rank, coords + 1, rank - 1);
const auto bFOffset = shape::getOffset(0, bShapeInfo + 1, bShapeInfo + rank - 1, coords + 2, rank - 2);
const auto bROffset = bFOffset + 2 * K * bShapeInfo[2]; // 2*K*b_stride
const T maskVal = mask ? mask[maskOffst] : static_cast<T>(1);
const T bF = b[bFOffset];
const T bR = b[bROffset];
T c0Val = c0[c0Offset];
const bool flip = coords[2] >= K;
if(flip)
coords[0] = time - 1;
else
coords[0] = 0;
auto xOffset = shape::getOffset(0, xShapeInfo + 1, xShapeInfo + rank + 1, coords, rank);
auto htOffset = shape::getOffset(0, htShapeInfo + 1, htShapeInfo + rank + 1, coords, rank);
auto ctOffset = shape::getOffset(0, ctShapeInfo + 1, ctShapeInfo + rank + 1, coords, rank);
coords[2] *= 3;
auto wiOffset0 = shape::getOffset(0, wiShapeInfo + 1, wiShapeInfo + rank + 1, coords, rank);
auto wiOffset1 = wiOffset0 + wiShapeInfo[rank + 3]; // add last stride
auto wiOffset2 = wiOffset1 + wiShapeInfo[rank + 3]; // add last stride
// time loop
for (uint t = 0; t < time; ++t) {
// evaluate sigmoids
T ft = (1.f)/(1.f + nd4j::math::nd4j_exp<T, T>(-(wi[wiOffset1] + bF)));
T rt = (1.f)/(1.f + nd4j::math::nd4j_exp<T, T>(-(wi[wiOffset2] + bR)));
c0Val = (c0Val - wi[wiOffset0]) * ft + wi[wiOffset0];
ct[ctOffset] = c0Val;
T val = nd4j::math::nd4j_tanh<T, T>(c0Val);
T xVal = x[xOffset];
ht[htOffset] = (val * maskVal - xVal) * rt + xVal;
if(flip) {
xOffset -= xShapeInfo[rank + 1]; // first stride, corresponds to time step
htOffset -= htShapeInfo[rank + 1];
ctOffset -= htShapeInfo[rank + 1];
wiOffset0 -= wiShapeInfo[rank + 1];
wiOffset1 -= wiShapeInfo[rank + 1];
wiOffset2 -= wiShapeInfo[rank + 1];
}
else {
xOffset += xShapeInfo[rank + 1]; // first stride, corresponds to time step
htOffset += htShapeInfo[rank + 1];
ctOffset += htShapeInfo[rank + 1];
wiOffset0 += wiShapeInfo[rank + 1];
wiOffset1 += wiShapeInfo[rank + 1];
wiOffset2 += wiShapeInfo[rank + 1];
}
}
}
//////////////////////////////////////////////////////////////////////////
template <typename T>
static void sruBICudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream,
const void* vx, const Nd4jLong* xShapeInfo,
const void* vwi, const Nd4jLong* wiShapeInfo,
const void* vb, const Nd4jLong* bShapeInfo,
const void* vc0, const Nd4jLong* c0ShapeInfo,
const void* vmask, const Nd4jLong* maskShapeInfo,
void* vht, const Nd4jLong* htShapeInfo,
void* vct, const Nd4jLong* ctShapeInfo) {
sruBICuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vx, xShapeInfo, vwi, wiShapeInfo, vb, bShapeInfo, vc0, c0ShapeInfo, vmask, maskShapeInfo, vht, htShapeInfo, vct, ctShapeInfo);
}
BUILD_SINGLE_TEMPLATE(template void sruBICudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const void* vwi, const Nd4jLong* wiShapeInfo, const void* vb, const Nd4jLong* bShapeInfo, const void* vc0, const Nd4jLong* c0ShapeInfo, const void* vmask, const Nd4jLong* maskShapeInfo, void* vht, const Nd4jLong* htShapeInfo, void* vct, const Nd4jLong* ctShapeInfo), FLOAT_TYPES);
//////////////////////////////////////////////////////////////////////////
void sruBI(nd4j::LaunchContext * context, NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* mask, NDArray* ht, NDArray* ct) {
// x = x * mask
if(mask)
x->applyBroadcast(broadcast::Multiply, {1, 2}, mask, x, nullptr); // apply mask
// U = x * w
NDArray wi = mmul(*x, *w); // U [time x bS x 6*K]
PointersManager manager(context, "sru_bi");
const int threadsPerBlock = MAX_NUM_THREADS / 4;
const int blocksPerGrid = (x->sizeAt(1) * x->sizeAt(2) + threadsPerBlock - 1) / threadsPerBlock; // loop through last two dimensions of x array -> bS, 2*K
const int sharedMem = threadsPerBlock * sizeof(Nd4jLong) * x->rankOf() + 128;
NDArray::prepareSpecialUse({ht, ct}, {x, &wi, b, c0, mask});
BUILD_SINGLE_SELECTOR(x->dataType(), sruBICudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), x->getSpecialBuffer(), x->getSpecialShapeInfo(), wi.getSpecialBuffer(), wi.getSpecialShapeInfo(), b->getSpecialBuffer(), b->getSpecialShapeInfo(), c0->getSpecialBuffer(), c0->getSpecialShapeInfo(), mask ? mask->getSpecialBuffer() : nullptr, mask ? mask->getSpecialShapeInfo() : nullptr, ht->specialBuffer(), ht->specialShapeInfo(), ct->specialBuffer(), ct->specialShapeInfo()), FLOAT_TYPES);
NDArray::registerSpecialUse({ht, ct}, {x, &wi, b, c0, mask});
manager.synchronize();
}
//////////////////////////////////////////////////////////////////////////
template <typename T>
__global__ static void sruBIBPCuda(const void* vx, const Nd4jLong* xShapeInfo,
const void* vwi, const Nd4jLong* wiShapeInfo,
const void* vb, const Nd4jLong* bShapeInfo,
const void* vc0, const Nd4jLong* c0ShapeInfo,
const void* vmask, const Nd4jLong* maskShapeInfo,
const void* vct, const Nd4jLong* ctShapeInfo,
const void* vgradHt, const Nd4jLong* gradHtShapeInfo,
const void* vgradCt, const Nd4jLong* gradCtShapeInfo,
void* vgradI, const Nd4jLong* gradIShapeInfo,
void* vgradWi, const Nd4jLong* gradWiShapeInfo,
void* vgradB, const Nd4jLong* gradBShapeInfo,
void* vgradC0, const Nd4jLong* gradC0ShapeInfo) {
// inputs:
// x [time, bS, 2*K]
// wi [time, bS, 6*K], wi = mmul(x, weights);
// b [4*K]
// c0 [bS, 2*K]
// mask [bS, 2*K], optional
// ct [time, bS, 2*K]
// gradHt [time, bS, 2*K]
// gradCt [bS, 2*K]
// outputs
// gradI [time, bS, 2*K]
// gradWi [time, 2*K, 6*K]
// gradB [bS, 4*K]
// gradC0 [bS, 2*K]
const auto x = reinterpret_cast<const T*>(vx);
const auto wi = reinterpret_cast<const T*>(vwi);
const auto b = reinterpret_cast<const T*>(vb);
const auto c0 = reinterpret_cast<const T*>(vc0);
const auto mask = reinterpret_cast<const T*>(vmask);
const auto ct = reinterpret_cast<const T*>(vct);
const auto gradHt = reinterpret_cast<const T*>(vgradHt);
const auto gradCt = reinterpret_cast<const T*>(vgradCt);
auto gradI = reinterpret_cast<T*>(vgradI);
auto gradWi = reinterpret_cast<T*>(vgradWi);
auto gradB = reinterpret_cast<T*>(vgradB);
auto gradC0 = reinterpret_cast<T*>(vgradC0);
const int rank = 3;
__shared__ int time, K;
__shared__ Nd4jLong len, totalThreads, *sharedMem;
if (threadIdx.x == 0) {
extern __shared__ unsigned char shmem[];
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
time = xShapeInfo[1];
K = xShapeInfo[3] / 2;
len = xShapeInfo[2] * xShapeInfo[3]; // 2*K*bS
totalThreads = gridDim.x * blockDim.x;
} }
__syncthreads();
void sruBI(nd4j::LaunchContext * context, NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* mask, NDArray* ht, NDArray* ct) { const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
BUILD_SINGLE_SELECTOR(x->dataType(), sruBI_, (x, w, b, c0, mask, ht, ct), FLOAT_TYPES); Nd4jLong* coords = sharedMem + threadIdx.x * rank;
if(tid >= len)
return;
shape::index2coords(rank, xShapeInfo + 2, tid, len, coords + 1); // loop through last two dimensions of x : {bS, 2*K}
const auto maskOffst = mask ? shape::getOffset(0, maskShapeInfo + 1, maskShapeInfo + rank, coords + 1, rank - 1) : 0;
const auto c0Offset = shape::getOffset(0, c0ShapeInfo + 1, c0ShapeInfo + rank, coords + 1, rank - 1);
const auto gradCtOffset = shape::getOffset(0, gradCtShapeInfo + 1, gradCtShapeInfo + rank, coords + 1, rank - 1);
const auto gradC0Offset = shape::getOffset(0, gradC0ShapeInfo + 1, gradC0ShapeInfo + rank, coords + 1, rank - 1);
const auto bFOffset = shape::getOffset(0, bShapeInfo + 1, bShapeInfo + rank - 1, coords + 2, rank - 2);
const auto bROffset = bFOffset + 2 * K * bShapeInfo[2]; // 2*K*b_stride
// const auto gradBFOffset = shape::getOffset(0, gradBShapeInfo + 1, gradBShapeInfo + rank, coords + 1, rank - 1);
const auto gradBFOffset = coords[1] * gradBShapeInfo[3] / 2 + coords[2] * gradBShapeInfo[4];
const auto gradBROffset = gradBFOffset + gradBShapeInfo[3];
const bool flip = coords[2] >= K;
if(flip)
coords[0] = 0;
else
coords[0] = time - 1;
auto xOffset = shape::getOffset(0, xShapeInfo + 1, xShapeInfo + rank + 1, coords, rank);
auto ctOffset = shape::getOffset(0, ctShapeInfo + 1, ctShapeInfo + rank + 1, coords, rank);
auto gradIOffset = shape::getOffset(0, gradIShapeInfo + 1, gradIShapeInfo + rank + 1, coords, rank);
auto gradHtOffset = shape::getOffset(0, gradHtShapeInfo + 1, gradHtShapeInfo + rank + 1, coords, rank);
coords[2] *= 3;
auto gradWiOffset0 = shape::getOffset(0, gradWiShapeInfo + 1, gradWiShapeInfo + rank + 1, coords, rank);
auto gradWiOffset1 = gradWiOffset0 + gradWiShapeInfo[rank + 3]; // add last stride
auto gradWiOffset2 = gradWiOffset1 + gradWiShapeInfo[rank + 3]; // add last stride
auto wiOffset0 = shape::getOffset(0, wiShapeInfo + 1, wiShapeInfo + rank + 1, coords, rank);
auto wiOffset1 = wiOffset0 + wiShapeInfo[rank + 3]; // add last stride
auto wiOffset2 = wiOffset1 + wiShapeInfo[rank + 3]; // add last stride
const T xVal = x[xOffset];
const T maskVal = mask ? mask[maskOffst] : static_cast<T>(1);
const T c0Val = c0[c0Offset];
const T bF = b[bFOffset];
const T bR = b[bROffset];
T gradCtVal = gradCt[gradCtOffset];
T gbF = 0.f;
T gbR = 0.f;
// time loop
for (uint t = 0; t < time; ++t) {
// evaluate sigmoids
T ft = (1.f)/(1.f + nd4j::math::nd4j_exp<T, T>(-(wi[wiOffset1] + bF)));
T rt = (1.f)/(1.f + nd4j::math::nd4j_exp<T, T>(-(wi[wiOffset2] + bR)));
T val = nd4j::math::nd4j_tanh<T,T>(ct[ctOffset]);
T prevVal;
if(t < time-1)
prevVal = ct[ctOffset += flip ? ctShapeInfo[rank + 1] : -ctShapeInfo[rank + 1]];
else
prevVal = c0Val;
// grad wrt input
gradI[gradIOffset] = gradHt[gradHtOffset] - gradHt[gradHtOffset] * rt ;
// grad wrt rt, wiR and bR
T grt = gradHt[gradHtOffset] * (val * maskVal - x[xOffset]) * (rt - rt * rt);
gradWi[gradWiOffset2] = grt;
gbR += grt;
// grad wrt state
T gradC0Val = gradHt[gradHtOffset] * maskVal * (rt - rt * val * val) + gradCtVal;
// grad wrt wi0
gradWi[gradWiOffset0] = gradC0Val - gradC0Val * ft;
// grad wrt ft, wi1, and bF
T gft = gradC0Val * (prevVal - wi[wiOffset0]) * (ft - ft * ft);
gradWi[gradWiOffset1] = gft;
gbF += gft;
// grad wrt c_previous
gradCtVal = gradC0Val * ft;
if(flip) {
xOffset += xShapeInfo[rank + 1]; // first stride, corresponds to time step
gradHtOffset += gradHtShapeInfo[rank + 1];
gradIOffset += gradIShapeInfo[rank + 1];
wiOffset0 += wiShapeInfo[rank + 1];
wiOffset1 += wiShapeInfo[rank + 1];
wiOffset2 += wiShapeInfo[rank + 1];
gradWiOffset0 += gradWiShapeInfo[rank + 1];
gradWiOffset1 += gradWiShapeInfo[rank + 1];
gradWiOffset2 += gradWiShapeInfo[rank + 1];
}
else {
xOffset -= xShapeInfo[rank + 1]; // first stride, corresponds to time step
gradHtOffset -= gradHtShapeInfo[rank + 1];
gradIOffset -= gradIShapeInfo[rank + 1];
wiOffset0 -= wiShapeInfo[rank + 1];
wiOffset1 -= wiShapeInfo[rank + 1];
wiOffset2 -= wiShapeInfo[rank + 1];
gradWiOffset0 -= gradWiShapeInfo[rank + 1];
gradWiOffset1 -= gradWiShapeInfo[rank + 1];
gradWiOffset2 -= gradWiShapeInfo[rank + 1];
}
} }
void sruBIBP(nd4j::LaunchContext * context, NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* ct, const NDArray* inGradC0, const NDArray* inGradH, const NDArray* mask, NDArray* gradI, NDArray* gradW, NDArray* gradB, NDArray* gradC0) { gradB[gradBFOffset] = gbF;
BUILD_SINGLE_SELECTOR(x->dataType(), sruBIBP_, (x, w, b, c0, ct, inGradC0, inGradH, mask, gradI, gradW, gradB, gradC0), FLOAT_TYPES); gradB[gradBROffset] = gbR;
} gradC0[gradC0Offset] = gradCtVal;
}
//////////////////////////////////////////////////////////////////////////
template <typename T>
static void sruBIBPCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream,
const void* vx, const Nd4jLong* xShapeInfo,
const void* vwi, const Nd4jLong* wiShapeInfo,
const void* vb, const Nd4jLong* bShapeInfo,
const void* vc0, const Nd4jLong* c0ShapeInfo,
const void* vmask, const Nd4jLong* maskShapeInfo,
const void* vct, const Nd4jLong* ctShapeInfo,
const void* vgradHt, const Nd4jLong* gradHtShapeInfo,
const void* vgradCt, const Nd4jLong* gradCtShapeInfo,
void* vgradI, const Nd4jLong* gradIShapeInfo,
void* vgradWi, const Nd4jLong* gradWiShapeInfo,
void* vgradB, const Nd4jLong* gradBShapeInfo,
void* vgradC0, const Nd4jLong* gradC0ShapeInfo) {
sruBIBPCuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vx, xShapeInfo, vwi, wiShapeInfo, vb, bShapeInfo, vc0, c0ShapeInfo, vmask, maskShapeInfo, vct, ctShapeInfo, vgradHt, gradHtShapeInfo, vgradCt, gradCtShapeInfo, vgradI, gradIShapeInfo, vgradWi, gradWiShapeInfo, vgradB, gradBShapeInfo, vgradC0, gradC0ShapeInfo);
}
BUILD_SINGLE_TEMPLATE(template void sruBIBPCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const void* vwi, const Nd4jLong* wiShapeInfo, const void* vb, const Nd4jLong* bShapeInfo, const void* vc0, const Nd4jLong* c0ShapeInfo, const void* vmask, const Nd4jLong* maskShapeInfo, const void* vct, const Nd4jLong* ctShapeInfo, const void* vgradHt, const Nd4jLong* gradHtShapeInfo, const void* vgradCt, const Nd4jLong* gradCtShapeInfo, void* vgradI, const Nd4jLong* gradIShapeInfo, void* vgradWi, const Nd4jLong* gradWiShapeInfo, void* vgradB, const Nd4jLong* gradBShapeInfo, void* vgradC0, const Nd4jLong* gradC0ShapeInfo), FLOAT_TYPES);
//////////////////////////////////////////////////////////////////////////
void sruBIBP(nd4j::LaunchContext* context, NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* ct,
const NDArray* gradCt, const NDArray* gradHt, const NDArray* mask,
NDArray* gradI, NDArray* gradW, NDArray* gradB, NDArray* gradC0) {
// x = x * mask
if(mask)
x->applyBroadcast(broadcast::Multiply, {1, 2}, mask, x, nullptr); // apply mask
// U = x * w
NDArray wi = mmul(*x, *w); // U [time x bS x 6*K]
const int time = x->sizeAt(0);
const int bS = x->sizeAt(1);
const int K = x->sizeAt(2) / 2;
NDArray gradBias(x->ordering(), {bS, 4*K}, x->dataType(), context);
NDArray gradWi (x->ordering(), {time, bS, 6*K}, x->dataType(), context);
PointersManager manager(context, "sru_bi_bp");
const int threadsPerBlock = MAX_NUM_THREADS / 4;
const int blocksPerGrid = (x->sizeAt(1) * x->sizeAt(2) + threadsPerBlock - 1) / threadsPerBlock; // loop through last two dimensions of x array -> bS, 2*K
const int sharedMem = threadsPerBlock * sizeof(Nd4jLong) * x->rankOf() + 128;
NDArray::prepareSpecialUse({gradI, &gradWi, &gradBias, gradC0}, {x, &wi, b, c0, ct, gradCt, gradHt, mask});
BUILD_SINGLE_SELECTOR(x->dataType(), sruBIBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), x->getSpecialBuffer(), x->getSpecialShapeInfo(), wi.getSpecialBuffer(), wi.getSpecialShapeInfo(), b->getSpecialBuffer(), b->getSpecialShapeInfo(), c0->getSpecialBuffer(), c0->getSpecialShapeInfo(), mask ? mask->getSpecialBuffer() : nullptr, mask ? mask->getSpecialShapeInfo() : nullptr, ct->getSpecialBuffer(), ct->getSpecialShapeInfo(), gradHt->getSpecialBuffer(), gradHt->getSpecialShapeInfo(), gradCt->getSpecialBuffer(), gradCt->getSpecialShapeInfo(), gradI->specialBuffer(), gradI->specialShapeInfo(), gradWi.specialBuffer(), gradWi.specialShapeInfo(), gradBias.specialBuffer(), gradBias.specialShapeInfo(), gradC0->specialBuffer(), gradC0->specialShapeInfo()), FLOAT_TYPES);
NDArray::registerSpecialUse({gradI, &gradWi, &gradBias, gradC0}, {x, &wi, b, c0, ct, gradCt, gradHt, mask});
manager.synchronize();
// gradB
gradBias.reduceAlongDimension(reduce::Sum, gradB, {0}); // [4*K]
// gradW
x->permutei({0, 2, 1}); // [time, bS, 2*K] -> [time, 2*K, bS]
MmulHelper::mmul(x, &gradWi, gradW, 1., 0.); // [time, 2*K, bS] x [time, bS , 6*K] = [time, 2*K, 6*K]
}
BUILD_SINGLE_TEMPLATE(template void sruBI_, (NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* mask, NDArray* ht, NDArray* ct), FLOAT_TYPES);
BUILD_SINGLE_TEMPLATE(template void sruBIBP_, (NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* ct, const NDArray* inGradC0, const NDArray* inGradH, const NDArray* mask, NDArray* gradI, NDArray* gradW, NDArray* gradB, NDArray* gradC0), FLOAT_TYPES);
} }
} }

View File

@ -0,0 +1,639 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com)
//
#include <svd.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <cusolverDn.h>
#include <exceptions/cuda_exception.h>
#include <PointersManager.h>
#include <ShapeUtils.h>
namespace nd4j {
namespace ops {
namespace helpers {
// FIXME -> we should optimize these helpers for the case when input matrices have c order (perform transpositions appropriately)
template <typename T>
__global__ static void inverseColumnSignCuda(void* vu, const Nd4jLong* uShapeInfo, void* vv, const Nd4jLong* vShapeInfo) {
T* u = reinterpret_cast<T*>(vu);
T* v = reinterpret_cast<T*>(vv);
__shared__ int rank, uLastButOneColumn, vLastButOneColumn; // uRank = vRank
__shared__ Nd4jLong uLen, vLen;
__shared__ Nd4jLong *sharedMem;
if (threadIdx.x == 0) {
extern __shared__ unsigned char shmem[];
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
rank = shape::rank(uShapeInfo);
uLen = shape::length(uShapeInfo);
vLen = shape::length(vShapeInfo);
uLastButOneColumn = uShapeInfo[rank] - 2;
vLastButOneColumn = vShapeInfo[rank - 1] - 2;
}
__syncthreads();
const auto ind = threadIdx.x + blockIdx.x * blockDim.x;
auto coords = sharedMem + threadIdx.x * rank;
// u
for (Nd4jLong i = ind; i < uLen; i += gridDim.x * blockDim.x) {
shape::index2coords(rank, uShapeInfo + 1, i, uLen, coords);
if(coords[rank - 1] == 0 || coords[rank - 1] == uLastButOneColumn) // do not change sign in first and last but one columns
continue;
const auto uOffset = shape::getOffset(0, uShapeInfo + 1, uShapeInfo + rank + 1, coords, rank);
u[uOffset] = -u[uOffset];
}
// v
for (Nd4jLong i = ind; i < vLen; i += gridDim.x * blockDim.x) {
shape::index2coords(rank, vShapeInfo + 1, i, vLen, coords);
if(coords[rank - 2] == 0 || coords[rank - 2] == vLastButOneColumn) // do not change sign in first and last but one columns
continue;
const auto vOffset = shape::getOffset(0, vShapeInfo + 1, vShapeInfo + rank + 1, coords, rank);
v[vOffset] = -v[vOffset];
}
}
//////////////////////////////////////////////////////////////////////////
template <typename T>
static void inverseColumnSignCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream,
void* vu, const Nd4jLong* uShapeInfo,
void* vv, const Nd4jLong* vShapeInfo) {
inverseColumnSignCuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vu, uShapeInfo, vv, vShapeInfo);
}
BUILD_SINGLE_TEMPLATE(template void inverseColumnSignCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t* stream, void* vu, const Nd4jLong* uShapeInfo, void* vv, const Nd4jLong* vShapeInfo), FLOAT_TYPES);
//////////////////////////////////////////////////////////////////////////
static void svdQR(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, NDArray& U, NDArray& VT, const bool fullUV, const bool calcUV) {
// since cusa api cusolverDnDgesvd/cusolverDnSgesvd have following constrain on input matrix A: A_rows >= A_columns && A_order = 'f'
// we make this function to have deal with 2 valid cases only:
// 1) A_rows >= A_columns and A_corder = 'f'
// 2) A_rows <= A_columns and A_corder = 'c' - int this case perform transposition to get f order
// if 1) or 2) are not met then throw exception
// A [m, n]
// S [n]
// U [m, m] or [m, n] if fullUV = false and m > n
// VT [n, n] or [m, n] if fullUV = false and m < n
if(A.rankOf() != 2)
throw std::runtime_error("svdQR: rank of A array is not equal 2 !");
auto m = A.sizeAt(0);
auto n = A.sizeAt(1);
const int minDim = m < n ? m : n;
const char orderA = A.ordering();
if(m < n)
throw std::runtime_error("svdQR: due to cuda api input constrains given shape of A array are not valid !");
if(ShapeUtils::shapeAsString({minDim}) != ShapeUtils::shapeAsString(&S))
throw std::runtime_error("svdQR: wrong shape of S array !");
if(calcUV) {
if(fullUV && ShapeUtils::shapeAsString({m,m}) != ShapeUtils::shapeAsString(&U))
throw std::runtime_error("svdQR: wrong shape of U array !");
else if(!fullUV && ShapeUtils::shapeAsString({m,minDim}) != ShapeUtils::shapeAsString(&U))
throw std::runtime_error("svdQR: wrong shape of U array !");
if(fullUV && ShapeUtils::shapeAsString({n,n}) != ShapeUtils::shapeAsString(&VT))
throw std::runtime_error("svdQR: wrong shape of VT array !");
else if(!fullUV && ShapeUtils::shapeAsString({minDim,n}) != ShapeUtils::shapeAsString(&VT))
throw std::runtime_error("svdQR: wrong shape of VT array !");
}
NDArray* pA = const_cast<NDArray*>(&A);
NDArray* pS = &S;
NDArray* pU = &U;
NDArray* pVT = &VT;
std::vector<NDArray*> toDelete;
if(pA->ews() != 1 || pA->ordering() == 'c') {
pA = A.dup('f');
toDelete.push_back(pA);
}
if(S.ews() != 1) {
pS = S.dup('f');
toDelete.push_back(pS);
}
if(calcUV) {
if(pU->ews() != 1 || pU->ordering() == 'c') {
pU = U.dup('f');
toDelete.push_back(pU);
}
if(pVT->ews() != 1 || pVT->ordering() == 'c') {
pVT = VT.dup('f');
toDelete.push_back(pVT);
}
}
// create cusolverDn handle
cusolverDnHandle_t handle = nullptr;
cusolverStatus_t status = cusolverDnCreate(&handle);
if(status != CUSOLVER_STATUS_SUCCESS)
throw cuda_exception::build("svdQR: cuda failed !", status);
// stream
status = cusolverDnSetStream(handle, *context->getCudaStream());
if(status != CUSOLVER_STATUS_SUCCESS)
throw cuda_exception::build("svdQR: cuda failed !", status);
// query working space of SVD
int lwork = 0;
if(A.dataType() == DataType::DOUBLE)
status = cusolverDnDgesvd_bufferSize(handle, m, n, &lwork);
else if(A.dataType() == DataType::FLOAT32)
status = cusolverDnSgesvd_bufferSize(handle, m, n, &lwork);
else
throw std::invalid_argument("svdQR: given data type is unsupported !");
if(status != CUSOLVER_STATUS_SUCCESS)
throw cuda_exception::build("svdQR: cuda failed !", status);
// allocate memory for dWork
void* dWork = nullptr;
cudaError_t status2 = cudaMalloc((void**)&dWork , A.sizeOfT() * lwork);
if(status2 != cudaSuccess)
throw cuda_exception::build("svdQR: cuda failed !", status2);
signed char jobu, jobvt;
if(calcUV) {
if(fullUV)
jobu = jobvt = 'A';
else
jobu = jobvt = 'S';
}
else {
jobu = jobvt = 'N';
}
int *devInfo = nullptr;
void* rWork = nullptr;
int lda(m), ldu, ldvt;
if(calcUV) {
ldu = pU->sizeAt(0);
ldvt = pVT->sizeAt(0);
}
PointersManager manager(context, "svdQR");
NDArray::prepareSpecialUse({pS, pU, pVT}, {pA});
// choose appropriate cuda gemm api depending on data types
if(A.dataType() == DataType::DOUBLE) {
status = cusolverDnDgesvd(handle, jobu, jobvt, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), reinterpret_cast<double*>(pU->getSpecialBuffer()), ldu, reinterpret_cast<double*>(pVT->getSpecialBuffer()), ldvt, reinterpret_cast<double*>(dWork), lwork, reinterpret_cast<double*>(rWork), devInfo);
}
else if(A.dataType() == DataType::FLOAT32) {
status = cusolverDnSgesvd(handle, jobu, jobvt, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), reinterpret_cast<float*>(pU->getSpecialBuffer()), ldu, reinterpret_cast<float*>(pVT->getSpecialBuffer()), ldvt, reinterpret_cast<float*>(dWork), lwork, reinterpret_cast<float*>(rWork), devInfo);
}
else
throw std::invalid_argument("svdQR: given data type is unsupported !");
if(status != CUSOLVER_STATUS_SUCCESS)
throw cuda_exception::build("svdQR: cuda failed !", status);
manager.synchronize();
NDArray::registerSpecialUse({pS, pU, pVT}, {pA});
S.assign(pS);
if(calcUV) {
U.assign(pU);
VT.assign(pVT);
}
for (int i = toDelete.size() - 1; i >= 0; --i)
delete toDelete[i];
if (devInfo)
cudaFree(devInfo);
if (dWork )
cudaFree(dWork);
if (rWork)
cudaFree(rWork);
if(handle)
cusolverDnDestroy(handle);
// cudaDeviceReset();
}
//////////////////////////////////////////////////////////////////////////
static void svdJcb(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, NDArray& U, NDArray& V, const bool fullUV, const bool calcUV) {
// A [m, n]
// S [n]
// U [m, m] or [m, n] if fullUV = false and m > n
// V [n, n] or [n, m] if fullUV = false and m < n
if(A.rankOf() != 2)
throw std::runtime_error("svdJcb: rank of A array is not equal 2 !");
auto m = A.sizeAt(0);
auto n = A.sizeAt(1);
const int minDim = m < n ? m : n;
if(ShapeUtils::shapeAsString({minDim}) != ShapeUtils::shapeAsString(&S))
throw std::runtime_error("svdJcb: wrong shape of S array !");
if(calcUV) {
if(fullUV && ShapeUtils::shapeAsString({m,m}) != ShapeUtils::shapeAsString(&U))
throw std::runtime_error("svdJcb: wrong shape of U array !");
else if(!fullUV && ShapeUtils::shapeAsString({m,minDim}) != ShapeUtils::shapeAsString(&U))
throw std::runtime_error("svdJcb: wrong shape of U array !");
if(fullUV && ShapeUtils::shapeAsString({n,n}) != ShapeUtils::shapeAsString(&V))
throw std::runtime_error("svdJcb: wrong shape of V array !");
else if(!fullUV && ShapeUtils::shapeAsString({n,minDim}) != ShapeUtils::shapeAsString(&V))
throw std::runtime_error("svdJcb: wrong shape of V array !");
}
NDArray* pA = const_cast<NDArray*>(&A);
NDArray* pS = &S;
NDArray* pU = &U;
NDArray* pV = &V;
std::vector<NDArray*> toDelete;
if(pA->ews() != 1 || pA->ordering() == 'c') {
pA = A.dup('f');
toDelete.push_back(pA);
}
if(S.ews() != 1) {
pS = S.dup('f');
toDelete.push_back(pS);
}
if(calcUV) {
if(pU->ews() != 1 || pU->ordering() == 'c') {
pU = U.dup('f');
toDelete.push_back(pU);
}
if(pV->ews() != 1 || pV->ordering() == 'c') {
pV = V.dup('f');
toDelete.push_back(pV);
}
}
// create cusolverDn handle
cusolverDnHandle_t handle = nullptr;
cusolverStatus_t status = cusolverDnCreate(&handle);
if(status != CUSOLVER_STATUS_SUCCESS)
throw cuda_exception::build("svdJcb: cuda failed !", status);
// stream
status = cusolverDnSetStream(handle, *context->getCudaStream());
if(status != CUSOLVER_STATUS_SUCCESS)
throw cuda_exception::build("svdJcb: cuda failed !", status);
// set parameters
gesvdjInfo_t gesvdjParams = nullptr;
status = cusolverDnCreateGesvdjInfo(&gesvdjParams);
if(status != CUSOLVER_STATUS_SUCCESS)
throw cuda_exception::build("svdJcb: cuda failed !", status);
status = cusolverDnXgesvdjSetTolerance(gesvdjParams, 1.e-7); // tolerance
if(status != CUSOLVER_STATUS_SUCCESS)
throw cuda_exception::build("svdJcb: cuda failed !", status);
status = cusolverDnXgesvdjSetMaxSweeps(gesvdjParams, 15); // max_sweeps
if(status != CUSOLVER_STATUS_SUCCESS)
throw cuda_exception::build("svdJcb: cuda failed !", status);
int *devInfo = nullptr;
const cusolverEigMode_t jobz = calcUV ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR;
const int econ = !fullUV;
int lda(m), ldu(m), ldv(m);
if(calcUV) {
ldu = pU->sizeAt(0);
ldv = pV->sizeAt(0);
}
// query working space of SVD
int lwork = 0;
if(A.dataType() == DataType::DOUBLE)
status = cusolverDnDgesvdj_bufferSize(handle, jobz, econ, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), reinterpret_cast<double*>(pU->getSpecialBuffer()), ldu, reinterpret_cast<double*>(pV->getSpecialBuffer()), ldv, &lwork, gesvdjParams);
else if(A.dataType() == DataType::FLOAT32)
status = cusolverDnSgesvdj_bufferSize(handle, jobz, econ, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), reinterpret_cast<float*>(pU->getSpecialBuffer()), ldu, reinterpret_cast<float*>(pV->getSpecialBuffer()), ldv, &lwork, gesvdjParams);
else
throw std::invalid_argument("svdJcb: given data type is unsupported !");
if(status != CUSOLVER_STATUS_SUCCESS)
throw cuda_exception::build("svdJcb: cuda failed !", status);
// allocate memory dWork
void* dWork = nullptr;
auto status2 = cudaMalloc((void**)&dWork , A.sizeOfT() * lwork);
if(status2 != cudaSuccess)
throw cuda_exception::build("svdJcb: cuda failed !", status2);
PointersManager manager(context, "svdJcb");
NDArray::prepareSpecialUse({pS, pU, pV}, {pA});
// choose appropriate cuda gemm api depending on data types
if(A.dataType() == DataType::DOUBLE) {
status = cusolverDnDgesvdj(handle, jobz, econ, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), reinterpret_cast<double*>(pU->getSpecialBuffer()), ldu, reinterpret_cast<double*>(pV->getSpecialBuffer()), ldv, reinterpret_cast<double*>(dWork), lwork, devInfo, gesvdjParams);
}
else if(A.dataType() == DataType::FLOAT32) {
status = cusolverDnSgesvdj(handle, jobz, econ, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), reinterpret_cast<float*>(pU->getSpecialBuffer()), ldu, reinterpret_cast<float*>(pV->getSpecialBuffer()), ldv, reinterpret_cast<float*>(dWork), lwork, devInfo, gesvdjParams);
}
else
throw std::invalid_argument("svdJcb: given data type is unsupported !");
if(status != CUSOLVER_STATUS_SUCCESS)
throw cuda_exception::build("svdJcb: cuda failed !", status);
manager.synchronize();
NDArray::registerSpecialUse({pS, pU, pV}, {pA});
S.assign(pS);
if(calcUV) {
U.assign(pU);
V.assign(pV);
}
for (int i = toDelete.size() - 1; i >= 0; --i)
delete toDelete[i];
if (devInfo)
cudaFree(devInfo);
if (dWork )
cudaFree(dWork);
if(handle)
cusolverDnDestroy(handle);
if(gesvdjParams)
cusolverDnDestroyGesvdjInfo(gesvdjParams);
// cudaDeviceReset();
}
//////////////////////////////////////////////////////////////////////////
static void svdBatched(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, NDArray& U, NDArray& V, const bool fullUV, const bool calcUV) {
// A [..., m, n]
// S [..., n]
// U [..., m, m] or [..., m, n] if fullUV = false and m > n
// V [..., n, n] or [..., n, m] if fullUV = false and m < n
auto m = A.sizeAt(-2);
auto n = A.sizeAt(-1);
const int minDim = m < n ? m : n;
const Nd4jLong bS = A.lengthOf() / (m * n);
if(m > 32 || n > 32)
throw std::runtime_error("svdBatched: numbers of rows and columns should be <= 32 !");
if(minDim != S.sizeAt(-1))
throw std::runtime_error("svdBatched: wrong shape of S array !");
if(calcUV) {
if(U.sizeAt(-2) != m)
throw std::runtime_error("svdBatched: wrong shape of U array !");
if(U.sizeAt(-1) != (fullUV ? m : minDim))
throw std::runtime_error("svdBatched: wrong shape of U array !");
if(U.lengthOf() / (U.sizeAt(-2) * U.sizeAt(-1)) != bS)
throw std::runtime_error("svdBatched: wrong shape of U array !");
if(V.sizeAt(-2) != n)
throw std::runtime_error("svdBatched: wrong shape of V array !");
if(V.sizeAt(-1) != (fullUV ? n : minDim))
throw std::runtime_error("svdBatched: wrong shape of V array !");
if(V.lengthOf() / (V.sizeAt(-2) * V.sizeAt(-1)) != bS)
throw std::runtime_error("svdBatched: wrong shape of V array !");
}
NDArray* pA = const_cast<NDArray*>(&A);
NDArray* pS = &S;
NDArray* pU = &U;
NDArray* pV = &V;
std::vector<NDArray*> toDelete;
if(pA->ews() != 1 || pA->ordering() == 'c') {
pA = A.dup('f');
toDelete.push_back(pA);
}
if(S.ews() != 1) {
pS = S.dup('f');
toDelete.push_back(pS);
}
if(calcUV) {
if(pU->ews() != 1 || pU->ordering() == 'c') {
pU = U.dup('f');
toDelete.push_back(pU);
}
if(pV->ews() != 1 || pV->ordering() == 'c') {
pV = V.dup('f');
toDelete.push_back(pV);
}
}
// create cusolverDn handle
cusolverDnHandle_t handle = nullptr;
cusolverStatus_t status = cusolverDnCreate(&handle);
if(status != CUSOLVER_STATUS_SUCCESS)
throw cuda_exception::build("svdBatched: cuda failed !", status);
// stream
status = cusolverDnSetStream(handle, *context->getCudaStream());
if(status != CUSOLVER_STATUS_SUCCESS)
throw cuda_exception::build("svdBatched: cuda failed !", status);
// set parameters
gesvdjInfo_t gesvdjParams = nullptr;
status = cusolverDnCreateGesvdjInfo(&gesvdjParams);
if(status != CUSOLVER_STATUS_SUCCESS)
throw cuda_exception::build("svdBatched: cuda failed !", status);
status = cusolverDnXgesvdjSetTolerance(gesvdjParams, 1.e-7); // tolerance
if(status != CUSOLVER_STATUS_SUCCESS)
throw cuda_exception::build("svdBatched: cuda failed !", status);
status = cusolverDnXgesvdjSetMaxSweeps(gesvdjParams, 15); // max_sweeps
if(status != CUSOLVER_STATUS_SUCCESS)
throw cuda_exception::build("svdBatched: cuda failed !", status);
// devInfo
int *devInfo = nullptr;
auto status2 = cudaMalloc((void**)&devInfo, sizeof(int) * bS);
if(status2 != cudaSuccess)
throw cuda_exception::build("svdBatched: cuda failed !", status2);
status2 = cudaDeviceSynchronize();
if(status2 != cudaSuccess)
throw cuda_exception::build("svdJcb: cuda failed !", status2);
const cusolverEigMode_t jobz = calcUV ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR;
int lda(m), ldu, ldv;
if(calcUV) {
ldu = pU->sizeAt(-2);
ldv = pV->sizeAt(-2);
}
// Ak (i,j) = A[i + 5*j + 25*k]
// query working space of SVD
int lwork = 0;
if(A.dataType() == DataType::DOUBLE)
status = cusolverDnDgesvdjBatched_bufferSize(handle, jobz, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), reinterpret_cast<double*>(pU->getSpecialBuffer()), ldu, reinterpret_cast<double*>(pV->getSpecialBuffer()), ldv, &lwork, gesvdjParams, bS);
else if(A.dataType() == DataType::FLOAT32)
status = cusolverDnSgesvdjBatched_bufferSize(handle, jobz, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), reinterpret_cast<float*>(pU->getSpecialBuffer()), ldu, reinterpret_cast<float*>(pV->getSpecialBuffer()), ldv, &lwork, gesvdjParams, bS);
else
throw std::invalid_argument("svdBatched: given data type is unsupported !");
if(status != CUSOLVER_STATUS_SUCCESS)
throw cuda_exception::build("svdBatched: cuda failed !", status);
// allocate memory dWork
void* dWork = nullptr;
status2 = cudaMalloc((void**)&dWork , A.sizeOfT() * lwork);
if(status2 != cudaSuccess)
throw cuda_exception::build("svdBatched: cuda failed !", status2);
status2 = cudaDeviceSynchronize();
if(status2 != cudaSuccess)
throw cuda_exception::build("svdBatched: cuda failed !", status2);
PointersManager manager(context, "svdBatched");
NDArray::prepareSpecialUse({pS, pU, pV}, {pA});
// choose appropriate cuda gemm api depending on data types
if(A.dataType() == DataType::DOUBLE) {
status = cusolverDnDgesvdjBatched(handle, jobz, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), reinterpret_cast<double*>(pU->getSpecialBuffer()), ldu, reinterpret_cast<double*>(pV->getSpecialBuffer()), ldv, reinterpret_cast<double*>(dWork), lwork, devInfo, gesvdjParams, bS);
}
else if(A.dataType() == DataType::FLOAT32) {
status = cusolverDnSgesvdjBatched(handle, jobz, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), reinterpret_cast<float*>(pU->getSpecialBuffer()), ldu, reinterpret_cast<float*>(pV->getSpecialBuffer()), ldv, reinterpret_cast<float*>(dWork), lwork, devInfo, gesvdjParams, bS);
}
else
throw std::invalid_argument("svdBatched: given data type is unsupported !");
if(status != CUSOLVER_STATUS_SUCCESS)
throw cuda_exception::build("svdBatched: cuda failed !", status);
manager.synchronize();
NDArray::registerSpecialUse({pS, pU, pV}, {pA});
S.assign(pS);
if(calcUV) {
U.assign(pU);
V.assign(pV);
}
for (int i = toDelete.size() - 1; i >= 0; --i)
delete toDelete[i];
if (devInfo)
cudaFree(devInfo);
if (dWork )
cudaFree(dWork);
if(handle)
cusolverDnDestroy(handle);
if(gesvdjParams)
cusolverDnDestroyGesvdjInfo(gesvdjParams);
// cudaDeviceReset();
}
////////////////////////////////////////////////////////////////////
void svd(nd4j::LaunchContext* context, const NDArray* x, const std::vector<NDArray*>& outArrs, const bool fullUV, const bool calcUV, const int switchNum) {
NDArray* S = outArrs[0];
NDArray* U = outArrs[1];
// NDArray VT = outArrs[2]->transpose();
NDArray* V = outArrs[2];
if(x->rankOf() == 2) {
// svdQR(context, *x, *S, *U, VT, fullUV, calcUV);
svdJcb(context, *x, *S, *U, *V, fullUV, calcUV);
}
else {
// svdBatched(context, *x, *S, *U, *V, fullUV, calcUV);
ResultSet *tadsU(nullptr), *tadsV(nullptr);
auto tadsX = x->allTensorsAlongDimension({x->rankOf() - 2, x->rankOf() - 1});
auto tadsS = S->allTensorsAlongDimension({S->rankOf() - 1});
if(calcUV) {
tadsU = U->allTensorsAlongDimension({U->rankOf() - 2, U->rankOf() - 1});
tadsV = V->allTensorsAlongDimension({V->rankOf() - 2, V->rankOf() - 1});
}
for (int i = 0; i < tadsX->size(); ++i)
svdJcb(context, *tadsX->at(i), *tadsS->at(i), calcUV ? *tadsU->at(i) : *S, calcUV ? *tadsV->at(i) : *S, fullUV, calcUV);
delete tadsX;
delete tadsS;
if(calcUV) {
delete tadsU;
delete tadsV;
}
}
}
}
}
}

View File

@ -323,40 +323,148 @@ void trace(nd4j::LaunchContext* context, const NDArray& input, NDArray& output)
manager.synchronize(); manager.synchronize();
} }
///////////////////////////////////////////////////////////////////
template<typename T>
__global__ static void triuBPCuda(const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int diag) {
// x and z have same shapes
const auto x = reinterpret_cast<const T*>(vx); // gradO
auto z = reinterpret_cast<T*>(vz); // gradI
__shared__ int rank, areSameOffsets; // xRank = zRank
__shared__ Nd4jLong len, totalThreads, *sharedMem; // xLen = zLen
if (threadIdx.x == 0) {
extern __shared__ unsigned char shmem[];
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
areSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
rank = shape::rank(xShapeInfo);
len = shape::length(zShapeInfo);
totalThreads = gridDim.x * blockDim.x;
//////////////////////////////////////////////////////////////////////////
template <typename T>
static void triuBP_(nd4j::LaunchContext * context, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int diagonal) {
} }
void triuBP(nd4j::LaunchContext * context, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int diagonal) { __syncthreads();
BUILD_SINGLE_SELECTOR(gradO.dataType(), triuBP_, (context, input, gradO, gradI, diagonal), LIBND4J_TYPES);
auto coords = sharedMem + threadIdx.x * rank;
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
for (Nd4jLong i = tid; i < len; i += totalThreads) {
shape::index2coords(rank, zShapeInfo + 1, i, len, coords);
const auto zOffset = shape::getOffset(0, zShapeInfo + 1, zShapeInfo + rank + 1, coords, rank);
if((coords[rank - 2] + diag > coords[rank - 1])) // row + diag > col
z[zOffset] = 0;
else
z[zOffset] = x[areSameOffsets ? zOffset : shape::getOffset(0, xShapeInfo + 1, xShapeInfo + rank + 1, coords, rank)];
}
}
///////////////////////////////////////////////////////////////////
template<typename T>
static void triuBPCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int diag) {
triuBPCuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vx, xShapeInfo, vz, zShapeInfo, diag);
}
BUILD_SINGLE_TEMPLATE(template void triuBPCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int diag), LIBND4J_TYPES);
///////////////////////////////////////////////////////////////////
void triuBP(nd4j::LaunchContext* context, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int diagonal) {
const int threadsPerBlock = MAX_NUM_THREADS / 4;
const int blocksPerGrid = (gradO.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
const int sharedMem = threadsPerBlock * sizeof(Nd4jLong) * gradO.rankOf() + 128;
PointersManager manager(context, "triuBP");
NDArray::prepareSpecialUse({&gradI}, {&gradO});
BUILD_SINGLE_SELECTOR(gradI.dataType(), triuBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), gradO.getSpecialBuffer(), gradO.getSpecialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), diagonal), LIBND4J_TYPES);
NDArray::registerSpecialUse({&gradI}, {&gradO});
manager.synchronize();
}
///////////////////////////////////////////////////////////////////
template<typename T>
__global__ static void tileBPCuda(const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, Nd4jLong* globMem) {
// x and z have same shapes
const auto x = reinterpret_cast<const T*>(vx); // gradO
auto z = reinterpret_cast<T*>(vz); // gradI
__shared__ int xRank, zRank; // xRank >= zRank
__shared__ Nd4jLong numOfXOffsets, zLen, totalThreads, *sharedMem; // xLen >= zLen
if (threadIdx.x == 0) {
extern __shared__ unsigned char shmem[];
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
xRank = shape::rank(zShapeInfo);
zLen = shape::length(zShapeInfo);
numOfXOffsets = shape::length(xShapeInfo) / zLen;
totalThreads = gridDim.x * blockDim.x;
} }
BUILD_SINGLE_TEMPLATE(template void triuBP_, (nd4j::LaunchContext * context, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int diagonal), LIBND4J_TYPES); __syncthreads();
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
auto memBuff = sharedMem + threadIdx.x * 2 * xRank;
auto xOffsets = globMem + tid * numOfXOffsets;
for (Nd4jLong i = tid; i < zLen; i += totalThreads) {
const auto zOffset = shape::getIndexOffset(i, zShapeInfo, zLen);
shape::outerArrayOffsets(xOffsets, i, xShapeInfo, zShapeInfo, memBuff);
z[zOffset] = x[xOffsets[0]]; // first offset
for (Nd4jLong j = 1; j < numOfXOffsets; ++j) // rest offsets
z[zOffset] += x[xOffsets[j]];
}
}
///////////////////////////////////////////////////////////////////
template<typename T>
static void tileBPCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, Nd4jLong* globMem) {
tileBPCuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vx, xShapeInfo, vz, zShapeInfo, globMem);
}
BUILD_SINGLE_TEMPLATE(template void tileBPCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, Nd4jLong* globMem), FLOAT_TYPES);
//////////////////////////////////////////////////////////////////////////
void tileBP(nd4j::LaunchContext * context, const NDArray& gradO /*input*/, NDArray& gradI /*output*/, const std::vector<Nd4jLong> reps) {
NDArray memBuff('c', gradO.getShapeAsVector(), nd4j::DataType::INT64, context); // empty auxiliary array for storing device memory which will be used in kernel calculations
const int threadsPerBlock = MAX_NUM_THREADS / 4;
const int blocksPerGrid = (gradI.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
const int sharedMem = threadsPerBlock * sizeof(Nd4jLong) * 2 * gradO.rankOf() + 128;
PointersManager manager(context, "tileBP");
NDArray::prepareSpecialUse({&gradI}, {&gradO, &memBuff});
BUILD_SINGLE_SELECTOR(gradI.dataType(), tileBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), gradO.getSpecialBuffer(), gradO.getSpecialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), reinterpret_cast<Nd4jLong*>(memBuff.specialBuffer())), FLOAT_TYPES);
NDArray::registerSpecialUse({&gradI}, {&gradO, &memBuff});
manager.synchronize();
}
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template <typename T> template <typename T>
@ -1036,18 +1144,6 @@ void concat(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs,
scatterSimpleKernel<X,Y><<<256, 256, 1024, *context->getCudaStream()>>>(input.getSpecialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), xLength, packX.numberOfTads(), indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), iLength, updates.getSpecialBuffer(), updates.getSpecialShapeInfo(), uLength); scatterSimpleKernel<X,Y><<<256, 256, 1024, *context->getCudaStream()>>>(input.getSpecialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), xLength, packX.numberOfTads(), indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), iLength, updates.getSpecialBuffer(), updates.getSpecialShapeInfo(), uLength);
} }
//////////////////////////////////////////////////////////////////////////
template <typename T>
static void tileBP_(nd4j::LaunchContext * context, const NDArray& gradO /*input*/, NDArray& gradI /*output*/, const std::vector<Nd4jLong> reps) {
}
void tileBP(nd4j::LaunchContext * context, const NDArray& gradO /*input*/, NDArray& gradI /*output*/, const std::vector<Nd4jLong> reps) {
BUILD_SINGLE_SELECTOR(gradI.dataType(), tileBP_, (context, gradO, gradI, reps), FLOAT_TYPES);
}
BUILD_SINGLE_TEMPLATE(template void tileBP_, (nd4j::LaunchContext * context, const NDArray& gradO /*input*/, NDArray& gradI /*output*/, const std::vector<Nd4jLong> reps), FLOAT_TYPES);
void scatterSimple(nd4j::LaunchContext * context, const int opId, NDArray& input, const NDArray& updates, const NDArray& indices, const std::vector<int>& dimensions) { void scatterSimple(nd4j::LaunchContext * context, const int opId, NDArray& input, const NDArray& updates, const NDArray& indices, const std::vector<int>& dimensions) {
auto xType = input.dataType(); auto xType = input.dataType();

View File

@ -20,63 +20,68 @@
#include <ops/declarable/helpers/helpers.h> #include <ops/declarable/helpers/helpers.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
void dilation2d(nd4j::LaunchContext * context, NDArray *input, NDArray *weights, NDArray *output, int stride_rows, int stride_cols, int rate_rows, int rate_cols, int pad_top, int pad_left); //////////////////////////////////////////////////////////////////////
void dilation2d(nd4j::LaunchContext* context, NDArray *input, NDArray *weights, NDArray *output, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW);
FORCEINLINE Nd4jStatus outputSize(nd4j::LaunchContext * context, int input_size, int filter_size, int dilation_rate, int stride, bool isSameMode, int *output_size, int *padding_before, int *padding_after) { //////////////////////////////////////////////////////////////////////
if (stride <= 0) FORCEINLINE Nd4jStatus outputSize(nd4j::LaunchContext * context, const int inSize, const int k, const int d, const int s, bool isSameMode, int *outSize, int *padding_before, int *padding_after) {
return Status::THROW("Dilation2D: Stride must be > 0"); if (s <= 0)
return Status::THROW("Dilation2D: Stride must be > 0");
if (dilation_rate < 1) if (d < 1)
return Status::THROW("Dilation2D: Dilation rate must be >= 1"); return Status::THROW("Dilation2D: Dilation rate must be >= 1");
int effective_filter_size = (filter_size - 1) * dilation_rate + 1; int kEff = (k - 1) * d + 1;
if (isSameMode) { if (isSameMode) {
*output_size = (input_size + stride - 1) / stride; *outSize = (inSize + s - 1) / s;
const int padding_needed = nd4j::math::nd4j_max<int>(0, (*output_size - 1) * stride + effective_filter_size -input_size); const int padding_needed = nd4j::math::nd4j_max<int>(0, (*outSize - 1) * s + kEff -inSize);
*padding_before = padding_needed / 2; *padding_before = padding_needed / 2;
*padding_after = padding_needed - *padding_before; *padding_after = padding_needed - *padding_before;
} else { } else {
*output_size = (input_size - effective_filter_size + stride) / stride; *outSize = (inSize - kEff + s) / s;
*padding_before = *padding_after = 0; *padding_before = *padding_after = 0;
}
if (*output_size < 0)
return Status::THROW("Dilation2D: output_size has negative value");
return Status::OK();
} }
if (*outSize < 0)
return Status::THROW("Dilation2D: outSize has negative value");
FORCEINLINE Nd4jStatus dilation_hw(nd4j::LaunchContext * context, Nd4jLong *in, Nd4jLong *wh, std::vector<int> &strides, std::vector<int> &rates, bool isSameMode, int *stride_rows, int *stride_cols, int *rate_rows, int *rate_cols, int *pad_top, int *pad_left, int *out_rows, int *out_cols) { return Status::OK();
const int input_rows = shape::sizeAt(in, 1); }
const int input_cols = shape::sizeAt(in, 2);
const int depth = shape::sizeAt(in, 3);
*stride_rows = strides[1]; //////////////////////////////////////////////////////////////////////
*stride_cols = strides[2]; FORCEINLINE Nd4jStatus dilation_hw(nd4j::LaunchContext * context, Nd4jLong *in, Nd4jLong *wh, std::vector<int> &strides, std::vector<int> &rates, bool isSameMode, int *sH, int *sW, int *pH, int *pW, int *dH, int *dW, int *oH, int *oW) {
*rate_rows = rates[1]; const int iH = shape::sizeAt(in, 1);
*rate_cols = rates[2]; const int iW = shape::sizeAt(in, 2);
const int iC = shape::sizeAt(in, 3);
const int filter_rows = shape::sizeAt(wh, 0); *sH = strides[1];
const int filter_cols = shape::sizeAt(wh, 1); *sW = strides[2];
*dH = rates[1];
*dW = rates[2];
const int filter_rows_eff = filter_rows + (filter_rows - 1) * (*rate_rows - 1); const int kH = shape::sizeAt(wh, 0);
const int filter_cols_eff = filter_cols + (filter_cols - 1) * (*rate_cols - 1); const int kW = shape::sizeAt(wh, 1);
const int kHeff = kH + (kH - 1) * (*dH - 1);
const int kWeff = kW + (kW - 1) * (*dW - 1);
int padding_after_unusedA, padding_after_unusedB;
if (outputSize(context, iH, kHeff, 1, *sH, isSameMode, oH, pH, &padding_after_unusedA) != Status::OK())
return Status::THROW("Dilation2D: bad height");
if (outputSize(context, iW, kWeff, 1, *sW, isSameMode, oW, pW, &padding_after_unusedA) != Status::OK())
return Status::THROW("Dilation2D: bad width");
return Status::OK();
}
int padding_after_unusedA, padding_after_unusedB;
if (outputSize(context, input_rows, filter_rows_eff, 1, *stride_rows, isSameMode, out_rows, pad_top, &padding_after_unusedA) != Status::OK())
return Status::THROW("Dilation2D: bad height");
if (outputSize(context, input_cols, filter_cols_eff, 1, *stride_cols, isSameMode, out_cols, pad_left, &padding_after_unusedA) != Status::OK())
return Status::THROW("Dilation2D: bad width");
return Status::OK();
}
} }
} }
} }

View File

@ -30,7 +30,7 @@ namespace helpers {
T lup(nd4j::LaunchContext * context, NDArray* input, NDArray* compound, NDArray* permutation); T lup(nd4j::LaunchContext * context, NDArray* input, NDArray* compound, NDArray* permutation);
int determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output); int determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output);
int log_abs_determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output); int logAbsDeterminant(nd4j::LaunchContext * context, NDArray* input, NDArray* output);
int inverse(nd4j::LaunchContext * context, NDArray* input, NDArray* output); int inverse(nd4j::LaunchContext * context, NDArray* input, NDArray* output);

View File

@ -26,11 +26,11 @@
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
void scatter(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock); void scatter(nd4j::LaunchContext* context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock);
void scatterND(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock); void scatterND(nd4j::LaunchContext* context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock);
void scatterForLoss(nd4j::LaunchContext *context, const NDArray& indices, const NDArray& updates, NDArray& output, const bool calcGrad); void scatterForLoss(nd4j::LaunchContext* context, const NDArray& indices, NDArray& updates, NDArray& output, const bool calcGrad);
} }
} }
} }

View File

@ -30,7 +30,7 @@ namespace helpers {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
// svd operation, this function is not method of SVD class, it is standalone function // svd operation, this function is not method of SVD class, it is standalone function
void svd(nd4j::LaunchContext * context, const NDArray* x, const std::vector<NDArray*>& outArrs, const bool fullUV, const bool calcUV, const int switchNum); void svd(nd4j::LaunchContext* context, const NDArray* x, const std::vector<NDArray*>& outArrs, const bool fullUV, const bool calcUV, const int switchNum);
} }

View File

@ -1197,7 +1197,7 @@ inline __device__ float nd4j_atomicMul<float>(float* address, float val) {
do { do {
assumed = old; assumed = old;
old = atomicCAS(address_as_ull, assumed, __float_as_int(val * old = atomicCAS(address_as_ull, assumed, __float_as_int(val *
__float_as_int(assumed))); __int_as_float(assumed)));
} while (assumed != old); } while (assumed != old);
return __int_as_float(old); return __int_as_float(old);
} }

View File

@ -2595,7 +2595,7 @@ TEST_F(DeclarableOpsTests1, sru_bi_1) {
NDArray input('c', {N,bS,2*K}, nd4j::DataType::DOUBLE); NDArray input('c', {N,bS,2*K}, nd4j::DataType::DOUBLE);
NDArray weights('c', {2*K,6*K}, nd4j::DataType::DOUBLE); NDArray weights('c', {2*K,6*K}, nd4j::DataType::DOUBLE);
NDArray bias('c', {1,4*K}, nd4j::DataType::DOUBLE); NDArray bias('c', {4*K}, nd4j::DataType::DOUBLE);
NDArray init('c', {bS,2*K}, nd4j::DataType::DOUBLE); NDArray init('c', {bS,2*K}, nd4j::DataType::DOUBLE);
NDArray mask('c', {bS,2*K}, nd4j::DataType::DOUBLE); NDArray mask('c', {bS,2*K}, nd4j::DataType::DOUBLE);
NDArray expState('c', {N,bS,2*K}, {1.02857, 1.02857, 1.02857, 1.11288, 1.11288, 1.11288, 1.02857, 1.02857, 1.02857, 1.11288, 1.11288, 1.11288, 1.0569, 1.0569, 1.0569, 1.08501, 1.08501, 1.08501, 1.0569, 1.0569, 1.0569, 1.08501, 1.08501, 1.08501, 1.08501, 1.08501, 1.08501, 1.0569, 1.0569, 1.0569, 1.08501, 1.08501, 1.08501, 1.0569, 1.0569, 1.0569, 1.11288, 1.11288, 1.11288, 1.02857, 1.02857, 1.02857, 1.11288, 1.11288, 1.11288, 1.02857, 1.02857, 1.02857}); NDArray expState('c', {N,bS,2*K}, {1.02857, 1.02857, 1.02857, 1.11288, 1.11288, 1.11288, 1.02857, 1.02857, 1.02857, 1.11288, 1.11288, 1.11288, 1.0569, 1.0569, 1.0569, 1.08501, 1.08501, 1.08501, 1.0569, 1.0569, 1.0569, 1.08501, 1.08501, 1.08501, 1.08501, 1.08501, 1.08501, 1.0569, 1.0569, 1.0569, 1.08501, 1.08501, 1.08501, 1.0569, 1.0569, 1.0569, 1.11288, 1.11288, 1.11288, 1.02857, 1.02857, 1.02857, 1.11288, 1.11288, 1.11288, 1.02857, 1.02857, 1.02857});
@ -2635,7 +2635,7 @@ TEST_F(DeclarableOpsTests1, sru_bi_bp_1) {
auto input = NDArrayFactory::create<double>('c', {N,bS,2*K}); auto input = NDArrayFactory::create<double>('c', {N,bS,2*K});
auto weights = NDArrayFactory::create<double>('c', {2*K,6*K}); auto weights = NDArrayFactory::create<double>('c', {2*K,6*K});
auto bias = NDArrayFactory::create<double>('c', {1,4*K}); auto bias = NDArrayFactory::create<double>('c', {4*K});
auto init = NDArrayFactory::create<double>('c', {bS,2*K}); auto init = NDArrayFactory::create<double>('c', {bS,2*K});
auto mask = NDArrayFactory::create<double>('c', {bS,2*K}); auto mask = NDArrayFactory::create<double>('c', {bS,2*K});
NDArray state('c', {N,bS,2*K}, stateBuff); NDArray state('c', {N,bS,2*K}, stateBuff);
@ -2646,8 +2646,8 @@ TEST_F(DeclarableOpsTests1, sru_bi_bp_1) {
NDArray expGradX('c', {N,bS,2*K}, expGradXBuff); NDArray expGradX('c', {N,bS,2*K}, expGradXBuff);
NDArray expGradW('c', {N,2*K,6*K}, expGradWBuff); NDArray expGradW('c', {N,2*K,6*K}, expGradWBuff);
auto expGradB = NDArrayFactory::create<double>('c', {1,4*K}); auto expGradB = NDArrayFactory::create<double>('c', {4*K});
gradBias.reduceAlongDimension(reduce::Sum, &expGradB, {0}, false, true); // [bS x 4K] -> [1 x 4K] gradBias.reduceAlongDimension(reduce::Sum, &expGradB, {0}); // [bS, 4K] -> [4K]
NDArray expGradInit('c', {bS,2*K}, expGradInitBuff); NDArray expGradInit('c', {bS,2*K}, expGradInitBuff);
input.assign(1.5); input.assign(1.5);

View File

@ -391,30 +391,6 @@ TEST_F(DeclarableOpsTests10, CosineDistance_SGO_Test_2) {
delete res; delete res;
} }
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, svd_test11) {
auto x = NDArrayFactory::create<double>('c', {3,3}, {1.,2.,3.,4.,5.,6.,7.,8.,9.});
auto expS = NDArrayFactory::create<double>('c', {3});
auto expU = NDArrayFactory::create<double>('c', {3,3});
auto expV = NDArrayFactory::create<double>('c', {3,3});
nd4j::ops::svd op;
auto results = op.execute({&x}, {}, {0, 1, 16});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto s = results->at(0);
auto u = results->at(1);
auto v = results->at(2);
ASSERT_TRUE(expS.isSameShape(s));
ASSERT_TRUE(expU.isSameShape(u));
ASSERT_TRUE(expV.isSameShape(v));
delete results;
}
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, TestMarixBandPart_Test_1) { TEST_F(DeclarableOpsTests10, TestMarixBandPart_Test_1) {

View File

@ -2093,21 +2093,30 @@ TEST_F(DeclarableOpsTests3, svd_test1) {
auto *u = results->at(1); auto *u = results->at(1);
auto *v = results->at(2); auto *v = results->at(2);
ASSERT_TRUE(expS.equalsTo(s));
ASSERT_TRUE(expU.equalsTo(u));
ASSERT_TRUE(expV.equalsTo(v));
ASSERT_TRUE(expS.isSameShape(s)); ASSERT_TRUE(expS.isSameShape(s));
ASSERT_TRUE(expU.isSameShape(u)); ASSERT_TRUE(expU.isSameShape(u));
ASSERT_TRUE(expV.isSameShape(v)); ASSERT_TRUE(expV.isSameShape(v));
ASSERT_TRUE(expS.equalsTo(s));
if(nd4j::Environment::getInstance()->isCPU()) {
ASSERT_TRUE(expU.equalsTo(u));
ASSERT_TRUE(expV.equalsTo(v));
}
else {
for(uint i = 0; i < expU.lengthOf(); ++i)
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.t<float>(i)), nd4j::math::nd4j_abs(u->t<float>(i)), 1e-5);
for(uint i = 0; i < expV.lengthOf(); ++i)
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.t<float>(i)), nd4j::math::nd4j_abs(v->t<float>(i)), 1e-5);
}
delete results; delete results;
} }
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests3, svd_test2) { TEST_F(DeclarableOpsTests3, svd_test2) {
auto x= NDArrayFactory::create<float>('c', {7,6}, {0. ,-9. ,-6 ,9 ,-10 ,-12 ,2 ,13 ,5 ,-11 ,20 ,-17 ,1 ,-2 ,-11 ,3 ,-8 ,3 ,-14 ,19 ,-20 ,20 ,-17 ,-5 ,6 ,-16 ,0 ,-1 ,-16 ,11 ,7 ,-19 ,2 ,-17 ,17 ,-16, 4, -9, 1, -15, 7, -2}); auto x = NDArrayFactory::create<float>('c', {7,6}, {0. ,-9. ,-6 ,9 ,-10 ,-12 ,2 ,13 ,5 ,-11 ,20 ,-17 ,1 ,-2 ,-11 ,3 ,-8 ,3 ,-14 ,19 ,-20 ,20 ,-17 ,-5 ,6 ,-16 ,0 ,-1 ,-16 ,11 ,7 ,-19 ,2 ,-17 ,17 ,-16, 4, -9, 1, -15, 7, -2});
auto expS= NDArrayFactory::create<float>('c', {6}, {56.76573, 39.11776, 26.00713, 11.83606, 6.16578, 3.99672}); auto expS= NDArrayFactory::create<float>('c', {6}, {56.76573, 39.11776, 26.00713, 11.83606, 6.16578, 3.99672});
auto expU= NDArrayFactory::create<float>('c', {7,7}, {-0.13417,-0.12443, -0.68854, 0.5196 , 0.21706, 0.03974, 0.41683, 0.347 , 0.62666, -0.04964, -0.01912, 0.66932, 0.1457 , -0.12183,-0.17329,-0.14666, -0.19639, -0.55355, 0.0614 , 0.75729, 0.1619 ,-0.64703, 0.37056, -0.37398, -0.32922, -0.0186 , -0.35656, -0.26134,-0.08027,-0.64405, -0.0127 , -0.06934, 0.59287, -0.14956, -0.44712, 0.55906,-0.06235, -0.58017, -0.12911, -0.359 , -0.00393, -0.44877, 0.30645,-0.11953, -0.09083, -0.54163, 0.14283, -0.50417, 0.56178}); auto expU= NDArrayFactory::create<float>('c', {7,7}, {-0.13417,-0.12443, -0.68854, 0.5196 , 0.21706, 0.03974, 0.41683, 0.347 , 0.62666, -0.04964, -0.01912, 0.66932, 0.1457 , -0.12183,-0.17329,-0.14666, -0.19639, -0.55355, 0.0614 , 0.75729, 0.1619 ,-0.64703, 0.37056, -0.37398, -0.32922, -0.0186 , -0.35656, -0.26134,-0.08027,-0.64405, -0.0127 , -0.06934, 0.59287, -0.14956, -0.44712, 0.55906,-0.06235, -0.58017, -0.12911, -0.359 , -0.00393, -0.44877, 0.30645,-0.11953, -0.09083, -0.54163, 0.14283, -0.50417, 0.56178});
auto expV= NDArrayFactory::create<float>('c', {6,6}, {0.2508 ,-0.2265 , 0.01689, 0.04486, 0.53132, 0.77537,-0.32281, 0.74559, 0.41845, -0.13821, 0.37642, 0.06315, 0.33139,-0.05528, 0.47186, 0.73171, 0.18905, -0.3055 ,-0.57263, 0.06276,-0.09542, 0.59396, -0.36152, 0.419 , 0.59193, 0.4361 , 0.13557, -0.03632, -0.5755 , 0.32944,-0.21165,-0.44227, 0.75794, -0.29895, -0.27993, 0.13187}); auto expV= NDArrayFactory::create<float>('c', {6,6}, {0.2508 ,-0.2265 , 0.01689, 0.04486, 0.53132, 0.77537,-0.32281, 0.74559, 0.41845, -0.13821, 0.37642, 0.06315, 0.33139,-0.05528, 0.47186, 0.73171, 0.18905, -0.3055 ,-0.57263, 0.06276,-0.09542, 0.59396, -0.36152, 0.419 , 0.59193, 0.4361 , 0.13557, -0.03632, -0.5755 , 0.32944,-0.21165,-0.44227, 0.75794, -0.29895, -0.27993, 0.13187});
@ -2121,14 +2130,23 @@ TEST_F(DeclarableOpsTests3, svd_test2) {
auto *u = results->at(1); auto *u = results->at(1);
auto *v = results->at(2); auto *v = results->at(2);
ASSERT_TRUE(expS.equalsTo(s));
ASSERT_TRUE(expU.equalsTo(u));
ASSERT_TRUE(expV.equalsTo(v));
ASSERT_TRUE(expS.isSameShape(s)); ASSERT_TRUE(expS.isSameShape(s));
ASSERT_TRUE(expU.isSameShape(u)); ASSERT_TRUE(expU.isSameShape(u));
ASSERT_TRUE(expV.isSameShape(v)); ASSERT_TRUE(expV.isSameShape(v));
ASSERT_TRUE(expS.equalsTo(s));
if(nd4j::Environment::getInstance()->isCPU()) {
ASSERT_TRUE(expU.equalsTo(u));
ASSERT_TRUE(expV.equalsTo(v));
}
else {
for(uint i = 0; i < expU.lengthOf(); ++i)
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.t<float>(i)), nd4j::math::nd4j_abs(u->t<float>(i)), 1e-5);
for(uint i = 0; i < expV.lengthOf(); ++i)
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.t<float>(i)), nd4j::math::nd4j_abs(v->t<float>(i)), 1e-5);
}
delete results; delete results;
} }
@ -2149,14 +2167,23 @@ TEST_F(DeclarableOpsTests3, svd_test3) {
auto *u = results->at(1); auto *u = results->at(1);
auto *v = results->at(2); auto *v = results->at(2);
ASSERT_TRUE(expS.equalsTo(s));
ASSERT_TRUE(expU.equalsTo(u));
ASSERT_TRUE(expV.equalsTo(v));
ASSERT_TRUE(expS.isSameShape(s)); ASSERT_TRUE(expS.isSameShape(s));
ASSERT_TRUE(expU.isSameShape(u)); ASSERT_TRUE(expU.isSameShape(u));
ASSERT_TRUE(expV.isSameShape(v)); ASSERT_TRUE(expV.isSameShape(v));
ASSERT_TRUE(expS.equalsTo(s));
if(nd4j::Environment::getInstance()->isCPU()) {
ASSERT_TRUE(expU.equalsTo(u));
ASSERT_TRUE(expV.equalsTo(v));
}
else {
for(uint i = 0; i < expU.lengthOf(); ++i)
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.t<float>(i)), nd4j::math::nd4j_abs(u->t<float>(i)), 1e-5);
for(uint i = 0; i < expV.lengthOf(); ++i)
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.t<float>(i)), nd4j::math::nd4j_abs(v->t<float>(i)), 1e-5);
}
delete results; delete results;
} }
@ -2177,14 +2204,23 @@ TEST_F(DeclarableOpsTests3, svd_test4) {
auto *u = results->at(1); auto *u = results->at(1);
auto *v = results->at(2); auto *v = results->at(2);
ASSERT_TRUE(expS.equalsTo(s));
ASSERT_TRUE(expU.equalsTo(u));
ASSERT_TRUE(expV.equalsTo(v));
ASSERT_TRUE(expS.isSameShape(s)); ASSERT_TRUE(expS.isSameShape(s));
ASSERT_TRUE(expU.isSameShape(u)); ASSERT_TRUE(expU.isSameShape(u));
ASSERT_TRUE(expV.isSameShape(v)); ASSERT_TRUE(expV.isSameShape(v));
ASSERT_TRUE(expS.equalsTo(s));
if(nd4j::Environment::getInstance()->isCPU()) {
ASSERT_TRUE(expU.equalsTo(u));
ASSERT_TRUE(expV.equalsTo(v));
}
else {
for(uint i = 0; i < expU.lengthOf(); ++i)
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.t<float>(i)), nd4j::math::nd4j_abs(u->t<float>(i)), 1e-5);
for(uint i = 0; i < expV.lengthOf(); ++i)
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.t<float>(i)), nd4j::math::nd4j_abs(v->t<float>(i)), 1e-5);
}
delete results; delete results;
} }
@ -2205,14 +2241,23 @@ TEST_F(DeclarableOpsTests3, svd_test5) {
auto *u = results->at(1); auto *u = results->at(1);
auto *v = results->at(2); auto *v = results->at(2);
ASSERT_TRUE(expS.equalsTo(s));
ASSERT_TRUE(expU.equalsTo(u));
ASSERT_TRUE(expV.equalsTo(v));
ASSERT_TRUE(expS.isSameShape(s)); ASSERT_TRUE(expS.isSameShape(s));
ASSERT_TRUE(expU.isSameShape(u)); ASSERT_TRUE(expU.isSameShape(u));
ASSERT_TRUE(expV.isSameShape(v)); ASSERT_TRUE(expV.isSameShape(v));
ASSERT_TRUE(expS.equalsTo(s));
if(nd4j::Environment::getInstance()->isCPU()) {
ASSERT_TRUE(expU.equalsTo(u));
ASSERT_TRUE(expV.equalsTo(v));
}
else {
for(uint i = 0; i < expU.lengthOf(); ++i)
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.t<float>(i)), nd4j::math::nd4j_abs(u->t<float>(i)), 1e-5);
for(uint i = 0; i < expV.lengthOf(); ++i)
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.t<float>(i)), nd4j::math::nd4j_abs(v->t<float>(i)), 1e-5);
}
delete results; delete results;
} }
@ -2220,56 +2265,27 @@ TEST_F(DeclarableOpsTests3, svd_test5) {
TEST_F(DeclarableOpsTests3, svd_test6) { TEST_F(DeclarableOpsTests3, svd_test6) {
auto x= NDArrayFactory::create<float>('c', {2,2,5,5}, {-7. ,17 ,4 ,-10 ,5 ,1 ,-5 ,-19 ,13 ,-8 ,9 ,13 ,19 ,13 ,-2 auto x= NDArrayFactory::create<float>('c', {2,2,5,5}, {-7. ,17 ,4 ,-10 ,5 ,1 ,-5 ,-19 ,13 ,-8 ,9 ,13 ,19 ,13 ,-2
,-8 ,10 ,-9 ,0 ,-20 ,-2 ,14 ,19 ,5 ,-18 ,4 ,-13 ,12 ,-10 ,-8 ,10 ,-9 ,0 ,-20 ,-2 ,14 ,19 ,5 ,-18 ,4 ,-13 ,12 ,-10 ,5 ,-10 ,-10 ,17 ,-5 ,-2 ,10 ,5 ,-4 ,-11 ,15 ,-3 ,15 ,-17
,5 ,-10 ,-10 ,17 ,-5 ,-2 ,10 ,5 ,-4 ,-11 ,15 ,-3 ,15 ,-17 ,-20 ,-10 ,-4 ,12 ,-9 ,16 ,13 ,10 ,-19 ,2 ,-9 ,-10 ,8 ,-2 ,-4 ,3 ,7 ,10 ,-19 ,-11 ,-4 ,-6 ,2 ,-12 ,6 ,-4 ,-14 ,14
,-20 ,-10 ,-4 ,12 ,-9 ,16 ,13 ,10 ,-19 ,2 ,-9 ,-10 ,8 ,-2 ,16 ,7 ,19 ,-17 ,2 ,-14 ,5 ,-1 ,16 ,19 ,-11 ,-14 ,-16 ,-19 ,15 ,-18 ,-12 ,-16 ,16 ,1 ,5 ,7 ,8 ,2 ,13 ,-3 ,6 ,2 ,-5});
,-4 ,3 ,7 ,10 ,-19 ,-11 ,-4 ,-6 ,2 ,-12 ,6 ,-4 ,-14 ,14
,16 ,7 ,19 ,-17 ,2 ,-14 ,5 ,-1 ,16 ,19 ,-11 ,-14 ,-16
,-19 ,15 ,-18 ,-12 ,-16 ,16 ,1 ,5 ,7 ,8 ,2 ,13 ,-3 ,6 ,2 ,-5});
auto expS= NDArrayFactory::create<float>('c', {2,2,5}, {40.95395, 31.46869, 24.79993, 12.33768, 1.80031, auto expS= NDArrayFactory::create<float>('c', {2,2,5}, {40.95395, 31.46869, 24.79993, 12.33768, 1.80031,
38.18412, 31.52287, 23.52755, 11.79484, 1.90195, 38.18412, 31.52287, 23.52755, 11.79484, 1.90195,
39.34498, 32.54861, 17.52492, 7.03003, 2.2399, 39.34498, 32.54861, 17.52492, 7.03003, 2.2399,
44.72126, 32.3164 , 16.60139, 6.88783, 0.78122}); 44.72126, 32.3164 , 16.60139, 6.88783, 0.78122});
auto expU= NDArrayFactory::create<float>('c', {2,2,5,5}, {0.25441, 0.16908, -0.68564, 0.58844, -0.30054, auto expU= NDArrayFactory::create<float>('c', {2,2,5,5}, {0.25441, 0.16908, -0.68564, 0.58844, -0.30054,
-0.32285, -0.58332, 0.3451 , 0.4746 , -0.45953, -0.32285, -0.58332, 0.3451 , 0.4746 , -0.45953,0.58332, 0.10605, 0.51533, 0.50234, 0.36136,0.12588, -0.73123, -0.37812, -0.00215, 0.55361,
0.58332, 0.10605, 0.51533, 0.50234, 0.36136, 0.68915, -0.2919 , 0.04767, -0.4197 , -0.51132,0.44464, -0.25326, -0.42493, -0.01712, -0.74653,0.516 , -0.16688, 0.1854 , -0.77155, 0.27611,
0.12588, -0.73123, -0.37812, -0.00215, 0.55361, -0.19321, -0.14317, -0.85886, -0.15224, 0.42585,-0.60155, -0.68323, 0.18819, -0.29053, -0.22696,-0.36993, 0.64862, -0.10956, -0.54483, -0.36552,
0.68915, -0.2919 , 0.04767, -0.4197 , -0.51132, -0.57697, -0.32277, 0.11229, 0.55495, 0.4923 ,-0.02937, 0.01689, -0.63257, 0.57075, -0.52245,-0.56002, -0.2036 , -0.53119, -0.6022 , 0.01017,
0.44464, -0.25326, -0.42493, -0.01712, -0.74653, -0.33605, -0.35257, 0.53215, -0.04936, -0.69075,0.48958, -0.85427, -0.14796, -0.03449, 0.08633,0.15008, 0.60996, 0.31071, -0.67721, 0.22421,
0.516 , -0.16688, 0.1854 , -0.77155, 0.27611, 0.67717, -0.59857, 0.04372, -0.2565 , 0.33979,0.68116, 0.49852, -0.13441, 0.51374, -0.07421,-0.20066, 0.04504, 0.42865, 0.44418, 0.75939,0.12113, -0.13826, 0.83651, 0.11988, -0.50209});
-0.19321, -0.14317, -0.85886, -0.15224, 0.42585,
-0.60155, -0.68323, 0.18819, -0.29053, -0.22696,
-0.36993, 0.64862, -0.10956, -0.54483, -0.36552,
-0.57697, -0.32277, 0.11229, 0.55495, 0.4923 ,
-0.02937, 0.01689, -0.63257, 0.57075, -0.52245,
-0.56002, -0.2036 , -0.53119, -0.6022 , 0.01017,
-0.33605, -0.35257, 0.53215, -0.04936, -0.69075,
0.48958, -0.85427, -0.14796, -0.03449, 0.08633,
0.15008, 0.60996, 0.31071, -0.67721, 0.22421,
0.67717, -0.59857, 0.04372, -0.2565 , 0.33979,
0.68116, 0.49852, -0.13441, 0.51374, -0.07421,
-0.20066, 0.04504, 0.42865, 0.44418, 0.75939,
0.12113, -0.13826, 0.83651, 0.11988, -0.50209});
auto expV= NDArrayFactory::create<float>('c', {2,2,5,5}, {0.01858, 0.17863, 0.51259, 0.14048, 0.82781, auto expV= NDArrayFactory::create<float>('c', {2,2,5,5}, {0.01858, 0.17863, 0.51259, 0.14048, 0.82781,
0.59651, -0.13439, -0.395 , 0.66979, 0.14654, 0.59651, -0.13439, -0.395 , 0.66979, 0.14654,0.73731, 0.47061, 0.19357, -0.41127, -0.16817,0.1047 , -0.29727, 0.73711, 0.38235, -0.45951,
0.73731, 0.47061, 0.19357, -0.41127, -0.16817, -0.29873, 0.80012, -0.02078, 0.4651 , -0.23201,-0.05314, -0.0419 , -0.52146, 0.77792, 0.344 ,-0.66438, 0.05648, 0.03756, -0.31531, 0.67422,
0.1047 , -0.29727, 0.73711, 0.38235, -0.45951, 0.74471, 0.01504, -0.03081, -0.24335, 0.62049,0.03172, 0.91947, 0.30828, 0.23713, 0.04796,-0.01311, 0.38652, -0.79415, -0.42423, -0.19945,
-0.29873, 0.80012, -0.02078, 0.4651 , -0.23201, -0.13783, -0.54667, -0.58527, 0.49955, 0.3001 ,0.85214, 0.01628, 0.02688, -0.02891, 0.52157,0.16608, -0.20181, 0.61371, 0.69894, -0.25794,
-0.05314, -0.0419 , -0.52146, 0.77792, 0.344 , 0.45726, -0.33952, -0.32659, -0.18938, -0.73015,0.13486, 0.73816, -0.41646, 0.47458, -0.1956 ,0.5536 , -0.137 , 0.64688, 0.50536, 0.03017,
-0.66438, 0.05648, 0.03756, -0.31531, 0.67422, -0.51827, -0.31837, -0.16732, 0.71378, -0.30425,-0.39314, 0.15266, 0.63693, -0.30945, -0.5663 ,-0.51981, 0.03325, 0.37603, 0.05147, 0.76462,-0.01282, 0.92491, -0.08042, 0.36977, -0.03428});
0.74471, 0.01504, -0.03081, -0.24335, 0.62049,
0.03172, 0.91947, 0.30828, 0.23713, 0.04796,
-0.01311, 0.38652, -0.79415, -0.42423, -0.19945,
-0.13783, -0.54667, -0.58527, 0.49955, 0.3001 ,
0.85214, 0.01628, 0.02688, -0.02891, 0.52157,
0.16608, -0.20181, 0.61371, 0.69894, -0.25794,
0.45726, -0.33952, -0.32659, -0.18938, -0.73015,
0.13486, 0.73816, -0.41646, 0.47458, -0.1956 ,
0.5536 , -0.137 , 0.64688, 0.50536, 0.03017,
-0.51827, -0.31837, -0.16732, 0.71378, -0.30425,
-0.39314, 0.15266, 0.63693, -0.30945, -0.5663 ,
-0.51981, 0.03325, 0.37603, 0.05147, 0.76462,
-0.01282, 0.92491, -0.08042, 0.36977, -0.03428});
nd4j::ops::svd op; nd4j::ops::svd op;
auto results = op.execute({&x}, {}, {1, 1, 16}); auto results = op.execute({&x}, {}, {1, 1, 16});
@ -2280,14 +2296,23 @@ TEST_F(DeclarableOpsTests3, svd_test6) {
auto *u = results->at(1); auto *u = results->at(1);
auto *v = results->at(2); auto *v = results->at(2);
ASSERT_TRUE(expS.equalsTo(s));
ASSERT_TRUE(expU.equalsTo(u));
ASSERT_TRUE(expV.equalsTo(v));
ASSERT_TRUE(expS.isSameShape(s)); ASSERT_TRUE(expS.isSameShape(s));
ASSERT_TRUE(expU.isSameShape(u)); ASSERT_TRUE(expU.isSameShape(u));
ASSERT_TRUE(expV.isSameShape(v)); ASSERT_TRUE(expV.isSameShape(v));
ASSERT_TRUE(expS.equalsTo(s));
if(nd4j::Environment::getInstance()->isCPU()) {
ASSERT_TRUE(expU.equalsTo(u));
ASSERT_TRUE(expV.equalsTo(v));
}
else {
for(uint i = 0; i < expU.lengthOf(); ++i)
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.t<float>(i)), nd4j::math::nd4j_abs(u->t<float>(i)), 1e-5);
for(uint i = 0; i < expV.lengthOf(); ++i)
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.t<float>(i)), nd4j::math::nd4j_abs(v->t<float>(i)), 1e-5);
}
delete results; delete results;
} }
@ -2451,13 +2476,22 @@ TEST_F(DeclarableOpsTests3, svd_test7) {
// auto *u = results->at(1); // auto *u = results->at(1);
// auto *v = results->at(2); // auto *v = results->at(2);
// ASSERT_TRUE(expS.isSameShape(s)); // ASSERT_TRUE(expS.isSameShape(s));
// ASSERT_TRUE(expU.isSameShape(u)); // ASSERT_TRUE(expU.isSameShape(u));
// ASSERT_TRUE(expV.isSameShape(v)); // ASSERT_TRUE(expV.isSameShape(v));
// ASSERT_TRUE(expS.equalsTo(s)); // ASSERT_TRUE(expS.equalsTo(s));
// ASSERT_TRUE(expU.equalsTo(u));
// ASSERT_TRUE(expV.equalsTo(v)); // if(nd4j::Environment::getInstance()->isCPU()) {
// ASSERT_TRUE(expU.equalsTo(u));
// ASSERT_TRUE(expV.equalsTo(v));
// }
// else {
// for(uint i = 0; i < expU.lengthOf(); ++i)
// ASSERT_NEAR(nd4j::math::nd4j_abs(expU.t<float>(i)), nd4j::math::nd4j_abs(u->t<float>(i)), 1e-5);
// for(uint i = 0; i < expV.lengthOf(); ++i)
// ASSERT_NEAR(nd4j::math::nd4j_abs(expV.t<float>(i)), nd4j::math::nd4j_abs(v->t<float>(i)), 1e-5);
// }
// delete results; // delete results;
// } // }
@ -2555,14 +2589,23 @@ TEST_F(DeclarableOpsTests3, svd_test9) {
auto *u = results->at(1); auto *u = results->at(1);
auto *v = results->at(2); auto *v = results->at(2);
ASSERT_TRUE(expS.equalsTo(s));
ASSERT_TRUE(expU.equalsTo(u));
ASSERT_TRUE(expV.equalsTo(v));
ASSERT_TRUE(expS.isSameShape(s)); ASSERT_TRUE(expS.isSameShape(s));
ASSERT_TRUE(expU.isSameShape(u)); ASSERT_TRUE(expU.isSameShape(u));
ASSERT_TRUE(expV.isSameShape(v)); ASSERT_TRUE(expV.isSameShape(v));
ASSERT_TRUE(expS.equalsTo(s));
if(nd4j::Environment::getInstance()->isCPU()) {
ASSERT_TRUE(expU.equalsTo(u));
ASSERT_TRUE(expV.equalsTo(v));
}
else {
for(uint i = 0; i < expU.lengthOf(); ++i)
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.t<float>(i)), nd4j::math::nd4j_abs(u->t<float>(i)), 1e-5);
for(uint i = 0; i < expV.lengthOf(); ++i)
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.t<float>(i)), nd4j::math::nd4j_abs(v->t<float>(i)), 1e-5);
}
delete results; delete results;
} }
@ -2659,9 +2702,42 @@ TEST_F(DeclarableOpsTests3, svd_test10) {
auto *u = results->at(1); auto *u = results->at(1);
auto *v = results->at(2); auto *v = results->at(2);
ASSERT_TRUE(expS.isSameShape(s));
ASSERT_TRUE(expU.isSameShape(u));
ASSERT_TRUE(expV.isSameShape(v));
ASSERT_TRUE(expS.equalsTo(s)); ASSERT_TRUE(expS.equalsTo(s));
ASSERT_TRUE(expU.equalsTo(u));
ASSERT_TRUE(expV.equalsTo(v)); if(nd4j::Environment::getInstance()->isCPU()) {
ASSERT_TRUE(expU.equalsTo(u));
ASSERT_TRUE(expV.equalsTo(v));
}
else {
for(uint i = 0; i < expU.lengthOf(); ++i)
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.t<float>(i)), nd4j::math::nd4j_abs(u->t<float>(i)), 1e-5);
for(uint i = 0; i < expV.lengthOf(); ++i)
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.t<float>(i)), nd4j::math::nd4j_abs(v->t<float>(i)), 1e-5);
}
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests3, svd_test11) {
auto x = NDArrayFactory::create<double>('c', {3,3}, {1.,2.,3.,4.,5.,6.,7.,8.,9.});
auto expS = NDArrayFactory::create<double>('c', {3});
auto expU = NDArrayFactory::create<double>('c', {3,3});
auto expV = NDArrayFactory::create<double>('c', {3,3});
nd4j::ops::svd op;
auto results = op.execute({&x}, {}, {0, 1, 16});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto s = results->at(0);
auto u = results->at(1);
auto v = results->at(2);
ASSERT_TRUE(expS.isSameShape(s)); ASSERT_TRUE(expS.isSameShape(s));
ASSERT_TRUE(expU.isSameShape(u)); ASSERT_TRUE(expU.isSameShape(u));
@ -2679,4 +2755,3 @@ TEST_F(DeclarableOpsTests3, svd_test10) {

View File

@ -1933,7 +1933,7 @@ TEST_F(DeclarableOpsTests4, triu_bp_test1) {
auto gradO = NDArrayFactory::create<double>('c', {2, 3, 2}); auto gradO = NDArrayFactory::create<double>('c', {2, 3, 2});
gradO = 0.5; gradO = 0.5;
auto expected = NDArrayFactory::create<double>('c', {2, 3, 2}, {0.,0.5,0.,0. ,0.,0. ,0.,0.5,0.,0. ,0.,0.}); auto expected = NDArrayFactory::create<double>('c', {2, 3, 2}, {0.,0.5,0.,0. ,0.,0. ,0.,0.5,0.,0. ,0.,0.});
nd4j::ops::triu_bp op; nd4j::ops::triu_bp op;
auto results = op.execute({&input, &gradO}, {}, {1}); auto results = op.execute({&input, &gradO}, {}, {1});

View File

@ -177,9 +177,9 @@ TEST_F(DeclarableOpsTests5, Test_TTS_bp_1) {
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
z->printShapeInfo("RES shape"); // z->printShapeInfo("RES shape");
x.printShapeInfo("EXP shape"); // x.printShapeInfo("EXP shape");
z->printIndexedBuffer("RES output"); // z->printIndexedBuffer("RES output");
ASSERT_TRUE(x.isSameShape(z)); ASSERT_TRUE(x.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
@ -1335,11 +1335,11 @@ TEST_F(DeclarableOpsTests5, trace_test1) {
auto results = op.execute({&input}, {}, {}); auto results = op.execute({&input}, {}, {});
auto output = results->at(0); auto output = results->at(0);
double traceM = matrix.getTrace(); double traceM = matrix.getTrace();
nd4j_printf("Trace for matrix is %f\n", traceM); // nd4j_printf("Trace for matrix is %f\n", traceM);
ASSERT_EQ(Status::OK(), results->status()); ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.isSameShape(output));
exp.printIndexedBuffer("EXP TRACE"); // exp.printIndexedBuffer("EXP TRACE");
output->printIndexedBuffer("OUT TRACE"); // output->printIndexedBuffer("OUT TRACE");
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
delete results; delete results;

View File

@ -1209,6 +1209,51 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_7) {
delete res; delete res;
} }
/////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_8) {
auto x = NDArrayFactory::create<int>('c', {1}, {1});
auto y = NDArrayFactory::create<int>('c', {1}, {4});
// ------------------------------------
auto exp = NDArrayFactory::create<int>('c', {1}, {4});
nd4j::ops::broadcast_dynamic_shape op;
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT32);
ASSERT_EQ(ND4J_STATUS_OK, res->status());
// res->at(0)->printIndexedBuffer("Output SGO 8");
// exp.printIndexedBuffer("Expect");
ASSERT_TRUE(exp.equalsTo(res->at(0)));
delete res;
}
/////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_9) {
auto x = NDArrayFactory::create<int>('c', {2}, {2,2});
auto y = NDArrayFactory::create<int>('c', {1}, {1});
// ------------------------------------
auto exp = NDArrayFactory::create<int>('c', {2}, {2,2});
nd4j::ops::broadcast_dynamic_shape op;
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT32);
ASSERT_EQ(ND4J_STATUS_OK, res->status());
res->at(0)->printIndexedBuffer("Output SGO 9");
exp.printIndexedBuffer("Expect9");
ASSERT_TRUE(exp.equalsTo(res->at(0)));
delete res;
}
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, ClipByGlobalNorm_1) { TEST_F(DeclarableOpsTests6, ClipByGlobalNorm_1) {
@ -1496,10 +1541,33 @@ TEST_F(DeclarableOpsTests6, LogDet_1) {
delete result; delete result;
} }
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, LogDet_2) {
auto x = NDArrayFactory::create<double>('c', {1, 3, 3}, {4,12,-16,12,37,-43,-16,-43,98});
auto exp = NDArrayFactory::create<double>('c', {1}, { 3.5835189});
//x.printIndexedBuffer("Input");
nd4j::ops::logdet op;
auto result = op.execute({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
// z->printIndexedBuffer("Output ");
// z->printShapeInfo("Shape");
//exp.printIndexedBuffer("Expected ");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixInverse_1) { TEST_F(DeclarableOpsTests6, MatrixInverse_1) {
auto x = NDArrayFactory::create<double>('c', {2, 5, 5}, { auto x = NDArrayFactory::create<float>('c', {2, 5, 5}, {
2., 4., 60., 8., 10., 2., 4., 60., 8., 10.,
0., 1., 2., 3., 4., 0., 1., 2., 3., 4.,
0., 0., 2., 4., 6., 0., 0., 2., 4., 6.,
@ -1513,7 +1581,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_1) {
5., 4., 3., 2., 1., 5., 4., 3., 2., 1.,
}); });
auto exp = NDArrayFactory::create<double>('c', {2, 5, 5}, { auto exp = NDArrayFactory::create<float>('c', {2, 5, 5}, {
0.5, -2.0, -13.0, 54.0, -6.75, 0.5, -2.0, -13.0, 54.0, -6.75,
0.0, 1.0, -1.0, 1.0, 0.0, 0.0, 1.0, -1.0, 1.0, 0.0,
0, 0, 0.5, -2.0, 0.25, 0, 0, 0.5, -2.0, 0.25,
@ -1528,7 +1596,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_1) {
}); });
nd4j::ops::matrix_inverse op; nd4j::ops::matrix_inverse op;
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::DOUBLE); auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32);
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());

View File

@ -575,7 +575,7 @@ TEST_F(DeclarableOpsTests7, Test_Dynamic_Stitch_119_1) {
z->printShapeInfo("Stitch Shape"); z->printShapeInfo("Stitch Shape");
ASSERT_TRUE(z->isSameShape(exp)); ASSERT_TRUE(z->isSameShape(exp));
ASSERT_TRUE(z->equalsTo(exp)); ASSERT_TRUE(z->equalsTo(exp));
delete result; delete result;
} }
@ -6242,23 +6242,18 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Dot_BP_3) {
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, Test_CumSum_BP_1) { TEST_F(DeclarableOpsTests7, cumsum_bp_1) {
auto x = NDArrayFactory::create<double>('c', {3, 4}); auto x = NDArrayFactory::create<double>('c', {3, 4});
// auto y = NDArrayFactory::create<double>('c', {3, 4});
// auto z; // = NDArrayFactory::create<double>('c', {4});
auto eps = NDArrayFactory::create<double>('c', {3, 4}); auto eps = NDArrayFactory::create<double>('c', {3, 4});
auto exp = NDArrayFactory::create<double>('c', {3, 4}, {12.f, 11.f, 10.f, 9.f, 8.f, 7.f, auto exp = NDArrayFactory::create<double>('c', {3, 4}, {12.f, 11.f, 10.f, 9.f, 8.f, 7.f,
6.f, 5.f, 4.f, 3.f, 2.f, 1.f}); 6.f, 5.f, 4.f, 3.f, 2.f, 1.f});
x.linspace(1); x.linspace(1);
eps.assign(1.f); eps.assign(1.f);
// z = x.applyReduce3<simdOps::Dot<float>>(&y, {0}, nullptr);
nd4j::ops::cumsum_bp op; nd4j::ops::cumsum_bp op;
auto result = op.execute({&x, &eps}, {}, {0,0}); auto result = op.execute({&x, &eps}, {}, {0,0});
auto output = result->at(0); auto output = result->at(0);
// output->printIndexedBuffer("Result is");
// output->printShapeInfo("Result shape is");
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -6266,27 +6261,21 @@ TEST_F(DeclarableOpsTests7, Test_CumSum_BP_1) {
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
delete result; delete result;
// delete z;
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, Test_CumSum_BP_2) { TEST_F(DeclarableOpsTests7, cumsum_bp_2) {
auto x = NDArrayFactory::create<double>('c', {3, 4}); auto x = NDArrayFactory::create<double>('c', {3, 4});
// auto y = NDArrayFactory::create<double>('c', {3, 4});
// auto z; // = NDArrayFactory::create<double>('c', {4});
auto eps = NDArrayFactory::create<double>('c', {3, 4}); auto eps = NDArrayFactory::create<double>('c', {3, 4});
auto exp = NDArrayFactory::create<double>('c', {3, 4}, { 11.f, 10.f, 9.f, 8.f, 7.f, 6.f, auto exp = NDArrayFactory::create<double>('c', {3, 4}, { 11.f, 10.f, 9.f, 8.f, 7.f, 6.f,
5.f, 4.f, 3.f, 2.f, 1.f, 0.f}); 5.f, 4.f, 3.f, 2.f, 1.f, 0.f});
x.linspace(1); x.linspace(1);
// exp.linspace(1);
eps.assign(1.f); eps.assign(1.f);
// z = x.applyReduce3<simdOps::Dot<float>>(&y, {0}, nullptr);
nd4j::ops::cumsum_bp op; nd4j::ops::cumsum_bp op;
auto result = op.execute({&x, &eps}, {}, {1,0}); auto result = op.execute({&x, &eps}, {}, {1,0});
auto output = result->at(0); auto output = result->at(0);
// output->printIndexedBuffer("Result is");
// output->printShapeInfo("Result shape is");
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -6294,14 +6283,11 @@ TEST_F(DeclarableOpsTests7, Test_CumSum_BP_2) {
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
delete result; delete result;
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, Test_Reduce_CumSum_BP_3) { TEST_F(DeclarableOpsTests7, cumsum_bp_3) {
auto x = NDArrayFactory::create<double>('c', {3, 4}); auto x = NDArrayFactory::create<double>('c', {3, 4});
// auto y = NDArrayFactory::create<double>('c', {3, 4});
// auto z; // = NDArrayFactory::create<double>('c', {4});
auto eps = NDArrayFactory::create<double>('c', {3, 4}); auto eps = NDArrayFactory::create<double>('c', {3, 4});
auto exp = NDArrayFactory::create<double>('c', {3, 4}); auto exp = NDArrayFactory::create<double>('c', {3, 4});
@ -6309,16 +6295,11 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_CumSum_BP_3) {
exp.linspace(0); exp.linspace(0);
eps.assign(1.f); eps.assign(1.f);
// z = x.applyReduce3<simdOps::Dot<float>>(&y, {0}, nullptr);
nd4j::ops::cumsum_bp op; nd4j::ops::cumsum_bp op;
auto result = op.execute({&x, &eps}, {}, {1,1}); auto result = op.execute({&x, &eps}, {}, {1,1});
auto output = result->at(0); auto output = result->at(0);
// output->printIndexedBuffer("Result is");
// output->printShapeInfo("Result shape is");
ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_EQ(ND4J_STATUS_OK, result->status());
// ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
delete result; delete result;

View File

@ -236,45 +236,6 @@ TEST_F(DeclarableOpsTests9, ScalarOpTest_MixedOrders_1) {
ASSERT_EQ(e, z); ASSERT_EQ(e, z);
} }
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, tile_bp_test1) {
auto input = NDArrayFactory::create<double>('c', {2, 3}, {1.,2.,3.,4.,5.,6.});
auto gradO = NDArrayFactory::create<double>('c', {4, 9});
auto gradIExp = NDArrayFactory::create<double>('c', {2, 3}, {0.78, 0.84, 0.9,1.32, 1.38, 1.44});
gradO.linspace(0.01, 0.01);
nd4j::ops::tile_bp op;
auto results = op.execute({&input, &gradO}, {}, {2, 3});
auto gradI = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(gradIExp.isSameShape(gradI));
ASSERT_TRUE(gradIExp.equalsTo(gradI));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, tile_bp_test2) {
auto input = NDArrayFactory::create<double>('c', {2, 3}, {1.,2.,3.,4.,5.,6.});
auto gradO = NDArrayFactory::create<double>('c', {2, 9});
auto gradIExp = NDArrayFactory::create<double>('c', {2, 3}, {0.12, 0.15, 0.18, 0.39, 0.42, 0.45});
gradO.linspace(0.01, 0.01);
nd4j::ops::tile_bp op;
auto results = op.execute({&input, &gradO}, {}, {1, 3});
auto gradI = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(gradIExp.isSameShape(gradI));
ASSERT_TRUE(gradIExp.equalsTo(gradI));
delete results;
}
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, concat_test1) { TEST_F(DeclarableOpsTests9, concat_test1) {
@ -626,6 +587,45 @@ TEST_F(DeclarableOpsTests9, concat_test16) {
delete result; delete result;
} }
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, tile_bp_test1) {
auto input = NDArrayFactory::create<double>('c', {2, 3}, {1.,2.,3.,4.,5.,6.});
auto gradO = NDArrayFactory::create<double>('c', {4, 9});
auto gradIExp = NDArrayFactory::create<double>('c', {2, 3}, {0.78, 0.84, 0.9,1.32, 1.38, 1.44});
gradO.linspace(0.01, 0.01);
nd4j::ops::tile_bp op;
auto results = op.execute({&input, &gradO}, {}, {2, 3});
auto gradI = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(gradIExp.isSameShape(gradI));
ASSERT_TRUE(gradIExp.equalsTo(gradI));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, tile_bp_test2) {
auto input = NDArrayFactory::create<double>('c', {2, 3}, {1.,2.,3.,4.,5.,6.});
auto gradO = NDArrayFactory::create<double>('c', {2, 9});
auto gradIExp = NDArrayFactory::create<double>('c', {2, 3}, {0.12, 0.15, 0.18, 0.39, 0.42, 0.45});
gradO.linspace(0.01, 0.01);
nd4j::ops::tile_bp op;
auto results = op.execute({&input, &gradO}, {}, {1, 3});
auto gradI = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(gradIExp.isSameShape(gradI));
ASSERT_TRUE(gradIExp.equalsTo(gradI));
delete results;
}
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, tile_bp_test3) { TEST_F(DeclarableOpsTests9, tile_bp_test3) {
@ -2623,21 +2623,52 @@ TEST_F(DeclarableOpsTests9, Dynamic_Partition_BP_1) {
auto dLdzX = NDArrayFactory::create<double>('c', {2, 4}); auto dLdzX = NDArrayFactory::create<double>('c', {2, 4});
auto dLdzY = NDArrayFactory::create<double>('c', {2, 4}); auto dLdzY = NDArrayFactory::create<double>('c', {2, 4});
auto dLdzZ = NDArrayFactory::create<double>('c', {2, 4}); auto dLdzZ = NDArrayFactory::create<double>('c', {2, 4});
auto exp = NDArrayFactory::create<double>('c', {2,3,4}, {1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 3, 3, 3, 3});
x.linspace(1); x.linspace(1);
dLdzX.linspace(1); // dLdzX.linspace(1);
dLdzY.linspace(2); // dLdzY.linspace(2);
dLdzZ.linspace(3); // dLdzZ.linspace(3);
dLdzX.assign(1);
dLdzY.assign(2);
dLdzZ.assign(3);
nd4j::ops::dynamic_partition op1; nd4j::ops::dynamic_partition op1;
auto res1 = op1.execute({&x, &y}, {}, {3}); auto res1 = op1.execute({&x, &y}, {}, {3});
nd4j::ops::dynamic_partition_bp op2; nd4j::ops::dynamic_partition_bp op2;
auto res2 = op2.execute({&x, &y, res1->at(0), res1->at(1), res1->at(2)}, {}, {3}); auto res2 = op2.execute({&x, &y, &dLdzX, &dLdzY, &dLdzZ}, {}, {3});
ASSERT_TRUE(res2->status() == ND4J_STATUS_OK); ASSERT_TRUE(res2->status() == ND4J_STATUS_OK);
ASSERT_TRUE(res2->size() == 2); ASSERT_TRUE(res2->size() == 2);
// printf("How many: %ul\n", res2->size());
// res2->at(0)->printBuffer("Ouputput0");
// res2->at(1)->printBuffer("Ouputput1");
ASSERT_TRUE(res2->at(0)->equalsTo(exp));
delete res1; delete res1;
delete res2; delete res2;
} }
//////////////////////////////////////////////////////////////////////
//TEST_F(DeclarableOpsTests9, Dynamic_Partition_BP_2) {
//
// auto x = NDArrayFactory::create<double>('c', {2, 3, 4});
// auto y = NDArrayFactory::create<int>('c', {2, 3}, {0, 1, 2, 1, 0, 2});
// auto dLdzX = NDArrayFactory::create<double>('c', {2, 4});
// auto dLdzY = NDArrayFactory::create<double>('c', {2, 4});
// auto dLdzZ = NDArrayFactory::create<double>('c', {2, 4});
// x.linspace(1);
// dLdzX.linspace(1);
// dLdzY.linspace(1);
// dLdzZ.linspace(1);
//
// const OpArgsHolder argsHolderFF({&x, &y}, {}, {3});
// const OpArgsHolder argsHolderBP({&x, &y, &dLdzX, &dLdzY, &dLdzZ}, {}, {3});
//
// nd4j::ops::dynamic_partition opFF;
// nd4j::ops::dynamic_partition_bp opBP;
//
// const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
//
// ASSERT_TRUE(isGradCorrect);
//}
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, Floormod_BP_Test_4) { TEST_F(DeclarableOpsTests9, Floormod_BP_Test_4) {
@ -2914,7 +2945,7 @@ TEST_F(DeclarableOpsTests9, Cholesky_Test_1) {
auto result = op.execute({&x}, {}, {}); auto result = op.execute({&x}, {}, {});
ASSERT_EQ(result->status(), ND4J_STATUS_OK); ASSERT_EQ(result->status(), ND4J_STATUS_OK);
auto res = result->at(0); auto res = result->at(0);
//res->printIndexedBuffer("Output for Cholesky"); // res->printIndexedBuffer("Output for Cholesky1");
ASSERT_TRUE(exp.equalsTo(res)); ASSERT_TRUE(exp.equalsTo(res));
delete result; delete result;
} }
@ -2922,6 +2953,22 @@ TEST_F(DeclarableOpsTests9, Cholesky_Test_1) {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, Cholesky_Test_2) { TEST_F(DeclarableOpsTests9, Cholesky_Test_2) {
NDArray x = NDArrayFactory::create<double>('c', {2, 3, 3}, {4, 12,-16, 12 ,37,-43, -16, -43, 98, 1, 1, 1, 1, 2, 2, 1, 2., 6});
NDArray exp = NDArrayFactory::create<double>('c', {2, 3, 3}, {2., 0., 0., 6., 1., 0., -8., 5., 3., 1., 0., 0., 1., 1., 0,1., 1., 2.});
nd4j::ops::cholesky op;
auto result = op.execute({&x}, {}, {});
ASSERT_EQ(result->status(), ND4J_STATUS_OK);
auto res = result->at(0);
// res->printIndexedBuffer("Output for Cholesky 2");
ASSERT_TRUE(exp.equalsTo(res));
delete result;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, Cholesky_Test_3) {
NDArray x = NDArrayFactory::create<float>('c', {2, 3, 3}, {4, 12,-16, 12 ,37,-43, -16, -43, 98, 1, 1, 1, 1, 2, 2, 1, 2., 6}); NDArray x = NDArrayFactory::create<float>('c', {2, 3, 3}, {4, 12,-16, 12 ,37,-43, -16, -43, 98, 1, 1, 1, 1, 2, 2, 1, 2., 6});
NDArray exp = NDArrayFactory::create<float>('c', {2, 3, 3}, {2., 0., 0., 6., 1., 0., -8., 5., 3., 1., 0., 0., 1., 1., 0,1., 1., 2.}); NDArray exp = NDArrayFactory::create<float>('c', {2, 3, 3}, {2., 0., 0., 6., 1., 0., -8., 5., 3., 1., 0., 0., 1., 1., 0,1., 1., 2.});
@ -2930,7 +2977,7 @@ TEST_F(DeclarableOpsTests9, Cholesky_Test_2) {
auto result = op.execute({&x}, {}, {}); auto result = op.execute({&x}, {}, {});
ASSERT_EQ(result->status(), ND4J_STATUS_OK); ASSERT_EQ(result->status(), ND4J_STATUS_OK);
auto res = result->at(0); auto res = result->at(0);
//res->printIndexedBuffer("Output for Cholesky"); // res->printIndexedBuffer("Output for Cholesky 3");
ASSERT_TRUE(exp.equalsTo(res)); ASSERT_TRUE(exp.equalsTo(res));
delete result; delete result;
} }

View File

@ -1215,10 +1215,14 @@ public class Nd4j {
protected static Indexer getIndexerByType(Pointer pointer, DataType dataType) { protected static Indexer getIndexerByType(Pointer pointer, DataType dataType) {
switch (dataType) { switch (dataType) {
case UINT64:
case LONG: case LONG:
return LongIndexer.create((LongPointer) pointer); return LongIndexer.create((LongPointer) pointer);
case UINT32:
case INT: case INT:
return IntIndexer.create((IntPointer) pointer); return IntIndexer.create((IntPointer) pointer);
case UINT16:
return UShortIndexer.create((ShortPointer) pointer);
case SHORT: case SHORT:
return ShortIndexer.create((ShortPointer) pointer); return ShortIndexer.create((ShortPointer) pointer);
case BYTE: case BYTE:
@ -1229,6 +1233,8 @@ public class Nd4j {
return BooleanIndexer.create((BooleanPointer) pointer); return BooleanIndexer.create((BooleanPointer) pointer);
case FLOAT: case FLOAT:
return FloatIndexer.create((FloatPointer) pointer); return FloatIndexer.create((FloatPointer) pointer);
case BFLOAT16:
return Bfloat16Indexer.create((ShortPointer) pointer);
case HALF: case HALF:
return HalfIndexer.create((ShortPointer) pointer); return HalfIndexer.create((ShortPointer) pointer);
case DOUBLE: case DOUBLE:
@ -1297,12 +1303,15 @@ public class Nd4j {
public static DataBuffer createBuffer(@NonNull Pointer pointer, @NonNull Pointer devicePointer, long length, @NonNull DataType dataType) { public static DataBuffer createBuffer(@NonNull Pointer pointer, @NonNull Pointer devicePointer, long length, @NonNull DataType dataType) {
Pointer nPointer = null; Pointer nPointer = null;
switch (dataType) { switch (dataType) {
case UINT64:
case LONG: case LONG:
nPointer = new LongPointer(pointer); nPointer = new LongPointer(pointer);
break; break;
case UINT32:
case INT: case INT:
nPointer = new IntPointer(pointer); nPointer = new IntPointer(pointer);
break; break;
case UINT16:
case SHORT: case SHORT:
nPointer = new ShortPointer(pointer); nPointer = new ShortPointer(pointer);
break; break;
@ -1315,12 +1324,13 @@ public class Nd4j {
case BOOL: case BOOL:
nPointer = new BooleanPointer(pointer); nPointer = new BooleanPointer(pointer);
break; break;
case FLOAT: case BFLOAT16:
nPointer = new FloatPointer(pointer);
break;
case HALF: case HALF:
nPointer = new ShortPointer(pointer); nPointer = new ShortPointer(pointer);
break; break;
case FLOAT:
nPointer = new FloatPointer(pointer);
break;
case DOUBLE: case DOUBLE:
nPointer = new DoublePointer(pointer); nPointer = new DoublePointer(pointer);
break; break;

View File

@ -93,18 +93,7 @@ public class BasicContextPool implements ContextPool {
try { try {
// this is lockable thing, but since it locks once per thread initialization, performance impact won't be big // this is lockable thing, but since it locks once per thread initialization, performance impact won't be big
lock.acquire(); lock.acquire();
// we create 1 CUcontext per device, which will be shared for all threads/streams on this device
/*
if (!cuPool.containsKey(deviceId)) {
CUcontext cuContext = createNewContext(deviceId);
cuPool.put(deviceId, cuContext);
}
int result = JCudaDriver.cuCtxSetCurrent(cuPool.get(deviceId));
if (result != CUresult.CUDA_SUCCESS) {
throw new RuntimeException("Failed to set context on assigner");
}
*/
if (!contextsForDevices.containsKey(deviceId)) { if (!contextsForDevices.containsKey(deviceId)) {
contextsForDevices.put(deviceId, new ConcurrentHashMap<Integer, CudaContext>()); contextsForDevices.put(deviceId, new ConcurrentHashMap<Integer, CudaContext>());
} }
@ -120,11 +109,11 @@ public class BasicContextPool implements ContextPool {
// if we have no contexts created - it's just awesome time to attach cuBLAS handle here // if we have no contexts created - it's just awesome time to attach cuBLAS handle here
log.debug("Creating new cuBLAS handle for device [{}]...", deviceId); log.debug("Creating new cuBLAS handle for device [{}]...", deviceId);
cudaStream_t cublasStream = createNewStream(deviceId).getOldStream(); //cudaStream_t cublasStream = createNewStream(deviceId).getOldStream();
cublasHandle_t handle = createNewCublasHandle(cublasStream); cublasHandle_t handle = createNewCublasHandle(context.getOldStream());
context.setHandle(handle); context.setHandle(handle);
context.setCublasStream(cublasStream); //context.setCublasStream(cublasStream);
cublasPool.put(deviceId, handle); cublasPool.put(deviceId, handle);

View File

@ -62,11 +62,11 @@ public class PackedContextPool extends BasicContextPool implements ContextPool {
// if we have no contexts created - it's just awesome time to attach cuBLAS handle here // if we have no contexts created - it's just awesome time to attach cuBLAS handle here
log.debug("Creating new cuBLAS handle for device [{}]", deviceId); log.debug("Creating new cuBLAS handle for device [{}]", deviceId);
cudaStream_t cublasStream = createNewStream(deviceId).getOldStream(); //cudaStream_t cublasStream = createNewStream(deviceId).getOldStream();
cublasHandle_t handle = createNewCublasHandle(cublasStream); cublasHandle_t handle = createNewCublasHandle(context.getOldStream());
context.setHandle(handle); context.setHandle(handle);
context.setCublasStream(cublasStream); //context.setCublasStream(cublasStream);
cublasPool.put(deviceId, handle); cublasPool.put(deviceId, handle);

View File

@ -31,6 +31,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.compression.CompressedDataBuffer; import org.nd4j.linalg.compression.CompressedDataBuffer;
import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer;
import org.nd4j.linalg.jcublas.context.CudaContext; import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.linalg.jcublas.ops.executioner.CudaGridExecutioner; import org.nd4j.linalg.jcublas.ops.executioner.CudaGridExecutioner;
import org.nd4j.linalg.memory.BasicMemoryManager; import org.nd4j.linalg.memory.BasicMemoryManager;
@ -151,6 +152,14 @@ public class CudaMemoryManager extends BasicMemoryManager {
} }
protected void allocateHostPointers(DataBuffer... dataBuffers) {
for (val v:dataBuffers) {
if (v != null && v instanceof BaseCudaDataBuffer) {
((BaseCudaDataBuffer) v).lazyAllocateHostPointer();
}
}
}
/** /**
* This method provides basic memcpy functionality with respect to target environment * This method provides basic memcpy functionality with respect to target environment
* *
@ -161,9 +170,13 @@ public class CudaMemoryManager extends BasicMemoryManager {
public void memcpy(DataBuffer dstBuffer, DataBuffer srcBuffer) { public void memcpy(DataBuffer dstBuffer, DataBuffer srcBuffer) {
CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
if (dstBuffer instanceof CompressedDataBuffer && !(srcBuffer instanceof CompressedDataBuffer)) { if (dstBuffer instanceof CompressedDataBuffer && !(srcBuffer instanceof CompressedDataBuffer)) {
// destination is compressed, source isn't // destination is compressed, source isn't
AllocationPoint srcPoint = AtomicAllocator.getInstance().getAllocationPoint(srcBuffer); AllocationPoint srcPoint = AtomicAllocator.getInstance().getAllocationPoint(srcBuffer);
allocateHostPointers(dstBuffer, srcBuffer);
long size = srcBuffer.getElementSize() * srcBuffer.length(); long size = srcBuffer.getElementSize() * srcBuffer.length();
if (!srcPoint.isActualOnHostSide()) { if (!srcPoint.isActualOnHostSide()) {
// copying device -> host // copying device -> host
@ -177,12 +190,14 @@ public class CudaMemoryManager extends BasicMemoryManager {
} // else { } // else {
// copying host -> host // copying host -> host
Pointer src = AtomicAllocator.getInstance().getHostPointer(srcBuffer); val src = AtomicAllocator.getInstance().getHostPointer(srcBuffer);
Pointer.memcpy(dstBuffer.addressPointer(), src, size); Pointer.memcpy(dstBuffer.addressPointer(), src, size);
// } // }
} else if (!(dstBuffer instanceof CompressedDataBuffer) && srcBuffer instanceof CompressedDataBuffer) { } else if (!(dstBuffer instanceof CompressedDataBuffer) && srcBuffer instanceof CompressedDataBuffer) {
allocateHostPointers(dstBuffer, srcBuffer);
// destination is NOT compressed, source is compressed // destination is NOT compressed, source is compressed
AllocationPoint dstPoint = AtomicAllocator.getInstance().getAllocationPoint(dstBuffer); AllocationPoint dstPoint = AtomicAllocator.getInstance().getAllocationPoint(dstBuffer);
long size = srcBuffer.getElementSize() * srcBuffer.length(); long size = srcBuffer.getElementSize() * srcBuffer.length();
@ -193,6 +208,7 @@ public class CudaMemoryManager extends BasicMemoryManager {
} else if (dstBuffer instanceof CompressedDataBuffer && srcBuffer instanceof CompressedDataBuffer) { } else if (dstBuffer instanceof CompressedDataBuffer && srcBuffer instanceof CompressedDataBuffer) {
// both buffers are compressed, just fire memcpy // both buffers are compressed, just fire memcpy
allocateHostPointers(dstBuffer, srcBuffer);
Pointer.memcpy(dstBuffer.addressPointer(), srcBuffer.addressPointer(), Pointer.memcpy(dstBuffer.addressPointer(), srcBuffer.addressPointer(),
srcBuffer.length() * srcBuffer.getElementSize()); srcBuffer.length() * srcBuffer.getElementSize());

View File

@ -124,65 +124,6 @@ public class CublasPointer implements AutoCloseable {
this.cudaContext = context; this.cudaContext = context;
this.devicePointer = AtomicAllocator.getInstance().getPointer(array, context); this.devicePointer = AtomicAllocator.getInstance().getPointer(array, context);
/*
if(array instanceof IComplexNDArray) {
if(array.length() * 2 < array.data().length() && !array.isVector()) {
array = Shape.toOffsetZero(array);
}
}
buffer = (JCudaBuffer) array.data();
//the opName of this thread for knowing whether to copy data or not
//String opName = Thread.currentThread().getName();
this.arr = array;
if(array.elementWiseStride() < 0) {
this.arr = array.dup();
buffer = (JCudaBuffer) this.arr.data();
if(this.arr.elementWiseStride() < 0)
throw new IllegalStateException("Unable to iterate over buffer");
}
*/
//int compLength = arr instanceof IComplexNDArray ? arr.length() * 2 : arr.length();
////int stride = arr instanceof IComplexNDArray ? BlasBufferUtil.getBlasStride(arr) / 2 : BlasBufferUtil.getBlasStride(arr);
//no striding for upload if we are using the whole buffer
// System.out.println("Allocation offset: ["+array.offset()+"], length: ["+compLength+"], stride: ["+ stride+"]");
/*
buffer.getPointer(
this.arr,
stride
,this.arr.offset()
,compLength);
*/
/**
* Neat edge case here.
*
* The striding will overshoot the original array
* when the offset is zero (the case being when offset is zero
* sayon a getRow(0) operation.
*
* We need to allocate the data differently here
* due to how the striding works out.
*/
// Copy the data to the device iff the whole buffer hasn't been copied
/*
//Data is already copied into CUDA buffer during allocation at getPointer
if(!buffer.copied(opName)) {
ContextHolder.getInstance().getMemoryStrategy().setData(buffer,0,1,buffer.length());
//mark the buffer copied
buffer.setCopied(opName);
}*/
/*
DevicePointerInfo info = buffer.getPointersToContexts().get(Thread.currentThread().getName(), Triple.of(0, buffer.length(), 1));
hostPointer = info.getPointers().getHostPointer();
*/
} }

View File

@ -360,12 +360,16 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
for (INDArray m : matrices) { for (INDArray m : matrices) {
if (m.isEmpty())
continue;
CudaContext context = allocator.getFlowController().prepareAction(ret, m); CudaContext context = allocator.getFlowController().prepareAction(ret, m);
if (m.ordering() == order && ret.elementWiseStride() == m.elementWiseStride() if (m.ordering() == order && ret.elementWiseStride() == m.elementWiseStride()
&& ret.elementWiseStride() == 1) { && ret.elementWiseStride() == 1) {
// do memcpy in proper direction and forget about that // do memcpy in proper direction and forget about that
// FIXME: get rid of this
((BaseCudaDataBuffer) m.data()).lazyAllocateHostPointer();
allocator.memcpyAsync(ret.data(), new CudaPointer(allocator.getHostPointer(m).address()), allocator.memcpyAsync(ret.data(), new CudaPointer(allocator.getHostPointer(m).address()),
AllocationUtils.getRequiredMemory(AllocationUtils.buildAllocationShape(m)), AllocationUtils.getRequiredMemory(AllocationUtils.buildAllocationShape(m)),
linearIndex * (m.data().dataType() == DataType.DOUBLE ? 8 linearIndex * (m.data().dataType() == DataType.DOUBLE ? 8
@ -560,6 +564,8 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
for (int i = 0; i < toConcat.length; i++) { for (int i = 0; i < toConcat.length; i++) {
((BaseCudaDataBuffer) toConcat[i].data()).lazyAllocateHostPointer();
if (toConcat[i].isCompressed()) if (toConcat[i].isCompressed())
Nd4j.getCompressor().decompressi(toConcat[i]); Nd4j.getCompressor().decompressi(toConcat[i]);
@ -577,15 +583,15 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
outputShape[dimension] = sumAlongDim; outputShape[dimension] = sumAlongDim;
val dummy = new PointerPointer(new Pointer[] {null});
val ret = Nd4j.createUninitialized(toConcat[0].dataType(), outputShape, Nd4j.order()); val ret = Nd4j.createUninitialized(toConcat[0].dataType(), outputShape, Nd4j.order());
((BaseCudaDataBuffer) ret.data()).lazyAllocateHostPointer();
nativeOps.specialConcat(dummy, dimension, toConcat.length, dataPointers, shapeInfoPointers, nativeOps.specialConcat(null, dimension, toConcat.length, dataPointers, shapeInfoPointers,
ret.data().addressPointer(), ret.data().addressPointer(),
(LongPointer) ret.shapeInfoDataBuffer().addressPointer(), (LongPointer) ret.shapeInfoDataBuffer().addressPointer(),
new PointerPointer(new Pointer[] {null}), new PointerPointer(new Pointer[] {null})); null, null);
AllocationPoint point = allocator.getAllocationPoint(ret); AllocationPoint point = allocator.getAllocationPoint(ret);
@ -780,8 +786,6 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
allocator.getFlowController().registerAction(context, target, arrays); allocator.getFlowController().registerAction(context, target, arrays);
tempX.address();
return target; return target;
} else { } else {
long len = target.lengthLong(); long len = target.lengthLong();
@ -803,9 +807,13 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
if (arrays[i].lengthLong() != len) if (arrays[i].lengthLong() != len)
throw new ND4JIllegalStateException("All arrays should have equal length for averaging"); throw new ND4JIllegalStateException("All arrays should have equal length for averaging");
((BaseCudaDataBuffer) arrays[i].data()).lazyAllocateHostPointer();
dataPointers.put(i, AtomicAllocator.getInstance().getHostPointer(arrays[i])); dataPointers.put(i, AtomicAllocator.getInstance().getHostPointer(arrays[i]));
} }
if (target != null)
((BaseCudaDataBuffer) target.data()).lazyAllocateHostPointer();
nativeOps.accumulate(extras, nativeOps.accumulate(extras,
dataPointers, dataPointers,
@ -823,7 +831,6 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
AtomicAllocator.getInstance().getAllocationPoint(target).tickHostWrite(); AtomicAllocator.getInstance().getAllocationPoint(target).tickHostWrite();
return target; return target;
} }
@ -893,8 +900,6 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
allocator.getFlowController().registerAction(context, target, arrays); allocator.getFlowController().registerAction(context, target, arrays);
tempX.address();
return target; return target;
} else { } else {
// otherwise we do averging on CPU side // otherwise we do averging on CPU side
@ -918,9 +923,14 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
if (arrays[i].lengthLong() != len) if (arrays[i].lengthLong() != len)
throw new ND4JIllegalStateException("All arrays should have equal length for averaging"); throw new ND4JIllegalStateException("All arrays should have equal length for averaging");
((BaseCudaDataBuffer) arrays[i].data()).lazyAllocateHostPointer();
dataPointers.put(i, AtomicAllocator.getInstance().getHostPointer(arrays[i])); dataPointers.put(i, AtomicAllocator.getInstance().getHostPointer(arrays[i]));
} }
if (target != null)
((BaseCudaDataBuffer) target.data()).lazyAllocateHostPointer();
nativeOps.average(extras, nativeOps.average(extras,
dataPointers, dataPointers,
(LongPointer) arrays[0].shapeInfoDataBuffer().addressPointer(), (LongPointer) arrays[0].shapeInfoDataBuffer().addressPointer(),
@ -1114,8 +1124,8 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
// just to keep reference // just to keep reference
shuffle.address(); //shuffle.address();
hostPointers.address(); //hostPointers.address();
tempX.dataType(); tempX.dataType();
tempShapes.dataType(); tempShapes.dataType();

View File

@ -17,6 +17,7 @@
package org.nd4j.linalg.jcublas.blas; package org.nd4j.linalg.jcublas.blas;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.bytedeco.javacpp.DoublePointer; import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer; import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.IntPointer; import org.bytedeco.javacpp.IntPointer;
@ -35,6 +36,7 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.jcublas.CublasPointer; import org.nd4j.linalg.jcublas.CublasPointer;
import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer;
import org.nd4j.linalg.jcublas.context.CudaContext; import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.nativeblas.NativeOps; import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder; import org.nd4j.nativeblas.NativeOpsHolder;
@ -81,7 +83,7 @@ public class JcublasLapack extends BaseLapack {
// synchronized on the solver // synchronized on the solver
synchronized (handle) { synchronized (handle) {
int result = cusolverDnSetStream(new cusolverDnContext(handle), new CUstream_st(ctx.getOldStream())); int result = cusolverDnSetStream(new cusolverDnContext(handle), new CUstream_st(ctx.getCublasStream()));
if (result != 0) if (result != 0)
throw new BlasException("solverSetStream failed"); throw new BlasException("solverSetStream failed");
@ -89,7 +91,8 @@ public class JcublasLapack extends BaseLapack {
CublasPointer xAPointer = new CublasPointer(a, ctx); CublasPointer xAPointer = new CublasPointer(a, ctx);
// this output - indicates how much memory we'll need for the real operation // this output - indicates how much memory we'll need for the real operation
DataBuffer worksizeBuffer = Nd4j.getDataBufferFactory().createInt(1); val worksizeBuffer = (BaseCudaDataBuffer) Nd4j.getDataBufferFactory().createInt(1);
worksizeBuffer.lazyAllocateHostPointer();
int stat = cusolverDnSgetrf_bufferSize(solverDn, M, N, (FloatPointer) xAPointer.getDevicePointer(), M, int stat = cusolverDnSgetrf_bufferSize(solverDn, M, N, (FloatPointer) xAPointer.getDevicePointer(), M,
(IntPointer) worksizeBuffer.addressPointer() // we intentionally use host pointer here (IntPointer) worksizeBuffer.addressPointer() // we intentionally use host pointer here
@ -147,15 +150,16 @@ public class JcublasLapack extends BaseLapack {
// synchronized on the solver // synchronized on the solver
synchronized (handle) { synchronized (handle) {
int result = cusolverDnSetStream(new cusolverDnContext(handle), new CUstream_st(ctx.getOldStream())); int result = cusolverDnSetStream(new cusolverDnContext(handle), new CUstream_st(ctx.getCublasStream()));
if (result != 0) if (result != 0)
throw new BlasException("solverSetStream failed"); throw new BlasException("solverSetStream failed");
// transfer the INDArray into GPU memory // transfer the INDArray into GPU memory
CublasPointer xAPointer = new CublasPointer(a, ctx); val xAPointer = new CublasPointer(a, ctx);
// this output - indicates how much memory we'll need for the real operation // this output - indicates how much memory we'll need for the real operation
DataBuffer worksizeBuffer = Nd4j.getDataBufferFactory().createInt(1); val worksizeBuffer = (BaseCudaDataBuffer) Nd4j.getDataBufferFactory().createInt(1);
worksizeBuffer.lazyAllocateHostPointer();
int stat = cusolverDnDgetrf_bufferSize(solverDn, M, N, (DoublePointer) xAPointer.getDevicePointer(), M, int stat = cusolverDnDgetrf_bufferSize(solverDn, M, N, (DoublePointer) xAPointer.getDevicePointer(), M,
(IntPointer) worksizeBuffer.addressPointer() // we intentionally use host pointer here (IntPointer) worksizeBuffer.addressPointer() // we intentionally use host pointer here
@ -167,7 +171,7 @@ public class JcublasLapack extends BaseLapack {
int worksize = worksizeBuffer.getInt(0); int worksize = worksizeBuffer.getInt(0);
// Now allocate memory for the workspace, the permutation matrix and a return code // Now allocate memory for the workspace, the permutation matrix and a return code
Pointer workspace = new Workspace(worksize * Nd4j.sizeOfDataType()); val workspace = new Workspace(worksize * Nd4j.sizeOfDataType());
// Do the actual LU decomp // Do the actual LU decomp
stat = cusolverDnDgetrf(solverDn, M, N, (DoublePointer) xAPointer.getDevicePointer(), M, stat = cusolverDnDgetrf(solverDn, M, N, (DoublePointer) xAPointer.getDevicePointer(), M,
@ -218,7 +222,7 @@ public class JcublasLapack extends BaseLapack {
// synchronized on the solver // synchronized on the solver
synchronized (handle) { synchronized (handle) {
int result = cusolverDnSetStream(new cusolverDnContext(handle), new CUstream_st(ctx.getOldStream())); int result = cusolverDnSetStream(new cusolverDnContext(handle), new CUstream_st(ctx.getCublasStream()));
if (result != 0) if (result != 0)
throw new IllegalStateException("solverSetStream failed"); throw new IllegalStateException("solverSetStream failed");
@ -227,7 +231,8 @@ public class JcublasLapack extends BaseLapack {
CublasPointer xTauPointer = new CublasPointer(tau, ctx); CublasPointer xTauPointer = new CublasPointer(tau, ctx);
// this output - indicates how much memory we'll need for the real operation // this output - indicates how much memory we'll need for the real operation
DataBuffer worksizeBuffer = Nd4j.getDataBufferFactory().createInt(1); val worksizeBuffer = (BaseCudaDataBuffer) Nd4j.getDataBufferFactory().createInt(1);
worksizeBuffer.lazyAllocateHostPointer();
int stat = cusolverDnSgeqrf_bufferSize(solverDn, M, N, int stat = cusolverDnSgeqrf_bufferSize(solverDn, M, N,
(FloatPointer) xAPointer.getDevicePointer(), M, (FloatPointer) xAPointer.getDevicePointer(), M,
@ -333,7 +338,7 @@ public class JcublasLapack extends BaseLapack {
// synchronized on the solver // synchronized on the solver
synchronized (handle) { synchronized (handle) {
int result = cusolverDnSetStream(new cusolverDnContext(handle), new CUstream_st(ctx.getOldStream())); int result = cusolverDnSetStream(new cusolverDnContext(handle), new CUstream_st(ctx.getCublasStream()));
if (result != 0) if (result != 0)
throw new BlasException("solverSetStream failed"); throw new BlasException("solverSetStream failed");
@ -342,7 +347,8 @@ public class JcublasLapack extends BaseLapack {
CublasPointer xTauPointer = new CublasPointer(tau, ctx); CublasPointer xTauPointer = new CublasPointer(tau, ctx);
// this output - indicates how much memory we'll need for the real operation // this output - indicates how much memory we'll need for the real operation
DataBuffer worksizeBuffer = Nd4j.getDataBufferFactory().createInt(1); val worksizeBuffer = (BaseCudaDataBuffer) Nd4j.getDataBufferFactory().createInt(1);
worksizeBuffer.lazyAllocateHostPointer();
int stat = cusolverDnDgeqrf_bufferSize(solverDn, M, N, int stat = cusolverDnDgeqrf_bufferSize(solverDn, M, N,
(DoublePointer) xAPointer.getDevicePointer(), M, (DoublePointer) xAPointer.getDevicePointer(), M,
@ -441,7 +447,7 @@ public class JcublasLapack extends BaseLapack {
// synchronized on the solver // synchronized on the solver
synchronized (handle) { synchronized (handle) {
int result = cusolverDnSetStream(new cusolverDnContext(handle), new CUstream_st(ctx.getOldStream())); int result = cusolverDnSetStream(new cusolverDnContext(handle), new CUstream_st(ctx.getCublasStream()));
if (result != 0) if (result != 0)
throw new BlasException("solverSetStream failed"); throw new BlasException("solverSetStream failed");
@ -449,7 +455,8 @@ public class JcublasLapack extends BaseLapack {
CublasPointer xAPointer = new CublasPointer(a, ctx); CublasPointer xAPointer = new CublasPointer(a, ctx);
// this output - indicates how much memory we'll need for the real operation // this output - indicates how much memory we'll need for the real operation
DataBuffer worksizeBuffer = Nd4j.getDataBufferFactory().createInt(1); val worksizeBuffer = (BaseCudaDataBuffer) Nd4j.getDataBufferFactory().createInt(1);
worksizeBuffer.lazyAllocateHostPointer();
int stat = cusolverDnSpotrf_bufferSize(solverDn, uplo, N, int stat = cusolverDnSpotrf_bufferSize(solverDn, uplo, N,
(FloatPointer) xAPointer.getDevicePointer(), N, (FloatPointer) xAPointer.getDevicePointer(), N,
@ -524,7 +531,7 @@ public class JcublasLapack extends BaseLapack {
// synchronized on the solver // synchronized on the solver
synchronized (handle) { synchronized (handle) {
int result = cusolverDnSetStream(solverDn, new CUstream_st(ctx.getOldStream())); int result = cusolverDnSetStream(solverDn, new CUstream_st(ctx.getCublasStream()));
if (result != 0) if (result != 0)
throw new BlasException("solverSetStream failed"); throw new BlasException("solverSetStream failed");
@ -532,7 +539,8 @@ public class JcublasLapack extends BaseLapack {
CublasPointer xAPointer = new CublasPointer(a, ctx); CublasPointer xAPointer = new CublasPointer(a, ctx);
// this output - indicates how much memory we'll need for the real operation // this output - indicates how much memory we'll need for the real operation
DataBuffer worksizeBuffer = Nd4j.getDataBufferFactory().createInt(1); val worksizeBuffer = (BaseCudaDataBuffer) Nd4j.getDataBufferFactory().createInt(1);
worksizeBuffer.lazyAllocateHostPointer();
int stat = cusolverDnDpotrf_bufferSize(solverDn, uplo, N, int stat = cusolverDnDpotrf_bufferSize(solverDn, uplo, N,
(DoublePointer) xAPointer.getDevicePointer(), N, (DoublePointer) xAPointer.getDevicePointer(), N,
@ -656,7 +664,7 @@ public class JcublasLapack extends BaseLapack {
// synchronized on the solver // synchronized on the solver
synchronized (handle) { synchronized (handle) {
int result = cusolverDnSetStream(new cusolverDnContext(handle), new CUstream_st(ctx.getOldStream())); int result = cusolverDnSetStream(new cusolverDnContext(handle), new CUstream_st(ctx.getCublasStream()));
if (result != 0) if (result != 0)
throw new BlasException("solverSetStream failed"); throw new BlasException("solverSetStream failed");
@ -664,7 +672,8 @@ public class JcublasLapack extends BaseLapack {
CublasPointer xAPointer = new CublasPointer(a, ctx); CublasPointer xAPointer = new CublasPointer(a, ctx);
// this output - indicates how much memory we'll need for the real operation // this output - indicates how much memory we'll need for the real operation
DataBuffer worksizeBuffer = Nd4j.getDataBufferFactory().createInt(1); val worksizeBuffer = (BaseCudaDataBuffer) Nd4j.getDataBufferFactory().createInt(1);
worksizeBuffer.lazyAllocateHostPointer();
int stat = cusolverDnSgesvd_bufferSize(solverDn, M, N, (IntPointer) worksizeBuffer.addressPointer() // we intentionally use host pointer here int stat = cusolverDnSgesvd_bufferSize(solverDn, M, N, (IntPointer) worksizeBuffer.addressPointer() // we intentionally use host pointer here
); );
@ -765,7 +774,7 @@ public class JcublasLapack extends BaseLapack {
// synchronized on the solver // synchronized on the solver
synchronized (handle) { synchronized (handle) {
int result = cusolverDnSetStream(new cusolverDnContext(handle), new CUstream_st(ctx.getOldStream())); int result = cusolverDnSetStream(new cusolverDnContext(handle), new CUstream_st(ctx.getCublasStream()));
if (result != 0) if (result != 0)
throw new BlasException("solverSetStream failed"); throw new BlasException("solverSetStream failed");
@ -773,7 +782,8 @@ public class JcublasLapack extends BaseLapack {
CublasPointer xAPointer = new CublasPointer(a, ctx); CublasPointer xAPointer = new CublasPointer(a, ctx);
// this output - indicates how much memory we'll need for the real operation // this output - indicates how much memory we'll need for the real operation
DataBuffer worksizeBuffer = Nd4j.getDataBufferFactory().createInt(1); val worksizeBuffer = (BaseCudaDataBuffer) Nd4j.getDataBufferFactory().createInt(1);
worksizeBuffer.lazyAllocateHostPointer();
int stat = cusolverDnSgesvd_bufferSize(solverDn, M, N, (IntPointer) worksizeBuffer.addressPointer() // we intentionally use host pointer here int stat = cusolverDnSgesvd_bufferSize(solverDn, M, N, (IntPointer) worksizeBuffer.addressPointer() // we intentionally use host pointer here
); );
@ -851,14 +861,16 @@ public class JcublasLapack extends BaseLapack {
// synchronized on the solver // synchronized on the solver
synchronized (handle) { synchronized (handle) {
status = cusolverDnSetStream(new cusolverDnContext(handle), new CUstream_st(ctx.getOldStream())); status = cusolverDnSetStream(new cusolverDnContext(handle), new CUstream_st(ctx.getCublasStream()));
if (status == 0) { if (status == 0) {
// transfer the INDArray into GPU memory // transfer the INDArray into GPU memory
CublasPointer xAPointer = new CublasPointer(a, ctx); CublasPointer xAPointer = new CublasPointer(a, ctx);
CublasPointer xRPointer = new CublasPointer(R, ctx); CublasPointer xRPointer = new CublasPointer(R, ctx);
// this output - indicates how much memory we'll need for the real operation // this output - indicates how much memory we'll need for the real operation
DataBuffer worksizeBuffer = Nd4j.getDataBufferFactory().createInt(1); val worksizeBuffer = (BaseCudaDataBuffer) Nd4j.getDataBufferFactory().createInt(1);
worksizeBuffer.lazyAllocateHostPointer();
status = cusolverDnSsyevd_bufferSize( status = cusolverDnSsyevd_bufferSize(
solverDn, jobz, uplo, M, solverDn, jobz, uplo, M,
(FloatPointer) xAPointer.getDevicePointer(), M, (FloatPointer) xAPointer.getDevicePointer(), M,
@ -869,7 +881,7 @@ public class JcublasLapack extends BaseLapack {
int worksize = worksizeBuffer.getInt(0); int worksize = worksizeBuffer.getInt(0);
// allocate memory for the workspace, the non-converging row buffer and a return code // allocate memory for the workspace, the non-converging row buffer and a return code
Pointer workspace = new Workspace(worksize * 4); //4 = float width val workspace = new Workspace(worksize * 4); //4 = float width
INDArray INFO = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt(1), INDArray INFO = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt(1),
Nd4j.getShapeInfoProvider().createShapeInformation(new long[]{1, 1}, A.dataType())); Nd4j.getShapeInfoProvider().createShapeInformation(new long[]{1, 1}, A.dataType()));
@ -924,14 +936,16 @@ public class JcublasLapack extends BaseLapack {
// synchronized on the solver // synchronized on the solver
synchronized (handle) { synchronized (handle) {
status = cusolverDnSetStream(new cusolverDnContext(handle), new CUstream_st(ctx.getOldStream())); status = cusolverDnSetStream(new cusolverDnContext(handle), new CUstream_st(ctx.getCublasStream()));
if (status == 0) { if (status == 0) {
// transfer the INDArray into GPU memory // transfer the INDArray into GPU memory
CublasPointer xAPointer = new CublasPointer(a, ctx); CublasPointer xAPointer = new CublasPointer(a, ctx);
CublasPointer xRPointer = new CublasPointer(R, ctx); CublasPointer xRPointer = new CublasPointer(R, ctx);
// this output - indicates how much memory we'll need for the real operation // this output - indicates how much memory we'll need for the real operation
DataBuffer worksizeBuffer = Nd4j.getDataBufferFactory().createInt(1); val worksizeBuffer = (BaseCudaDataBuffer) Nd4j.getDataBufferFactory().createInt(1);
worksizeBuffer.lazyAllocateHostPointer();
status = cusolverDnDsyevd_bufferSize( status = cusolverDnDsyevd_bufferSize(
solverDn, jobz, uplo, M, solverDn, jobz, uplo, M,
(DoublePointer) xAPointer.getDevicePointer(), M, (DoublePointer) xAPointer.getDevicePointer(), M,

View File

@ -17,12 +17,11 @@
package org.nd4j.linalg.jcublas.blas; package org.nd4j.linalg.jcublas.blas;
import lombok.val; import lombok.val;
import org.bytedeco.javacpp.DoublePointer; import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.IntPointer;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.jita.allocator.Allocator; import org.nd4j.jita.allocator.Allocator;
import org.nd4j.jita.allocator.impl.AtomicAllocator; import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t; import org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t;
import org.nd4j.linalg.api.blas.impl.BaseLevel1; import org.nd4j.linalg.api.blas.impl.BaseLevel1;
import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataBuffer;
@ -37,6 +36,7 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.CublasPointer; import org.nd4j.linalg.jcublas.CublasPointer;
import org.nd4j.linalg.jcublas.context.CudaContext; import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.linalg.jcublas.ops.executioner.CudaExecutioner; import org.nd4j.linalg.jcublas.ops.executioner.CudaExecutioner;
import org.nd4j.nativeblas.LongPointerWrapper;
import org.nd4j.nativeblas.NativeOps; import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder; import org.nd4j.nativeblas.NativeOpsHolder;
import org.nd4j.nativeblas.Nd4jBlas; import org.nd4j.nativeblas.Nd4jBlas;
@ -82,30 +82,9 @@ public class JcublasLevel1 extends BaseLevel1 {
Nd4j.getExecutioner().exec(dot); Nd4j.getExecutioner().exec(dot);
ret = dot.getFinalResult().floatValue(); ret = dot.getFinalResult().floatValue();
/*
cublasHandle_t handle = ctx.getHandle();
synchronized (handle) {
long result = cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream()));
if (result != 0)
throw new IllegalStateException("cublasSetStream failed");
FloatPointer resultPointer = new FloatPointer(0.0f);
cuBlasSdot_v2(new cublasContext(handle),
N,
xCPointer.getDevicePointer(),
incX,
yCPointer.getDevicePointer(),
incY, resultPointer);
ret = resultPointer.get();
}
*/
// allocator.registerAction(ctx, null, X, Y);
return ret; return ret;
} }
@Override @Override
protected float sdot(long N, INDArray X, int incX, INDArray Y, int incY) { protected float sdot(long N, INDArray X, int incX, INDArray Y, int incY) {
Preconditions.checkArgument(X.dataType() == DataType.FLOAT, "Float dtype expected"); Preconditions.checkArgument(X.dataType() == DataType.FLOAT, "Float dtype expected");
@ -114,22 +93,27 @@ public class JcublasLevel1 extends BaseLevel1 {
Nd4j.getExecutioner().push(); Nd4j.getExecutioner().push();
CudaContext ctx = allocator.getFlowController().prepareAction(null, X, Y); val ctx = allocator.getFlowController().prepareAction(null, X, Y);
float ret = 1f; float ret = 1f;
CublasPointer xCPointer = new CublasPointer(X, ctx); val xCPointer = new CublasPointer(X, ctx);
CublasPointer yCPointer = new CublasPointer(Y, ctx); val yCPointer = new CublasPointer(Y, ctx);
cublasHandle_t handle = ctx.getHandle(); val handle = ctx.getHandle();
val cctx = new cublasContext(handle);
synchronized (handle) { synchronized (handle) {
long result = cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream())); long result = cublasSetStream_v2(cctx, new CUstream_st(ctx.getCublasStream()));
if (result != 0) if (result != 0)
throw new IllegalStateException("cublasSetStream failed"); throw new IllegalStateException("cublasSetStream failed");
FloatPointer resultPointer = new FloatPointer(0.0f); val resultPointer = new FloatPointer(0.0f);
result = cublasSdot_v2(new cublasContext(handle), (int) N, (FloatPointer) xCPointer.getDevicePointer(), incX, result = cublasSdot_v2(cctx, (int) N, (FloatPointer) xCPointer.getDevicePointer(), incX, (FloatPointer) yCPointer.getDevicePointer(), incY, resultPointer);
(FloatPointer) yCPointer.getDevicePointer(), incY, resultPointer);
if (result != 0)
throw new IllegalStateException("cublasSdot_v2 failed. Error code: " + result);
ret = resultPointer.get(); ret = resultPointer.get();
} }
@ -155,17 +139,18 @@ public class JcublasLevel1 extends BaseLevel1 {
Nd4j.getExecutioner().push(); Nd4j.getExecutioner().push();
double ret; double ret;
CudaContext ctx = allocator.getFlowController().prepareAction(null, X, Y); val ctx = allocator.getFlowController().prepareAction(null, X, Y);
CublasPointer xCPointer = new CublasPointer(X, ctx); val xCPointer = new CublasPointer(X, ctx);
CublasPointer yCPointer = new CublasPointer(Y, ctx); val yCPointer = new CublasPointer(Y, ctx);
cublasHandle_t handle = ctx.getHandle(); val handle = ctx.getHandle();
synchronized (handle) { synchronized (handle) {
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream())); val cctx = new cublasContext(handle);
cublasSetStream_v2(cctx, new CUstream_st(ctx.getCublasStream()));
DoublePointer resultPointer = new DoublePointer(0.0); val resultPointer = new DoublePointer(0.0);
cublasDdot_v2(new cublasContext(handle), (int) N, (DoublePointer) xCPointer.getDevicePointer(), incX, cublasDdot_v2(cctx, (int) N, (DoublePointer) xCPointer.getDevicePointer(), incX,
(DoublePointer) yCPointer.getDevicePointer(), incY, resultPointer); (DoublePointer) yCPointer.getDevicePointer(), incY, resultPointer);
ret = resultPointer.get(); ret = resultPointer.get();
} }
@ -194,7 +179,7 @@ public class JcublasLevel1 extends BaseLevel1 {
cublasHandle_t handle = ctx.getHandle(); cublasHandle_t handle = ctx.getHandle();
synchronized (handle) { synchronized (handle) {
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream())); cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
FloatPointer resultPointer = new FloatPointer(0.0f); FloatPointer resultPointer = new FloatPointer(0.0f);
cublasSnrm2_v2(new cublasContext(handle), (int) N, (FloatPointer) cAPointer.getDevicePointer(), incX, cublasSnrm2_v2(new cublasContext(handle), (int) N, (FloatPointer) cAPointer.getDevicePointer(), incX,
@ -226,28 +211,6 @@ public class JcublasLevel1 extends BaseLevel1 {
float ret = asum.getFinalResult().floatValue(); float ret = asum.getFinalResult().floatValue();
return ret; return ret;
/*
CudaContext ctx = allocator.getFlowController().prepareAction(null, X);
float ret;
CublasPointer xCPointer = new CublasPointer(X, ctx);
cublasHandle_t handle = ctx.getHandle();
synchronized (handle) {
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream()));
FloatPointer resultPointer = new FloatPointer(0.0f);
cublasSasum_v2(new cublasContext(handle),
N,
(FloatPointer) xCPointer.getDevicePointer(),
incX, resultPointer);
ret = resultPointer.get();
}
allocator.registerAction(ctx, null, X);
return ret;
*/
} }
@Override @Override
@ -274,7 +237,7 @@ public class JcublasLevel1 extends BaseLevel1 {
cublasHandle_t handle = ctx.getHandle(); cublasHandle_t handle = ctx.getHandle();
synchronized (handle) { synchronized (handle) {
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream())); cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
DoublePointer resultPointer = new DoublePointer(0.0f); DoublePointer resultPointer = new DoublePointer(0.0f);
cublasDnrm2_v2(new cublasContext(handle), (int) N, (DoublePointer) cAPointer.getDevicePointer(), incX, cublasDnrm2_v2(new cublasContext(handle), (int) N, (DoublePointer) cAPointer.getDevicePointer(), incX,
@ -295,26 +258,6 @@ public class JcublasLevel1 extends BaseLevel1 {
double ret = asum.getFinalResult().doubleValue(); double ret = asum.getFinalResult().doubleValue();
return ret; return ret;
/*CudaContext ctx = allocator.getFlowController().prepareAction(null, X);
double ret;
CublasPointer xCPointer = new CublasPointer(X, ctx);
cublasHandle_t handle = ctx.getHandle();
synchronized (handle) {
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream()));
DoublePointer resultPointer = new DoublePointer(0.0);
cublasDasum_v2(new cublasContext(handle),
N,
xCPointer.getDevicePointer(),
incX, resultPointer);
ret = resultPointer.get();
}
allocator.registerAction(ctx, null, X);
return ret;
*/
} }
@Override @Override
@ -335,7 +278,7 @@ public class JcublasLevel1 extends BaseLevel1 {
cublasHandle_t handle = ctx.getHandle(); cublasHandle_t handle = ctx.getHandle();
synchronized (handle) { synchronized (handle) {
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream())); cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
IntPointer resultPointer = new IntPointer(new int[] {0}); IntPointer resultPointer = new IntPointer(new int[] {0});
cublasIsamax_v2(new cublasContext(handle), (int) N, (FloatPointer) xCPointer.getDevicePointer(), incX, cublasIsamax_v2(new cublasContext(handle), (int) N, (FloatPointer) xCPointer.getDevicePointer(), incX,
@ -365,7 +308,7 @@ public class JcublasLevel1 extends BaseLevel1 {
cublasHandle_t handle = ctx.getHandle(); cublasHandle_t handle = ctx.getHandle();
synchronized (handle) { synchronized (handle) {
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream())); cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
IntPointer resultPointer = new IntPointer(new int[] {0}); IntPointer resultPointer = new IntPointer(new int[] {0});
cublasIdamax_v2(new cublasContext(handle), (int) N, (DoublePointer) xCPointer.getDevicePointer(), incX, cublasIdamax_v2(new cublasContext(handle), (int) N, (DoublePointer) xCPointer.getDevicePointer(), incX,
@ -396,7 +339,7 @@ public class JcublasLevel1 extends BaseLevel1 {
cublasHandle_t handle = ctx.getHandle(); cublasHandle_t handle = ctx.getHandle();
synchronized (handle) { synchronized (handle) {
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream())); cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
cublasSswap_v2(new cublasContext(handle), (int) N, (FloatPointer) xCPointer.getDevicePointer(), incX, cublasSswap_v2(new cublasContext(handle), (int) N, (FloatPointer) xCPointer.getDevicePointer(), incX,
(FloatPointer) yCPointer.getDevicePointer(), incY); (FloatPointer) yCPointer.getDevicePointer(), incY);
@ -420,7 +363,7 @@ public class JcublasLevel1 extends BaseLevel1 {
cublasHandle_t handle = ctx.getHandle(); cublasHandle_t handle = ctx.getHandle();
synchronized (handle) { synchronized (handle) {
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream())); cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
cublasScopy_v2(new cublasContext(handle), (int) N, (FloatPointer) xCPointer.getDevicePointer(), incX, cublasScopy_v2(new cublasContext(handle), (int) N, (FloatPointer) xCPointer.getDevicePointer(), incX,
(FloatPointer) yCPointer.getDevicePointer(), incY); (FloatPointer) yCPointer.getDevicePointer(), incY);
@ -443,28 +386,6 @@ public class JcublasLevel1 extends BaseLevel1 {
Nd4j.getExecutioner().exec(new Axpy(X, Y, Y, alpha)); Nd4j.getExecutioner().exec(new Axpy(X, Y, Y, alpha));
OpExecutionerUtil.checkForAny(Y); OpExecutionerUtil.checkForAny(Y);
/*
CublasPointer xAPointer = new CublasPointer(X, ctx);
CublasPointer xBPointer = new CublasPointer(Y, ctx);
cublasHandle_t handle = ctx.getHandle();
synchronized (handle) {
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream()));
PointerPointer p = new cublasContext(handle);
cublasSaxpy_v2(p,
N,
alpha,
xAPointer.getDevicePointer(),
incX,
xBPointer.getDevicePointer(),
incY);
}
*/
// allocator.registerAction(ctx, Y, X);
} }
@Override @Override
@ -479,21 +400,6 @@ public class JcublasLevel1 extends BaseLevel1 {
((CudaExecutioner) Nd4j.getExecutioner()).exec(new Axpy(X, Y, Y, alpha)); ((CudaExecutioner) Nd4j.getExecutioner()).exec(new Axpy(X, Y, Y, alpha));
OpExecutionerUtil.checkForAny(Y); OpExecutionerUtil.checkForAny(Y);
/* synchronized (handle) {
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream()));
PointerPointer p = new cublasContext(handle);
cublasSaxpy_v2(p,
N,
alpha,
xAPointer.getDevicePointer(),
incX,
xBPointer.getDevicePointer(),
incY);
}
*/
// allocator.registerAction(ctx, Y, X);
} }
@Override @Override
@ -520,7 +426,7 @@ public class JcublasLevel1 extends BaseLevel1 {
cublasHandle_t handle = ctx.getHandle(); cublasHandle_t handle = ctx.getHandle();
synchronized (handle) { synchronized (handle) {
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream())); cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
cublasDswap_v2(new cublasContext(handle), (int) N, (DoublePointer) xCPointer.getDevicePointer(), incX, cublasDswap_v2(new cublasContext(handle), (int) N, (DoublePointer) xCPointer.getDevicePointer(), incX,
(DoublePointer) yCPointer.getDevicePointer(), incY); (DoublePointer) yCPointer.getDevicePointer(), incY);
@ -542,7 +448,7 @@ public class JcublasLevel1 extends BaseLevel1 {
cublasHandle_t handle = ctx.getHandle(); cublasHandle_t handle = ctx.getHandle();
synchronized (handle) { synchronized (handle) {
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream())); cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
cublasDcopy_v2(new cublasContext(handle), (int) N, (DoublePointer) xCPointer.getDevicePointer(), incX, cublasDcopy_v2(new cublasContext(handle), (int) N, (DoublePointer) xCPointer.getDevicePointer(), incX,
(DoublePointer) yCPointer.getDevicePointer(), incY); (DoublePointer) yCPointer.getDevicePointer(), incY);
@ -568,22 +474,6 @@ public class JcublasLevel1 extends BaseLevel1 {
Nd4j.getExecutioner().exec(new Axpy(X, Y, Y, alpha)); Nd4j.getExecutioner().exec(new Axpy(X, Y, Y, alpha));
OpExecutionerUtil.checkForAny(Y); OpExecutionerUtil.checkForAny(Y);
/*
CublasPointer xAPointer = new CublasPointer(X, ctx);
CublasPointer xBPointer = new CublasPointer(Y, ctx);
cublasHandle_t handle = ctx.getHandle();
synchronized (handle) {
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream()));
cublasDaxpy_v2(new cublasContext(handle),
N, alpha, xAPointer.getDevicePointer(),
incX, xBPointer.getDevicePointer(),
incY);
}
*/
// allocator.registerAction(ctx, Y, X);
} }
@Override @Override
@ -652,7 +542,7 @@ public class JcublasLevel1 extends BaseLevel1 {
cublasHandle_t handle = ctx.getHandle(); cublasHandle_t handle = ctx.getHandle();
synchronized (handle) { synchronized (handle) {
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream())); cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
cublasSscal_v2(new cublasContext(handle),(int) N, new FloatPointer(alpha), cublasSscal_v2(new cublasContext(handle),(int) N, new FloatPointer(alpha),
(FloatPointer) xCPointer.getDevicePointer(), incX); (FloatPointer) xCPointer.getDevicePointer(), incX);
@ -675,7 +565,7 @@ public class JcublasLevel1 extends BaseLevel1 {
cublasHandle_t handle = ctx.getHandle(); cublasHandle_t handle = ctx.getHandle();
synchronized (handle) { synchronized (handle) {
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream())); cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
cublasDscal_v2(new cublasContext(handle), (int) N, new DoublePointer(alpha), cublasDscal_v2(new cublasContext(handle), (int) N, new DoublePointer(alpha),
(DoublePointer) xCPointer.getDevicePointer(), incX); (DoublePointer) xCPointer.getDevicePointer(), incX);

View File

@ -64,7 +64,7 @@ public class JcublasLevel2 extends BaseLevel2 {
cublasHandle_t handle = ctx.getHandle(); cublasHandle_t handle = ctx.getHandle();
synchronized (handle) { synchronized (handle) {
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream())); cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
cublasSgemv_v2(new cublasContext(handle), convertTranspose(TransA), M, N, new FloatPointer(alpha), cublasSgemv_v2(new cublasContext(handle), convertTranspose(TransA), M, N, new FloatPointer(alpha),
(FloatPointer) cAPointer.getDevicePointer(), lda, (FloatPointer) cAPointer.getDevicePointer(), lda,
@ -136,7 +136,7 @@ public class JcublasLevel2 extends BaseLevel2 {
cublasHandle_t handle = ctx.getHandle(); cublasHandle_t handle = ctx.getHandle();
synchronized (handle) { synchronized (handle) {
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream())); cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
cublasDgemv_v2(new cublasContext(handle), convertTranspose(TransA), M, N, new DoublePointer(alpha), cublasDgemv_v2(new cublasContext(handle), convertTranspose(TransA), M, N, new DoublePointer(alpha),
(DoublePointer) cAPointer.getDevicePointer(), lda, (DoublePointer) cAPointer.getDevicePointer(), lda,

View File

@ -17,6 +17,7 @@
package org.nd4j.linalg.jcublas.blas; package org.nd4j.linalg.jcublas.blas;
import lombok.val;
import org.bytedeco.javacpp.DoublePointer; import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer; import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.ShortPointer; import org.bytedeco.javacpp.ShortPointer;
@ -73,7 +74,7 @@ public class JcublasLevel3 extends BaseLevel3 {
cublasHandle_t handle = ctx.getHandle(); cublasHandle_t handle = ctx.getHandle();
synchronized (handle) { synchronized (handle) {
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream())); cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
int arch = CudaEnvironment.getInstance().getCurrentDeviceArchitecture(); int arch = CudaEnvironment.getInstance().getCurrentDeviceArchitecture();
@ -111,15 +112,15 @@ public class JcublasLevel3 extends BaseLevel3 {
Nd4j.getExecutioner().push(); Nd4j.getExecutioner().push();
CudaContext ctx = allocator.getFlowController().prepareAction(C, A, B); val ctx = allocator.getFlowController().prepareAction(C, A, B);
CublasPointer cAPointer = new CublasPointer(A, ctx); val cAPointer = new CublasPointer(A, ctx);
CublasPointer cBPointer = new CublasPointer(B, ctx); val cBPointer = new CublasPointer(B, ctx);
CublasPointer cCPointer = new CublasPointer(C, ctx); val cCPointer = new CublasPointer(C, ctx);
cublasHandle_t handle = ctx.getHandle(); val handle = ctx.getHandle();
synchronized (handle) { synchronized (handle) {
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream())); cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
cublasSgemm_v2(new cublasContext(handle), convertTranspose(TransA), convertTranspose(TransB), M, N, K, cublasSgemm_v2(new cublasContext(handle), convertTranspose(TransA), convertTranspose(TransB), M, N, K,
new FloatPointer(alpha), (FloatPointer) cAPointer.getDevicePointer(), lda, new FloatPointer(alpha), (FloatPointer) cAPointer.getDevicePointer(), lda,
@ -145,7 +146,7 @@ public class JcublasLevel3 extends BaseLevel3 {
cublasHandle_t handle = ctx.getHandle(); cublasHandle_t handle = ctx.getHandle();
synchronized (handle) { synchronized (handle) {
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream())); cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
cublasSsymm_v2(new cublasContext(handle), convertSideMode(Side), convertUplo(Uplo), M, N, cublasSsymm_v2(new cublasContext(handle), convertSideMode(Side), convertUplo(Uplo), M, N,
new FloatPointer(alpha), (FloatPointer) aPointer.getDevicePointer(), lda, new FloatPointer(alpha), (FloatPointer) aPointer.getDevicePointer(), lda,
@ -170,7 +171,7 @@ public class JcublasLevel3 extends BaseLevel3 {
cublasHandle_t handle = ctx.getHandle(); cublasHandle_t handle = ctx.getHandle();
synchronized (handle) { synchronized (handle) {
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream())); cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
cublasSsyrk_v2(new cublasContext(handle), convertUplo(Uplo), convertTranspose(Trans), N, K, cublasSsyrk_v2(new cublasContext(handle), convertUplo(Uplo), convertTranspose(Trans), N, K,
new FloatPointer(alpha), (FloatPointer) aPointer.getDevicePointer(), lda, new FloatPointer(alpha), (FloatPointer) aPointer.getDevicePointer(), lda,
@ -207,7 +208,7 @@ public class JcublasLevel3 extends BaseLevel3 {
cublasHandle_t handle = ctx.getHandle(); cublasHandle_t handle = ctx.getHandle();
synchronized (handle) { synchronized (handle) {
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream())); cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
cublasStrsm_v2(new cublasContext(handle), convertSideMode(Side), convertUplo(Uplo), cublasStrsm_v2(new cublasContext(handle), convertSideMode(Side), convertUplo(Uplo),
convertTranspose(TransA), convertDiag(Diag), M, N, new FloatPointer(alpha), convertTranspose(TransA), convertDiag(Diag), M, N, new FloatPointer(alpha),
@ -227,17 +228,17 @@ public class JcublasLevel3 extends BaseLevel3 {
Nd4j.getExecutioner().push(); Nd4j.getExecutioner().push();
CudaContext ctx = allocator.getFlowController().prepareAction(C, A, B); val ctx = allocator.getFlowController().prepareAction(C, A, B);
DataTypeValidation.assertDouble(A, B, C); DataTypeValidation.assertDouble(A, B, C);
CublasPointer cAPointer = new CublasPointer(A, ctx); val cAPointer = new CublasPointer(A, ctx);
CublasPointer cBPointer = new CublasPointer(B, ctx); val cBPointer = new CublasPointer(B, ctx);
CublasPointer cCPointer = new CublasPointer(C, ctx); val cCPointer = new CublasPointer(C, ctx);
cublasHandle_t handle = ctx.getHandle(); val handle = ctx.getHandle();
synchronized (handle) { synchronized (handle) {
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream())); cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
cublasDgemm_v2(new cublasContext(handle), convertTranspose(TransA), convertTranspose(TransB), M, N, K, cublasDgemm_v2(new cublasContext(handle), convertTranspose(TransA), convertTranspose(TransB), M, N, K,
new DoublePointer(alpha), (DoublePointer) cAPointer.getDevicePointer(), lda, new DoublePointer(alpha), (DoublePointer) cAPointer.getDevicePointer(), lda,
@ -262,7 +263,7 @@ public class JcublasLevel3 extends BaseLevel3 {
cublasHandle_t handle = ctx.getHandle(); cublasHandle_t handle = ctx.getHandle();
synchronized (handle) { synchronized (handle) {
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream())); cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
cublasDsymm_v2(new cublasContext(handle), convertSideMode(Side), convertUplo(Uplo), M, N, cublasDsymm_v2(new cublasContext(handle), convertSideMode(Side), convertUplo(Uplo), M, N,
new DoublePointer(alpha), (DoublePointer) aPointer.getDevicePointer(), lda, new DoublePointer(alpha), (DoublePointer) aPointer.getDevicePointer(), lda,
@ -287,7 +288,7 @@ public class JcublasLevel3 extends BaseLevel3 {
cublasHandle_t handle = ctx.getHandle(); cublasHandle_t handle = ctx.getHandle();
synchronized (handle) { synchronized (handle) {
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream())); cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
cublasDsyrk_v2(new cublasContext(handle), convertUplo(Uplo), Trans, N, K, new DoublePointer(alpha), cublasDsyrk_v2(new cublasContext(handle), convertUplo(Uplo), Trans, N, K, new DoublePointer(alpha),
(DoublePointer) aPointer.getDevicePointer(), lda, new DoublePointer(beta), (DoublePointer) aPointer.getDevicePointer(), lda, new DoublePointer(beta),
@ -312,7 +313,7 @@ public class JcublasLevel3 extends BaseLevel3 {
cublasHandle_t handle = ctx.getHandle(); cublasHandle_t handle = ctx.getHandle();
synchronized (handle) { synchronized (handle) {
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream())); cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
cublasDsyr2k_v2(new cublasContext(handle), convertUplo(Uplo), Trans, N, K, new DoublePointer(alpha), cublasDsyr2k_v2(new cublasContext(handle), convertUplo(Uplo), Trans, N, K, new DoublePointer(alpha),
(DoublePointer) aPointer.getDevicePointer(), lda, (DoublePointer) aPointer.getDevicePointer(), lda,
@ -337,7 +338,7 @@ public class JcublasLevel3 extends BaseLevel3 {
cublasHandle_t handle = ctx.getHandle(); cublasHandle_t handle = ctx.getHandle();
synchronized (handle) { synchronized (handle) {
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream())); cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
cublasDtrmm_v2(new cublasContext(handle), convertSideMode(Side), convertUplo(Uplo), cublasDtrmm_v2(new cublasContext(handle), convertSideMode(Side), convertUplo(Uplo),
convertTranspose(TransA), convertDiag(Diag), M, N, new DoublePointer(alpha), convertTranspose(TransA), convertDiag(Diag), M, N, new DoublePointer(alpha),
@ -363,7 +364,7 @@ public class JcublasLevel3 extends BaseLevel3 {
cublasHandle_t handle = ctx.getHandle(); cublasHandle_t handle = ctx.getHandle();
synchronized (handle) { synchronized (handle) {
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream())); cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
cublasDtrsm_v2(new cublasContext(handle), convertSideMode(Side), convertUplo(Uplo), cublasDtrsm_v2(new cublasContext(handle), convertSideMode(Side), convertUplo(Uplo),
convertTranspose(TransA), convertDiag(Diag), M, N, new DoublePointer(alpha), convertTranspose(TransA), convertDiag(Diag), M, N, new DoublePointer(alpha),

View File

@ -241,7 +241,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
initPointers(length, Nd4j.sizeOfDataType(dtype), initialize); initPointers(length, Nd4j.sizeOfDataType(dtype), initialize);
} }
protected void lazyAllocateHostPointer() { public void lazyAllocateHostPointer() {
if (allocationPoint.getPointers().getHostPointer() == null) if (allocationPoint.getPointers().getHostPointer() == null)
initHostPointerAndIndexer(); initHostPointerAndIndexer();
} }

View File

@ -42,6 +42,10 @@ public class CudaBfloat16DataBuffer extends BaseCudaDataBuffer {
super(pointer, indexer, length); super(pointer, indexer, length);
} }
public CudaBfloat16DataBuffer(Pointer pointer, Pointer specialPointer, Indexer indexer, long length){
super(pointer, specialPointer, indexer, length);
}
/** /**
* Base constructor * Base constructor
* *

View File

@ -42,6 +42,10 @@ public class CudaUInt16DataBuffer extends BaseCudaDataBuffer {
super(pointer, indexer, length); super(pointer, indexer, length);
} }
public CudaUInt16DataBuffer(Pointer pointer, Pointer specialPointer, Indexer indexer, long length){
super(pointer, specialPointer, indexer, length);
}
/** /**
* Base constructor * Base constructor
* *

View File

@ -42,6 +42,10 @@ public class CudaUInt32DataBuffer extends BaseCudaDataBuffer {
super(pointer, indexer, length); super(pointer, indexer, length);
} }
public CudaUInt32DataBuffer(Pointer pointer, Pointer specialPointer, Indexer indexer, long length){
super(pointer, specialPointer, indexer, length);
}
/** /**
* Base constructor * Base constructor
* *

View File

@ -100,6 +100,10 @@ public class CudaUInt64DataBuffer extends BaseCudaDataBuffer {
super(data, copy, offset, workspace); super(data, copy, offset, workspace);
} }
public CudaUInt64DataBuffer(Pointer pointer, Pointer specialPointer, Indexer indexer, long length){
super(pointer, specialPointer, indexer, length);
}
public CudaUInt64DataBuffer(double[] data) { public CudaUInt64DataBuffer(double[] data) {
super(data); super(data);
} }

View File

@ -72,10 +72,18 @@ public class CudaDataBufferFactory implements DataBufferFactory {
return new CudaFloatDataBuffer(underlyingBuffer, length, offset); return new CudaFloatDataBuffer(underlyingBuffer, length, offset);
case HALF: case HALF:
return new CudaHalfDataBuffer(underlyingBuffer, length, offset); return new CudaHalfDataBuffer(underlyingBuffer, length, offset);
case BFLOAT16:
return new CudaBfloat16DataBuffer(underlyingBuffer, length, offset);
case UINT64:
return new CudaUInt64DataBuffer(underlyingBuffer, length, offset);
case LONG: case LONG:
return new CudaLongDataBuffer(underlyingBuffer, length, offset); return new CudaLongDataBuffer(underlyingBuffer, length, offset);
case UINT32:
return new CudaUInt32DataBuffer(underlyingBuffer, length, offset);
case INT: case INT:
return new CudaIntDataBuffer(underlyingBuffer, length, offset); return new CudaIntDataBuffer(underlyingBuffer, length, offset);
case UINT16:
return new CudaUInt16DataBuffer(underlyingBuffer, length, offset);
case SHORT: case SHORT:
return new CudaShortDataBuffer(underlyingBuffer, length, offset); return new CudaShortDataBuffer(underlyingBuffer, length, offset);
case UBYTE: case UBYTE:
@ -684,16 +692,32 @@ public class CudaDataBufferFactory implements DataBufferFactory {
@Override @Override
public DataBuffer create(Pointer pointer, DataType type, long length, Indexer indexer) { public DataBuffer create(Pointer pointer, DataType type, long length, Indexer indexer) {
switch (type) { switch (type) {
case UINT64:
return new CudaUInt64DataBuffer(pointer, indexer, length);
case LONG: case LONG:
return new CudaLongDataBuffer(pointer, indexer, length); return new CudaLongDataBuffer(pointer, indexer, length);
case UINT32:
return new CudaUInt32DataBuffer(pointer, indexer, length);
case INT: case INT:
return new CudaIntDataBuffer(pointer, indexer, length); return new CudaIntDataBuffer(pointer, indexer, length);
case UINT16:
return new CudaUInt16DataBuffer(pointer, indexer, length);
case SHORT:
return new CudaShortDataBuffer(pointer, indexer, length);
case UBYTE:
return new CudaUByteDataBuffer(pointer, indexer, length);
case BYTE:
return new CudaByteDataBuffer(pointer, indexer, length);
case DOUBLE: case DOUBLE:
return new CudaDoubleDataBuffer(pointer, indexer, length); return new CudaDoubleDataBuffer(pointer, indexer, length);
case FLOAT: case FLOAT:
return new CudaFloatDataBuffer(pointer, indexer, length); return new CudaFloatDataBuffer(pointer, indexer, length);
case HALF: case HALF:
return new CudaHalfDataBuffer(pointer, indexer, length); return new CudaHalfDataBuffer(pointer, indexer, length);
case BFLOAT16:
return new CudaBfloat16DataBuffer(pointer, indexer, length);
case BOOL:
return new CudaBoolDataBuffer(pointer, indexer, length);
} }
throw new IllegalArgumentException("Illegal dtype " + type); throw new IllegalArgumentException("Illegal dtype " + type);
@ -702,16 +726,32 @@ public class CudaDataBufferFactory implements DataBufferFactory {
@Override @Override
public DataBuffer create(Pointer pointer, Pointer specialPointer, DataType type, long length, Indexer indexer) { public DataBuffer create(Pointer pointer, Pointer specialPointer, DataType type, long length, Indexer indexer) {
switch (type) { switch (type) {
case UINT64:
return new CudaUInt64DataBuffer(pointer, specialPointer, indexer, length);
case LONG: case LONG:
return new CudaLongDataBuffer(pointer, specialPointer, indexer, length); return new CudaLongDataBuffer(pointer, specialPointer, indexer, length);
case UINT32:
return new CudaUInt32DataBuffer(pointer, specialPointer, indexer, length);
case INT: case INT:
return new CudaIntDataBuffer(pointer, specialPointer, indexer, length); return new CudaIntDataBuffer(pointer, specialPointer, indexer, length);
case UINT16:
return new CudaUInt16DataBuffer(pointer, specialPointer, indexer, length);
case SHORT:
return new CudaShortDataBuffer(pointer, specialPointer, indexer, length);
case UBYTE:
return new CudaUByteDataBuffer(pointer, specialPointer, indexer, length);
case BYTE:
return new CudaByteDataBuffer(pointer, specialPointer, indexer, length);
case DOUBLE: case DOUBLE:
return new CudaDoubleDataBuffer(pointer, specialPointer, indexer, length); return new CudaDoubleDataBuffer(pointer, specialPointer, indexer, length);
case FLOAT: case FLOAT:
return new CudaFloatDataBuffer(pointer, specialPointer, indexer, length); return new CudaFloatDataBuffer(pointer, specialPointer, indexer, length);
case HALF: case HALF:
return new CudaHalfDataBuffer(pointer, specialPointer, indexer, length); return new CudaHalfDataBuffer(pointer, specialPointer, indexer, length);
case BFLOAT16:
return new CudaBfloat16DataBuffer(pointer, specialPointer, indexer, length);
case BOOL:
return new CudaBoolDataBuffer(pointer, specialPointer, indexer, length);
} }
throw new IllegalArgumentException("Illegal dtype " + type); throw new IllegalArgumentException("Illegal dtype " + type);

View File

@ -17,9 +17,13 @@
package org.nd4j.linalg.jcublas.context; package org.nd4j.linalg.jcublas.context;
import lombok.Data; import lombok.Data;
import lombok.val;
import org.bytedeco.javacpp.LongPointer;
import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerPointer;
import org.nd4j.jita.allocator.garbage.GarbageResourceReference; import org.nd4j.jita.allocator.garbage.GarbageResourceReference;
import org.nd4j.jita.allocator.impl.AtomicAllocator; import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t; import org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t;
import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t; import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t;
import org.nd4j.jita.allocator.pointers.cuda.cusolverDnHandle_t; import org.nd4j.jita.allocator.pointers.cuda.cusolverDnHandle_t;
@ -46,7 +50,6 @@ public class CudaContext {
//private CUevent cUevent; //private CUevent cUevent;
private cudaStream_t oldStream; private cudaStream_t oldStream;
private cudaStream_t cublasStream;
private cudaStream_t solverStream; private cudaStream_t solverStream;
private cudaStream_t specialStream; private cudaStream_t specialStream;
@ -130,17 +133,11 @@ public class CudaContext {
// ContextHolder.getInstance().setContext(); // ContextHolder.getInstance().setContext();
if (nativeOps.streamSynchronize(oldStream) == 0) if (nativeOps.streamSynchronize(oldStream) == 0)
throw new ND4JIllegalStateException("CUDA stream synchronization failed"); throw new ND4JIllegalStateException("CUDA stream synchronization failed");
if (syncCuBlas)
syncCublasStream();
} }
public void syncCublasStream() { public Pointer getCublasStream() {
if (cublasStream != null) { val lptr = new PointerPointer(this.getOldStream());
if (nativeOps.streamSynchronize(cublasStream) == 0) return lptr.get(0);
throw new ND4JIllegalStateException("CUDA stream synchronization failed");
} else
throw new IllegalStateException("cuBLAS stream isnt set");
} }

View File

@ -1984,13 +1984,6 @@ public class CudaExecutioner extends DefaultOpExecutioner {
AtomicAllocator.getInstance().getAllocationPoint(encodedBuffer).tickDeviceWrite(); AtomicAllocator.getInstance().getAllocationPoint(encodedBuffer).tickDeviceWrite();
AtomicAllocator.getInstance().getAllocationPoint(buffer).tickDeviceWrite(); AtomicAllocator.getInstance().getAllocationPoint(buffer).tickDeviceWrite();
// just to ensure it's not purged
extras.address();
tempX.address();
buffers.getClass();
return Nd4j.createArrayFromShapeBuffer(encodedBuffer, input.shapeInfoDataBuffer()); return Nd4j.createArrayFromShapeBuffer(encodedBuffer, input.shapeInfoDataBuffer());
} }
@ -2171,14 +2164,17 @@ public class CudaExecutioner extends DefaultOpExecutioner {
return Collections.emptyList(); return Collections.emptyList();
} }
val inputBuffers = new PointerPointer<>(op.inputArguments().length); val inputBuffers = new PointerPointer<>(op.inputArguments().length * 2);
val inputShapes = new PointerPointer<>(op.inputArguments().length); val inputShapes = new PointerPointer<>(op.inputArguments().length);
int cnt= 0; int cnt= 0;
for (val in: op.inputArguments()) { for (val in: op.inputArguments()) {
// NOT A TYPO: shape functions work on host side only // NOT A TYPO: shape functions work on host side only
if (!in.isEmpty()) if (!in.isEmpty()) {
inputBuffers.put(cnt, in.data().addressPointer()); inputBuffers.put(cnt, in.data().addressPointer());
inputBuffers.put(cnt + op.inputArguments().length, AtomicAllocator.getInstance().getPointer(in.data()));
}
inputShapes.put(cnt++, in.shapeInfoDataBuffer().addressPointer()); inputShapes.put(cnt++, in.shapeInfoDataBuffer().addressPointer());
} }

View File

@ -4394,9 +4394,9 @@ public static class NativeOps extends org.nd4j.nativeblas.NativeOps {
* - down from main diagonal starting at subdiagonal number "lower" if direction = 'd' (down) or 'b' (both) * - down from main diagonal starting at subdiagonal number "lower" if direction = 'd' (down) or 'b' (both)
* - up from main diagonal starting at superdiagonal number "upper"if direction = 'u' (up) or 'b' (both) * - up from main diagonal starting at superdiagonal number "upper"if direction = 'u' (up) or 'b' (both)
* direction - in what direction to fill matrix. There are 3 possible directions: * direction - in what direction to fill matrix. There are 3 possible directions:
* 'u' - fill up, mathematically this corresponds to lower triangular matrix, parameter "lower" is not taken into account * 'u' - fill up, mathematically this corresponds to lower triangular matrix, subdiagonal "lower" unaffected
* 'l' - fill down, mathematically this corresponds to upper triangular matrix, parameter "upper" is not taken into account * 'l' - fill down, mathematically this corresponds to upper triangular matrix, superdiagonal "upper" remains unaffected
* 'b' - fill in both directions, both parameters "lower" and "upper" are taken into account * 'b' - fill in both directions, both "lower" and "upper" are taken into account
* rest of target elements are equal to this array elements * rest of target elements are equal to this array elements
* target and this array should have same shapes, except when this_rank = 1 (in that case should be target_rank = 2) * target and this array should have same shapes, except when this_rank = 1 (in that case should be target_rank = 2)
*/ */
@ -7857,9 +7857,18 @@ public static final int PREALLOC_SIZE = 33554432;
* @param indices the indices to iterate over * @param indices the indices to iterate over
* @return the double at the specified index * @return the double at the specified index
*/ */
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("Nd4jLong") long baseOffset, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer stride, @Cast("const Nd4jLong*") LongPointer indices,int rank); @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("Nd4jLong") long baseOffset, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer stride, @Cast("const Nd4jLong*") LongPointer indices, int rank);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("Nd4jLong") long baseOffset, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer stride, @Cast("const Nd4jLong*") LongBuffer indices,int rank); @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("Nd4jLong") long baseOffset, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer stride, @Cast("const Nd4jLong*") LongBuffer indices, int rank);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("Nd4jLong") long baseOffset, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] stride, @Cast("const Nd4jLong*") long[] indices,int rank); @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("Nd4jLong") long baseOffset, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] stride, @Cast("const Nd4jLong*") long[] indices, int rank);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer indices, @Cast("Nd4jLong") long baseOffset/*=0*/);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer indices);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer indices, @Cast("Nd4jLong") long baseOffset/*=0*/);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer indices);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] indices, @Cast("Nd4jLong") long baseOffset/*=0*/);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] indices);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("uint*") @StdVector IntPointer indices);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("uint*") @StdVector IntBuffer indices);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("uint*") @StdVector int[] indices);
@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer createShapeInfo(@Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer stride, int rank); @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer createShapeInfo(@Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer stride, int rank);
@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer createShapeInfo(@Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer stride, int rank); @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer createShapeInfo(@Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer stride, int rank);
@ -8027,12 +8036,13 @@ public static final int PREALLOC_SIZE = 33554432;
// calculate offsets of max-array, these output offsets correspond to one minIdx index of min-array which is sub-array of max-array // calculate offsets of max-array, these output offsets correspond to one minIdx index of min-array which is sub-array of max-array
// dimsToExclude - should be sorted in increasing order // dimsToExclude - should be sorted in increasing order
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/); // memBuff - auxiliary memory buffer (size = 2 * max_rank) for coordinates and increments storing, should be passed from outside
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo); @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") LongPointer memBuff, @Const IntPointer dimsToExclude/*=nullptr*/);
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/); @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") LongPointer memBuff);
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo); @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") LongBuffer memBuff, @Const IntBuffer dimsToExclude/*=nullptr*/);
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/); @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") LongBuffer memBuff);
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo); @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Cast("Nd4jLong*") long[] memBuff, @Const int[] dimsToExclude/*=nullptr*/);
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Cast("Nd4jLong*") long[] memBuff);
// calculates offsets for entities (elements or sub-arrays), shape in context of sub-array means dimensions excluded from outer array // calculates offsets for entities (elements or sub-arrays), shape in context of sub-array means dimensions excluded from outer array
// rank is equal to size of shape // rank is equal to size of shape

View File

@ -4394,9 +4394,9 @@ public static class NativeOps extends org.nd4j.nativeblas.NativeOps {
* - down from main diagonal starting at subdiagonal number "lower" if direction = 'd' (down) or 'b' (both) * - down from main diagonal starting at subdiagonal number "lower" if direction = 'd' (down) or 'b' (both)
* - up from main diagonal starting at superdiagonal number "upper"if direction = 'u' (up) or 'b' (both) * - up from main diagonal starting at superdiagonal number "upper"if direction = 'u' (up) or 'b' (both)
* direction - in what direction to fill matrix. There are 3 possible directions: * direction - in what direction to fill matrix. There are 3 possible directions:
* 'u' - fill up, mathematically this corresponds to lower triangular matrix, parameter "lower" is not taken into account * 'u' - fill up, mathematically this corresponds to lower triangular matrix, subdiagonal "lower" unaffected
* 'l' - fill down, mathematically this corresponds to upper triangular matrix, parameter "upper" is not taken into account * 'l' - fill down, mathematically this corresponds to upper triangular matrix, superdiagonal "upper" remains unaffected
* 'b' - fill in both directions, both parameters "lower" and "upper" are taken into account * 'b' - fill in both directions, both "lower" and "upper" are taken into account
* rest of target elements are equal to this array elements * rest of target elements are equal to this array elements
* target and this array should have same shapes, except when this_rank = 1 (in that case should be target_rank = 2) * target and this array should have same shapes, except when this_rank = 1 (in that case should be target_rank = 2)
*/ */
@ -7857,9 +7857,18 @@ public static final int PREALLOC_SIZE = 33554432;
* @param indices the indices to iterate over * @param indices the indices to iterate over
* @return the double at the specified index * @return the double at the specified index
*/ */
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("Nd4jLong") long baseOffset, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer stride, @Cast("const Nd4jLong*") LongPointer indices,int rank); @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("Nd4jLong") long baseOffset, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer stride, @Cast("const Nd4jLong*") LongPointer indices, int rank);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("Nd4jLong") long baseOffset, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer stride, @Cast("const Nd4jLong*") LongBuffer indices,int rank); @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("Nd4jLong") long baseOffset, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer stride, @Cast("const Nd4jLong*") LongBuffer indices, int rank);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("Nd4jLong") long baseOffset, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] stride, @Cast("const Nd4jLong*") long[] indices,int rank); @Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("Nd4jLong") long baseOffset, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] stride, @Cast("const Nd4jLong*") long[] indices, int rank);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer indices, @Cast("Nd4jLong") long baseOffset/*=0*/);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer indices);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer indices, @Cast("Nd4jLong") long baseOffset/*=0*/);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer indices);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] indices, @Cast("Nd4jLong") long baseOffset/*=0*/);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] indices);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("uint*") @StdVector IntPointer indices);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("uint*") @StdVector IntBuffer indices);
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("uint*") @StdVector int[] indices);
@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer createShapeInfo(@Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer stride, int rank); @Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer createShapeInfo(@Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer stride, int rank);
@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer createShapeInfo(@Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer stride, int rank); @Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer createShapeInfo(@Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer stride, int rank);
@ -8027,12 +8036,13 @@ public static final int PREALLOC_SIZE = 33554432;
// calculate offsets of max-array, these output offsets correspond to one minIdx index of min-array which is sub-array of max-array // calculate offsets of max-array, these output offsets correspond to one minIdx index of min-array which is sub-array of max-array
// dimsToExclude - should be sorted in increasing order // dimsToExclude - should be sorted in increasing order
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/); // memBuff - auxiliary memory buffer (size = 2 * max_rank) for coordinates and increments storing, should be passed from outside
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo); @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") LongPointer memBuff, @Const IntPointer dimsToExclude/*=nullptr*/);
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/); @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") LongPointer memBuff);
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo); @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") LongBuffer memBuff, @Const IntBuffer dimsToExclude/*=nullptr*/);
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/); @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") LongBuffer memBuff);
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo); @Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Cast("Nd4jLong*") long[] memBuff, @Const int[] dimsToExclude/*=nullptr*/);
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Cast("Nd4jLong*") long[] memBuff);
// calculates offsets for entities (elements or sub-arrays), shape in context of sub-array means dimensions excluded from outer array // calculates offsets for entities (elements or sub-arrays), shape in context of sub-array means dimensions excluded from outer array
// rank is equal to size of shape // rank is equal to size of shape

View File

@ -28,6 +28,7 @@ import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.junit.runners.Parameterized; import org.junit.runners.Parameterized;
import org.nd4j.imports.TFGraphs.NodeReader; import org.nd4j.imports.TFGraphs.NodeReader;
import org.nd4j.linalg.api.blas.Level1;
import org.nd4j.linalg.api.blas.params.GemmParams; import org.nd4j.linalg.api.blas.params.GemmParams;
import org.nd4j.linalg.api.blas.params.MMulTranspose; import org.nd4j.linalg.api.blas.params.MMulTranspose;
import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataBuffer;
@ -118,6 +119,7 @@ import static org.junit.Assert.assertArrayEquals;
public class Nd4jTestsC extends BaseNd4jTest { public class Nd4jTestsC extends BaseNd4jTest {
DataType initialType; DataType initialType;
Level1 l1;
@Rule @Rule
public TemporaryFolder testDir = new TemporaryFolder(); public TemporaryFolder testDir = new TemporaryFolder();
@ -125,6 +127,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
public Nd4jTestsC(Nd4jBackend backend) { public Nd4jTestsC(Nd4jBackend backend) {
super(backend); super(backend);
this.initialType = Nd4j.dataType(); this.initialType = Nd4j.dataType();
l1 = Nd4j.getBlasWrapper().level1();
} }
@ -431,7 +434,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
} }
@Test @Test
public void testMmulOp() { public void testMmulOp() throws Exception {
INDArray arr = Nd4j.create(new double[][] {{1, 2, 3}, {4, 5, 6}}); INDArray arr = Nd4j.create(new double[][] {{1, 2, 3}, {4, 5, 6}});
INDArray z = Nd4j.create(2, 2); INDArray z = Nd4j.create(2, 2);
INDArray assertion = Nd4j.create(new double[][] {{14, 32}, {32, 77}}); INDArray assertion = Nd4j.create(new double[][] {{14, 32}, {32, 77}});
@ -2797,15 +2800,21 @@ public class Nd4jTestsC extends BaseNd4jTest {
@Test @Test
public void testDot() { public void testDot() throws Exception {
INDArray vec1 = Nd4j.create(new float[] {1, 2, 3, 4}); INDArray vec1 = Nd4j.create(new float[] {1, 2, 3, 4});
INDArray vec2 = Nd4j.create(new float[] {1, 2, 3, 4}); INDArray vec2 = Nd4j.create(new float[] {1, 2, 3, 4});
assertEquals(10.f, vec1.sumNumber().floatValue(), 1e-5);
assertEquals(10.f, vec2.sumNumber().floatValue(), 1e-5);
assertEquals(30, Nd4j.getBlasWrapper().dot(vec1, vec2), 1e-1); assertEquals(30, Nd4j.getBlasWrapper().dot(vec1, vec2), 1e-1);
INDArray matrix = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); INDArray matrix = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
INDArray row = matrix.getRow(1); INDArray row = matrix.getRow(1);
assertEquals(25, Nd4j.getBlasWrapper().dot(row, row), 1e-1);
assertEquals(7.0f, row.sumNumber().floatValue(), 1e-5f);
assertEquals(25, Nd4j.getBlasWrapper().dot(row, row), 1e-1);
} }
@ -2815,8 +2824,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertTrue(Arrays.equals(new long[] {5, 5}, eye.shape())); assertTrue(Arrays.equals(new long[] {5, 5}, eye.shape()));
eye = Nd4j.eye(5); eye = Nd4j.eye(5);
assertTrue(Arrays.equals(new long[] {5, 5}, eye.shape())); assertTrue(Arrays.equals(new long[] {5, 5}, eye.shape()));
} }
@Test @Test
@ -4224,6 +4231,8 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertEquals(10, matrix.rows()); assertEquals(10, matrix.rows());
assertEquals(6, matrix.columns()); assertEquals(6, matrix.columns());
log.info("Result: {}", matrix);
for (int x = 0; x < 10; x++) { for (int x = 0; x < 10; x++) {
assertEquals((double) x, matrix.getRow(x).meanNumber().doubleValue(), 0.1); assertEquals((double) x, matrix.getRow(x).meanNumber().doubleValue(), 0.1);
assertEquals(arrays.get(x), matrix.getRow(x).reshape(1, matrix.size(1))); assertEquals(arrays.get(x), matrix.getRow(x).reshape(1, matrix.size(1)));

View File

@ -32,6 +32,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.ByteOrder; import java.nio.ByteOrder;
import java.nio.ShortBuffer; import java.nio.ShortBuffer;