[WIP] more CUDA stuff (#57)
* initial commit Signed-off-by: raver119 <raver119@gmail.com> * Added gradcheck test for dynamic_partition_bp op. * - implementation of dilation op (cpu and cuda) Signed-off-by: Yurii <yurii@skymind.io> * Fixed broadcast_dynamic_shape 1D case and tests. * Fixed usage of default integer arguments. * Fixed dynamic_partition_bp op and tests. * Eliminated test with grad check for dynamic_partition_bp op. * start working on cuda svd - porting available corresponding api from cuSOLVER library Signed-off-by: Yurii <yurii@skymind.io> * provide prelu_bp Signed-off-by: Yurii <yurii@skymind.io> * - provide gruCell_bp (old version ??) Signed-off-by: Yurii <yurii@skymind.io> * - polishing cumsum_bp and cumprod_bp tests Signed-off-by: Yurii <yurii@skymind.io> * provide sparseSoftmaxCrossEntropyWithLogits and sparseSoftmaxCrossEntropyWithLogits_grad Signed-off-by: Yurii <yurii@skymind.io> * Fixed atomicMul with float input/output * implementation of cuda kernel for triu_bp operation Signed-off-by: Yurii <yurii@skymind.io> * Refactored lup helper to add parrallel computing. * cusolver libraries Signed-off-by: raver119 <raver119@gmail.com> * uncomment cuSolver APIs in svd.cu Signed-off-by: Yurii <yurii@skymind.io> * cusolver var Signed-off-by: raver119 <raver119@gmail.com> * - further work on cuSolver svd Signed-off-by: Yurii <yurii@skymind.io> * Implement usage of cuda solver to LUP decomposition. * - correct naames in lup functions Signed-off-by: Yurii <yurii@skymind.io> * correct svdQR cuda Signed-off-by: Yurii <yurii@skymind.io> * - provide transpositions of input matrices in case of c order in svdCudaQR Signed-off-by: Yurii <yurii@skymind.io> * Fixed implementation issues with LUP usign cuda solver. * Implementation of matrix_determinant helper with cuda kernels. Working revision. * Implemented log_matrix_determinant helper with cuda kernels. * - implementation of batched cuda svd Signed-off-by: Yurii <yurii@skymind.io> * Refactored cholesky helper and implementation of cuda solver cholesky batch. * - implementation of cuda kernel for tile bp Signed-off-by: Yurii <yurii@skymind.io> * Implementation of cholesky and logdet with cuda kernels. * - implementation of cuda kernel for sru_bidirectional Signed-off-by: Yurii <yurii@skymind.io> * Fixed cholesky helper. * Cholesky op helper implementation. Working double-based cublas implementation. * bad import excluded Signed-off-by: raver119 <raver119@gmail.com> * Finished with cuda implementation of cholesky helper and tests. * - implementation of cuda kernel for sru_bidirectional_backprop operation Signed-off-by: Yurii <yurii@skymind.io> * Implementation of matrix_inverse op helper with cuda kernels. The first revision. * - start working on gruCell_bp Signed-off-by: Yurii <yurii@skymind.io> * Implementation of matrix_inverse helper. * - further work on new gruCell_bp Signed-off-by: Yurii <yurii@skymind.io> * cuBLAS related fixes Signed-off-by: raver119 <raver119@gmail.com> * calculateOutputShapes() now passes device buffers as well Signed-off-by: raver119 <raver119@gmail.com> * special concat/average/accumulate init host pointers now Signed-off-by: raver119 <raver119@gmail.com> * few more tweaks Signed-off-by: raver119 <raver119@gmail.com> * additional CudaDataBufferFactory signatures certain for data types Signed-off-by: raver119 <raver119@gmail.com> * cuSolver host buffer Signed-off-by: raver119 <raver119@gmail.com> * buffer to buffer memcpy host ptr allocation Signed-off-by: raver119 <raver119@gmail.com>master
parent
cb6654bebb
commit
c969b724bb
|
@ -288,7 +288,7 @@ if(CUDA_BLAS)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
||||||
target_link_libraries(${LIBND4J_NAME} ${CUDA_LIBRARIES} ${CUDA_CUBLAS_LIBRARIES})
|
target_link_libraries(${LIBND4J_NAME} ${CUDA_LIBRARIES} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_cusolver_LIBRARY})
|
||||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/cuda)
|
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}/cuda)
|
||||||
|
|
||||||
install(TARGETS ${LIBND4J_NAME} DESTINATION .)
|
install(TARGETS ${LIBND4J_NAME} DESTINATION .)
|
||||||
|
|
|
@ -1143,9 +1143,9 @@ namespace nd4j {
|
||||||
* - down from main diagonal starting at subdiagonal number "lower" if direction = 'd' (down) or 'b' (both)
|
* - down from main diagonal starting at subdiagonal number "lower" if direction = 'd' (down) or 'b' (both)
|
||||||
* - up from main diagonal starting at superdiagonal number "upper"if direction = 'u' (up) or 'b' (both)
|
* - up from main diagonal starting at superdiagonal number "upper"if direction = 'u' (up) or 'b' (both)
|
||||||
* direction - in what direction to fill matrix. There are 3 possible directions:
|
* direction - in what direction to fill matrix. There are 3 possible directions:
|
||||||
* 'u' - fill up, mathematically this corresponds to lower triangular matrix, parameter "lower" is not taken into account
|
* 'u' - fill up, mathematically this corresponds to lower triangular matrix, subdiagonal "lower" unaffected
|
||||||
* 'l' - fill down, mathematically this corresponds to upper triangular matrix, parameter "upper" is not taken into account
|
* 'l' - fill down, mathematically this corresponds to upper triangular matrix, superdiagonal "upper" remains unaffected
|
||||||
* 'b' - fill in both directions, both parameters "lower" and "upper" are taken into account
|
* 'b' - fill in both directions, both "lower" and "upper" are taken into account
|
||||||
* rest of target elements are equal to this array elements
|
* rest of target elements are equal to this array elements
|
||||||
* target and this array should have same shapes, except when this_rank = 1 (in that case should be target_rank = 2)
|
* target and this array should have same shapes, except when this_rank = 1 (in that case should be target_rank = 2)
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -2371,7 +2371,7 @@ void NDArray::tileToShape(const std::vector<Nd4jLong>& shape, NDArray* target) {
|
||||||
if(i > rankOf())
|
if(i > rankOf())
|
||||||
repeats[newRank-i] = shape[newRank - i];
|
repeats[newRank-i] = shape[newRank - i];
|
||||||
else
|
else
|
||||||
repeats[newRank-i] = shape[newRank - i] / thisShape[rankOf() - i];
|
repeats[newRank-i] = shape[newRank - i] / thisShape[rankOf() - i];
|
||||||
}
|
}
|
||||||
|
|
||||||
tilei(repeats);
|
tilei(repeats);
|
||||||
|
|
|
@ -228,7 +228,7 @@ void* NDArray::getSpecialBuffer() const {
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
// change an array by repeating it the number of times given by reps.
|
// change an array by repeating it the number of times given by reps.
|
||||||
NDArray NDArray::tile(const std::vector<Nd4jLong>& reps) const {
|
NDArray NDArray::tile(const std::vector<Nd4jLong>& reps) const {
|
||||||
int dim = reps.size();
|
const int repsSize = reps.size();
|
||||||
int product = 1;
|
int product = 1;
|
||||||
for(const auto& item : reps)
|
for(const auto& item : reps)
|
||||||
product *= item;
|
product *= item;
|
||||||
|
@ -236,11 +236,11 @@ NDArray NDArray::tile(const std::vector<Nd4jLong>& reps) const {
|
||||||
throw std::runtime_error("NDArray::tile method: one of the elements in reps array is zero !");
|
throw std::runtime_error("NDArray::tile method: one of the elements in reps array is zero !");
|
||||||
|
|
||||||
int rankOld = rankOf();
|
int rankOld = rankOf();
|
||||||
int diff = rankOld - dim;
|
int diff = rankOld - repsSize;
|
||||||
if(product==1) { // in this case 2 possibilities are present: just reshape or nothing to do
|
if(product==1) { // in this case 2 possibilities are present: just reshape or nothing to do
|
||||||
NDArray result(*this);
|
NDArray result(*this);
|
||||||
if(diff < 0) { // reshape to higher dimension
|
if(diff < 0) { // reshape to higher dimension
|
||||||
std::vector<Nd4jLong> shapeNew = reps; // need to have unities at first "diff" positions of new shape
|
std::vector<Nd4jLong> shapeNew = reps; // there is requirement to have unities at first "diff" positions of new shape
|
||||||
memcpy(&shapeNew[-diff], result.getShapeInfo()+1, rankOld * sizeof(Nd4jLong)); // put old shape numbers at rest of positions
|
memcpy(&shapeNew[-diff], result.getShapeInfo()+1, rankOld * sizeof(Nd4jLong)); // put old shape numbers at rest of positions
|
||||||
result.reshapei(ordering(), shapeNew);
|
result.reshapei(ordering(), shapeNew);
|
||||||
}
|
}
|
||||||
|
|
|
@ -1489,14 +1489,8 @@ void NativeOps::specialConcat(
|
||||||
Nd4jPointer *inputShapeInfo,
|
Nd4jPointer *inputShapeInfo,
|
||||||
void *dZ,
|
void *dZ,
|
||||||
Nd4jLong *dZShapeInfo, Nd4jPointer *tadPointers, Nd4jPointer *offsetPointers) {
|
Nd4jLong *dZShapeInfo, Nd4jPointer *tadPointers, Nd4jPointer *offsetPointers) {
|
||||||
nd4j::SpecialMethods<float>::concatCpuGeneric(
|
|
||||||
dimension,
|
|
||||||
numArrays,
|
|
||||||
data,
|
|
||||||
inputShapeInfo,
|
|
||||||
dZ,
|
|
||||||
dZShapeInfo);
|
|
||||||
|
|
||||||
|
BUILD_SINGLE_SELECTOR(ArrayOptions::dataType(dZShapeInfo), nd4j::SpecialMethods ,::concatCpuGeneric(dimension, numArrays, data, inputShapeInfo, dZ, dZShapeInfo), LIBND4J_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -2578,8 +2572,9 @@ nd4j::ShapeList* _calculateOutputShapes(Nd4jPointer* extraPointers, nd4j::ops::D
|
||||||
|
|
||||||
// we shouldn't copy buffer if that's empty array
|
// we shouldn't copy buffer if that's empty array
|
||||||
void *buffer_ = nd4j::ArrayOptions::arrayType(shape_) == ArrayType::EMPTY ? nullptr : inputBuffers[e];
|
void *buffer_ = nd4j::ArrayOptions::arrayType(shape_) == ArrayType::EMPTY ? nullptr : inputBuffers[e];
|
||||||
|
void *bufferD_ = nd4j::ArrayOptions::arrayType(shape_) == ArrayType::EMPTY ? nullptr : inputBuffers[e + numInputShapes];
|
||||||
|
|
||||||
auto array = new nd4j::NDArray(buffer_, shape_);
|
auto array = new nd4j::NDArray(buffer_, bufferD_, shape_);
|
||||||
|
|
||||||
// block should contain references to proper variable
|
// block should contain references to proper variable
|
||||||
varSpace.putVariable(1, e, array);
|
varSpace.putVariable(1, e, array);
|
||||||
|
|
|
@ -86,6 +86,7 @@ class ND4J_EXPORT LaunchContext {
|
||||||
|
|
||||||
FORCEINLINE void setCudaStream(cudaStream_t* cudaStream) {_cudaStream = cudaStream;};
|
FORCEINLINE void setCudaStream(cudaStream_t* cudaStream) {_cudaStream = cudaStream;};
|
||||||
FORCEINLINE void setCudaSpecialStream(cudaStream_t* cudaStream) {_cudaSpecialStream = cudaStream;};
|
FORCEINLINE void setCudaSpecialStream(cudaStream_t* cudaStream) {_cudaSpecialStream = cudaStream;};
|
||||||
|
FORCEINLINE void setCublasHandle(void *handle) {_cublasHandle = handle; };
|
||||||
|
|
||||||
|
|
||||||
#endif // JCPP
|
#endif // JCPP
|
||||||
|
|
|
@ -454,6 +454,9 @@ namespace nd4j {
|
||||||
#ifdef __CUDABLAS__
|
#ifdef __CUDABLAS__
|
||||||
_context = new LaunchContext(cudaStream, reductionPointer, allocationPointer);
|
_context = new LaunchContext(cudaStream, reductionPointer, allocationPointer);
|
||||||
|
|
||||||
|
// FIXME: either pass handle from outside, or make sure outside we use the same handle
|
||||||
|
_context->setCublasHandle(LaunchContext::defaultContext()->getCublasHandle());
|
||||||
|
|
||||||
for (auto v: _fastpath_out)
|
for (auto v: _fastpath_out)
|
||||||
v->setContext(_context);
|
v->setContext(_context);
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,6 @@
|
||||||
#include <cublas_v2.h>
|
#include <cublas_v2.h>
|
||||||
#include "../MmulHelper.h"
|
#include "../MmulHelper.h"
|
||||||
#include <specials_cuda.h>
|
#include <specials_cuda.h>
|
||||||
#include <helpers/PointersManager.h>
|
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
|
|
||||||
|
|
|
@ -552,7 +552,7 @@ std::vector<int> ShapeUtils::getDimsWithSameShape(const NDArray& max, const NDAr
|
||||||
// evaluate shapeInfo for resulting array from tile operation
|
// evaluate shapeInfo for resulting array from tile operation
|
||||||
Nd4jLong* ShapeUtils::evalTileShapeInfo(const NDArray& arr, const std::vector<Nd4jLong>& reps, nd4j::memory::Workspace* workspace) {
|
Nd4jLong* ShapeUtils::evalTileShapeInfo(const NDArray& arr, const std::vector<Nd4jLong>& reps, nd4j::memory::Workspace* workspace) {
|
||||||
// check whether reps contains at least one zero (then throw exception) or whether all elements in reps are unities (then simply reshape or do nothing)
|
// check whether reps contains at least one zero (then throw exception) or whether all elements in reps are unities (then simply reshape or do nothing)
|
||||||
int dim = reps.size();
|
int repsSize = reps.size();
|
||||||
int product = 1;
|
int product = 1;
|
||||||
for(const auto& item : reps)
|
for(const auto& item : reps)
|
||||||
product *= item;
|
product *= item;
|
||||||
|
@ -560,24 +560,24 @@ Nd4jLong* ShapeUtils::evalTileShapeInfo(const NDArray& arr, const std::vector<Nd
|
||||||
throw std::runtime_error("NDArray::tile method: one of the elements in reps array is zero !");
|
throw std::runtime_error("NDArray::tile method: one of the elements in reps array is zero !");
|
||||||
|
|
||||||
int rankOld = arr.rankOf();
|
int rankOld = arr.rankOf();
|
||||||
int diff = rankOld - dim;
|
int diff = rankOld - repsSize;
|
||||||
|
|
||||||
// evaluate new shapeInfo
|
// evaluate new shapeInfo
|
||||||
Nd4jLong* newShapeInfo = nullptr;
|
Nd4jLong* newShapeInfo = nullptr;
|
||||||
if(diff < 0) {
|
if(diff < 0) {
|
||||||
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(dim), Nd4jLong);
|
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(repsSize), Nd4jLong);
|
||||||
newShapeInfo[0] = dim; // set new rank
|
newShapeInfo[0] = repsSize; // set new rank
|
||||||
for(int i=1; i <= -diff; ++i)
|
for(int i=1; i <= -diff; ++i)
|
||||||
newShapeInfo[i] = 1; // set unities to be new dimensions at left-hand side of newShapeInfo shape place
|
newShapeInfo[i] = 1; // set unities to be new dimensions at left-hand side of newShapeInfo shape place
|
||||||
memcpy(newShapeInfo + 1 - diff, arr.getShapeInfo() + 1, rankOld*sizeof(Nd4jLong)); // copy old dimensions to the right-hand side of newShapeInfo shape place
|
memcpy(newShapeInfo + 1 - diff, arr.getShapeInfo() + 1, rankOld*sizeof(Nd4jLong)); // copy old dimensions to the right-hand side of newShapeInfo shape place
|
||||||
for(int i=1; i <= dim; ++i)
|
for(int i=1; i <= repsSize; ++i)
|
||||||
newShapeInfo[i] *= reps[i - 1]; // set new shape by multiplying old dimensions by corresponding numbers from reps
|
newShapeInfo[i] *= reps[i - 1]; // set new shape by multiplying old dimensions by corresponding numbers from reps
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(rankOld), Nd4jLong);
|
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(rankOld), Nd4jLong);
|
||||||
memcpy(newShapeInfo, arr.getShapeInfo(), shape::shapeInfoByteLength(rankOld)); // copy all elements of _shapeInfo to newShapeInfo
|
memcpy(newShapeInfo, arr.getShapeInfo(), shape::shapeInfoByteLength(rankOld)); // copy all elements of _shapeInfo to newShapeInfo
|
||||||
for(int i=1; i <= dim; ++i)
|
for(int i=1; i <= repsSize; ++i)
|
||||||
newShapeInfo[rankOld + 1 - i] *= reps[dim - i]; // set new shape by multiplying old dimensions by corresponding numbers from reps
|
newShapeInfo[rankOld + 1 - i] *= reps[repsSize - i]; // set new shape by multiplying old dimensions by corresponding numbers from reps
|
||||||
}
|
}
|
||||||
shape::updateStrides(newShapeInfo, arr.ordering());
|
shape::updateStrides(newShapeInfo, arr.ordering());
|
||||||
ArrayOptions::setDataType(newShapeInfo, arr.dataType());
|
ArrayOptions::setDataType(newShapeInfo, arr.dataType());
|
||||||
|
|
|
@ -889,7 +889,9 @@ namespace shape {
|
||||||
* @param indices the indices to iterate over
|
* @param indices the indices to iterate over
|
||||||
* @return the double at the specified index
|
* @return the double at the specified index
|
||||||
*/
|
*/
|
||||||
ND4J_EXPORT _CUDA_HD Nd4jLong getOffset(Nd4jLong baseOffset, const Nd4jLong *shape, const Nd4jLong *stride, const Nd4jLong *indices,int rank);
|
ND4J_EXPORT _CUDA_HD Nd4jLong getOffset(Nd4jLong baseOffset, const Nd4jLong *shape, const Nd4jLong *stride, const Nd4jLong *indices, const int rank);
|
||||||
|
ND4J_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const Nd4jLong *indices, Nd4jLong baseOffset = 0);
|
||||||
|
ND4J_EXPORT Nd4jLong getOffset(const Nd4jLong *shapeInfo, const std::vector<uint>& indices);
|
||||||
|
|
||||||
ND4J_EXPORT _CUDA_HD Nd4jLong* createShapeInfo(Nd4jLong *shape, Nd4jLong *stride, int rank);
|
ND4J_EXPORT _CUDA_HD Nd4jLong* createShapeInfo(Nd4jLong *shape, Nd4jLong *stride, int rank);
|
||||||
|
|
||||||
|
@ -987,7 +989,8 @@ namespace shape {
|
||||||
|
|
||||||
// calculate offsets of max-array, these output offsets correspond to one minIdx index of min-array which is sub-array of max-array
|
// calculate offsets of max-array, these output offsets correspond to one minIdx index of min-array which is sub-array of max-array
|
||||||
// dimsToExclude - should be sorted in increasing order
|
// dimsToExclude - should be sorted in increasing order
|
||||||
ND4J_EXPORT _CUDA_HD int outerArrayOffsets(Nd4jLong* maxOffsets, const Nd4jLong minIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, const int* dimsToExclude = nullptr);
|
// memBuff - auxiliary memory buffer (size = 2 * max_rank) for coordinates and increments storing, should be passed from outside
|
||||||
|
ND4J_EXPORT _CUDA_HD int outerArrayOffsets(Nd4jLong* maxOffsets, const Nd4jLong minIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, Nd4jLong* memBuff, const int* dimsToExclude = nullptr);
|
||||||
|
|
||||||
// calculates offsets for entities (elements or sub-arrays), shape in context of sub-array means dimensions excluded from outer array
|
// calculates offsets for entities (elements or sub-arrays), shape in context of sub-array means dimensions excluded from outer array
|
||||||
// rank is equal to size of shape
|
// rank is equal to size of shape
|
||||||
|
@ -3200,7 +3203,7 @@ INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, cons
|
||||||
* @param indices the indices to iterate over
|
* @param indices the indices to iterate over
|
||||||
* @return the double at the specified index
|
* @return the double at the specified index
|
||||||
*/
|
*/
|
||||||
INLINEDEF _CUDA_HD Nd4jLong getOffset(Nd4jLong baseOffset, const Nd4jLong *shape, const Nd4jLong *stride, const Nd4jLong *indices, int rank) {
|
INLINEDEF _CUDA_HD Nd4jLong getOffset(Nd4jLong baseOffset, const Nd4jLong *shape, const Nd4jLong *stride, const Nd4jLong *indices, const int rank) {
|
||||||
Nd4jLong offset = baseOffset;
|
Nd4jLong offset = baseOffset;
|
||||||
for(int i = 0; i < rank; i++) {
|
for(int i = 0; i < rank; i++) {
|
||||||
if(shape[i] != 1)
|
if(shape[i] != 1)
|
||||||
|
@ -3210,6 +3213,21 @@ INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, cons
|
||||||
return offset;
|
return offset;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const Nd4jLong *indices, Nd4jLong baseOffset) {
|
||||||
|
return shape::getOffset(baseOffset, shape::shapeOf(const_cast<Nd4jLong*>(shapeInfo)), shape::stride(const_cast<Nd4jLong*>(shapeInfo)), indices, shapeInfo[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
INLINEDEF Nd4jLong getOffset(const Nd4jLong *shapeInfo, const std::vector<uint>& indices) {
|
||||||
|
|
||||||
|
Nd4jLong offset = 0;
|
||||||
|
|
||||||
|
for(uint i = 0; i < shapeInfo[0]; ++i)
|
||||||
|
if(shapeInfo[i + 1] != 1)
|
||||||
|
offset += indices[i] * shapeInfo[shapeInfo[0] + i + 1];
|
||||||
|
|
||||||
|
return offset;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -4226,21 +4244,18 @@ INLINEDEF _CUDA_HD void maxIndToMinInd(Nd4jLong* maxIdxs, Nd4jLong* minIdxs, con
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
INLINEDEF _CUDA_HD int outerArrayOffsets(Nd4jLong* maxOffsets, const Nd4jLong minIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, const int* dimsToExclude) {
|
INLINEDEF _CUDA_HD int outerArrayOffsets(Nd4jLong* maxOffsets, const Nd4jLong minIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, Nd4jLong* memBuff, const int* dimsToExclude) {
|
||||||
|
|
||||||
const auto rankMin = shape::rank(minShapeInfo);
|
const auto rankMin = shape::rank(minShapeInfo);
|
||||||
const auto rankMax = shape::rank(maxShapeInfo);
|
const auto rankMax = shape::rank(maxShapeInfo);
|
||||||
|
|
||||||
// if(rankMin >= rankMax)
|
// if(rankMin >= rankMax)
|
||||||
// throw std::runtime_error("shape::subArrayIndex method: rank of min array should be smaller then rank of max array!");
|
// throw std::runtime_error("shape::subArrayIndex method: rank of min array should be smaller then rank of max array!");
|
||||||
// if(rankMax > MAX_RANK/2)
|
|
||||||
// throw std::runtime_error("shape::subArrayIndex method: rank of max array should be <= MAX_RANK/2 !");
|
|
||||||
|
|
||||||
const auto diff = rankMax - rankMin; // the size of dimsToExclude is equal to diff
|
const auto diff = rankMax - rankMin; // the size of dimsToExclude is equal to diff
|
||||||
|
|
||||||
Nd4jLong buffer[MAX_RANK];
|
Nd4jLong* indices = memBuff;
|
||||||
Nd4jLong* indices = buffer;
|
Nd4jLong* increment = memBuff + rankMax;
|
||||||
Nd4jLong* increment = buffer + MAX_RANK/2;
|
|
||||||
|
|
||||||
int N, minI, maxI;
|
int N, minI, maxI;
|
||||||
|
|
||||||
|
|
|
@ -92,7 +92,6 @@ __global__ void execOesTadKernelKey(void *vx, Nd4jLong *xShapeInfo,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
template<typename T>
|
template<typename T>
|
||||||
__global__ void execOesTadKernel(void *vx, Nd4jLong *xShapeInfo,
|
__global__ void execOesTadKernel(void *vx, Nd4jLong *xShapeInfo,
|
||||||
|
|
|
@ -39,7 +39,6 @@ CONFIGURABLE_OP_IMPL(prelu, 2, 1, true, 0, 0) {
|
||||||
std::vector<int> sharedAxes = *block.getIArguments();
|
std::vector<int> sharedAxes = *block.getIArguments();
|
||||||
|
|
||||||
const int inputRank = input->rankOf();
|
const int inputRank = input->rankOf();
|
||||||
const int alphaRank = alpha->rankOf();
|
|
||||||
const int numSharedAxes = sharedAxes.size(); // can be zero as well
|
const int numSharedAxes = sharedAxes.size(); // can be zero as well
|
||||||
const Nd4jLong inputLen = input->lengthOf();
|
const Nd4jLong inputLen = input->lengthOf();
|
||||||
const Nd4jLong alphaLen = alpha->lengthOf();
|
const Nd4jLong alphaLen = alpha->lengthOf();
|
||||||
|
@ -91,7 +90,6 @@ CONFIGURABLE_OP_IMPL(prelu_bp, 3, 2, true, 0, 0) {
|
||||||
std::vector<int> sharedAxes = *block.getIArguments();
|
std::vector<int> sharedAxes = *block.getIArguments();
|
||||||
|
|
||||||
const int inputRank = input->rankOf();
|
const int inputRank = input->rankOf();
|
||||||
const int alphaRank = alpha->rankOf();
|
|
||||||
const int numSharedAxes = sharedAxes.size(); // can be zero as well
|
const int numSharedAxes = sharedAxes.size(); // can be zero as well
|
||||||
const Nd4jLong inputLen = input->lengthOf();
|
const Nd4jLong inputLen = input->lengthOf();
|
||||||
const Nd4jLong alphaLen = alpha->lengthOf();
|
const Nd4jLong alphaLen = alpha->lengthOf();
|
||||||
|
|
|
@ -29,18 +29,18 @@ namespace ops {
|
||||||
|
|
||||||
CUSTOM_OP_IMPL(svd, 1, 1, false, 0, 3) {
|
CUSTOM_OP_IMPL(svd, 1, 1, false, 0, 3) {
|
||||||
auto x = INPUT_VARIABLE(0);
|
auto x = INPUT_VARIABLE(0);
|
||||||
|
|
||||||
const int rank = x->rankOf();
|
const int rank = x->rankOf();
|
||||||
REQUIRE_TRUE(rank >= 2 , 0, "SVD OP: the rank of input array must be >=2, but got %i instead!", rank);
|
REQUIRE_TRUE(rank >= 2 , 0, "SVD OP: the rank of input array must be >=2, but got %i instead!", rank);
|
||||||
|
|
||||||
const bool fullUV = (bool)INT_ARG(0);
|
const bool fullUV = (bool)INT_ARG(0);
|
||||||
const bool calcUV = (bool)INT_ARG(1);
|
const bool calcUV = (bool)INT_ARG(1);
|
||||||
const int switchNum = INT_ARG(2);
|
const int switchNum = INT_ARG(2);
|
||||||
|
|
||||||
#ifndef __CUDABLAS__
|
// #ifndef __CUDABLAS__
|
||||||
helpers::svd(block.launchContext(), x, {OUTPUT_VARIABLE(0), calcUV ? OUTPUT_VARIABLE(1) : nullptr, calcUV ? OUTPUT_VARIABLE(2) : nullptr}, fullUV, calcUV, switchNum);
|
helpers::svd(block.launchContext(), x, {OUTPUT_VARIABLE(0), calcUV ? OUTPUT_VARIABLE(1) : nullptr, calcUV ? OUTPUT_VARIABLE(2) : nullptr}, fullUV, calcUV, switchNum);
|
||||||
#endif
|
// #endif
|
||||||
|
|
||||||
return Status::OK();;
|
return Status::OK();;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -56,28 +56,28 @@ DECLARE_SHAPE_FN(svd) {
|
||||||
auto inShapeInfo = inputShape->at(0);
|
auto inShapeInfo = inputShape->at(0);
|
||||||
bool fullUV = (bool)INT_ARG(0);
|
bool fullUV = (bool)INT_ARG(0);
|
||||||
bool calcUV = (bool)INT_ARG(1);
|
bool calcUV = (bool)INT_ARG(1);
|
||||||
|
|
||||||
const int rank = inShapeInfo[0];
|
const int rank = inShapeInfo[0];
|
||||||
REQUIRE_TRUE(rank >= 2 , 0, "SVD OP: the rank of input array must be >=2, but got %i instead!", rank);
|
REQUIRE_TRUE(rank >= 2 , 0, "SVD OP: the rank of input array must be >=2, but got %i instead!", rank);
|
||||||
|
|
||||||
const int diagSize = inShapeInfo[rank] < inShapeInfo[rank-1] ? inShapeInfo[rank] : inShapeInfo[rank-1];
|
const int diagSize = inShapeInfo[rank] < inShapeInfo[rank-1] ? inShapeInfo[rank] : inShapeInfo[rank-1];
|
||||||
|
|
||||||
Nd4jLong* sShapeInfo(nullptr);
|
Nd4jLong* sShapeInfo(nullptr);
|
||||||
if(rank == 2) {
|
if(rank == 2) {
|
||||||
ALLOCATE(sShapeInfo, block.getWorkspace(), shape::shapeInfoLength(1), Nd4jLong);
|
ALLOCATE(sShapeInfo, block.getWorkspace(), shape::shapeInfoLength(1), Nd4jLong);
|
||||||
sShapeInfo[0] = 1;
|
sShapeInfo[0] = 1;
|
||||||
sShapeInfo[1] = diagSize;
|
sShapeInfo[1] = diagSize;
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
ALLOCATE(sShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank-1), Nd4jLong);
|
ALLOCATE(sShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank-1), Nd4jLong);
|
||||||
sShapeInfo[0] = rank - 1;
|
sShapeInfo[0] = rank - 1;
|
||||||
for(int i=1; i <= rank-2; ++i)
|
for(int i=1; i <= rank-2; ++i)
|
||||||
sShapeInfo[i] = inShapeInfo[i];
|
sShapeInfo[i] = inShapeInfo[i];
|
||||||
sShapeInfo[rank-1] = diagSize;
|
sShapeInfo[rank-1] = diagSize;
|
||||||
}
|
}
|
||||||
|
|
||||||
ShapeUtils::updateStridesAndType(sShapeInfo, inShapeInfo, shape::order(inShapeInfo));
|
ShapeUtils::updateStridesAndType(sShapeInfo, inShapeInfo, shape::order(inShapeInfo));
|
||||||
|
|
||||||
if(calcUV){
|
if(calcUV){
|
||||||
|
|
||||||
Nd4jLong *uShapeInfo(nullptr), *vShapeInfo(nullptr);
|
Nd4jLong *uShapeInfo(nullptr), *vShapeInfo(nullptr);
|
||||||
|
@ -93,10 +93,10 @@ DECLARE_SHAPE_FN(svd) {
|
||||||
vShapeInfo[rank-1] = vShapeInfo[rank];
|
vShapeInfo[rank-1] = vShapeInfo[rank];
|
||||||
vShapeInfo[rank] = diagSize;
|
vShapeInfo[rank] = diagSize;
|
||||||
}
|
}
|
||||||
|
|
||||||
shape::updateStrides(uShapeInfo, shape::order(inShapeInfo));
|
shape::updateStrides(uShapeInfo, shape::order(inShapeInfo));
|
||||||
shape::updateStrides(vShapeInfo, shape::order(inShapeInfo));
|
shape::updateStrides(vShapeInfo, shape::order(inShapeInfo));
|
||||||
|
|
||||||
auto result = SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(sShapeInfo)), ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(uShapeInfo)), ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(vShapeInfo)));
|
auto result = SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(sShapeInfo)), ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(uShapeInfo)), ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(vShapeInfo)));
|
||||||
RELEASE(sShapeInfo, block.workspace());
|
RELEASE(sShapeInfo, block.workspace());
|
||||||
RELEASE(uShapeInfo, block.workspace());
|
RELEASE(uShapeInfo, block.workspace());
|
||||||
|
|
|
@ -35,8 +35,8 @@ namespace ops {
|
||||||
REQUIRE_TRUE(input->rankOf() == 4, 0, "Dilation2D: input should be 4D");
|
REQUIRE_TRUE(input->rankOf() == 4, 0, "Dilation2D: input should be 4D");
|
||||||
REQUIRE_TRUE(weights->rankOf() == 3, 0, "Dilation2D: weights should be 3D");
|
REQUIRE_TRUE(weights->rankOf() == 3, 0, "Dilation2D: weights should be 3D");
|
||||||
|
|
||||||
const int batch_size = input->sizeAt(0);
|
const int bS = input->sizeAt(0);
|
||||||
const int depth = input->sizeAt(3);
|
const int iC = input->sizeAt(3);
|
||||||
const bool isSameShape = INT_ARG(0) == 1;
|
const bool isSameShape = INT_ARG(0) == 1;
|
||||||
|
|
||||||
REQUIRE_TRUE(input->sizeAt(3) == weights->sizeAt(2), 0, "Dilation2D: number of input channels doesn't match number of channels in weights: %i vs %i", input->sizeAt(3), weights->sizeAt(2));
|
REQUIRE_TRUE(input->sizeAt(3) == weights->sizeAt(2), 0, "Dilation2D: number of input channels doesn't match number of channels in weights: %i vs %i", input->sizeAt(3), weights->sizeAt(2));
|
||||||
|
@ -66,17 +66,17 @@ namespace ops {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
int stride_rows = 0, stride_cols = 0;
|
int sH = 0, sW = 0;
|
||||||
int rate_rows = 0, rate_cols = 0;
|
int dH = 0, dW = 0;
|
||||||
int pad_top = 0, pad_left = 0;
|
int pH = 0, pW = 0;
|
||||||
int out_rows = 0, out_cols = 0;
|
int oH = 0, oW = 0;
|
||||||
|
|
||||||
helpers::dilation_hw(block.launchContext(), input->shapeInfo(), weights->shapeInfo(), strides, rates, isSameShape, &stride_rows, &stride_cols, &rate_rows, &rate_cols, &pad_top, &pad_left, &out_rows, &out_cols);
|
helpers::dilation_hw(block.launchContext(), input->shapeInfo(), weights->shapeInfo(), strides, rates, isSameShape, &sH, &sW, &pH, &pW, &dH, &dW, &oH, &oW);
|
||||||
|
|
||||||
|
|
||||||
REQUIRE_TRUE(out_rows > 0 && out_cols > 0, 0, "Dilation2D: outY and outX should have positive values, but got [%i, %i] instead", out_rows, out_cols);
|
REQUIRE_TRUE(oH > 0 && oW > 0, 0, "Dilation2D: outY and outX should have positive values, but got [%i, %i] instead", oH, oW);
|
||||||
|
|
||||||
helpers::dilation2d(block.launchContext(), input, weights, output, stride_rows, stride_cols, rate_rows, rate_cols, pad_top, pad_left);
|
helpers::dilation2d(block.launchContext(), input, weights, output, sH, sW, pH, pW, dH, dW);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -91,8 +91,8 @@ namespace ops {
|
||||||
auto input = inputShape->at(0);
|
auto input = inputShape->at(0);
|
||||||
auto weights = inputShape->at(1);
|
auto weights = inputShape->at(1);
|
||||||
|
|
||||||
const int batch_size = shape::sizeAt(input, 0);
|
const int bS = shape::sizeAt(input, 0);
|
||||||
const int depth = shape::sizeAt(input, 3);
|
const int iC = shape::sizeAt(input, 3);
|
||||||
const bool isSameShape = INT_ARG(0) == 1;
|
const bool isSameShape = INT_ARG(0) == 1;
|
||||||
|
|
||||||
std::vector<int> strides(4);
|
std::vector<int> strides(4);
|
||||||
|
@ -121,14 +121,14 @@ namespace ops {
|
||||||
strides[cnt] = INT_ARG(e++);
|
strides[cnt] = INT_ARG(e++);
|
||||||
}
|
}
|
||||||
|
|
||||||
int stride_rows = 0, stride_cols = 0;
|
int sH = 0, sW = 0;
|
||||||
int rate_rows = 0, rate_cols = 0;
|
int dH = 0, dW = 0;
|
||||||
int pad_top = 0, pad_left = 0;
|
int pH = 0, pW = 0;
|
||||||
int out_rows = 0, out_cols = 0;
|
int oH = 0, oW = 0;
|
||||||
|
|
||||||
helpers::dilation_hw(block.launchContext(), input, weights, strides, rates, isSameShape, &stride_rows, &stride_cols, &rate_rows, &rate_cols, &pad_top, &pad_left, &out_rows, &out_cols);
|
helpers::dilation_hw(block.launchContext(), input, weights, strides, rates, isSameShape, &sH, &sW, &pH, &pW, &dH, &dW, &oH, &oW);
|
||||||
|
|
||||||
std::array<Nd4jLong, 4> shape = {{batch_size, out_rows, out_cols, depth}};
|
std::array<Nd4jLong, 4> shape = {{bS, oH, oW, iC}};
|
||||||
newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(weights), 'c', 4, shape.data());
|
newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(weights), 'c', 4, shape.data());
|
||||||
return SHAPELIST(newShape);
|
return SHAPELIST(newShape);
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,32 +37,32 @@ CUSTOM_OP_IMPL(sparse_softmax_cross_entropy_loss_with_logits, 2, 1, false, 0, 0)
|
||||||
|
|
||||||
const int labelsRank = labels->rankOf();
|
const int labelsRank = labels->rankOf();
|
||||||
const int logitsRank = logits->rankOf();
|
const int logitsRank = logits->rankOf();
|
||||||
|
|
||||||
// input validation
|
// input validation
|
||||||
REQUIRE_TRUE(labelsRank == logitsRank - 1, 0, "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: input arrays should satisfy relation (labels_rank = logits_rank - 1), but got labels_rank = %i and logits_rank = %i instead !", labelsRank, logitsRank);
|
REQUIRE_TRUE(labelsRank == logitsRank - 1, 0, "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: input arrays should satisfy relation (labels_rank = logits_rank - 1), but got labels_rank = %i and logits_rank = %i instead !", labelsRank, logitsRank);
|
||||||
|
|
||||||
std::vector<Nd4jLong> labelsShape = labels->getShapeAsVector(); // this is correct
|
std::vector<Nd4jLong> labelsShape = labels->getShapeAsVector(); // this is correct
|
||||||
std::vector<Nd4jLong> logitsShape = logits->getShapeAsVector();
|
std::vector<Nd4jLong> logitsShape = logits->getShapeAsVector();
|
||||||
logitsShape.pop_back();
|
logitsShape.pop_back();
|
||||||
bool equalSoft = logitsShape == labelsShape;
|
bool equalSoft = logitsShape == labelsShape;
|
||||||
|
|
||||||
REQUIRE_TRUE(equalSoft, 0, "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: wrong shape of labels array, its shape should be the same as logits shape with last dimension excluded, however got labels_shape = %s and logits_shape = %s instead !", ShapeUtils::shapeAsString(labelsShape).c_str(), ShapeUtils::shapeAsString(logitsShape).c_str());
|
REQUIRE_TRUE(equalSoft, 0, "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: wrong shape of labels array, its shape should be the same as logits shape with last dimension excluded, however got labels_shape = %s and logits_shape = %s instead !", ShapeUtils::shapeAsString(labelsShape).c_str(), ShapeUtils::shapeAsString(logitsShape).c_str());
|
||||||
|
|
||||||
std::vector<int> dimension = {-1};
|
std::vector<int> dimension = {-1};
|
||||||
|
|
||||||
auto maxAlongDim = logits->reduceAlongDims(reduce::Max, dimension, true);
|
auto maxAlongDim = logits->reduceAlongDims(reduce::Max, dimension, true);
|
||||||
auto logitsExp = (*logits - maxAlongDim).transform(transform::Exp, nullptr);
|
auto logitsExp = (*logits - maxAlongDim).transform(transform::Exp, nullptr);
|
||||||
auto logSoftMax = ( logitsExp / logitsExp.reduceAlongDims(reduce::Sum, dimension, true) ).transform(transform::Log);
|
auto logSoftMax = -(( logitsExp / logitsExp.reduceAlongDims(reduce::Sum, dimension, true) ).transform(transform::Log));
|
||||||
|
|
||||||
helpers::scatterForLoss(block.launchContext(), *labels, -logSoftMax, *output, false);
|
helpers::scatterForLoss(block.launchContext(), *labels, logSoftMax, *output, false);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
DECLARE_TYPES(sparse_softmax_cross_entropy_loss_with_logits) {
|
DECLARE_TYPES(sparse_softmax_cross_entropy_loss_with_logits) {
|
||||||
|
|
||||||
getOpDescriptor()->setAllowedInputTypes(0, {ALL_INTS})->setAllowedInputTypes(1, {ALL_FLOATS})->setAllowedOutputTypes({ALL_FLOATS});
|
getOpDescriptor()->setAllowedInputTypes(0, {ALL_INTS})->setAllowedInputTypes(1, {ALL_FLOATS})->setAllowedOutputTypes({ALL_FLOATS});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -79,8 +79,8 @@ DECLARE_SHAPE_FN(sparse_softmax_cross_entropy_loss_with_logits) {
|
||||||
if (labelsShapeInfo[i] != logitsShapeInfo[i]) {
|
if (labelsShapeInfo[i] != logitsShapeInfo[i]) {
|
||||||
equalSoft = false;
|
equalSoft = false;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
REQUIRE_TRUE(equalSoft, 0, "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: wrong shape of labels array, its shape should be the same as logits shape with last dimension excluded, however got labels_shape = %s and logits_shape = %s instead !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str());
|
REQUIRE_TRUE(equalSoft, 0, "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: wrong shape of labels array, its shape should be the same as logits shape with last dimension excluded, however got labels_shape = %s and logits_shape = %s instead !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str());
|
||||||
|
|
||||||
auto outShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, logitsShapeInfo, false, block.getWorkspace());
|
auto outShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, logitsShapeInfo, false, block.getWorkspace());
|
||||||
|
@ -96,7 +96,7 @@ DECLARE_SHAPE_FN(sparse_softmax_cross_entropy_loss_with_logits) {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
CUSTOM_OP_IMPL(sparse_softmax_cross_entropy_loss_with_logits_grad, 2, 1, false, 0, 0) {
|
CUSTOM_OP_IMPL(sparse_softmax_cross_entropy_loss_with_logits_grad, 2, 1, false, 0, 0) {
|
||||||
|
|
||||||
auto labels = INPUT_VARIABLE(0);
|
auto labels = INPUT_VARIABLE(0);
|
||||||
auto logits = INPUT_VARIABLE(1);
|
auto logits = INPUT_VARIABLE(1);
|
||||||
|
|
||||||
|
@ -104,15 +104,15 @@ CUSTOM_OP_IMPL(sparse_softmax_cross_entropy_loss_with_logits_grad, 2, 1, false,
|
||||||
|
|
||||||
const int labelsRank = labels->rankOf();
|
const int labelsRank = labels->rankOf();
|
||||||
const int logitsRank = logits->rankOf();
|
const int logitsRank = logits->rankOf();
|
||||||
|
|
||||||
// input validation
|
// input validation
|
||||||
REQUIRE_TRUE(labelsRank == logitsRank - 1, 0, "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: input arrays should satisfy relation (labels_rank = logits_rank - 1), but got labels_rank = %i and logits_rank = %i instead !", labelsRank, logitsRank);
|
REQUIRE_TRUE(labelsRank == logitsRank - 1, 0, "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: input arrays should satisfy relation (labels_rank = logits_rank - 1), but got labels_rank = %i and logits_rank = %i instead !", labelsRank, logitsRank);
|
||||||
|
|
||||||
std::vector<Nd4jLong> labelsShape = labels->getShapeAsVector(); // this is correct
|
std::vector<Nd4jLong> labelsShape = labels->getShapeAsVector(); // this is correct
|
||||||
std::vector<Nd4jLong> logitsShape = logits->getShapeAsVector();
|
std::vector<Nd4jLong> logitsShape = logits->getShapeAsVector();
|
||||||
logitsShape.pop_back();
|
logitsShape.pop_back();
|
||||||
bool equalSoft = logitsShape == labelsShape;
|
bool equalSoft = logitsShape == labelsShape;
|
||||||
|
|
||||||
REQUIRE_TRUE(equalSoft, 0, "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: wrong shape of labels array, its shape should be the same as logits shape with last dimension excluded, however got labels_shape = %s and logits_shape = %s instead !", ShapeUtils::shapeAsString(labelsShape).c_str(), ShapeUtils::shapeAsString(logitsShape).c_str());
|
REQUIRE_TRUE(equalSoft, 0, "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: wrong shape of labels array, its shape should be the same as logits shape with last dimension excluded, however got labels_shape = %s and logits_shape = %s instead !", ShapeUtils::shapeAsString(labelsShape).c_str(), ShapeUtils::shapeAsString(logitsShape).c_str());
|
||||||
|
|
||||||
std::vector<int> dimension = {-1};
|
std::vector<int> dimension = {-1};
|
||||||
|
@ -123,7 +123,7 @@ CUSTOM_OP_IMPL(sparse_softmax_cross_entropy_loss_with_logits_grad, 2, 1, false,
|
||||||
// dEdp = softmax - 1 (or 0)
|
// dEdp = softmax - 1 (or 0)
|
||||||
dLdp->assign(softmax);
|
dLdp->assign(softmax);
|
||||||
|
|
||||||
// subtract unities at appropriate indexes of dLdp array
|
// subtract unities at appropriate indexes of dLdp array
|
||||||
helpers::scatterForLoss(block.launchContext(), *labels, *dLdp, *labels /*actually third array is unnecessary for gradient calculation*/, true);
|
helpers::scatterForLoss(block.launchContext(), *labels, *dLdp, *labels /*actually third array is unnecessary for gradient calculation*/, true);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -131,8 +131,8 @@ CUSTOM_OP_IMPL(sparse_softmax_cross_entropy_loss_with_logits_grad, 2, 1, false,
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
DECLARE_TYPES(sparse_softmax_cross_entropy_loss_with_logits_grad) {
|
DECLARE_TYPES(sparse_softmax_cross_entropy_loss_with_logits_grad) {
|
||||||
|
|
||||||
getOpDescriptor()->setAllowedInputTypes(0, {ALL_INTS})->setAllowedInputTypes(1, {ALL_FLOATS})->setAllowedOutputTypes({ALL_FLOATS});
|
getOpDescriptor()->setAllowedInputTypes(0, {ALL_INTS})->setAllowedInputTypes(1, {ALL_FLOATS})->setAllowedOutputTypes({ALL_FLOATS});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -149,14 +149,14 @@ DECLARE_SHAPE_FN(sparse_softmax_cross_entropy_loss_with_logits_grad) {
|
||||||
if (labelsShapeInfo[i] != logitsShapeInfo[i]) {
|
if (labelsShapeInfo[i] != logitsShapeInfo[i]) {
|
||||||
equalSoft = false;
|
equalSoft = false;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
REQUIRE_TRUE(equalSoft, 0, "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: wrong shape of labels array, its shape should be the same as logits shape with last dimension excluded, however got labels_shape = %s and logits_shape = %s instead !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str());
|
REQUIRE_TRUE(equalSoft, 0, "SPARSE_SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: wrong shape of labels array, its shape should be the same as logits shape with last dimension excluded, however got labels_shape = %s and logits_shape = %s instead !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str());
|
||||||
|
|
||||||
DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo));
|
DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo));
|
||||||
|
|
||||||
|
Nd4jLong *dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(logitsShapeInfo, outType, false, block.getWorkspace());
|
||||||
|
|
||||||
Nd4jLong *dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(logitsShapeInfo, outType, false, block.getWorkspace());
|
|
||||||
|
|
||||||
return SHAPELIST(CONSTANT(dLdpShapeInfo));
|
return SHAPELIST(CONSTANT(dLdpShapeInfo));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -113,7 +113,7 @@ namespace ops {
|
||||||
originalIndices.linspace(0);
|
originalIndices.linspace(0);
|
||||||
ops::dynamic_partition op;
|
ops::dynamic_partition op;
|
||||||
auto res = op.execute({&originalIndices, indices}, {}, {numPartition});
|
auto res = op.execute({&originalIndices, indices}, {}, {numPartition});
|
||||||
REQUIRE_OK(res->status());
|
REQUIRE_TRUE(res->status() == ND4J_STATUS_OK, 0, "dynamic_partition_bp: Error with dynamic partitioning.");
|
||||||
ops::dynamic_stitch stichOp;
|
ops::dynamic_stitch stichOp;
|
||||||
std::vector<NDArray*> partitions(numPartition * 2);
|
std::vector<NDArray*> partitions(numPartition * 2);
|
||||||
for (size_t i = 0; i < res->size(); i++) {
|
for (size_t i = 0; i < res->size(); i++) {
|
||||||
|
@ -122,7 +122,8 @@ namespace ops {
|
||||||
}
|
}
|
||||||
|
|
||||||
auto result = stichOp.execute(partitions, {}, {numPartition}, {}, false);
|
auto result = stichOp.execute(partitions, {}, {numPartition}, {}, false);
|
||||||
REQUIRE_OK(result->status());
|
REQUIRE_TRUE(result->status() == ND4J_STATUS_OK, 0, "dynamic_partition_bp: Error with dynamic partitioning.");
|
||||||
|
result->at(0)->reshapei(outputList[0]->getShapeAsVector());
|
||||||
outputList[1]->assign(indices);
|
outputList[1]->assign(indices);
|
||||||
outputList[0]->assign(result->at(0));
|
outputList[0]->assign(result->at(0));
|
||||||
|
|
||||||
|
|
|
@ -46,8 +46,11 @@ namespace nd4j {
|
||||||
max = &m2;
|
max = &m2;
|
||||||
}
|
}
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
int numBits = INT_ARG(0);
|
int numBits = 8;
|
||||||
bool narrowed = INT_ARG(1);
|
if (block.getIArguments() && block.getIArguments()->size())
|
||||||
|
numBits = INT_ARG(0);
|
||||||
|
bool narrowed = false;
|
||||||
|
//INT_ARG(1);
|
||||||
if (block.getIArguments()->size() == 2) {
|
if (block.getIArguments()->size() == 2) {
|
||||||
numBits = INT_ARG(0);
|
numBits = INT_ARG(0);
|
||||||
narrowed = INT_ARG(1);
|
narrowed = INT_ARG(1);
|
||||||
|
|
|
@ -44,10 +44,10 @@ namespace nd4j {
|
||||||
if (targetRank == 0) { // scalar only
|
if (targetRank == 0) { // scalar only
|
||||||
determinantShape = ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inShape));
|
determinantShape = ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inShape));
|
||||||
}
|
}
|
||||||
else if (targetRank == 1) { // vector
|
else if (targetRank == 1) { // vector
|
||||||
determinantShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(shape::sizeAt(inShape, 0), ArrayOptions::dataType(inShape));
|
determinantShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(shape::sizeAt(inShape, 0), ArrayOptions::dataType(inShape));
|
||||||
}
|
}
|
||||||
else { // only two last dimensions are excluded
|
else { // only two last dimensions are excluded
|
||||||
determinantShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), shape::order(inShape), targetRank, shape::shapeOf(inShape));
|
determinantShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), shape::order(inShape), targetRank, shape::shapeOf(inShape));
|
||||||
}
|
}
|
||||||
return SHAPELIST(determinantShape);
|
return SHAPELIST(determinantShape);
|
||||||
|
@ -79,7 +79,7 @@ namespace nd4j {
|
||||||
REQUIRE_TRUE(input->rankOf() >=2, 0, "log_matrix_determinant: The rank of input array should not less than 2, but %i is given", input->rankOf());
|
REQUIRE_TRUE(input->rankOf() >=2, 0, "log_matrix_determinant: The rank of input array should not less than 2, but %i is given", input->rankOf());
|
||||||
REQUIRE_TRUE(input->sizeAt(-1) == input->sizeAt(-2), 0, "log_matrix_determinant: The last two dimmensions should be equal, but %i and %i are given", input->sizeAt(-1), input->sizeAt(-2));
|
REQUIRE_TRUE(input->sizeAt(-1) == input->sizeAt(-2), 0, "log_matrix_determinant: The last two dimmensions should be equal, but %i and %i are given", input->sizeAt(-1), input->sizeAt(-2));
|
||||||
|
|
||||||
return helpers::log_abs_determinant(block.launchContext(), input, output);
|
return helpers::logAbsDeterminant(block.launchContext(), input, output);
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(log_matrix_determinant) {
|
DECLARE_SHAPE_FN(log_matrix_determinant) {
|
||||||
|
@ -91,7 +91,7 @@ namespace nd4j {
|
||||||
if (targetRank == 0) { // scalar only
|
if (targetRank == 0) { // scalar only
|
||||||
determinantShape = ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inShape));
|
determinantShape = ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inShape));
|
||||||
}
|
}
|
||||||
else if (targetRank == 1) { // vector
|
else if (targetRank == 1) { // vector
|
||||||
determinantShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(shape::sizeAt(inShape, 0), ArrayOptions::dataType(inShape));
|
determinantShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(shape::sizeAt(inShape, 0), ArrayOptions::dataType(inShape));
|
||||||
}
|
}
|
||||||
else { // only two last dimensions are excluded
|
else { // only two last dimensions are excluded
|
||||||
|
@ -132,7 +132,7 @@ namespace nd4j {
|
||||||
if (targetRank == 0) { // scalar only
|
if (targetRank == 0) { // scalar only
|
||||||
determinantShape = ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inShape));
|
determinantShape = ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inShape));
|
||||||
}
|
}
|
||||||
else if (targetRank == 1) { // vector
|
else if (targetRank == 1) { // vector
|
||||||
determinantShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(shape::sizeAt(inShape, 0), ArrayOptions::dataType(inShape));
|
determinantShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(shape::sizeAt(inShape, 0), ArrayOptions::dataType(inShape));
|
||||||
}
|
}
|
||||||
else { // only two last dimensions are excluded
|
else { // only two last dimensions are excluded
|
||||||
|
|
|
@ -31,36 +31,36 @@ namespace ops {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
CUSTOM_OP_IMPL(gruCell, 6, 4, false, 0, 0) {
|
CUSTOM_OP_IMPL(gruCell, 6, 4, false, 0, 0) {
|
||||||
auto x = INPUT_VARIABLE(0); // input [bS x inSize]
|
auto x = INPUT_VARIABLE(0); // input [bS, nIn], nIn - input size
|
||||||
auto hLast = INPUT_VARIABLE(1); // previous cell output [bS x numUnits], that is at previous time step t-1
|
auto hLast = INPUT_VARIABLE(1); // previous cell output [bS, nU], that is at previous time step t-1, nU - number of units
|
||||||
auto Wru = INPUT_VARIABLE(2); // RU weights - [(nIn+nOut), 2*numUnits] - reset and update gates (input/recurrent weights)
|
auto Wru = INPUT_VARIABLE(2); // RU weights - [nIn+nU, 2*nU] - reset and update gates (input/recurrent weights)
|
||||||
auto Wc = INPUT_VARIABLE(3); // C weights - [(nIn+nOut), numUnits] - cell gate (input/recurrent weights)
|
auto Wc = INPUT_VARIABLE(3); // C weights - [nIn+nU, nU] - cell gate (input/recurrent weights)
|
||||||
auto bru = INPUT_VARIABLE(4); // reset and update biases, [2*numUnits] - reset and update gates
|
auto bru = INPUT_VARIABLE(4); // reset and update biases, [2*nU] - reset and update gates
|
||||||
auto bc = INPUT_VARIABLE(5); // cell biases, [numUnits]
|
auto bc = INPUT_VARIABLE(5); // cell biases, [nU]
|
||||||
|
|
||||||
auto r = OUTPUT_VARIABLE(0); // Reset gate output [bS, numUnits]
|
auto r = OUTPUT_VARIABLE(0); // Reset gate output [bS, nU]
|
||||||
auto u = OUTPUT_VARIABLE(1); // Update gate output [bS, numUnits]
|
auto u = OUTPUT_VARIABLE(1); // Update gate output [bS, nU]
|
||||||
auto c = OUTPUT_VARIABLE(2); // Cell gate output [bS, numUnits]
|
auto c = OUTPUT_VARIABLE(2); // Cell gate output [bS, nU]
|
||||||
auto h = OUTPUT_VARIABLE(3); // current cell output [bS, numUnits]
|
auto h = OUTPUT_VARIABLE(3); // current cell output [bS, nU]
|
||||||
|
|
||||||
REQUIRE_TRUE(x->rankOf()==2 && hLast->rankOf()==2, 0, "gruCell: Input ranks must be 2 for inputs 0 and 1 (x, hLast) - got %i, %i", x->rankOf(), hLast->rankOf());
|
REQUIRE_TRUE(x->rankOf()==2 && hLast->rankOf()==2, 0, "gruCell: Input ranks must be 2 for inputs 0 and 1 (x, hLast) - got %i, %i", x->rankOf(), hLast->rankOf());
|
||||||
|
|
||||||
const int rank = x->rankOf();
|
const int rank = x->rankOf();
|
||||||
const auto bS = x->sizeAt(0);
|
const auto bS = x->sizeAt(0);
|
||||||
const auto nIn = x->sizeAt(1);
|
const auto nIn = x->sizeAt(1);
|
||||||
const auto nU = hLast->sizeAt(1);
|
const auto nU = hLast->sizeAt(1);
|
||||||
|
|
||||||
REQUIRE_TRUE(x->sizeAt(0) == hLast->sizeAt(0), 0, "gruCell: Input minibatch sizes (dimension 0) must be same for x and hLast");
|
REQUIRE_TRUE(x->sizeAt(0) == hLast->sizeAt(0), 0, "gruCell: Input minibatch sizes (dimension 0) must be same for x and hLast");
|
||||||
REQUIRE_TRUE(Wru->rankOf()==2 && Wc->rankOf()==2, 0, "gruCell: weight arrays (Wru, Wc) arrays must be 2, got %i and %i", Wru->rankOf(), Wc->rankOf());
|
REQUIRE_TRUE(Wru->rankOf()==2 && Wc->rankOf()==2, 0, "gruCell: weight arrays (Wru, Wc) arrays must be 2, got %i and %i", Wru->rankOf(), Wc->rankOf());
|
||||||
REQUIRE_TRUE(Wru->sizeAt(0)==(nIn+nU) && Wc->sizeAt(0)==(nIn+nU), 0, "gruCell: Weights size(0) must be equal to inSize + numUnits, got %i", Wru->sizeAt(0));
|
REQUIRE_TRUE(Wru->sizeAt(0)==(nIn+nU) && Wc->sizeAt(0)==(nIn+nU), 0, "gruCell: Weights size(0) must be equal to nIn + nU, got %i", Wru->sizeAt(0));
|
||||||
REQUIRE_TRUE(Wru->sizeAt(1)==(2*nU), 0, "gruCell: Weights (reset and update) size(1) must be equal to 2*numUnits, got %i", Wru->sizeAt(1));
|
REQUIRE_TRUE(Wru->sizeAt(1)==(2*nU), 0, "gruCell: Weights (reset and update) size(1) must be equal to 2*nU, got %i", Wru->sizeAt(1));
|
||||||
REQUIRE_TRUE(Wc->sizeAt(1)==nU, 0, "gruCell: Weights (cell) size(1) must be equal to numUnits, got %i", Wc->sizeAt(1));
|
REQUIRE_TRUE(Wc->sizeAt(1)==nU, 0, "gruCell: Weights (cell) size(1) must be equal to nU, got %i", Wc->sizeAt(1));
|
||||||
REQUIRE_TRUE(bru->rankOf()==1 && bru->sizeAt(0)==(2*nU), 0, "gruCell: reset/update biases must be rank 1, size 2*numUnits");
|
REQUIRE_TRUE(bru->rankOf()==1 && bru->sizeAt(0)==(2*nU), 0, "gruCell: reset/update biases must be rank 1, size 2*nU");
|
||||||
REQUIRE_TRUE(bc->rankOf()==1 && bc->sizeAt(0)==nU, 0, "gruCell: cell biases must be rank 1, size numUnits");
|
REQUIRE_TRUE(bc->rankOf()==1 && bc->sizeAt(0)==nU, 0, "gruCell: cell biases must be rank 1, size nU");
|
||||||
REQUIRE_TRUE(r->rankOf()==2 && u->rankOf()==2 && c->rankOf()==2 && h->rankOf()==2 &&
|
REQUIRE_TRUE(r->rankOf()==2 && u->rankOf()==2 && c->rankOf()==2 && h->rankOf()==2 &&
|
||||||
r->sizeAt(0)==bS && u->sizeAt(0)==bS && c->sizeAt(0)==bS && h->sizeAt(0)==bS &&
|
r->sizeAt(0)==bS && u->sizeAt(0)==bS && c->sizeAt(0)==bS && h->sizeAt(0)==bS &&
|
||||||
r->sizeAt(1)==nU && u->sizeAt(1)==nU && c->sizeAt(1)==nU && h->sizeAt(1)==nU,
|
r->sizeAt(1)==nU && u->sizeAt(1)==nU && c->sizeAt(1)==nU && h->sizeAt(1)==nU,
|
||||||
0, "gruCell: Output arrays must all be rank 2 with size(0) == batchSize and size(1) == numUnits");
|
0, "gruCell: Output arrays must all be rank 2 with size(0) == batchSize and size(1) == nU");
|
||||||
|
|
||||||
helpers::gruCell(block.launchContext(), x, hLast, Wru, Wc, bru, bc, r, u, c, h);
|
helpers::gruCell(block.launchContext(), x, hLast, Wru, Wc, bru, bc, r, u, c, h);
|
||||||
|
|
||||||
|
@ -80,39 +80,39 @@ DECLARE_TYPES(gruCell) {
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(gruCell) {
|
DECLARE_SHAPE_FN(gruCell) {
|
||||||
|
|
||||||
auto x = inputShape->at(0); // input [bS x inSize]
|
auto x = inputShape->at(0); // input [bS x nIn]
|
||||||
auto hLast = inputShape->at(1); // previous cell output [bS x numUnits], that is at previous time step t-1
|
auto hLast = inputShape->at(1); // previous cell output [bS x nU], that is at previous time step t-1
|
||||||
auto Wru = inputShape->at(2); // RU weights - [(nIn+nOut), 2*numUnits] - reset and update gates (input/recurrent weights)
|
auto Wru = inputShape->at(2); // RU weights - [(nIn+nU), 2*nU] - reset and update gates (input/recurrent weights)
|
||||||
auto Wc = inputShape->at(3); // C weights - [(nIn+nOut), numUnits] - cell gate (input/recurrent weights)
|
auto Wc = inputShape->at(3); // C weights - [(nIn+nU), nU] - cell gate (input/recurrent weights)
|
||||||
auto bru = inputShape->at(4); // reset and update biases, [2*numUnits] - reset and update gates
|
auto bru = inputShape->at(4); // reset and update biases, [2*nU] - reset and update gates
|
||||||
auto bc = inputShape->at(5); // cell biases, [numUnits]
|
auto bc = inputShape->at(5); // cell biases, [nU]
|
||||||
|
|
||||||
REQUIRE_TRUE(shape::rank(x)==2 && shape::rank(hLast)==2, 0, "gruCell: Input ranks must be 2 for inputs 0 and 1 (x, hLast) - got %i, %i", shape::rank(x), shape::rank(hLast));
|
REQUIRE_TRUE(shape::rank(x)==2 && shape::rank(hLast)==2, 0, "gruCell: Input ranks must be 2 for inputs 0 and 1 (x, hLast) - got %i, %i", shape::rank(x), shape::rank(hLast));
|
||||||
|
|
||||||
const int rank = x[0];
|
const int rank = x[0];
|
||||||
const auto bS = x[1];
|
const auto bS = x[1];
|
||||||
const auto inSize = x[2];
|
const auto nIn = x[2];
|
||||||
const auto numUnits = hLast[2];
|
const auto nU = hLast[2];
|
||||||
|
|
||||||
REQUIRE_TRUE(x[1] == hLast[1], 0, "gruCell: Input minibatch sizes (dimension 0) must be same for x and hLast");
|
REQUIRE_TRUE(x[1] == hLast[1], 0, "gruCell: Input minibatch sizes (dimension 0) must be same for x and hLast");
|
||||||
REQUIRE_TRUE(shape::rank(Wru)==2 && shape::rank(Wc)==2, 0, "gruCell: weight arrays (Wru, Wc) arrays must be 2, got %i and %i", shape::rank(Wru), shape::rank(Wc));
|
REQUIRE_TRUE(shape::rank(Wru)==2 && shape::rank(Wc)==2, 0, "gruCell: weight arrays (Wru, Wc) arrays must be 2, got %i and %i", shape::rank(Wru), shape::rank(Wc));
|
||||||
REQUIRE_TRUE(Wru[1]==(inSize+numUnits) && Wc[1]==(inSize+numUnits), 0, "gruCell: Weights size(0) must be equal to inSize + numUnits, got %i and %i", Wru[1], Wc[1]);
|
REQUIRE_TRUE(Wru[1]==(nIn+nU) && Wc[1]==(nIn+nU), 0, "gruCell: Weights size(0) must be equal to nIn + nU, got %i and %i", Wru[1], Wc[1]);
|
||||||
REQUIRE_TRUE(Wru[2]==(2*numUnits), 0, "gruCell: Weights (reset and update) size(1) must be equal to 2*numUnits, got %i", Wru[2]);
|
REQUIRE_TRUE(Wru[2]==(2*nU), 0, "gruCell: Weights (reset and update) size(1) must be equal to 2*nU, got %i", Wru[2]);
|
||||||
REQUIRE_TRUE(Wc[2]==numUnits, 0, "gruCell: Weights (cell) size(1) must be equal to numUnits, got %i", Wc[2]);
|
REQUIRE_TRUE(Wc[2]==nU, 0, "gruCell: Weights (cell) size(1) must be equal to nU, got %i", Wc[2]);
|
||||||
REQUIRE_TRUE(shape::rank(bru)==1 && bru[1]==(2*numUnits), 0, "gruCell: reset/update biases must be rank 1, size 2*numUnits");
|
REQUIRE_TRUE(shape::rank(bru)==1 && bru[1]==(2*nU), 0, "gruCell: reset/update biases must be rank 1, size 2*nU");
|
||||||
REQUIRE_TRUE(shape::rank(bc)==1 && bc[1]==numUnits, 0, "gruCell: cell biases must be rank 1, size numUnits");
|
REQUIRE_TRUE(shape::rank(bc)==1 && bc[1]==nU, 0, "gruCell: cell biases must be rank 1, size nU");
|
||||||
|
|
||||||
Nd4jLong *s0(nullptr);
|
Nd4jLong *s0(nullptr);
|
||||||
ALLOCATE(s0, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong);// [bS x numUnits]
|
ALLOCATE(s0, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong);// [bS x nU]
|
||||||
|
|
||||||
s0[0] = rank;
|
s0[0] = rank;
|
||||||
s0[1] = bS;
|
s0[1] = bS;
|
||||||
s0[2] = numUnits;
|
s0[2] = nU;
|
||||||
|
|
||||||
ShapeUtils::updateStridesAndType(s0, x, shape::order(hLast));
|
ShapeUtils::updateStridesAndType(s0, x, shape::order(hLast));
|
||||||
auto ts0 = ConstantShapeHelper::getInstance()->createFromExisting(s0, block.workspace());
|
auto ts0 = ConstantShapeHelper::getInstance()->createFromExisting(s0, block.workspace());
|
||||||
|
|
||||||
//4 output shapes, all [bs, numUnits]
|
//4 output shapes, all [bs, nU]
|
||||||
return SHAPELIST(ts0, ts0, ts0, ts0);
|
return SHAPELIST(ts0, ts0, ts0, ts0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -152,9 +152,9 @@ CUSTOM_OP_IMPL(sru_bp, 8, 4, true, 0, 0) {
|
||||||
auto inGradH = INPUT_VARIABLE(6); // [bS x K x N]
|
auto inGradH = INPUT_VARIABLE(6); // [bS x K x N]
|
||||||
NDArray* mask = nullptr; // optional, 2d tensor of dropout mask [bS x K]
|
NDArray* mask = nullptr; // optional, 2d tensor of dropout mask [bS x K]
|
||||||
|
|
||||||
bool applyMask = false;
|
bool applyMask = false;
|
||||||
if (block.width() > 7) {
|
if (block.width() > 7) {
|
||||||
mask = INPUT_VARIABLE(7);
|
mask = INPUT_VARIABLE(7);
|
||||||
applyMask = true;
|
applyMask = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -166,7 +166,7 @@ CUSTOM_OP_IMPL(sru_bp, 8, 4, true, 0, 0) {
|
||||||
const int bS = x->shapeOf()[0];
|
const int bS = x->shapeOf()[0];
|
||||||
const int K = x->shapeOf()[1];
|
const int K = x->shapeOf()[1];
|
||||||
const int N = x->shapeOf()[2]; // N - number of time steps
|
const int N = x->shapeOf()[2]; // N - number of time steps
|
||||||
|
|
||||||
auto gradBias = NDArrayFactory::create_(x->ordering(), {bS, 2*K, N}, gradX->dataType(), block.launchContext());
|
auto gradBias = NDArrayFactory::create_(x->ordering(), {bS, 2*K, N}, gradX->dataType(), block.launchContext());
|
||||||
auto gradU = NDArrayFactory::create_(x->ordering(), {bS, 3*K, N}, gradX->dataType(), block.launchContext());
|
auto gradU = NDArrayFactory::create_(x->ordering(), {bS, 3*K, N}, gradX->dataType(), block.launchContext());
|
||||||
auto gradHX = NDArrayFactory::create_(x->ordering(), {bS, K, N}, gradX->dataType(), block.launchContext());
|
auto gradHX = NDArrayFactory::create_(x->ordering(), {bS, K, N}, gradX->dataType(), block.launchContext());
|
||||||
|
@ -199,7 +199,7 @@ CUSTOM_OP_IMPL(sru_bp, 8, 4, true, 0, 0) {
|
||||||
|
|
||||||
std::vector<Nd4jLong> idx = {0,0, 0,0, 0,0};
|
std::vector<Nd4jLong> idx = {0,0, 0,0, 0,0};
|
||||||
|
|
||||||
for (int t = N-1; t >=0 ; --t) {
|
for (int t = N-1; t >=0 ; --t) {
|
||||||
// initialization
|
// initialization
|
||||||
idx[4] = t;
|
idx[4] = t;
|
||||||
idx[5] = t + 1;
|
idx[5] = t + 1;
|
||||||
|
@ -242,7 +242,7 @@ CUSTOM_OP_IMPL(sru_bp, 8, 4, true, 0, 0) {
|
||||||
gct->applyPairwiseTransform(pairwise::Subtract, &xt, temp1, nullptr); // temp1 = (g_ct - xt)
|
gct->applyPairwiseTransform(pairwise::Subtract, &xt, temp1, nullptr); // temp1 = (g_ct - xt)
|
||||||
rtMinus->applyPairwiseTransform(pairwise::Multiply, &rt, temp2, nullptr); // temp2 = (1.0f - rt) * rt;
|
rtMinus->applyPairwiseTransform(pairwise::Multiply, &rt, temp2, nullptr); // temp2 = (1.0f - rt) * rt;
|
||||||
temp1->applyPairwiseTransform(pairwise::Multiply, *temp2, nullptr); // temp1 = (g_ct - xt) * (1.0f - rt) * rt;
|
temp1->applyPairwiseTransform(pairwise::Multiply, *temp2, nullptr); // temp1 = (g_ct - xt) * (1.0f - rt) * rt;
|
||||||
inGradHt.applyPairwiseTransform(pairwise::Multiply, temp1, &gradBRt, nullptr); // = inGradHt * (g_ct - xt) * (1.0f - rt) * rt;
|
inGradHt.applyPairwiseTransform(pairwise::Multiply, temp1, &gradBRt, nullptr); // = inGradHt * (g_ct - xt) * (1.0f - rt) * rt;
|
||||||
|
|
||||||
// bF, TODO - tanh
|
// bF, TODO - tanh
|
||||||
// gradTanh = (1.0f - g_ct * g_ct);
|
// gradTanh = (1.0f - g_ct * g_ct);
|
||||||
|
@ -259,7 +259,7 @@ CUSTOM_OP_IMPL(sru_bp, 8, 4, true, 0, 0) {
|
||||||
temp1->applyPairwiseTransform(pairwise::Multiply, temp2, &gradBFt, nullptr); // gradBFt = (gradCt + inGradCt) * (ct_1 - zt) * (1 - ft) * ft;
|
temp1->applyPairwiseTransform(pairwise::Multiply, temp2, &gradBFt, nullptr); // gradBFt = (gradCt + inGradCt) * (ct_1 - zt) * (1 - ft) * ft;
|
||||||
|
|
||||||
// x_t (highway connection), gradHXt = inGradHt * (1.0f - rt);
|
// x_t (highway connection), gradHXt = inGradHt * (1.0f - rt);
|
||||||
inGradHt.applyPairwiseTransform(pairwise::Multiply, rtMinus, &gradHXt, nullptr);
|
inGradHt.applyPairwiseTransform(pairwise::Multiply, rtMinus, &gradHXt, nullptr);
|
||||||
|
|
||||||
// U_t, gradUZt = (inGradHt * rt * grad_tanh + inGradCt) * (1.0f - ft);
|
// U_t, gradUZt = (inGradHt * rt * grad_tanh + inGradCt) * (1.0f - ft);
|
||||||
rt.applyPairwiseTransform(pairwise::Multiply, gradTanh, temp1, nullptr); // temp1 = rt * grad_tanh
|
rt.applyPairwiseTransform(pairwise::Multiply, gradTanh, temp1, nullptr); // temp1 = rt * grad_tanh
|
||||||
|
@ -280,14 +280,14 @@ CUSTOM_OP_IMPL(sru_bp, 8, 4, true, 0, 0) {
|
||||||
// gradInit
|
// gradInit
|
||||||
gradInit->assign(inGradCt);
|
gradInit->assign(inGradCt);
|
||||||
|
|
||||||
// gradX
|
// gradX
|
||||||
auto weightsT = w->transpose(); // [K x 3K]
|
auto weightsT = w->transpose(); // [K x 3K]
|
||||||
MmulHelper::mmul(&weightsT, gradU, gradX, 1., 0.); // [bS x K x N]
|
MmulHelper::mmul(&weightsT, gradU, gradX, 1., 0.); // [bS x K x N]
|
||||||
gradX->applyPairwiseTransform(pairwise::Add, gradHX, gradX, nullptr); // + grad_highway_x
|
gradX->applyPairwiseTransform(pairwise::Add, gradHX, gradX, nullptr); // + grad_highway_x
|
||||||
if(applyMask)
|
if(applyMask)
|
||||||
gradX->applyBroadcast(broadcast::Multiply, {0,1}, mask, gradX, nullptr); // apply mask
|
gradX->applyBroadcast(broadcast::Multiply, {0,1}, mask, gradX, nullptr); // apply mask
|
||||||
|
|
||||||
// gradB
|
// gradB
|
||||||
auto temp3 = gradBias->reduceAlongDimension(reduce::Sum, {0,2}, false, true); // [1 x 2K]
|
auto temp3 = gradBias->reduceAlongDimension(reduce::Sum, {0,2}, false, true); // [1 x 2K]
|
||||||
gradB->assign(temp3);
|
gradB->assign(temp3);
|
||||||
|
|
||||||
|
@ -298,7 +298,7 @@ CUSTOM_OP_IMPL(sru_bp, 8, 4, true, 0, 0) {
|
||||||
delete gct; delete gradU; delete gradHX;
|
delete gct; delete gradU; delete gradHX;
|
||||||
delete temp1; delete temp2; delete temp3; delete gradCt; delete wi;
|
delete temp1; delete temp2; delete temp3; delete gradCt; delete wi;
|
||||||
delete gradTanh; delete ftMinus; delete rtMinus; delete gradBias;
|
delete gradTanh; delete ftMinus; delete rtMinus; delete gradBias;
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -322,21 +322,21 @@ DECLARE_SHAPE_FN(sru_bp) {
|
||||||
ShapeDescriptor descriptor4(ArrayOptions::dataType(inShape), order, {bS, inSize});
|
ShapeDescriptor descriptor4(ArrayOptions::dataType(inShape), order, {bS, inSize});
|
||||||
|
|
||||||
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(descriptor1), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor2), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor3), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor4));
|
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(descriptor1), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor2), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor3), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor4));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
CUSTOM_OP_IMPL(sru_bi, 5, 2, true, 0, 0) {
|
CUSTOM_OP_IMPL(sru_bi, 5, 2, true, 0, 0) {
|
||||||
|
|
||||||
auto x = INPUT_VARIABLE(0); // X, input 3d tensor [time x bS x 2*inSize], time - number of time steps, bS - batch size, inSize - number of features
|
auto x = INPUT_VARIABLE(0); // X, input 3d tensor [time x bS x 2*inSize], time - number of time steps, bS - batch size, inSize - number of features
|
||||||
auto w = INPUT_VARIABLE(1); // W, 2d tensor of weights [2*inSize x 6*inSize]
|
auto w = INPUT_VARIABLE(1); // W, 2d tensor of weights [2*inSize x 6*inSize]
|
||||||
auto b = INPUT_VARIABLE(2); // B, row of biases with twice length [1 x 4*inSize]
|
auto b = INPUT_VARIABLE(2); // B, row of biases with twice length [1 x 4*inSize]
|
||||||
auto c0 = INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state [bS x 2*inSize] at time t=0
|
auto c0 = INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state [bS x 2*inSize] at time t=0
|
||||||
NDArray* mask = block.width() > 4 ? INPUT_VARIABLE(4) : nullptr; // optional, 2d tensor of dropout mask [bS x 2*inSize]
|
NDArray* mask = block.width() > 4 ? INPUT_VARIABLE(4) : nullptr; // optional, 2d tensor of dropout mask [bS x 2*inSize]
|
||||||
|
|
||||||
auto ht = OUTPUT_VARIABLE(0); // h_t, [time x bS x 2*inSize]
|
auto ht = OUTPUT_VARIABLE(0); // h_t, [time x bS x 2*inSize]
|
||||||
auto ct = OUTPUT_VARIABLE(1); // c_t, [time x bS x 2*inSize]
|
auto ct = OUTPUT_VARIABLE(1); // c_t, [time x bS x 2*inSize]
|
||||||
|
|
||||||
// input shapes validation
|
// input shapes validation
|
||||||
const int rank = x->rankOf();
|
const int rank = x->rankOf();
|
||||||
|
@ -345,20 +345,20 @@ CUSTOM_OP_IMPL(sru_bi, 5, 2, true, 0, 0) {
|
||||||
|
|
||||||
REQUIRE_TRUE(x->rankOf() == rank, 0, "SRU_BI operation: wrong rank of input array, expected is %i, but got %i instead !", rank, x->rankOf());
|
REQUIRE_TRUE(x->rankOf() == rank, 0, "SRU_BI operation: wrong rank of input array, expected is %i, but got %i instead !", rank, x->rankOf());
|
||||||
REQUIRE_TRUE(w->rankOf() == rank-1, 0, "SRU_BI operation: wrong rank of weights array, expected is %i, but got %i instead !", rank-1, w->rankOf());
|
REQUIRE_TRUE(w->rankOf() == rank-1, 0, "SRU_BI operation: wrong rank of weights array, expected is %i, but got %i instead !", rank-1, w->rankOf());
|
||||||
REQUIRE_TRUE(b->rankOf() <= rank-1, 0, "SRU_BI operation: wrong rank of biases array, expected is <=2, but got %i instead !", b->rankOf());
|
REQUIRE_TRUE(b->rankOf() == 1, 0, "SRU_BI operation: wrong rank of biases array, expected is 1, but got %i instead !", b->rankOf());
|
||||||
REQUIRE_TRUE(c0->rankOf() == rank-1, 0, "SRU_BI operation: wrong rank of initial state array, expected is %i, but got %i instead !", rank-1, c0->rankOf());
|
REQUIRE_TRUE(c0->rankOf() == rank-1, 0, "SRU_BI operation: wrong rank of initial state array, expected is %i, but got %i instead !", rank-1, c0->rankOf());
|
||||||
if(mask)
|
if(mask)
|
||||||
REQUIRE_TRUE(mask->rankOf() == rank-1, 0, "SRU_BI operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, mask->rankOf());
|
REQUIRE_TRUE(mask->rankOf() == rank-1, 0, "SRU_BI operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, mask->rankOf());
|
||||||
|
|
||||||
const std::string wShape = ShapeUtils::shapeAsString(w);
|
const std::string wShape = ShapeUtils::shapeAsString(w);
|
||||||
const std::string wCorrectShape = ShapeUtils::shapeAsString({2*inSize, 6*inSize});
|
const std::string wCorrectShape = ShapeUtils::shapeAsString({2*inSize, 6*inSize});
|
||||||
const std::string bShape = ShapeUtils::shapeAsString(b);
|
const std::string bShape = ShapeUtils::shapeAsString(b);
|
||||||
const std::string bCorrectShape = ShapeUtils::shapeAsString({1, 4*inSize});
|
const std::string bCorrectShape = ShapeUtils::shapeAsString({4*inSize});
|
||||||
const std::string c0Shape = ShapeUtils::shapeAsString(c0);
|
const std::string c0Shape = ShapeUtils::shapeAsString(c0);
|
||||||
const std::string c0CorrectShape = ShapeUtils::shapeAsString({bS, 2*inSize});
|
const std::string c0CorrectShape = ShapeUtils::shapeAsString({bS, 2*inSize});
|
||||||
|
|
||||||
REQUIRE_TRUE(wShape == wCorrectShape, 0, "SRU_BI operation: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
|
REQUIRE_TRUE(wShape == wCorrectShape, 0, "SRU_BI operation: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
|
||||||
REQUIRE_TRUE(bShape == bCorrectShape, 0, "SRU_BI operation: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
|
REQUIRE_TRUE(bShape == bCorrectShape, 0, "SRU_BI operation: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
|
||||||
REQUIRE_TRUE(c0Shape == c0CorrectShape, 0, "SRU_BI operation: wrong shape of initial state array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), c0Shape.c_str());
|
REQUIRE_TRUE(c0Shape == c0CorrectShape, 0, "SRU_BI operation: wrong shape of initial state array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), c0Shape.c_str());
|
||||||
if(mask) {
|
if(mask) {
|
||||||
const std::string maskShape = ShapeUtils::shapeAsString(mask);
|
const std::string maskShape = ShapeUtils::shapeAsString(mask);
|
||||||
|
@ -366,7 +366,7 @@ CUSTOM_OP_IMPL(sru_bi, 5, 2, true, 0, 0) {
|
||||||
}
|
}
|
||||||
|
|
||||||
helpers::sruBI(block.launchContext(), x, w, b, c0, mask, ht, ct);
|
helpers::sruBI(block.launchContext(), x, w, b, c0, mask, ht, ct);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -379,20 +379,20 @@ CUSTOM_OP_IMPL(sru_bi, 5, 2, true, 0, 0) {
|
||||||
DECLARE_SHAPE_FN(sru_bi) {
|
DECLARE_SHAPE_FN(sru_bi) {
|
||||||
|
|
||||||
auto xShapeInfo = inputShape->at(0); // [time x bS x 2K ]
|
auto xShapeInfo = inputShape->at(0); // [time x bS x 2K ]
|
||||||
auto wShapeInfo = inputShape->at(1);
|
auto wShapeInfo = inputShape->at(1);
|
||||||
auto bShapeInfo = inputShape->at(2);
|
auto bShapeInfo = inputShape->at(2);
|
||||||
auto c0ShapeInfo = inputShape->at(3);
|
auto c0ShapeInfo = inputShape->at(3);
|
||||||
Nd4jLong* maskShapeInfo = block.width() > 4 ? inputShape->at(4) : nullptr; // optional, 2d tensor of dropout mask [bS x inSize]
|
Nd4jLong* maskShapeInfo = block.width() > 4 ? inputShape->at(4) : nullptr; // optional, 2d tensor of dropout mask [bS x inSize]
|
||||||
|
|
||||||
const int rank = xShapeInfo[0]; // = 3
|
const int rank = xShapeInfo[0]; // = 3
|
||||||
const Nd4jLong time = xShapeInfo[1];
|
const Nd4jLong time = xShapeInfo[1];
|
||||||
const Nd4jLong bS = xShapeInfo[2];
|
const Nd4jLong bS = xShapeInfo[2];
|
||||||
const Nd4jLong inSize = xShapeInfo[3] / 2;
|
const Nd4jLong inSize = xShapeInfo[3] / 2;
|
||||||
|
|
||||||
|
|
||||||
// input shapes validation
|
// input shapes validation
|
||||||
REQUIRE_TRUE(wShapeInfo[0] == rank-1, 0, "SRU_BI operation: wrong rank of weights array, expected is %i, but got %i instead !", rank-1, wShapeInfo[0]);
|
REQUIRE_TRUE(wShapeInfo[0] == rank-1, 0, "SRU_BI operation: wrong rank of weights array, expected is %i, but got %i instead !", rank-1, wShapeInfo[0]);
|
||||||
REQUIRE_TRUE(bShapeInfo[0] == rank-1, 0, "SRU_BI operation: wrong rank of biases array, expected is <=2, but got %i instead !", rank-1, bShapeInfo[0]);
|
REQUIRE_TRUE(bShapeInfo[0] == 1, 0, "SRU_BI operation: wrong rank of biases array, expected is 1, but got %i instead !", bShapeInfo[0]);
|
||||||
REQUIRE_TRUE(c0ShapeInfo[0] == rank-1, 0, "SRU_BI operation: wrong rank of initial state array, expected is %i, but got %i instead !", rank-1, c0ShapeInfo[0]);
|
REQUIRE_TRUE(c0ShapeInfo[0] == rank-1, 0, "SRU_BI operation: wrong rank of initial state array, expected is %i, but got %i instead !", rank-1, c0ShapeInfo[0]);
|
||||||
if(maskShapeInfo)
|
if(maskShapeInfo)
|
||||||
REQUIRE_TRUE(maskShapeInfo[0] == rank-1, 0, "SRU_BI operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, maskShapeInfo[0]);
|
REQUIRE_TRUE(maskShapeInfo[0] == rank-1, 0, "SRU_BI operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, maskShapeInfo[0]);
|
||||||
|
@ -400,12 +400,12 @@ DECLARE_SHAPE_FN(sru_bi) {
|
||||||
const std::string wShape = ShapeUtils::shapeAsString(wShapeInfo);
|
const std::string wShape = ShapeUtils::shapeAsString(wShapeInfo);
|
||||||
const std::string wCorrectShape = ShapeUtils::shapeAsString({2*inSize, 6*inSize});
|
const std::string wCorrectShape = ShapeUtils::shapeAsString({2*inSize, 6*inSize});
|
||||||
const std::string bShape = ShapeUtils::shapeAsString(bShapeInfo);
|
const std::string bShape = ShapeUtils::shapeAsString(bShapeInfo);
|
||||||
const std::string bCorrectShape = ShapeUtils::shapeAsString({1, 4*inSize});
|
const std::string bCorrectShape = ShapeUtils::shapeAsString({4*inSize});
|
||||||
const std::string c0Shape = ShapeUtils::shapeAsString(c0ShapeInfo);
|
const std::string c0Shape = ShapeUtils::shapeAsString(c0ShapeInfo);
|
||||||
const std::string c0CorrectShape = ShapeUtils::shapeAsString({bS, 2*inSize});
|
const std::string c0CorrectShape = ShapeUtils::shapeAsString({bS, 2*inSize});
|
||||||
|
|
||||||
REQUIRE_TRUE(wShape == wCorrectShape, 0, "SRU_BI operation: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
|
REQUIRE_TRUE(wShape == wCorrectShape, 0, "SRU_BI operation: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
|
||||||
REQUIRE_TRUE(bShape == bCorrectShape, 0, "SRU_BI operation: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
|
REQUIRE_TRUE(bShape == bCorrectShape, 0, "SRU_BI operation: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
|
||||||
REQUIRE_TRUE(c0Shape == c0CorrectShape, 0, "SRU_BI operation: wrong shape of initial state array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), c0Shape.c_str());
|
REQUIRE_TRUE(c0Shape == c0CorrectShape, 0, "SRU_BI operation: wrong shape of initial state array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), c0Shape.c_str());
|
||||||
if(maskShapeInfo) {
|
if(maskShapeInfo) {
|
||||||
const std::string maskShape = ShapeUtils::shapeAsString(maskShapeInfo);
|
const std::string maskShape = ShapeUtils::shapeAsString(maskShapeInfo);
|
||||||
|
@ -428,15 +428,15 @@ DECLARE_SHAPE_FN(sru_bi) {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
CUSTOM_OP_IMPL(sru_bi_bp, 8, 4, true, 0, 0) {
|
CUSTOM_OP_IMPL(sru_bi_bp, 8, 4, true, 0, 0) {
|
||||||
|
|
||||||
auto x = INPUT_VARIABLE(0); // X, input 3d tensor [time x bS x 2*inSize], time - number of time steps, bS - batch size, inSize - number of features
|
auto x = INPUT_VARIABLE(0); // X, input 3d tensor [time x bS x 2*inSize], time - number of time steps, bS - batch size, inSize - number of features
|
||||||
auto w = INPUT_VARIABLE(1); // W, 2d tensor of weights [2*inSize x 6*inSize]
|
auto w = INPUT_VARIABLE(1); // W, 2d tensor of weights [2*inSize x 6*inSize]
|
||||||
auto b = INPUT_VARIABLE(2); // B, row of biases with twice length [1 x 4*inSize]
|
auto b = INPUT_VARIABLE(2); // B, row of biases with twice length [4*inSize]
|
||||||
auto c0 = INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state [bS x 2*inSize] at time t=0
|
auto c0 = INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state [bS x 2*inSize] at time t=0
|
||||||
auto ct = INPUT_VARIABLE(4); // C, [time x bS x 2*inSize]
|
auto ct = INPUT_VARIABLE(4); // C, [time x bS x 2*inSize]
|
||||||
auto inGradC0 = INPUT_VARIABLE(5); // [bS x 2*inSize]
|
auto inGradC0 = INPUT_VARIABLE(5); // [bS x 2*inSize]
|
||||||
auto inGradHt = INPUT_VARIABLE(6); // [time x bS x 2*inSize]
|
auto inGradHt = INPUT_VARIABLE(6); // [time x bS x 2*inSize]
|
||||||
NDArray* mask = block.width() > 7 ? INPUT_VARIABLE(7) : nullptr; // optional, 2d tensor of dropout mask [bS x 2*inSize]
|
NDArray* mask = block.width() > 7 ? INPUT_VARIABLE(7) : nullptr; // optional, 2d tensor of dropout mask [bS x 2*inSize]
|
||||||
|
|
||||||
// input shapes validation
|
// input shapes validation
|
||||||
const int rank = x->rankOf();
|
const int rank = x->rankOf();
|
||||||
|
@ -445,7 +445,7 @@ CUSTOM_OP_IMPL(sru_bi_bp, 8, 4, true, 0, 0) {
|
||||||
const Nd4jLong inSize = x->sizeAt(2) / 2;
|
const Nd4jLong inSize = x->sizeAt(2) / 2;
|
||||||
|
|
||||||
REQUIRE_TRUE(w->rankOf() == rank-1, 0, "SRU_BI_BP operation: wrong rank of weights array, expected is %i, but got %i instead !", rank-1, w->rankOf());
|
REQUIRE_TRUE(w->rankOf() == rank-1, 0, "SRU_BI_BP operation: wrong rank of weights array, expected is %i, but got %i instead !", rank-1, w->rankOf());
|
||||||
REQUIRE_TRUE(b->rankOf() <= rank-1, 0, "SRU_BI_BP operation: wrong rank of biases array, expected is <=2, but got %i instead !", b->rankOf());
|
REQUIRE_TRUE(b->rankOf() == 1, 0, "SRU_BI_BP operation: wrong rank of biases array, expected is 1, but got %i instead !", b->rankOf());
|
||||||
REQUIRE_TRUE(c0->rankOf() == rank-1, 0, "SRU_BI_BP operation: wrong rank of initial state array, expected is %i, but got %i instead !", rank-1, c0->rankOf());
|
REQUIRE_TRUE(c0->rankOf() == rank-1, 0, "SRU_BI_BP operation: wrong rank of initial state array, expected is %i, but got %i instead !", rank-1, c0->rankOf());
|
||||||
REQUIRE_TRUE(ct->rankOf() == rank, 0, "SRU_BI_BP operation: wrong rank of state array, expected is %i, but got %i instead !", rank, ct->rankOf());
|
REQUIRE_TRUE(ct->rankOf() == rank, 0, "SRU_BI_BP operation: wrong rank of state array, expected is %i, but got %i instead !", rank, ct->rankOf());
|
||||||
REQUIRE_TRUE(inGradC0->rankOf() == rank-1, 0, "SRU_BI_BP operation: wrong rank of gradient c0, expected is %i, but got %i instead !", rank-1, inGradC0->rankOf());
|
REQUIRE_TRUE(inGradC0->rankOf() == rank-1, 0, "SRU_BI_BP operation: wrong rank of gradient c0, expected is %i, but got %i instead !", rank-1, inGradC0->rankOf());
|
||||||
|
@ -456,7 +456,7 @@ CUSTOM_OP_IMPL(sru_bi_bp, 8, 4, true, 0, 0) {
|
||||||
const std::string wShape = ShapeUtils::shapeAsString(w);
|
const std::string wShape = ShapeUtils::shapeAsString(w);
|
||||||
const std::string wCorrectShape = ShapeUtils::shapeAsString({2*inSize, 6*inSize});
|
const std::string wCorrectShape = ShapeUtils::shapeAsString({2*inSize, 6*inSize});
|
||||||
const std::string bShape = ShapeUtils::shapeAsString(b);
|
const std::string bShape = ShapeUtils::shapeAsString(b);
|
||||||
const std::string bCorrectShape = ShapeUtils::shapeAsString({1, 4*inSize});
|
const std::string bCorrectShape = ShapeUtils::shapeAsString({4*inSize});
|
||||||
const std::string c0Shape = ShapeUtils::shapeAsString(c0);
|
const std::string c0Shape = ShapeUtils::shapeAsString(c0);
|
||||||
const std::string c0CorrectShape = ShapeUtils::shapeAsString({bS, 2*inSize});
|
const std::string c0CorrectShape = ShapeUtils::shapeAsString({bS, 2*inSize});
|
||||||
const std::string ctShape = ShapeUtils::shapeAsString(ct);
|
const std::string ctShape = ShapeUtils::shapeAsString(ct);
|
||||||
|
@ -470,7 +470,7 @@ CUSTOM_OP_IMPL(sru_bi_bp, 8, 4, true, 0, 0) {
|
||||||
const std::string maskShape = ShapeUtils::shapeAsString(mask);
|
const std::string maskShape = ShapeUtils::shapeAsString(mask);
|
||||||
REQUIRE_TRUE(maskShape == c0CorrectShape, 0, "SRU_BI operation: wrong shape of mask array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), maskShape.c_str());
|
REQUIRE_TRUE(maskShape == c0CorrectShape, 0, "SRU_BI operation: wrong shape of mask array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), maskShape.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
auto gradI = OUTPUT_VARIABLE(0); // [time x bS x 2*inSize]
|
auto gradI = OUTPUT_VARIABLE(0); // [time x bS x 2*inSize]
|
||||||
auto gradW = OUTPUT_VARIABLE(1); // [time x 2*inSize x 6*inSize]
|
auto gradW = OUTPUT_VARIABLE(1); // [time x 2*inSize x 6*inSize]
|
||||||
auto gradB = OUTPUT_VARIABLE(2); // [1 x 4*inSize]
|
auto gradB = OUTPUT_VARIABLE(2); // [1 x 4*inSize]
|
||||||
|
@ -484,9 +484,9 @@ CUSTOM_OP_IMPL(sru_bi_bp, 8, 4, true, 0, 0) {
|
||||||
DECLARE_SHAPE_FN(sru_bi_bp) {
|
DECLARE_SHAPE_FN(sru_bi_bp) {
|
||||||
|
|
||||||
auto xShapeInfo = inputShape->at(0); // [time x bS x 2K ]
|
auto xShapeInfo = inputShape->at(0); // [time x bS x 2K ]
|
||||||
auto wShapeInfo = inputShape->at(1);
|
auto wShapeInfo = inputShape->at(1);
|
||||||
auto bShapeInfo = inputShape->at(2);
|
auto bShapeInfo = inputShape->at(2);
|
||||||
auto c0ShapeInfo = inputShape->at(3);
|
auto c0ShapeInfo = inputShape->at(3);
|
||||||
auto ctShapeInfo = inputShape->at(4);
|
auto ctShapeInfo = inputShape->at(4);
|
||||||
auto inGradC0ShapeInfo = inputShape->at(5);
|
auto inGradC0ShapeInfo = inputShape->at(5);
|
||||||
auto inGradHtShapeInfo = inputShape->at(6);
|
auto inGradHtShapeInfo = inputShape->at(6);
|
||||||
|
@ -499,7 +499,7 @@ DECLARE_SHAPE_FN(sru_bi_bp) {
|
||||||
const Nd4jLong inSize = xShapeInfo[3] / 2;
|
const Nd4jLong inSize = xShapeInfo[3] / 2;
|
||||||
|
|
||||||
REQUIRE_TRUE(wShapeInfo[0] == rank-1, 0, "SRU_BI_BP operation: wrong rank of weights array, expected is %i, but got %i instead !", rank-1, wShapeInfo[0]);
|
REQUIRE_TRUE(wShapeInfo[0] == rank-1, 0, "SRU_BI_BP operation: wrong rank of weights array, expected is %i, but got %i instead !", rank-1, wShapeInfo[0]);
|
||||||
REQUIRE_TRUE(bShapeInfo[0] <= rank-1, 0, "SRU_BI_BP operation: wrong rank of biases array, expected is <=2, but got %i instead !", bShapeInfo);
|
REQUIRE_TRUE(bShapeInfo[0] == 1, 0, "SRU_BI_BP operation: wrong rank of biases array, expected is 1, but got %i instead !", bShapeInfo[0]);
|
||||||
REQUIRE_TRUE(c0ShapeInfo[0] == rank-1, 0, "SRU_BI_BP operation: wrong rank of initial state array, expected is %i, but got %i instead !", rank-1, c0ShapeInfo);
|
REQUIRE_TRUE(c0ShapeInfo[0] == rank-1, 0, "SRU_BI_BP operation: wrong rank of initial state array, expected is %i, but got %i instead !", rank-1, c0ShapeInfo);
|
||||||
REQUIRE_TRUE(ctShapeInfo[0] == rank, 0, "SRU_BI_BP operation: wrong rank of state array, expected is %i, but got %i instead !", rank, ctShapeInfo);
|
REQUIRE_TRUE(ctShapeInfo[0] == rank, 0, "SRU_BI_BP operation: wrong rank of state array, expected is %i, but got %i instead !", rank, ctShapeInfo);
|
||||||
REQUIRE_TRUE(inGradC0ShapeInfo[0] == rank-1, 0, "SRU_BI_BP operation: wrong rank of gradient c0, expected is %i, but got %i instead !", rank-1, inGradC0ShapeInfo[0]);
|
REQUIRE_TRUE(inGradC0ShapeInfo[0] == rank-1, 0, "SRU_BI_BP operation: wrong rank of gradient c0, expected is %i, but got %i instead !", rank-1, inGradC0ShapeInfo[0]);
|
||||||
|
@ -510,7 +510,7 @@ DECLARE_SHAPE_FN(sru_bi_bp) {
|
||||||
const std::string wShape = ShapeUtils::shapeAsString(wShapeInfo);
|
const std::string wShape = ShapeUtils::shapeAsString(wShapeInfo);
|
||||||
const std::string wCorrectShape = ShapeUtils::shapeAsString({2*inSize, 6*inSize});
|
const std::string wCorrectShape = ShapeUtils::shapeAsString({2*inSize, 6*inSize});
|
||||||
const std::string bShape = ShapeUtils::shapeAsString(bShapeInfo);
|
const std::string bShape = ShapeUtils::shapeAsString(bShapeInfo);
|
||||||
const std::string bCorrectShape = ShapeUtils::shapeAsString({1, 4*inSize});
|
const std::string bCorrectShape = ShapeUtils::shapeAsString({4*inSize});
|
||||||
const std::string c0Shape = ShapeUtils::shapeAsString(c0ShapeInfo);
|
const std::string c0Shape = ShapeUtils::shapeAsString(c0ShapeInfo);
|
||||||
const std::string c0CorrectShape = ShapeUtils::shapeAsString({bS, 2*inSize});
|
const std::string c0CorrectShape = ShapeUtils::shapeAsString({bS, 2*inSize});
|
||||||
const std::string ctShape = ShapeUtils::shapeAsString(ctShapeInfo);
|
const std::string ctShape = ShapeUtils::shapeAsString(ctShapeInfo);
|
||||||
|
@ -535,11 +535,11 @@ DECLARE_SHAPE_FN(sru_bi_bp) {
|
||||||
|
|
||||||
ShapeDescriptor descriptor1(ArrayOptions::dataType(xShapeInfo), order, {time, bS, 2 * inSize});
|
ShapeDescriptor descriptor1(ArrayOptions::dataType(xShapeInfo), order, {time, bS, 2 * inSize});
|
||||||
ShapeDescriptor descriptor2(ArrayOptions::dataType(xShapeInfo), order, {time, 2 * inSize, 6 * inSize});
|
ShapeDescriptor descriptor2(ArrayOptions::dataType(xShapeInfo), order, {time, 2 * inSize, 6 * inSize});
|
||||||
ShapeDescriptor descriptor3(ArrayOptions::dataType(xShapeInfo), order, {1, 4 * inSize});
|
ShapeDescriptor descriptor3(ArrayOptions::dataType(xShapeInfo), order, {4 * inSize});
|
||||||
ShapeDescriptor descriptor4(ArrayOptions::dataType(xShapeInfo), order, {bS, 2 * inSize});
|
ShapeDescriptor descriptor4(ArrayOptions::dataType(xShapeInfo), order, {bS, 2 * inSize});
|
||||||
|
|
||||||
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(descriptor1), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor2), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor3), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor4));
|
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(descriptor1), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor2), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor3), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor4));
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -549,15 +549,15 @@ DECLARE_SHAPE_FN(sru_bi_bp) {
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
/**
|
/**
|
||||||
* Implementation of operations for Simple Recurrent Unit: "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi
|
* Implementation of operations for Simple Recurrent Unit: "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi
|
||||||
*
|
*
|
||||||
* Input arrays:
|
* Input arrays:
|
||||||
* 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - batch size, K - number of features
|
* 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - batch size, K - number of features
|
||||||
* 1: 2d tensor of weights [3K x K]
|
* 1: 2d tensor of weights [3K x K]
|
||||||
* 2: row of biases with twice length [1 x 2K]
|
* 2: row of biases with twice length [1 x 2K]
|
||||||
* 3: 2d tensor of previous cell state [bS x K]
|
* 3: 2d tensor of previous cell state [bS x K]
|
||||||
* 4: optional, 2d tensor of dropout mask [bS x K]
|
* 4: optional, 2d tensor of dropout mask [bS x K]
|
||||||
*
|
*
|
||||||
* Output arrays:
|
* Output arrays:
|
||||||
* 0: 3d tensor of cell output [bS x K x N]
|
* 0: 3d tensor of cell output [bS x K x N]
|
||||||
* 1: 3d tensor of cell state [bS x K x N]
|
* 1: 3d tensor of cell state [bS x K x N]
|
||||||
*/
|
*/
|
||||||
|
@ -568,15 +568,15 @@ DECLARE_SHAPE_FN(sru_bi_bp) {
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
/**
|
/**
|
||||||
* Implementation of operation for Simple Recurrent Unit: "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi
|
* Implementation of operation for Simple Recurrent Unit: "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi
|
||||||
*
|
*
|
||||||
* Input arrays:
|
* Input arrays:
|
||||||
* 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - batch size, K - number of features
|
* 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - batch size, K - number of features
|
||||||
* 1: 2d tensor of weights [3K x K]
|
* 1: 2d tensor of weights [3K x K]
|
||||||
* 2: row of biases with twice length [1 x 2K]
|
* 2: row of biases with twice length [1 x 2K]
|
||||||
* 3: 2d tensor of previous cell state [bS x K]
|
* 3: 2d tensor of previous cell state [bS x K]
|
||||||
* 4: optional, 2d tensor of dropout mask [bS x K]
|
* 4: optional, 2d tensor of dropout mask [bS x K]
|
||||||
*
|
*
|
||||||
* Output arrays:
|
* Output arrays:
|
||||||
* 0: 3d tensor of cell output [bS x K x N]
|
* 0: 3d tensor of cell output [bS x K x N]
|
||||||
* 1: 3d tensor of cell state [bS x K x N]
|
* 1: 3d tensor of cell state [bS x K x N]
|
||||||
*/
|
*/
|
||||||
|
@ -588,8 +588,8 @@ DECLARE_SHAPE_FN(sru_bi_bp) {
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
/**
|
/**
|
||||||
* Implementation of operation for back propagation in Simple Recurrent Unit: "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi
|
* Implementation of operation for back propagation in Simple Recurrent Unit: "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi
|
||||||
*
|
*
|
||||||
* Input arrays:
|
* Input arrays:
|
||||||
* 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - batch size, K - number of features
|
* 0: input 3d tensor with shape [bS x K x N], N - number of time steps, bS - batch size, K - number of features
|
||||||
* 1: 2d tensor of weights [3K x K]
|
* 1: 2d tensor of weights [3K x K]
|
||||||
* 2: row of biases with twice length [1 x 2K]
|
* 2: row of biases with twice length [1 x 2K]
|
||||||
|
@ -598,13 +598,13 @@ DECLARE_SHAPE_FN(sru_bi_bp) {
|
||||||
* 5: 2d tensor of cell state gradients [bS x K]
|
* 5: 2d tensor of cell state gradients [bS x K]
|
||||||
* 6: 3d tensor of state output gradients [bS x K x N]
|
* 6: 3d tensor of state output gradients [bS x K x N]
|
||||||
* 7: optional, 2d tensor of dropout mask [bS x K]
|
* 7: optional, 2d tensor of dropout mask [bS x K]
|
||||||
*
|
*
|
||||||
* Output arrays:
|
* Output arrays:
|
||||||
* 0: 3d tensor of input gradients [bS x K x N]
|
* 0: 3d tensor of input gradients [bS x K x N]
|
||||||
* 1: 3d tensor of weights gradients [bS x 3K x K]
|
* 1: 3d tensor of weights gradients [bS x 3K x K]
|
||||||
* 2: 2d, row of biases gradients [1 x 2K]
|
* 2: 2d, row of biases gradients [1 x 2K]
|
||||||
* 3: 2d, tensor of state gradients [bS x K]
|
* 3: 2d, tensor of state gradients [bS x K]
|
||||||
*/
|
*/
|
||||||
// #if NOT_EXCLUDED(OP_sru_logic)
|
// #if NOT_EXCLUDED(OP_sru_logic)
|
||||||
// DECLARE_CUSTOM_OP(sru_bp_logic,8, 4, true, 0, 0);
|
// DECLARE_CUSTOM_OP(sru_bp_logic,8, 4, true, 0, 0);
|
||||||
// #endif
|
// #endif
|
||||||
|
@ -618,27 +618,27 @@ DECLARE_SHAPE_FN(sru_bi_bp) {
|
||||||
// }
|
// }
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////
|
||||||
// CUSTOM_OP_IMPL(sru_logic, 5, 2, false, 0, 0) {
|
// CUSTOM_OP_IMPL(sru_logic, 5, 2, false, 0, 0) {
|
||||||
|
|
||||||
// auto input = INPUT_VARIABLE(0); // X, input 3d tensor [bS x K x N], N - number of time steps, bS - batch size, K - number of features
|
// auto input = INPUT_VARIABLE(0); // X, input 3d tensor [bS x K x N], N - number of time steps, bS - batch size, K - number of features
|
||||||
// auto weights = INPUT_VARIABLE(1); // W, 2d tensor of weights [3K x K]
|
// auto weights = INPUT_VARIABLE(1); // W, 2d tensor of weights [3K x K]
|
||||||
// auto bias = INPUT_VARIABLE(2); // B, row of biases with twice length [1 x 2*K]
|
// auto bias = INPUT_VARIABLE(2); // B, row of biases with twice length [1 x 2*K]
|
||||||
// auto init = INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state [bS x K] at time t=0
|
// auto init = INPUT_VARIABLE(3); // C_{0}, 2d tensor of initial state [bS x K] at time t=0
|
||||||
// NDArray* mask = nullptr; // optional, 2d tensor of dropout mask [bS x K]
|
// NDArray* mask = nullptr; // optional, 2d tensor of dropout mask [bS x K]
|
||||||
|
|
||||||
// bool applyMask = false;
|
// bool applyMask = false;
|
||||||
// if (block.width() > 4) {
|
// if (block.width() > 4) {
|
||||||
// mask = INPUT_VARIABLE(4);
|
// mask = INPUT_VARIABLE(4);
|
||||||
// applyMask = true;
|
// applyMask = true;
|
||||||
// }
|
// }
|
||||||
|
|
||||||
// auto output = OUTPUT_VARIABLE(0); // h_t, [bS x K x N]
|
// auto output = OUTPUT_VARIABLE(0); // h_t, [bS x K x N]
|
||||||
// auto state = OUTPUT_VARIABLE(1); // c_t, [bS x K x N]
|
// auto state = OUTPUT_VARIABLE(1); // c_t, [bS x K x N]
|
||||||
|
|
||||||
// const int bS = input->shapeOf()[0]; // bS - batch size
|
// const int bS = input->shapeOf()[0]; // bS - batch size
|
||||||
// const int K = input->shapeOf()[1]; // K - number of features
|
// const int K = input->shapeOf()[1]; // K - number of features
|
||||||
// const int N = input->shapeOf()[2]; // N - number of time steps
|
// const int N = input->shapeOf()[2]; // N - number of time steps
|
||||||
|
|
||||||
// const auto wi = mmul(*weights, *input); // U [bS x 3K x N]
|
// const auto wi = mmul(*weights, *input); // U [bS x 3K x N]
|
||||||
// const auto bF = (*bias)({0,0, 0, K}); // biases for forget gate [1 x K]
|
// const auto bF = (*bias)({0,0, 0, K}); // biases for forget gate [1 x K]
|
||||||
// const auto bR = (*bias)({0,0, K,2*K}); // biases for reset gate [1 x K]
|
// const auto bR = (*bias)({0,0, K,2*K}); // biases for reset gate [1 x K]
|
||||||
|
@ -664,7 +664,7 @@ DECLARE_SHAPE_FN(sru_bi_bp) {
|
||||||
// ft = sigmoid_(ft + bF);
|
// ft = sigmoid_(ft + bF);
|
||||||
// rt = sigmoid_(rt + bR);
|
// rt = sigmoid_(rt + bR);
|
||||||
|
|
||||||
// ct = ft * (ct - zt) + zt;
|
// ct = ft * (ct - zt) + zt;
|
||||||
// // TODO T val = (activation_type == 1) ? tanh(cur) : ((activation_type == 2) ? reluf(cur) : cur );
|
// // TODO T val = (activation_type == 1) ? tanh(cur) : ((activation_type == 2) ? reluf(cur) : cur );
|
||||||
// ct.applyTransform(transform::Tanh, &gct);
|
// ct.applyTransform(transform::Tanh, &gct);
|
||||||
// ht = rt * (gct - xt) + xt;
|
// ht = rt * (gct - xt) + xt;
|
||||||
|
@ -694,16 +694,16 @@ DECLARE_SHAPE_FN(sru_bi_bp) {
|
||||||
|
|
||||||
// Nd4jLong* newShapeInfo1 = nullptr;
|
// Nd4jLong* newShapeInfo1 = nullptr;
|
||||||
// ALLOCATE(newShapeInfo1, block.getWorkspace(), size, Nd4jLong);
|
// ALLOCATE(newShapeInfo1, block.getWorkspace(), size, Nd4jLong);
|
||||||
|
|
||||||
// newShapeInfo1[0] = rank;
|
// newShapeInfo1[0] = rank;
|
||||||
// newShapeInfo1[1] = bS;
|
// newShapeInfo1[1] = bS;
|
||||||
// newShapeInfo1[2] = K;
|
// newShapeInfo1[2] = K;
|
||||||
// newShapeInfo1[3] = N;
|
// newShapeInfo1[3] = N;
|
||||||
|
|
||||||
// ShapeUtils::updateStridesAndType(newShapeInfo1, inShape, order);
|
// ShapeUtils::updateStridesAndType(newShapeInfo1, inShape, order);
|
||||||
// auto result = CONSTANT(newShapeInfo1);
|
// auto result = CONSTANT(newShapeInfo1);
|
||||||
// return SHAPELIST(result, result);
|
// return SHAPELIST(result, result);
|
||||||
// }
|
// }
|
||||||
|
|
||||||
|
|
||||||
// //////////////////////////////////////////////////////////////////////////
|
// //////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -860,7 +860,7 @@ DECLARE_SHAPE_FN(sru_bi_bp) {
|
||||||
// const std::string inGradCtCorrectShape = ShapeUtils::shapeAsString({bS, inSize});
|
// const std::string inGradCtCorrectShape = ShapeUtils::shapeAsString({bS, inSize});
|
||||||
// const std::string inGradHShape = ShapeUtils::shapeAsString(inGradH);
|
// const std::string inGradHShape = ShapeUtils::shapeAsString(inGradH);
|
||||||
// const std::string inGradHCorrectShape = ShapeUtils::shapeAsString({bS, inSize, time});
|
// const std::string inGradHCorrectShape = ShapeUtils::shapeAsString({bS, inSize, time});
|
||||||
|
|
||||||
// REQUIRE_TRUE(wShape == wCorrectShape, 0, "SRU_BP operation: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
|
// REQUIRE_TRUE(wShape == wCorrectShape, 0, "SRU_BP operation: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
|
||||||
// // REQUIRE_TRUE(bShape == bCorrectShape, 0, "SRU_BP operation: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
|
// // REQUIRE_TRUE(bShape == bCorrectShape, 0, "SRU_BP operation: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
|
||||||
// REQUIRE_TRUE(c0Shape == c0CorrectShape, 0, "SRU_BP operation: wrong shape of initial state array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), c0Shape.c_str());
|
// REQUIRE_TRUE(c0Shape == c0CorrectShape, 0, "SRU_BP operation: wrong shape of initial state array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), c0Shape.c_str());
|
||||||
|
@ -896,11 +896,11 @@ DECLARE_SHAPE_FN(sru_bi_bp) {
|
||||||
// auto inGradHt = (*inGradH)({ 0,0, 0,0, t,t+1}); // [bS x inSize x time] -> [bS x inSize]
|
// auto inGradHt = (*inGradH)({ 0,0, 0,0, t,t+1}); // [bS x inSize x time] -> [bS x inSize]
|
||||||
|
|
||||||
// auto ct_1 = t ? (*c)({ 0,0, 0,0, t-1,t}) : *c0; // previous c_{t-1}
|
// auto ct_1 = t ? (*c)({ 0,0, 0,0, t-1,t}) : *c0; // previous c_{t-1}
|
||||||
|
|
||||||
// ///////////////// forward
|
// ///////////////// forward
|
||||||
// // ft = sigmoid(ft + bf), rt = sigmoid(rt + bR)
|
// // ft = sigmoid(ft + bf), rt = sigmoid(rt + bR)
|
||||||
// ft = sigmoid_(ft + bF);
|
// ft = sigmoid_(ft + bF);
|
||||||
// rt = sigmoid_(rt + bR);
|
// rt = sigmoid_(rt + bR);
|
||||||
// // TODO T val = (activation_type == 1) ? tanh(cur) : ((activation_type == 2) ? reluf(cur) : cur );
|
// // TODO T val = (activation_type == 1) ? tanh(cur) : ((activation_type == 2) ? reluf(cur) : cur );
|
||||||
// ct.applyTransform(transform::Tanh, &gct);
|
// ct.applyTransform(transform::Tanh, &gct);
|
||||||
|
|
||||||
|
@ -910,7 +910,7 @@ DECLARE_SHAPE_FN(sru_bi_bp) {
|
||||||
// NDArray ftMinus = 1. - ft;
|
// NDArray ftMinus = 1. - ft;
|
||||||
// NDArray rtMinus = 1. - rt;
|
// NDArray rtMinus = 1. - rt;
|
||||||
// NDArray gradBRt = inGradHt * (gct - xt) * rtMinus * rt;
|
// NDArray gradBRt = inGradHt * (gct - xt) * rtMinus * rt;
|
||||||
// // bF, TODO - tanh
|
// // bF, TODO - tanh
|
||||||
// NDArray gradTanh = 1. - gct * gct;
|
// NDArray gradTanh = 1. - gct * gct;
|
||||||
// NDArray gradCt = inGradHt * rt * gradTanh;
|
// NDArray gradCt = inGradHt * rt * gradTanh;
|
||||||
// NDArray gradBFt = (gradCt + *inGradCt) * (ct_1 - zt) * ftMinus * ft;
|
// NDArray gradBFt = (gradCt + *inGradCt) * (ct_1 - zt) * ftMinus * ft;
|
||||||
|
@ -923,7 +923,7 @@ DECLARE_SHAPE_FN(sru_bi_bp) {
|
||||||
// // c_{t-1}, inGradCt = (gradCt + inGradCt) * ft;
|
// // c_{t-1}, inGradCt = (gradCt + inGradCt) * ft;
|
||||||
// *inGradCt = (gradCt + *inGradCt) * ft;
|
// *inGradCt = (gradCt + *inGradCt) * ft;
|
||||||
|
|
||||||
// // save results
|
// // save results
|
||||||
// gradBias({0,0, 0,inSize, t,t+1}, true).assign(gradBFt);
|
// gradBias({0,0, 0,inSize, t,t+1}, true).assign(gradBFt);
|
||||||
// gradBias({0,0, inSize,2*inSize, t,t+1}, true).assign(gradBRt);
|
// gradBias({0,0, inSize,2*inSize, t,t+1}, true).assign(gradBRt);
|
||||||
// gradU({0,0, 0,inSize, t,t+1}, true).assign(gradUZt);
|
// gradU({0,0, 0,inSize, t,t+1}, true).assign(gradUZt);
|
||||||
|
@ -934,19 +934,19 @@ DECLARE_SHAPE_FN(sru_bi_bp) {
|
||||||
|
|
||||||
// // gradInit
|
// // gradInit
|
||||||
// gradInit->assign(inGradCt);
|
// gradInit->assign(inGradCt);
|
||||||
// // gradX
|
// // gradX
|
||||||
// w->transposei(); // [inSize x 3K]
|
// w->transposei(); // [inSize x 3K]
|
||||||
// gradX->assign( mmul(*w, gradU) + gradHX);
|
// gradX->assign( mmul(*w, gradU) + gradHX);
|
||||||
// if(mask)
|
// if(mask)
|
||||||
// gradX->applyBroadcast(broadcast::Multiply, {0,1}, mask, gradX, nullptr); // apply mask
|
// gradX->applyBroadcast(broadcast::Multiply, {0,1}, mask, gradX, nullptr); // apply mask
|
||||||
|
|
||||||
// // gradB
|
// // gradB
|
||||||
// gradBias.reduceAlongDimension(reduce::Sum, gradB, {0,2}, false, true); // [1 x 2K]
|
// gradBias.reduceAlongDimension(reduce::Sum, gradB, {0,2}, false, true); // [1 x 2K]
|
||||||
|
|
||||||
// // gradW [bS x 3K x inSize]
|
// // gradW [bS x 3K x inSize]
|
||||||
// x->permutei({0, 2, 1}); // [bS x time x inSize]
|
// x->permutei({0, 2, 1}); // [bS x time x inSize]
|
||||||
// gradW->assign( mmul(gradU, *x) );
|
// gradW->assign( mmul(gradU, *x) );
|
||||||
|
|
||||||
// return Status::OK();
|
// return Status::OK();
|
||||||
// }
|
// }
|
||||||
|
|
||||||
|
@ -969,4 +969,4 @@ DECLARE_SHAPE_FN(sru_bi_bp) {
|
||||||
// ShapeDescriptor descriptor4(ArrayOptions::dataType(inShape), order, {bS, inSize});
|
// ShapeDescriptor descriptor4(ArrayOptions::dataType(inShape), order, {bS, inSize});
|
||||||
|
|
||||||
// return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(descriptor1), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor2), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor3), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor4));
|
// return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(descriptor1), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor2), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor3), ConstantShapeHelper::getInstance()->createShapeInfo(descriptor4));
|
||||||
// }
|
// }
|
||||||
|
|
|
@ -293,46 +293,46 @@ void softmax(nd4j::LaunchContext * context, const NDArray& input, NDArray& outpu
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void prelu(nd4j::LaunchContext * context, const NDArray& input, const NDArray& alpha, NDArray& output) {
|
void prelu(nd4j::LaunchContext * context, const NDArray& input, const NDArray& alpha, NDArray& output) {
|
||||||
const Nd4jLong inputLen = input.lengthOf();
|
const Nd4jLong inputLen = input.lengthOf();
|
||||||
const Nd4jLong* inputShapeInfo = input.getShapeInfo();
|
const Nd4jLong* inputShapeInfo = input.getShapeInfo();
|
||||||
const Nd4jLong* alphaShapeInfo = alpha.getShapeInfo();
|
const Nd4jLong* alphaShapeInfo = alpha.getShapeInfo();
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_IF(inputLen > Environment::getInstance()->elementwiseThreshold())
|
PRAGMA_OMP_PARALLEL_FOR_IF(inputLen > Environment::getInstance()->elementwiseThreshold())
|
||||||
for(Nd4jLong i = 0; i < inputLen; ++i) {
|
for(Nd4jLong i = 0; i < inputLen; ++i) {
|
||||||
// FIXME: double!
|
// FIXME: double!
|
||||||
double x = input.e<double>(i);
|
double x = input.e<double>(i);
|
||||||
if(x < 0.0) {
|
if(x < 0.0) {
|
||||||
// FIXME: double
|
|
||||||
output.p(i, (x * alpha.e<double>(shape::subArrayIndex(i, inputShapeInfo, alphaShapeInfo))));
|
|
||||||
} else
|
|
||||||
output.p(i, x);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
|
||||||
void preluBP(nd4j::LaunchContext * context, const NDArray& input, const NDArray& alpha, const NDArray& dLdO, NDArray& dLdI, NDArray& dLdA) {
|
|
||||||
|
|
||||||
const Nd4jLong inputLen = input.lengthOf();
|
|
||||||
const Nd4jLong* inputShapeInfo = input.getShapeInfo();
|
|
||||||
const Nd4jLong* alphaShapeInfo = alpha.getShapeInfo();
|
|
||||||
|
|
||||||
dLdA.assign(0.0f);
|
|
||||||
|
|
||||||
for(Nd4jLong i = 0; i < inputLen; ++i) {
|
|
||||||
// FIXME: double
|
// FIXME: double
|
||||||
double x = input.e<double>(i);
|
output.p(i, (x * alpha.e<double>(shape::subArrayIndex(i, inputShapeInfo, alphaShapeInfo))));
|
||||||
double grO = dLdO.e<double>(i);
|
} else
|
||||||
if(x < 0.0) {
|
output.p(i, x);
|
||||||
Nd4jLong alphaInd = shape::subArrayIndex(i, inputShapeInfo, alphaShapeInfo);
|
}
|
||||||
dLdI.p(i, grO * alpha.e<double>(alphaInd));
|
}
|
||||||
double prevVal = dLdA.e<double>(alphaInd);
|
|
||||||
prevVal += (grO * x);
|
//////////////////////////////////////////////////////////////////////////
|
||||||
dLdA.p(alphaInd, prevVal );
|
void preluBP(nd4j::LaunchContext * context, const NDArray& input, const NDArray& alpha, const NDArray& dLdO, NDArray& dLdI, NDArray& dLdA) {
|
||||||
}
|
|
||||||
else
|
const Nd4jLong inputLen = input.lengthOf();
|
||||||
dLdI.p(i, grO);
|
const Nd4jLong* inputShapeInfo = input.getShapeInfo();
|
||||||
|
const Nd4jLong* alphaShapeInfo = alpha.getShapeInfo();
|
||||||
|
|
||||||
|
dLdA.assign(0.0f);
|
||||||
|
|
||||||
|
for(Nd4jLong i = 0; i < inputLen; ++i) {
|
||||||
|
// FIXME: double
|
||||||
|
double x = input.e<double>(i);
|
||||||
|
double grO = dLdO.e<double>(i);
|
||||||
|
if(x < 0.0) {
|
||||||
|
Nd4jLong alphaInd = shape::subArrayIndex(i, inputShapeInfo, alphaShapeInfo);
|
||||||
|
dLdI.p(i, grO * alpha.e<double>(alphaInd));
|
||||||
|
double prevVal = dLdA.e<double>(alphaInd);
|
||||||
|
prevVal += (grO * x);
|
||||||
|
dLdA.p(alphaInd, prevVal);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
dLdI.p(i, grO);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -45,7 +45,7 @@ static void batchnorm_(const NDArray* input, const NDArray* mean, const NDArray*
|
||||||
}
|
}
|
||||||
|
|
||||||
// auto sigmaInvGam = (*variance + epsilon).transform(transform::RSqrt); // sigmaInvGam = 1 / sqrt(variance + epsilon)
|
// auto sigmaInvGam = (*variance + epsilon).transform(transform::RSqrt); // sigmaInvGam = 1 / sqrt(variance + epsilon)
|
||||||
// if(gamma != nullptr) sigmaInvGam *= *gamma;
|
// if(gamma != nullptr) sigmaInvGam *= *gamma;
|
||||||
|
|
||||||
const T* sigmaBuff = sigmaInvGam.bufferAsT<T>();
|
const T* sigmaBuff = sigmaInvGam.bufferAsT<T>();
|
||||||
const T* meanBuff = mean->bufferAsT<T>();
|
const T* meanBuff = mean->bufferAsT<T>();
|
||||||
|
@ -60,8 +60,8 @@ static void batchnorm_(const NDArray* input, const NDArray* mean, const NDArray*
|
||||||
uint inShapeInfoCast[MAX_RANK];
|
uint inShapeInfoCast[MAX_RANK];
|
||||||
uint meanShapeInfoCast[MAX_RANK];
|
uint meanShapeInfoCast[MAX_RANK];
|
||||||
bool canCastIn = nd4j::DataTypeUtils::castShapeInfo(inShapeInfo, inShapeInfoCast);
|
bool canCastIn = nd4j::DataTypeUtils::castShapeInfo(inShapeInfo, inShapeInfoCast);
|
||||||
bool canCastMean = nd4j::DataTypeUtils::castShapeInfo(meanShapeInfo, meanShapeInfoCast);
|
bool canCastMean = nd4j::DataTypeUtils::castShapeInfo(meanShapeInfo, meanShapeInfoCast);
|
||||||
|
|
||||||
const Nd4jLong step = lenBig / lenSmall;
|
const Nd4jLong step = lenBig / lenSmall;
|
||||||
std::vector<int> dimsToExclude = ShapeUtils::evalDimsToExclude(input->rankOf(), axes);
|
std::vector<int> dimsToExclude = ShapeUtils::evalDimsToExclude(input->rankOf(), axes);
|
||||||
|
|
||||||
|
@ -70,58 +70,62 @@ static void batchnorm_(const NDArray* input, const NDArray* mean, const NDArray*
|
||||||
if(beta != nullptr) {
|
if(beta != nullptr) {
|
||||||
const T* betaBuff = beta->bufferAsT<T>();
|
const T* betaBuff = beta->bufferAsT<T>();
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
||||||
{
|
{
|
||||||
const auto threadNum = omp_get_thread_num();
|
const auto threadNum = omp_get_thread_num();
|
||||||
Nd4jLong* inOffsets = new Nd4jLong[step];
|
Nd4jLong* inOffsets = new Nd4jLong[step];
|
||||||
|
Nd4jLong* memBuff = new Nd4jLong[2 * inShapeInfo[0]];
|
||||||
for (int j = 0; j < lenSmall; ++j) {
|
|
||||||
|
for (int j = 0; j < lenSmall; ++j) {
|
||||||
|
|
||||||
const bool isOwner = j < info._numThreads ? threadNum == j : threadNum == j % info._numThreads;
|
const bool isOwner = j < info._numThreads ? threadNum == j : threadNum == j % info._numThreads;
|
||||||
if (!isOwner) continue;
|
if (!isOwner) continue;
|
||||||
|
|
||||||
const Nd4jLong start = j * step;
|
const Nd4jLong start = j * step;
|
||||||
const Nd4jLong end = start + step;
|
const Nd4jLong end = start + step;
|
||||||
|
|
||||||
// calculate offset for mean, variance, gamma, beta (all of them have the same shape)
|
// calculate offset for mean, variance, gamma, beta (all of them have the same shape)
|
||||||
auto offsetSmall = shape::indexOffset(j, meanShapeInfo, meanShapeInfoCast, lenSmall, canCastMean);
|
auto offsetSmall = shape::indexOffset(j, meanShapeInfo, meanShapeInfoCast, lenSmall, canCastMean);
|
||||||
// calculate offset for input and output (all of them have the same shape)
|
// calculate offset for input and output (all of them have the same shape)
|
||||||
shape::outerArrayOffsets(inOffsets, j, inShapeInfo, meanShapeInfo, dimsToExclude.data());
|
shape::outerArrayOffsets(inOffsets, j, inShapeInfo, meanShapeInfo, memBuff, dimsToExclude.data());
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (Nd4jLong i = 0; i < step; ++i) {
|
for (Nd4jLong i = 0; i < step; ++i) {
|
||||||
auto offsetBig = inOffsets[i];
|
auto offsetBig = inOffsets[i];
|
||||||
outBuff[offsetBig] = (inBuff[offsetBig] - meanBuff[offsetSmall]) * sigmaBuff[offsetSmall] + betaBuff[offsetSmall];
|
outBuff[offsetBig] = (inBuff[offsetBig] - meanBuff[offsetSmall]) * sigmaBuff[offsetSmall] + betaBuff[offsetSmall];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
delete []inOffsets;
|
delete []inOffsets;
|
||||||
}
|
delete []memBuff;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
||||||
{
|
{
|
||||||
const auto threadNum = omp_get_thread_num();
|
const auto threadNum = omp_get_thread_num();
|
||||||
Nd4jLong* inOffsets = new Nd4jLong[step];
|
Nd4jLong* inOffsets = new Nd4jLong[step];
|
||||||
|
Nd4jLong* memBuff = new Nd4jLong[2 * inShapeInfo[0]];
|
||||||
for (int j = 0; j < lenSmall; ++j) {
|
|
||||||
|
for (int j = 0; j < lenSmall; ++j) {
|
||||||
|
|
||||||
const bool isOwner = j < info._numThreads ? threadNum == j : threadNum == j % info._numThreads;
|
const bool isOwner = j < info._numThreads ? threadNum == j : threadNum == j % info._numThreads;
|
||||||
if (!isOwner) continue;
|
if (!isOwner) continue;
|
||||||
|
|
||||||
const Nd4jLong start = j * step;
|
const Nd4jLong start = j * step;
|
||||||
const Nd4jLong end = start + step;
|
const Nd4jLong end = start + step;
|
||||||
|
|
||||||
// calculate offset for mean, variance, gamma, beta (all of them have the same shape)
|
// calculate offset for mean, variance, gamma, beta (all of them have the same shape)
|
||||||
auto offsetSmall = shape::indexOffset(j, meanShapeInfo, meanShapeInfoCast, lenSmall, canCastMean);
|
auto offsetSmall = shape::indexOffset(j, meanShapeInfo, meanShapeInfoCast, lenSmall, canCastMean);
|
||||||
// calculate offset for input and output (all of them have the same shape)
|
// calculate offset for input and output (all of them have the same shape)
|
||||||
shape::outerArrayOffsets(inOffsets, j, inShapeInfo, meanShapeInfo, dimsToExclude.data());
|
shape::outerArrayOffsets(inOffsets, j, inShapeInfo, meanShapeInfo, memBuff, dimsToExclude.data());
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (Nd4jLong i = 0; i < step; ++i) {
|
for (Nd4jLong i = 0; i < step; ++i) {
|
||||||
auto offsetBig = inOffsets[i];
|
auto offsetBig = inOffsets[i];
|
||||||
outBuff[offsetBig] = (inBuff[offsetBig] - meanBuff[offsetSmall]) * sigmaBuff[offsetSmall];
|
outBuff[offsetBig] = (inBuff[offsetBig] - meanBuff[offsetSmall]) * sigmaBuff[offsetSmall];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
delete []inOffsets;
|
delete []inOffsets;
|
||||||
|
delete []memBuff;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,11 +30,20 @@ namespace helpers {
|
||||||
|
|
||||||
|
|
||||||
if (x_shape->lengthOf() == 1 || y_shape->lengthOf() == 1) {// except case
|
if (x_shape->lengthOf() == 1 || y_shape->lengthOf() == 1) {// except case
|
||||||
auto lesser = (x_shape->lengthOf() == 1 ? x_shape: y_shape);
|
// lenght are equals
|
||||||
auto greater = (x_shape->lengthOf() == 1 ? y_shape: x_shape);
|
if (x_shape->lengthOf() == y_shape->lengthOf()) {
|
||||||
output->assign(greater);
|
auto greater = (x_shape->e<Nd4jLong>(0) < y_shape->e<Nd4jLong>(0) ? y_shape : x_shape);
|
||||||
|
output->assign(greater);
|
||||||
output->p(greater->lengthOf() - 1, lesser->e(0L));
|
}
|
||||||
|
else {
|
||||||
|
auto lesser = (x_shape->lengthOf() == 1 ? x_shape : y_shape);
|
||||||
|
auto greater = (x_shape->lengthOf() == 1 ? y_shape : x_shape);
|
||||||
|
output->assign(greater);
|
||||||
|
auto lastG = greater->lengthOf() - 1;
|
||||||
|
auto lastL = lesser->lengthOf() - 1;
|
||||||
|
if (greater->e<Nd4jLong>(lastG) < lesser->e<Nd4jLong>(lastL))
|
||||||
|
output->p(lastG, lesser->e(lastL));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
//int e = 0, x = 0, y = 0;
|
//int e = 0, x = 0, y = 0;
|
||||||
|
|
|
@ -1418,18 +1418,15 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d(
|
||||||
for (Nd4jLong kw = wstart; kw < wend; kw += iStep3)
|
for (Nd4jLong kw = wstart; kw < wend; kw += iStep3)
|
||||||
sum += pIn[kh + kw];
|
sum += pIn[kh + kw];
|
||||||
|
|
||||||
|
if (extraParam0 == 0) { //Exclude padding
|
||||||
auto oi = b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3;
|
int a = (hend-hstart)/iStep2 + ((hend-hstart) % iStep2 == 0 ? 0 : 1);
|
||||||
|
int b = (wend-wstart)/iStep3 + ((wend-wstart) % iStep3 == 0 ? 0 : 1);
|
||||||
if (extraParam0 == 0) { //Exclude padding
|
sum /= static_cast<T>(a * b); // Accounts for dilation
|
||||||
int _a = (hend-hstart)/iStep2 + ((hend-hstart) % iStep2 == 0 ? 0 : 1);
|
}
|
||||||
int _b = (wend-wstart)/iStep3 + ((wend-wstart) % iStep3 == 0 ? 0 : 1);
|
else if (extraParam0 == 1) //Include padding
|
||||||
|
|
||||||
sum /= _a * _b; //Accounts for dilation
|
|
||||||
} else if (extraParam0 == 1) //Include padding
|
|
||||||
sum /= kProd;
|
sum /= kProd;
|
||||||
|
|
||||||
out[oi] = sum;
|
out[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3] = sum;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,76 +1,92 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyrigkht (c) 2015-2018 Skymind, Inc.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* Tkhis program and tkhe accompanying materials are made available under tkhe
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of tkhe Apackhe License, Version 2.0 wkhickh is available at
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
* khttps://www.apackhe.org/licenses/LICENSE-2.0.
|
||||||
*
|
*
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
* distributed under tkhe License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
* WARRANTIES OR CONDITIONS OF ANY KIND, eitkher express or implied. See tkhe
|
||||||
* License for the specific language governing permissions and limitations
|
* License for tkhe specific language governing permissions and limitations
|
||||||
* under the License.
|
* under tkhe License.
|
||||||
*
|
*
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apackhe-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// @author raver119@gmail.com
|
// @autkhor raver119@gmail.com
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <ops/declarable/helpers/dilation2d.h>
|
#include <ops/declarable/helpers/dilation2d.h>
|
||||||
#include <array/DataTypeUtils.h>
|
#include <array/DataTypeUtils.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
template <typename X, typename Y>
|
|
||||||
static void dilation2d_(NDArray *input, NDArray *weights, NDArray *output, int stride_rows, int stride_cols, int rate_rows, int rate_cols, int pad_top, int pad_left) {
|
|
||||||
const int batch = input->sizeAt(0);
|
|
||||||
const int input_rows = input->sizeAt(1);
|
|
||||||
const int input_cols = input->sizeAt(2);
|
|
||||||
const int depth = input->sizeAt(3);
|
|
||||||
|
|
||||||
const int filter_rows = weights->sizeAt(0);
|
//////////////////////////////////////////////////////////////////////
|
||||||
const int filter_cols = weights->sizeAt(1);
|
template <typename X, typename Z>
|
||||||
|
static void dilation2d_(NDArray *input, NDArray *weights, NDArray *output, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW) {
|
||||||
|
|
||||||
const int output_rows = output->sizeAt(1);
|
// input [bS, iH, iW, iC]
|
||||||
const int output_cols = output->sizeAt(2);
|
// weights [kH, kW, iC]
|
||||||
|
// output [bS, oH, oW, iC]
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
const X* x = input->bufferAsT<X>();
|
||||||
for (int b = 0; b < batch; ++b) {
|
const X* y = weights->bufferAsT<X>();
|
||||||
for (int h_out = 0; h_out < output_rows; ++h_out) {
|
Z* z = output->bufferAsT<Z>();
|
||||||
int h_beg = h_out * stride_rows - pad_top;
|
|
||||||
for (int w_out = 0; w_out < output_cols; ++w_out) {
|
const Nd4jLong* xShapeInfo = input->getShapeInfo();
|
||||||
int w_beg = w_out * stride_cols - pad_left;
|
const Nd4jLong* yShapeInfo = weights->getShapeInfo();
|
||||||
for (int d = 0; d < depth; ++d) {
|
const Nd4jLong* zShapeInfo = output->getShapeInfo();
|
||||||
Y cur_val = -DataTypeUtils::max<Y>();
|
|
||||||
for (int h = 0; h < filter_rows; ++h) {
|
const uint bS = input->sizeAt(0);
|
||||||
const int h_in = h_beg + h * rate_rows;
|
const uint iH = input->sizeAt(1);
|
||||||
if (h_in >= 0 && h_in < input_rows) {
|
const uint iW = input->sizeAt(2);
|
||||||
for (int w = 0; w < filter_cols; ++w) {
|
const uint iC = input->sizeAt(3);
|
||||||
const int w_in = w_beg + w * rate_cols;
|
|
||||||
if (w_in >= 0 && w_in < input_cols) {
|
const uint kH = weights->sizeAt(0);
|
||||||
const Y val = input->e<Y>(b, h_in, w_in, d) + weights->e<Y>(h, w, d);
|
const uint kW = weights->sizeAt(1);
|
||||||
if (val > cur_val) {
|
|
||||||
cur_val = val;
|
const uint oH = output->sizeAt(1);
|
||||||
}
|
const uint oW = output->sizeAt(2);
|
||||||
}
|
|
||||||
}
|
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(4))
|
||||||
}
|
for (uint b = 0; b < bS; ++b) {
|
||||||
|
for (uint oh = 0; oh < oH; ++oh) {
|
||||||
|
for (uint ow = 0; ow < oW; ++ow) {
|
||||||
|
for (uint c = 0; c < iC; ++c) {
|
||||||
|
|
||||||
|
X max = -DataTypeUtils::max<X>();
|
||||||
|
|
||||||
|
for (uint kh = 0; kh < kH; ++kh) {
|
||||||
|
const int ih = oh * sH - pH + kh * dH;
|
||||||
|
if (ih < 0 || ih >= iH) continue;
|
||||||
|
|
||||||
|
for (uint kw = 0; kw < kW; ++kw) {
|
||||||
|
const int iw = ow * sW - pW + kw * dW;
|
||||||
|
if(iw < 0 || iw >= iW) continue;
|
||||||
|
|
||||||
|
const X val = x[shape::getOffset(xShapeInfo, {b,(uint)ih,(uint)iw,c})] + y[shape::getOffset(yShapeInfo, {kh,kw,c})];
|
||||||
|
if (val > max)
|
||||||
|
max = val;
|
||||||
}
|
}
|
||||||
(*output).p<Y>(b, h_out, w_out, d, cur_val);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
z[shape::getOffset(zShapeInfo, {b,oh,ow,c})] = static_cast<Z>(max);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
|
||||||
|
|
||||||
void dilation2d(nd4j::LaunchContext * context, NDArray *input, NDArray *weights, NDArray *output, int stride_rows, int stride_cols, int rate_rows, int rate_cols, int pad_top, int pad_left) {
|
|
||||||
BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), dilation2d_, (input, weights, output, stride_rows, stride_cols, rate_rows, rate_cols, pad_top, pad_left), LIBND4J_TYPES, FLOAT_TYPES);
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template void dilation2d_, (NDArray *input, NDArray *weights, NDArray *output, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW), LIBND4J_TYPES, FLOAT_TYPES);
|
||||||
|
|
||||||
|
void dilation2d(nd4j::LaunchContext* context, NDArray *input, NDArray *weights, NDArray *output, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW) {
|
||||||
|
BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), dilation2d_, (input, weights, output, sH, sW, pH, pW, dH, dW), LIBND4J_TYPES, FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
BUILD_DOUBLE_TEMPLATE(template void dilation2d_, (NDArray *input, NDArray *weights, NDArray *output, int stride_rows, int stride_cols, int rate_rows, int rate_cols, int pad_top, int pad_left), LIBND4J_TYPES, FLOAT_TYPES);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -44,9 +44,9 @@ namespace helpers {
|
||||||
BUILD_SINGLE_TEMPLATE(template void dropoutSimple, (NDArray const* input, NDArray* output, double probValue, int seed), FLOAT_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void dropoutSimple, (NDArray const* input, NDArray* output, double probValue, int seed), FLOAT_TYPES);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
int _dropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue) {
|
int dropOutFunctor_(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue) {
|
||||||
//NativeOps native;
|
//NativeOps native;
|
||||||
//nd4j::graph::RandomGenerator nodeRng(seed); //static int _dropOutFunctor(nd4j::random::RandomBuffer* rng, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue) {
|
//nd4j::graph::RandomGenerator nodeRng(seed); //static int dropOutFunctor_(nd4j::random::RandomBuffer* rng, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue) {
|
||||||
//NativeOps native;
|
//NativeOps native;
|
||||||
//native.reSeedBuffer(nullptr, (long)seed, rng);
|
//native.reSeedBuffer(nullptr, (long)seed, rng);
|
||||||
//if (newRng )
|
//if (newRng )
|
||||||
|
@ -78,9 +78,9 @@ namespace helpers {
|
||||||
// broadcast chunk to full matrix
|
// broadcast chunk to full matrix
|
||||||
std::unique_ptr<NDArray> dropOutMultiplier(new NDArray(*input));
|
std::unique_ptr<NDArray> dropOutMultiplier(new NDArray(*input));
|
||||||
dropOutMultiplier->assign(1.f);
|
dropOutMultiplier->assign(1.f);
|
||||||
|
|
||||||
*dropOutMultiplier += *chunk;
|
*dropOutMultiplier += *chunk;
|
||||||
|
|
||||||
output->assign(*input * *dropOutMultiplier); //input->applyPairwiseTransform(pairwise::Multiply, dropOutMultiplier.get(), output, nullptr);
|
output->assign(*input * *dropOutMultiplier); //input->applyPairwiseTransform(pairwise::Multiply, dropOutMultiplier.get(), output, nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -90,10 +90,10 @@ namespace helpers {
|
||||||
int dropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue) {
|
int dropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue) {
|
||||||
auto xType = input->dataType();
|
auto xType = input->dataType();
|
||||||
|
|
||||||
BUILD_SINGLE_SELECTOR(xType, return _dropOutFunctor, (context, input, output, reduceShape, seed, probValue), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(xType, return dropOutFunctor_, (context, input, output, reduceShape, seed, probValue), FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template int _dropOutFunctor, (graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue);, FLOAT_TYPES);
|
BUILD_SINGLE_TEMPLATE(template int dropOutFunctor_, (graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue);, FLOAT_TYPES);
|
||||||
|
|
||||||
/////////////////////////////////// backrpopagations ///////////////////////////////////////////////
|
/////////////////////////////////// backrpopagations ///////////////////////////////////////////////
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
|
|
@ -40,51 +40,75 @@ void gruCell(nd4j::LaunchContext * context, const NDArray* x, const NDArray* hLa
|
||||||
NDArray* r, NDArray* u, NDArray* c, NDArray* h) {
|
NDArray* r, NDArray* u, NDArray* c, NDArray* h) {
|
||||||
|
|
||||||
//Inputs:
|
//Inputs:
|
||||||
// x input [bS x inSize]
|
// x input [bS, nIn], nIn - input size
|
||||||
// hLast previous cell output [bS x numUnits], that is at previous time step t-1
|
// hLast previous cell output [bS, nUn], that is at previous time step t-1, nUn - number of units
|
||||||
// Wru RU weights - [bS, 2*numUnits] - reset and update gates
|
// Wru RU weights - [nIn+nUn, 2*nUn] - reset and update gates
|
||||||
// Wc C weights - [bS, numUnits] - cell gate
|
// Wc C weights - [nIn+nUn, nUn] - cell gate
|
||||||
// bru r and u biases, [2*numUnits] - reset and update gates
|
// bru r and u biases, [2*nUn] - reset and update gates
|
||||||
// bc c biases, [numUnits] - cell gate
|
// bc c biases, [nUn] - cell gate
|
||||||
|
|
||||||
//Outputs:
|
//Outputs:
|
||||||
// r Reset gate output [bS, numUnits]
|
// r Reset gate output [bS, nUn]
|
||||||
// u Update gate output [bS, numUnits]
|
// u Update gate output [bS, nUn]
|
||||||
// c Cell gate output [bS, numUnits]
|
// c Cell gate output [bS, nUn]
|
||||||
// h current cell output [bS, numUnits]
|
// h current cell output [bS, nUn]
|
||||||
|
|
||||||
|
/***************************************************************************************/
|
||||||
|
/************************ THIS IS NOT OPTIMAZED CODE ***********************************/
|
||||||
|
/** however it is more math-friendly and convenient for backprop formulas derivation) **/
|
||||||
|
|
||||||
|
const int bS = x->sizeAt(0);
|
||||||
const int nIn = x->sizeAt(1);
|
const int nIn = x->sizeAt(1);
|
||||||
const int nU = hLast->sizeAt(1); // number of units
|
const int nUn = hLast->sizeAt(1);
|
||||||
|
|
||||||
//Concat inputs: [x, yt-1]: concat([bs,nIn],[bs,nOut]) -> [bs, (nIn+nOut)]
|
NDArray Wr = (*Wru)({0,nIn, 0,0}); // reset gates weights [nIn, 2*nUn]
|
||||||
nd4j::ops::concat concatOp;
|
NDArray Wu = (*Wru)({nIn,nIn+nUn, 0,0}); // updates gates weights [nUn, 2*nUn]
|
||||||
std::vector<NDArray*> inputs;
|
|
||||||
std::vector<double> targs;
|
|
||||||
std::vector<Nd4jLong> iargs({1}); //Axis = 1
|
|
||||||
std::vector<bool> bargs;
|
|
||||||
inputs.emplace_back(const_cast<NDArray*>(x));
|
|
||||||
inputs.emplace_back(const_cast<NDArray*>(hLast));
|
|
||||||
|
|
||||||
auto result = concatOp.execute(inputs, targs, iargs, bargs);
|
NDArray Wcr = (*Wc)({0,nIn, 0,0}); // reset cell weights [nIn, nUn]
|
||||||
auto concatOut = result->at(0);
|
NDArray Wcu = (*Wc)({nIn,nIn+nUn, 0,0}); // updates cell weights [nUn, nUn]
|
||||||
|
|
||||||
//mmul/z for reset and update gates: (x * weight_ux + hLast * weight_xr + b_u)
|
// gates = sigmoid(x*Wr + hLast*Wu + br + bu)
|
||||||
auto m = mmul(*concatOut, *Wru); //mmul: [bs, (nIn+numUnits)]* [(inSize+numUnits), 2*numUnits] = [bs, 4*numUnits]
|
NDArray gates = mmul(*x, Wr) + mmul(*hLast, Wu) + *bru; // [bS, nIn] * [nIn, 2*nUn] + [bS, nUn] * [nUn, 2*nUn] + [2*nUn] = [bS, 2*nUn]
|
||||||
m += (*bru);
|
gates.applyTransform(transform::Sigmoid);
|
||||||
|
|
||||||
|
// reset gate
|
||||||
|
r->assign(gates({0,0, 0,nUn})); // [bS, nUn]
|
||||||
|
|
||||||
|
// update gate
|
||||||
|
u->assign(gates({0,0, nUn,2*nUn})); // [bS, nUn]
|
||||||
|
|
||||||
|
// cell gate c = activation(x*Wcr + (r◦hlast)*Wcu + bc)
|
||||||
|
c->assign(mmul(*x, Wcr) + mmul(*r * *hLast, Wcu) + *bc); // [bS, nIn] * [nIn, nUn] + [bS, nUn] * [nUn, nUn] + [nUn] = [bS, nUn]
|
||||||
|
c->applyTransform(transform::Tanh);
|
||||||
|
|
||||||
|
// cell output
|
||||||
|
h->assign(*u * *hLast + (1.f - *u) * *c);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
/***************************************************************************************/
|
||||||
|
/********************** THIS MORE OPTIMAZED CODE (except concat ) **********************/
|
||||||
|
/***************************************************************************************/
|
||||||
|
/*
|
||||||
|
//Concat inputs: x + hLast : [bs, nIn + nUn]
|
||||||
|
NDArray xhConcat(x->ordering(), {bS, nIn + nUn}, x->dataType(), context); // concat([bs, nIn], [bs, nUn]) -> [bs, nIn + nUn]
|
||||||
|
helpers::concat(context, {const_cast<NDArray*>(x), const_cast<NDArray*>(hLast)}, xhConcat, {1});
|
||||||
|
|
||||||
|
//mmul for reset and update gates: (x * weight_ux + hLast * weight_xr + b_u)
|
||||||
|
auto m = mmul(xhConcat, *Wru) + *bru ; // [bs, nIn+nUn] * [nIn+nUn, 2*nUn] = [bs, 2*nUn]
|
||||||
|
// m += *bru;
|
||||||
|
|
||||||
sigmoidInplace(m); //sigmoid(rz) and sigmoid(uz)
|
sigmoidInplace(m); //sigmoid(rz) and sigmoid(uz)
|
||||||
auto mr = m({0,0, 0, nU});
|
|
||||||
auto mu = m({0,0, nU, 2*nU});
|
|
||||||
|
|
||||||
r->assign(&mr);
|
r->assign(m({0,0, 0, nUn}));
|
||||||
u->assign(&mu);
|
u->assign(m({0,0, nUn, 2*nUn}));
|
||||||
|
|
||||||
//Concatenated inputs: [x, yt-1 .* r]
|
// hLast = hLast * r
|
||||||
auto yr = (*concatOut)({0,0, nIn, nIn+nU});
|
xhConcat({0,0, nIn, nIn+nUn}) *= *r;
|
||||||
yr *= (*r);
|
|
||||||
|
|
||||||
//c = tanh(x * weight_cx + (hLast .* r) * weight_cr + b_c)
|
//c = tanh(x * weight_cx + (hLast .* r) * weight_cr + b_c)
|
||||||
MmulHelper::mmul(concatOut, const_cast<NDArray*>(Wc), c, 1.0, 0.0); //c = 1.0 * concatOut * Wc + 0.0 * c
|
MmulHelper::mmul(&xhConcat, Wc, c, 1.0, 0.0); //c = 1.0 * xhConcat * Wc + 0.0 * c
|
||||||
*c += *bc;
|
*c += *bc;
|
||||||
tanhInplace(*c);
|
tanhInplace(*c);
|
||||||
|
|
||||||
|
@ -94,135 +118,134 @@ void gruCell(nd4j::LaunchContext * context, const NDArray* x, const NDArray* hLa
|
||||||
auto temp = (1.0f - *u);
|
auto temp = (1.0f - *u);
|
||||||
temp *= (*c);
|
temp *= (*c);
|
||||||
(*h) += temp;
|
(*h) += temp;
|
||||||
|
*/
|
||||||
delete result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void gruTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h) {
|
void gruTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h) {
|
||||||
|
|
||||||
// x input [time, bS, iS]
|
// x input [time, bS, iS]
|
||||||
// h0 initial cell output (at time step = 0) [bS, nU]
|
// h0 initial cell output (at time step = 0) [bS, nUn]
|
||||||
// Wx input-to-hidden weights, [iS, 3*nU]
|
// Wx input-to-hidden weights, [iS, 3*nUn]
|
||||||
// Wh hidden-to-hidden weights, [nU, 3*nU]
|
// Wh hidden-to-hidden weights, [nUn, 3*nUn]
|
||||||
// b biases, [3*nU]
|
// b biases, [3*nUn]
|
||||||
|
|
||||||
// h is cell outputs at each time step [time, bS, nU]
|
// h is cell outputs at each time step [time, bS, nUn]
|
||||||
|
|
||||||
const int time = x->sizeAt(0);
|
const int time = x->sizeAt(0);
|
||||||
|
|
||||||
NDArray ht_1(*h0);
|
NDArray ht_1(*h0);
|
||||||
|
|
||||||
// loop through time steps
|
// loop through time steps
|
||||||
for (int t = 0; t < time; ++t) {
|
for (int t = 0; t < time; ++t) {
|
||||||
|
|
||||||
auto xt = (*x)({t,t+1, 0,0, 0,0});
|
auto xt = (*x)({t,t+1, 0,0, 0,0});
|
||||||
auto ht = (*h)({t,t+1, 0,0, 0,0});
|
auto ht = (*h)({t,t+1, 0,0, 0,0});
|
||||||
|
|
||||||
//helpers::gruCell(&xt, &ht_1, Wx, Wh, b, &ht);
|
//helpers::gruCell(&xt, &ht_1, Wx, Wh, b, &ht);
|
||||||
//ht_1.assign(ht);
|
//ht_1.assign(ht);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void gruCellBP(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh, const NDArray* dLdWx0,
|
void gruCellBP(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh, const NDArray* dLdWx0,
|
||||||
const NDArray* dLdWh0, const NDArray* dLdb0, NDArray* dLdx, NDArray* dLdh0, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb) {
|
const NDArray* dLdWh0, const NDArray* dLdb0, NDArray* dLdx, NDArray* dLdh0, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb) {
|
||||||
|
|
||||||
// x input [bS, iS]
|
// x input [bS, iS]
|
||||||
// h0 previous cell output [bS, nU], that is at previous time step t-1
|
// h0 previous cell output [bS, nUn], that is at previous time step t-1
|
||||||
// Wx input-to-hidden weights, [iS, 3*nU]
|
// Wx input-to-hidden weights, [iS, 3*nUn]
|
||||||
// Wh hidden-to-hidden weights, [nU, 3*nU]
|
// Wh hidden-to-hidden weights, [nUn, 3*nUn]
|
||||||
// b biases, [3*nU]
|
// b biases, [3*nUn]
|
||||||
// dLdh gradient wrt output, [bS,nU], that is epsilon_next
|
// dLdh gradient wrt output, [bS,nUn], that is epsilon_next
|
||||||
// dLdWx0 gradient wrt Wx at previous time step, [iS, 3*nU]
|
// dLdWx0 gradient wrt Wx at previous time step, [iS, 3*nUn]
|
||||||
// dLdWh0 gradient wrt Wh at previous time step, [nU, 3*nU]
|
// dLdWh0 gradient wrt Wh at previous time step, [nUn, 3*nUn]
|
||||||
// dLdb0 gradient wrt b at previous time step, [3*nU]
|
// dLdb0 gradient wrt b at previous time step, [3*nUn]
|
||||||
|
|
||||||
// dLdx gradient wrt x, [bS, iS], that is epsilon
|
// dLdx gradient wrt x, [bS, iS], that is epsilon
|
||||||
// dLdh0 gradient wrt h0, [bS, nU]
|
// dLdh0 gradient wrt h0, [bS, nUn]
|
||||||
// dLdWx gradient wrt Wx, [iS, 3*nU]
|
// dLdWx gradient wrt Wx, [iS, 3*nUn]
|
||||||
// dLdWh gradient wrt Wh, [nU, 3*nU]
|
// dLdWh gradient wrt Wh, [nUn, 3*nUn]
|
||||||
// dLdb gradient wrt b at previous time step, [3*nU]
|
// dLdb gradient wrt b at previous time step, [3*nUn]
|
||||||
|
|
||||||
// h is current cell output [bS, nU], that is at current time step t
|
// h is current cell output [bS, nUn], that is at current time step t
|
||||||
|
|
||||||
const int nU = h0->sizeAt(1);
|
const int nUn = h0->sizeAt(1);
|
||||||
|
|
||||||
// ***** feed forward step ***** //
|
// ***** feed forward step ***** //
|
||||||
// gates = sigmoid(x*Wx + h0*Wh + b)
|
// gates = sigmoid(x*Wx + h0*Wh + b)
|
||||||
auto gates = sigmoid(mmul(*x, (*Wx)({0,0, 0,2*nU})) + mmul(*h0, (*Wh)({0,0, 0,2*nU})) + (*b)({0,2*nU})); // [bS, 2*nU] + [bS, 2*nU] + [1, 2*nU] = [bS, 2*nU]
|
auto gates = sigmoid(mmul(*x, (*Wx)({0,0, 0,2*nUn})) + mmul(*h0, (*Wh)({0,0, 0,2*nUn})) + (*b)({0,2*nUn})); // [bS, 2*nUn] + [bS, 2*nUn] + [1, 2*nUn] = [bS, 2*nUn]
|
||||||
// reset gate
|
// reset gate
|
||||||
auto r = gates({0,0, 0, nU}); // [bS, nU]
|
auto r = gates({0,0, 0, nUn}); // [bS, nUn]
|
||||||
// update gate
|
// update gate
|
||||||
auto u = gates({0,0, nU, 2*nU}); // [bS, nU]
|
auto u = gates({0,0, nUn, 2*nUn}); // [bS, nUn]
|
||||||
// ◦ means element-wise product or so called Hadamard product
|
// ◦ means element-wise product or so called Hadamard product
|
||||||
// n = tanh(x*Wx + (r◦h0)*Wh + b)
|
// n = tanh(x*Wx + (r◦h0)*Wh + b)
|
||||||
auto n = tanh(mmul(*x, (*Wx)({0,0, 2*nU,3*nU})) + mmul((*h0)*r, (*Wh)({0,0, 2*nU,3*nU})) + (*b)({2*nU,3*nU})); // [bS, nU]
|
auto n = tanh(mmul(*x, (*Wx)({0,0, 2*nUn,3*nUn})) + mmul((*h0)*r, (*Wh)({0,0, 2*nUn,3*nUn})) + (*b)({2*nUn,3*nUn})); // [bS, nUn]
|
||||||
|
|
||||||
// ***** back prop step ***** //
|
// ***** back prop step ***** //
|
||||||
auto Wxr = (*Wx)({0,0, 0, nU});
|
auto Wxr = (*Wx)({0,0, 0, nUn});
|
||||||
auto Wxu = (*Wx)({0,0, nU, 2*nU});
|
auto Wxu = (*Wx)({0,0, nUn, 2*nUn});
|
||||||
auto Wxn = (*Wx)({0,0, 2*nU,3*nU});
|
auto Wxn = (*Wx)({0,0, 2*nUn,3*nUn});
|
||||||
auto Whr = (*Wh)({0,0, 0, nU});
|
auto Whr = (*Wh)({0,0, 0, nUn});
|
||||||
auto Whu = (*Wh)({0,0, nU, 2*nU});
|
auto Whu = (*Wh)({0,0, nUn, 2*nUn});
|
||||||
auto Whn = (*Wh)({0,0, 2*nU,3*nU});
|
auto Whn = (*Wh)({0,0, 2*nUn,3*nUn});
|
||||||
auto WxrT = Wxr.transpose();
|
auto WxrT = Wxr.transpose();
|
||||||
auto WxuT = Wxu.transpose();
|
auto WxuT = Wxu.transpose();
|
||||||
auto WxnT = Wxn.transpose();
|
auto WxnT = Wxn.transpose();
|
||||||
auto WhrT = Whr.transpose();
|
auto WhrT = Whr.transpose();
|
||||||
auto WhuT = Whu.transpose();
|
auto WhuT = Whu.transpose();
|
||||||
auto WhnT = Whn.transpose();
|
auto WhnT = Whn.transpose();
|
||||||
auto xT = x->transpose();
|
auto xT = x->transpose();
|
||||||
auto h0T = h0->transpose();
|
auto h0T = h0->transpose();
|
||||||
|
|
||||||
auto dLdWxr = (*dLdWx)({0,0, 0, nU});
|
auto dLdWxr = (*dLdWx)({0,0, 0, nUn});
|
||||||
auto dLdWxu = (*dLdWx)({0,0, nU, 2*nU});
|
auto dLdWxu = (*dLdWx)({0,0, nUn, 2*nUn});
|
||||||
auto dLdWxn = (*dLdWx)({0,0, 2*nU,3*nU});
|
auto dLdWxn = (*dLdWx)({0,0, 2*nUn,3*nUn});
|
||||||
|
|
||||||
auto dLdWhr = (*dLdWh)({0,0, 0, nU});
|
auto dLdWhr = (*dLdWh)({0,0, 0, nUn});
|
||||||
auto dLdWhu = (*dLdWh)({0,0, nU, 2*nU});
|
auto dLdWhu = (*dLdWh)({0,0, nUn, 2*nUn});
|
||||||
auto dLdWhn = (*dLdWh)({0,0, 2*nU,3*nU});
|
auto dLdWhn = (*dLdWh)({0,0, 2*nUn,3*nUn});
|
||||||
|
|
||||||
auto dLdbr = (*dLdb)({0, nU});
|
auto dLdbr = (*dLdb)({0, nUn});
|
||||||
auto dLdbu = (*dLdb)({nU, 2*nU});
|
auto dLdbu = (*dLdb)({nUn, 2*nUn});
|
||||||
auto dLdbn = (*dLdb)({2*nU,3*nU});
|
auto dLdbn = (*dLdb)({2*nUn,3*nUn});
|
||||||
|
|
||||||
auto dhdu = *h0 - n; // [bS, nU]
|
auto dhdu = *h0 - n; // [bS, nUn]
|
||||||
auto dhdn = 1.f - u; // [bS, nU]
|
auto dhdn = 1.f - u; // [bS, nUn]
|
||||||
auto dSigdu = u * (1.f - u); // [bS, nU]
|
auto dSigdu = u * (1.f - u); // [bS, nUn]
|
||||||
auto dSigdr = r * (1.f - r); // [bS, nU]
|
auto dSigdr = r * (1.f - r); // [bS, nUn]
|
||||||
auto dActdn = 1.f - n * n; // [bS, nU]
|
auto dActdn = 1.f - n * n; // [bS, nUn]
|
||||||
auto dndr = mmul(dActdn * (*h0), WhnT);
|
auto dndr = mmul(dActdn * (*h0), WhnT);
|
||||||
auto drdh0 = mmul(dSigdr, WhrT);
|
auto drdh0 = mmul(dSigdr, WhrT);
|
||||||
|
|
||||||
auto dLdn = (*dLdh) * dhdn;
|
auto dLdn = (*dLdh) * dhdn;
|
||||||
auto dLdu = (*dLdh) * dhdu;
|
auto dLdu = (*dLdh) * dhdu;
|
||||||
auto dLdr = dLdn * dndr;
|
auto dLdr = dLdn * dndr;
|
||||||
|
|
||||||
dLdx->assign( mmul(dLdu * dSigdu, WxuT) + mmul(dLdr * dSigdr, WxrT) + mmul(dLdn * dActdn, WxnT) ); // [bS,iS]
|
dLdx->assign( mmul(dLdu * dSigdu, WxuT) + mmul(dLdr * dSigdr, WxrT) + mmul(dLdn * dActdn, WxnT) ); // [bS,iS]
|
||||||
dLdh0->assign( mmul(dLdu * dSigdu, WhuT) + mmul(dLdn * dActdn * (r + drdh0), WhnT) + (*dLdh)*u ); // [bS,nU]
|
dLdh0->assign( mmul(dLdu * dSigdu, WhuT) + mmul(dLdn * dActdn * (r + drdh0), WhnT) + (*dLdh)*u ); // [bS,nUn]
|
||||||
|
|
||||||
dLdWxr.assign( mmul(xT, dSigdr * dLdr) ); // [iS,nU]
|
dLdWxr.assign( mmul(xT, dSigdr * dLdr) ); // [iS,nUn]
|
||||||
dLdWhr.assign( mmul(h0T, dSigdr * dLdr) ); // [nU,nU]
|
dLdWhr.assign( mmul(h0T, dSigdr * dLdr) ); // [nUn,nUn]
|
||||||
|
|
||||||
dLdWxu.assign( mmul(xT, dSigdu * dLdu) ); // [iS,nU]
|
dLdWxu.assign( mmul(xT, dSigdu * dLdu) ); // [iS,nUn]
|
||||||
dLdWhu.assign( mmul(h0T, dSigdu * dLdu) ); // [nU,nU]
|
dLdWhu.assign( mmul(h0T, dSigdu * dLdu) ); // [nUn,nUn]
|
||||||
|
|
||||||
dLdWxn.assign( mmul(xT, dActdn * dLdn) ); // [iS,nU]
|
dLdWxn.assign( mmul(xT, dActdn * dLdn) ); // [iS,nUn]
|
||||||
dLdWhn.assign( mmul((r*(*h0)).transpose(), dActdn * dLdn) ); // [nU,nU]
|
dLdWhn.assign( mmul((r*(*h0)).transpose(), dActdn * dLdn) ); // [nUn,nUn]
|
||||||
|
|
||||||
dLdbr.assign( (dSigdr * dLdr).reduceAlongDims(reduce::Sum, {0})); // [nU]
|
dLdbr.assign( (dSigdr * dLdr).reduceAlongDims(reduce::Sum, {0})); // [nUn]
|
||||||
dLdbu.assign( (dSigdu * dLdu).reduceAlongDims(reduce::Sum, {0})); // [nU]
|
dLdbu.assign( (dSigdu * dLdu).reduceAlongDims(reduce::Sum, {0})); // [nUn]
|
||||||
dLdbn.assign( (dActdn * dLdn).reduceAlongDims(reduce::Sum, {0})); // [nU]
|
dLdbn.assign( (dActdn * dLdn).reduceAlongDims(reduce::Sum, {0})); // [nUn]
|
||||||
|
|
||||||
if(dLdWx0 != nullptr)
|
if(dLdWx0 != nullptr)
|
||||||
*dLdWx += *dLdWx0;
|
*dLdWx += *dLdWx0;
|
||||||
|
|
||||||
if(dLdWh0 != nullptr)
|
if(dLdWh0 != nullptr)
|
||||||
*dLdWh += *dLdWh0;
|
*dLdWh += *dLdWh0;
|
||||||
|
|
||||||
if(dLdb0 != nullptr)
|
if(dLdb0 != nullptr)
|
||||||
*dLdb += *dLdb0;
|
*dLdb += *dLdb0;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -232,24 +255,24 @@ if(dLdb0 != nullptr)
|
||||||
// void gruTimeLoopBP(const std::vector<NDArray<T>*>& inArrs, const std::vector<NDArray<T>*>& outArrs) {
|
// void gruTimeLoopBP(const std::vector<NDArray<T>*>& inArrs, const std::vector<NDArray<T>*>& outArrs) {
|
||||||
|
|
||||||
// NDArray<T>* x = inArrs[0]; // input [time, bS, iS]
|
// NDArray<T>* x = inArrs[0]; // input [time, bS, iS]
|
||||||
// NDArray<T>* hi = inArrs[1]; // previous/initial cell output [bS, nU], that is at previous time step t-1
|
// NDArray<T>* hi = inArrs[1]; // previous/initial cell output [bS, nUn], that is at previous time step t-1
|
||||||
// NDArray<T>* Wx = inArrs[2]; // input-to-hidden weights, [iS, 3*nU]
|
// NDArray<T>* Wx = inArrs[2]; // input-to-hidden weights, [iS, 3*nUn]
|
||||||
// NDArray<T>* Wh = inArrs[3]; // hidden-to-hidden weights, [nU, 3*nU]
|
// NDArray<T>* Wh = inArrs[3]; // hidden-to-hidden weights, [nUn, 3*nUn]
|
||||||
// NDArray<T>* b = inArrs[4]; // biases, [3*nU]
|
// NDArray<T>* b = inArrs[4]; // biases, [3*nUn]
|
||||||
// NDArray<T>* dLdh = inArrs[5]; // gradient wrt output, [time, bS, nU], that is epsilon_next
|
// NDArray<T>* dLdh = inArrs[5]; // gradient wrt output, [time, bS, nUn], that is epsilon_next
|
||||||
|
|
||||||
// NDArray<T>* dLdx = outArrs[0]; // gradient wrt x, [time, bS, iS], that is epsilon
|
// NDArray<T>* dLdx = outArrs[0]; // gradient wrt x, [time, bS, iS], that is epsilon
|
||||||
// NDArray<T>* dLdhi = outArrs[1]; // gradient wrt hi, [bS, nU]
|
// NDArray<T>* dLdhi = outArrs[1]; // gradient wrt hi, [bS, nUn]
|
||||||
// NDArray<T>* dLdWx = outArrs[2]; // gradient wrt Wx, [iS, 3*nU]
|
// NDArray<T>* dLdWx = outArrs[2]; // gradient wrt Wx, [iS, 3*nUn]
|
||||||
// NDArray<T>* dLdWh = outArrs[3]; // gradient wrt Wh, [nU, 3*nU]
|
// NDArray<T>* dLdWh = outArrs[3]; // gradient wrt Wh, [nUn, 3*nUn]
|
||||||
// NDArray<T>* dLdb = outArrs[4]; // gradient wrt b, [3*nU]
|
// NDArray<T>* dLdb = outArrs[4]; // gradient wrt b, [3*nUn]
|
||||||
|
|
||||||
// const Nd4jLong time = x->sizeAt(0);
|
// const Nd4jLong time = x->sizeAt(0);
|
||||||
// const Nd4jLong bS = x->sizeAt(1);
|
// const Nd4jLong bS = x->sizeAt(1);
|
||||||
// const Nd4jLong iS = x->sizeAt(2);
|
// const Nd4jLong iS = x->sizeAt(2);
|
||||||
// const Nd4jLong nU = hi->sizeAt(1);
|
// const Nd4jLong nUn = hi->sizeAt(1);
|
||||||
|
|
||||||
// NDArray<T> h(hi->ordering(), {time, bS, nU}); // feed forward output
|
// NDArray<T> h(hi->ordering(), {time, bS, nUn}); // feed forward output
|
||||||
|
|
||||||
// // first step, time = 0, feed forward
|
// // first step, time = 0, feed forward
|
||||||
// NDArray<T> x0 = (*x)({{0,1}, {}, {}});
|
// NDArray<T> x0 = (*x)({{0,1}, {}, {}});
|
||||||
|
|
|
@ -85,11 +85,11 @@ namespace helpers {
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_IF(n > Environment::getInstance()->elementwiseThreshold())
|
PRAGMA_OMP_PARALLEL_FOR_IF(n > Environment::getInstance()->elementwiseThreshold())
|
||||||
for (int i = 0; i < n; i++)
|
for (int i = 0; i < n; i++)
|
||||||
invertedMatrix->p(i, i, invertedMatrix->e<T>(i, i) / inputMatrix->e<T>(i, i));
|
invertedMatrix->t<T>(i, i) /= inputMatrix->t<T>(i, i);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_IF(n > Environment::getInstance()->elementwiseThreshold())
|
PRAGMA_OMP_PARALLEL_FOR_IF(n > Environment::getInstance()->elementwiseThreshold())
|
||||||
for (int i = 0; i < n - 1; i++)
|
for (int i = 0; i < n - 1; i++)
|
||||||
invertedMatrix->t<T>(i, i + 1) = invertedMatrix->t<T>(i, i+1) - (inputMatrix->t<T>(i, i + 1) * invertedMatrix->t<T>(i + 1, i + 1) / inputMatrix->t<T>(i, i));
|
invertedMatrix->t<T>(i, i + 1) -= (inputMatrix->t<T>(i, i + 1) * invertedMatrix->t<T>(i + 1, i + 1) / inputMatrix->t<T>(i, i));
|
||||||
|
|
||||||
// PRAGMA_OMP_PARALLEL_FOR_SIMD
|
// PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||||
for (int i = n - 2; i > - 1; i--) {
|
for (int i = n - 2; i > - 1; i--) {
|
||||||
|
@ -124,25 +124,25 @@ namespace helpers {
|
||||||
for(int i = 0; i < rowNum; i++ ) {
|
for(int i = 0; i < rowNum; i++ ) {
|
||||||
pivotValue = T(0.0);
|
pivotValue = T(0.0);
|
||||||
pivot = -1;
|
pivot = -1;
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR //_ARGS(firstprivate(pivot,pivotValue))
|
||||||
for(int rowCounter = i; rowCounter < rowNum; rowCounter++ ) {
|
for(int rowCounter = i; rowCounter < rowNum; rowCounter++ ) {
|
||||||
if(nd4j::math::nd4j_abs(compoundMatrix.e<T>(rowCounter, i)) > pivotValue ) {
|
if (nd4j::math::nd4j_abs(compoundMatrix.t<T>(rowCounter, i)) > pivotValue) {
|
||||||
pivotValue = nd4j::math::nd4j_abs(compoundMatrix.e<T>(rowCounter, i));
|
pivotValue = nd4j::math::nd4j_abs(compoundMatrix.t<T>(rowCounter, i));
|
||||||
pivot = rowCounter;
|
pivot = rowCounter;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if( pivotValue != T(0.0) ) {
|
if( pivotValue > T(0.00001)) {
|
||||||
swapRows(&compoundMatrix, pivot, i);
|
swapRows(&compoundMatrix, pivot, i);
|
||||||
swapRows(&permutationMatrix, pivot, i);
|
swapRows(&permutationMatrix, pivot, i);
|
||||||
if (pivot != i)
|
if (pivot != i)
|
||||||
swapCount++;
|
swapCount++;
|
||||||
|
|
||||||
for( int j = i + 1; j < rowNum; j++ ) {
|
for( int j = i + 1; j < rowNum; j++ ) {
|
||||||
compoundMatrix.p(j, i, compoundMatrix.e<T>(j, i) / compoundMatrix.e<T>(i, i));
|
compoundMatrix.t<T>(j, i) /= compoundMatrix.t<T>(i, i);
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR
|
||||||
for( int k = i + 1; k < rowNum; k++ ) {
|
for( int k = i + 1; k < rowNum; k++ ) {
|
||||||
T arg = compoundMatrix.e<T>(j, i) * compoundMatrix.e<T>(i, k);
|
compoundMatrix.t<T>(j, k) -= compoundMatrix.t<T>(j, i) * compoundMatrix.t<T>(i, k);
|
||||||
compoundMatrix.p(j, k, compoundMatrix.e<T>(j, k) - arg);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -188,7 +188,7 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
int log_abs_determinant_(NDArray* input, NDArray* output) {
|
int logAbsDeterminant_(NDArray* input, NDArray* output) {
|
||||||
|
|
||||||
Nd4jLong n = input->sizeAt(-1);
|
Nd4jLong n = input->sizeAt(-1);
|
||||||
Nd4jLong n2 = n * n;
|
Nd4jLong n2 = n * n;
|
||||||
|
@ -206,14 +206,14 @@ template <typename T>
|
||||||
return ND4J_STATUS_OK;
|
return ND4J_STATUS_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template int log_abs_determinant_, (NDArray* input, NDArray* output), FLOAT_TYPES);
|
BUILD_SINGLE_TEMPLATE(template int logAbsDeterminant_, (NDArray* input, NDArray* output), FLOAT_TYPES);
|
||||||
|
|
||||||
int log_abs_determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
int logAbsDeterminant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), return log_abs_determinant_, (input, output), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (input, output), FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static int _inverse(NDArray* input, NDArray* output) {
|
static int inverse_(NDArray* input, NDArray* output) {
|
||||||
|
|
||||||
auto n = input->sizeAt(-1);
|
auto n = input->sizeAt(-1);
|
||||||
auto n2 = n * n;
|
auto n2 = n * n;
|
||||||
|
@ -236,7 +236,7 @@ template <typename T>
|
||||||
T det = lup_<T>(&matrix, &compound, &permutation).template e<T>(0);
|
T det = lup_<T>(&matrix, &compound, &permutation).template e<T>(0);
|
||||||
|
|
||||||
// FIXME: and how this is going to work on float16?
|
// FIXME: and how this is going to work on float16?
|
||||||
if (nd4j::math::nd4j_abs<T>(det) < T(0.0000001)) {
|
if (nd4j::math::nd4j_abs<T>(det) < T(0.000001)) {
|
||||||
nd4j_printf("matrix_inverse: The matrix %i has no inverse due determinant is %lf. Quiting...\n", e, det);
|
nd4j_printf("matrix_inverse: The matrix %i has no inverse due determinant is %lf. Quiting...\n", e, det);
|
||||||
matrix.printIndexedBuffer("Wrong matrix");
|
matrix.printIndexedBuffer("Wrong matrix");
|
||||||
return ND4J_STATUS_VALIDATION;
|
return ND4J_STATUS_VALIDATION;
|
||||||
|
@ -244,12 +244,12 @@ template <typename T>
|
||||||
lowerMatrix.setIdentity(); // set up U to identity matrix
|
lowerMatrix.setIdentity(); // set up U to identity matrix
|
||||||
for (int k = 1; k < n; k++) { // and then put all values under main diagonal on to it
|
for (int k = 1; k < n; k++) { // and then put all values under main diagonal on to it
|
||||||
for (int j = 0; j < k; j++)
|
for (int j = 0; j < k; j++)
|
||||||
lowerMatrix.p(k, j, compound.template e<T>(k, j));
|
lowerMatrix.template t<T>(k, j) = compound.template t<T>(k, j);
|
||||||
}
|
}
|
||||||
upperMatrix.setIdentity(); // set up U to identity matrix
|
upperMatrix.setIdentity(); // set up U to identity matrix
|
||||||
for (int k = 0; k < n; k++) { // and then put all values under main diagonal on to it
|
for (int k = 0; k < n; k++) { // and then put all values under main diagonal on to it
|
||||||
for (int j = k; j < n; j++)
|
for (int j = k; j < n; j++)
|
||||||
upperMatrix.p(k, j, compound.template e<T>(k, j));
|
upperMatrix.template t<T>(k, j) = compound.template e<T>(k, j);
|
||||||
}
|
}
|
||||||
invertUpperMatrix(&upperMatrix, &matrix);
|
invertUpperMatrix(&upperMatrix, &matrix);
|
||||||
|
|
||||||
|
@ -258,7 +258,7 @@ template <typename T>
|
||||||
nd4j::MmulHelper::mmul(&matrix, &upperMatrix, &compound, 1.0, 0.0);
|
nd4j::MmulHelper::mmul(&matrix, &upperMatrix, &compound, 1.0, 0.0);
|
||||||
nd4j::MmulHelper::mmul(&compound, &permutation, &matrix, 1.0, 0.0);
|
nd4j::MmulHelper::mmul(&compound, &permutation, &matrix, 1.0, 0.0);
|
||||||
for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) {
|
for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) {
|
||||||
output->p(k, matrix.template e<T>(row++));
|
output->t<T>(k) = matrix.template t<T>(row++);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -266,7 +266,7 @@ template <typename T>
|
||||||
}
|
}
|
||||||
|
|
||||||
int inverse(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
int inverse(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), return _inverse, (input, output), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (input, output), FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@ -346,7 +346,7 @@ template <typename T>
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (input, output, inplace), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (input, output, inplace), FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
BUILD_SINGLE_TEMPLATE(template int cholesky_, (NDArray* input, NDArray* output, bool inplace), FLOAT_TYPES);
|
BUILD_SINGLE_TEMPLATE(template int cholesky_, (NDArray* input, NDArray* output, bool inplace), FLOAT_TYPES);
|
||||||
BUILD_SINGLE_TEMPLATE(template int _inverse, (NDArray* input, NDArray* output), FLOAT_TYPES);
|
BUILD_SINGLE_TEMPLATE(template int inverse_, (NDArray* input, NDArray* output), FLOAT_TYPES);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
int logdetFunctor_(NDArray* input, NDArray* output) {
|
int logdetFunctor_(NDArray* input, NDArray* output) {
|
||||||
|
|
|
@ -22,125 +22,122 @@
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <helpers/ShapeUtils.h>
|
#include <helpers/ShapeUtils.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
void scatter(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock) {
|
|
||||||
|
|
||||||
const int outRank = output.rankOf();
|
void scatter(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock) {
|
||||||
const int indRank = indices.rankOf();
|
|
||||||
const int updRank = updates.rankOf();
|
|
||||||
const Nd4jLong indLen = indices.lengthOf();
|
|
||||||
|
|
||||||
if(outRank == 1) {
|
const int outRank = output.rankOf();
|
||||||
|
const int indRank = indices.rankOf();
|
||||||
|
const int updRank = updates.rankOf();
|
||||||
|
const Nd4jLong indLen = indices.lengthOf();
|
||||||
|
|
||||||
|
if(outRank == 1) {
|
||||||
|
|
||||||
|
<<<<<<< HEAD
|
||||||
// PRAGMA_OMP_PARALLEL_FOR_ARGS(if(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
|
// PRAGMA_OMP_PARALLEL_FOR_ARGS(if(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
|
||||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided))
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided))
|
||||||
for(Nd4jLong i = 0; i < indLen; ++i) {
|
for(Nd4jLong i = 0; i < indLen; ++i) {
|
||||||
|
for(Nd4jLong i = 0; i < indLen; ++i) {
|
||||||
|
|
||||||
Nd4jLong idx = indices.e<Nd4jLong>(i);
|
Nd4jLong idx = indices.e<Nd4jLong>(i);
|
||||||
NDArray out = output({idx, idx+1});
|
NDArray out = output({idx, idx+1});
|
||||||
|
|
||||||
out.applyPairwiseTransform(op, updates.e(i), nullptr);
|
out.applyPairwiseTransform(op, updates.e(i), nullptr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else { // outRank > 1
|
else { // outRank > 1
|
||||||
|
|
||||||
int sizeOfDims = indRank;
|
int sizeOfDims = indRank;
|
||||||
if(outRank == updRank && indices.isVector())
|
if(outRank == updRank && indices.isVector())
|
||||||
sizeOfDims = 1;
|
sizeOfDims = 1;
|
||||||
|
|
||||||
std::vector<int> dimsToExcludeUpd(sizeOfDims);
|
std::vector<int> dimsToExcludeUpd(sizeOfDims);
|
||||||
std::iota(dimsToExcludeUpd.begin(), dimsToExcludeUpd.end(), 0);
|
std::iota(dimsToExcludeUpd.begin(), dimsToExcludeUpd.end(), 0);
|
||||||
|
|
||||||
// PRAGMA_OMP_PARALLEL_FOR_ARGS(if(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) // causes known openMP asan bug !
|
// PRAGMA_OMP_PARALLEL_FOR_ARGS(if(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) // causes known openMP asan bug !
|
||||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided))
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided))
|
||||||
for(Nd4jLong i = 0; i < indLen; ++i) {
|
for(Nd4jLong i = 0; i < indLen; ++i) {
|
||||||
|
|
||||||
NDArray outSubArr = output(indices.e<Nd4jLong>(i), std::vector<int>({0}));
|
NDArray outSubArr = output(indices.e<Nd4jLong>(i), std::vector<int>({0}));
|
||||||
NDArray updSubArr = updates(i, dimsToExcludeUpd);
|
NDArray updSubArr = updates(i, dimsToExcludeUpd);
|
||||||
|
|
||||||
outSubArr.applyPairwiseTransform(op, updSubArr, nullptr);
|
outSubArr.applyPairwiseTransform(op, updSubArr, nullptr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
void scatterND(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock) {
|
void scatterND(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock) {
|
||||||
|
|
||||||
const Nd4jLong indLen = indices.lengthOf();
|
const Nd4jLong indLen = indices.lengthOf();
|
||||||
const int outRank = output.rankOf();
|
const int outRank = output.rankOf();
|
||||||
const int indRank = indices.rankOf();
|
const int indRank = indices.rankOf();
|
||||||
const Nd4jLong indLastDim = indices.sizeAt(-1);
|
const Nd4jLong indLastDim = indices.sizeAt(-1);
|
||||||
|
|
||||||
if(outRank == 1) {
|
if(outRank == 1) {
|
||||||
|
|
||||||
// PRAGMA_OMP_PARALLEL_FOR_ARGS(if(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
|
// PRAGMA_OMP_PARALLEL_FOR_ARGS(if(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided))
|
||||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided))
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided))
|
||||||
for(Nd4jLong i = 0; i < indLen; ++i) {
|
for(Nd4jLong i = 0; i < indLen; ++i) {
|
||||||
|
|
||||||
Nd4jLong idx = indices.e<Nd4jLong>(i);
|
Nd4jLong idx = indices.e<Nd4jLong>(i);
|
||||||
NDArray out = output({idx, idx+1});
|
NDArray out = output({idx, idx+1});
|
||||||
|
|
||||||
out.applyPairwiseTransform(op, updates.e(i), nullptr);
|
out.applyPairwiseTransform(op, updates.e(i), nullptr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
std::vector<int> dimsToExcludeInd = ShapeUtils::evalDimsToExclude(indRank, {indRank-1});
|
std::vector<int> dimsToExcludeInd = ShapeUtils::evalDimsToExclude(indRank, {indRank-1});
|
||||||
std::vector<int> dimsToExcludeUpd(indRank - 1);
|
std::vector<int> dimsToExcludeUpd(indRank - 1);
|
||||||
std::iota(dimsToExcludeUpd.begin(), dimsToExcludeUpd.end(), 0);
|
std::iota(dimsToExcludeUpd.begin(), dimsToExcludeUpd.end(), 0);
|
||||||
std::vector<Nd4jLong> idxRangeOut(2*outRank, 0);
|
std::vector<Nd4jLong> idxRangeOut(2*outRank, 0);
|
||||||
|
|
||||||
// PRAGMA_OMP_PARALLEL_FOR_ARGS(if(indLen/indLastDim > Environment::getInstance()->elementwiseThreshold()) schedule(guided) firstprivate(idxRangeOut))
|
// PRAGMA_OMP_PARALLEL_FOR_ARGS(if(indLen/indLastDim > Environment::getInstance()->elementwiseThreshold()) schedule(guided) firstprivate(idxRangeOut))
|
||||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided) firstprivate(idxRangeOut))
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided) firstprivate(idxRangeOut))
|
||||||
for(Nd4jLong i = 0; i < indLen/indLastDim; ++i) {
|
for(Nd4jLong i = 0; i < indLen/indLastDim; ++i) {
|
||||||
|
|
||||||
NDArray indSubArr = indices(i, dimsToExcludeInd);
|
NDArray indSubArr = indices(i, dimsToExcludeInd);
|
||||||
|
|
||||||
for(Nd4jLong j = 0; j < indLastDim; ++j) {
|
for(Nd4jLong j = 0; j < indLastDim; ++j) {
|
||||||
idxRangeOut[2*j] = indSubArr.e<Nd4jLong>(j);
|
idxRangeOut[2*j] = indSubArr.e<Nd4jLong>(j);
|
||||||
idxRangeOut[2*j + 1] = idxRangeOut[2*j] + 1;
|
idxRangeOut[2*j + 1] = idxRangeOut[2*j] + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
NDArray outSubArr = output(idxRangeOut);
|
NDArray outSubArr = output(idxRangeOut);
|
||||||
NDArray updSubArr = updates(i, dimsToExcludeUpd);
|
NDArray updSubArr = updates(i, dimsToExcludeUpd);
|
||||||
|
|
||||||
outSubArr.applyPairwiseTransform(op, updSubArr, nullptr);
|
outSubArr.applyPairwiseTransform(op, updSubArr, nullptr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void scatterForLoss(nd4j::LaunchContext *context, const NDArray& indices, NDArray& updates, NDArray& output, const bool calcGrad) {
|
||||||
|
|
||||||
void scatterForLoss(nd4j::LaunchContext *context, const NDArray& indices, const NDArray& updates, NDArray& output, const bool calcGrad) {
|
// shapes of indices and output must be the same
|
||||||
// requirements for arrays
|
// shape of indices should be the same as updates shape with last dimension excluded
|
||||||
// shapes of updates and output must be the same
|
// for example if updates is {a,b,c} then indices should be {a,b}
|
||||||
// shape of indices should be the same as updates shape with last dimension excluded
|
|
||||||
// for example if updates is {a,b,c} then indices should be {a,b}
|
const Nd4jLong indicesLen = indices.lengthOf();
|
||||||
|
|
||||||
const Nd4jLong indicesLen = indices.lengthOf();
|
std::vector<int> dimsToExclude = ShapeUtils::evalDimsToExclude(updates.rankOf(), {-1});
|
||||||
|
|
||||||
std::vector<int> dimsToExclude = ShapeUtils::evalDimsToExclude(updates.rankOf(), {-1});
|
if(!calcGrad) {
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(schedule(guided))
|
||||||
if(!calcGrad) {
|
for(Nd4jLong i = 0; i < indicesLen; ++i) {
|
||||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(schedule(guided))
|
|
||||||
for(Nd4jLong i = 0; i < indicesLen; ++i) {
|
auto subArr = updates(i, dimsToExclude);
|
||||||
|
output.p(i, subArr.e(indices.e<Nd4jLong>(i)));
|
||||||
auto subArr = updates(i, dimsToExclude);
|
}
|
||||||
output.p(i, subArr.e(indices.e<Nd4jLong>(i)));
|
} else {
|
||||||
}
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(schedule(guided))
|
||||||
}
|
for(Nd4jLong i = 0; i < indicesLen; ++i) {
|
||||||
else {
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(schedule(guided))
|
auto subArr = updates(i, dimsToExclude);
|
||||||
for(Nd4jLong i = 0; i < indicesLen; ++i) {
|
auto ind = indices.e<Nd4jLong>(i);
|
||||||
|
subArr.p(ind, subArr.e(ind) - 1.);
|
||||||
auto subArr = updates(i, dimsToExclude);
|
|
||||||
auto ind = indices.e<Nd4jLong>(i);
|
|
||||||
subArr.p(ind, subArr.e(ind) - 1.);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
|
|
||||||
#include<ops/declarable/helpers/sru.h>
|
#include<ops/declarable/helpers/sru.h>
|
||||||
#include <NDArrayFactory.h>
|
#include <NDArrayFactory.h>
|
||||||
|
#include <MmulHelper.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -108,27 +109,27 @@ void sruTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray*
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static void sruBI_(NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* mask, NDArray* ht, NDArray* ct) {
|
static void sruBI_(NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* mask, NDArray* ht, NDArray* ct) {
|
||||||
|
|
||||||
// x input 3d tensor [time x bS x 2*inSize], time - number of time steps, bS - batch size, inSize - number of features
|
// x input 3d tensor [time x bS x 2*K], time - number of time steps, bS - batch size, K - number of features
|
||||||
// w 2d tensor of weights [2*inSize x 6*inSize]
|
// w 2d tensor of weights [2*K x 6*K]
|
||||||
// b row of biases with twice length [1 x 4*inSize]
|
// b row of biases with twice length [4*K]
|
||||||
// c0 2d tensor of initial state [bS x 2*inSize] at time t=0
|
// c0 2d tensor of initial state [bS x 2*K] at time t=0
|
||||||
// mask optional, 2d tensor of dropout mask [bS x 2*inSize]
|
// mask optional, 2d tensor of dropout mask [bS x 2*K]
|
||||||
|
|
||||||
// ht [time x bS x 2*inSize]
|
// ht [time x bS x 2*K]
|
||||||
// ct [time x bS x 2*inSize]
|
// ct [time x bS x 2*K]
|
||||||
|
|
||||||
const Nd4jLong time = x->sizeAt(0); // time - number of time steps
|
const Nd4jLong time = x->sizeAt(0); // time - number of time steps
|
||||||
const Nd4jLong bS = x->sizeAt(1); // bS - batch size
|
const Nd4jLong bS = x->sizeAt(1); // bS - batch size
|
||||||
const Nd4jLong inSize = x->sizeAt(2) / 2; // inSize - number of features
|
const Nd4jLong K = x->sizeAt(2) / 2; // K - number of features
|
||||||
|
|
||||||
// x = x * mask
|
// x = x * mask
|
||||||
if(mask)
|
if(mask)
|
||||||
x->applyBroadcast(broadcast::Multiply, {1, 2}, mask, x, nullptr); // apply mask
|
x->applyBroadcast(broadcast::Multiply, {1, 2}, mask, x, nullptr); // apply mask
|
||||||
|
|
||||||
// U = x * w
|
// U = x * w
|
||||||
NDArray wi = mmul(*x, *w); // U [time x bS x 6*inSize]
|
NDArray wi = mmul(*x, *w); // U [time x bS x 6*K]
|
||||||
|
|
||||||
const Nd4jLong d2 = 2*inSize;
|
const Nd4jLong d2 = 2*K;
|
||||||
const Nd4jLong ncols = bS*d2;
|
const Nd4jLong ncols = bS*d2;
|
||||||
const Nd4jLong ncolsWi = 3*ncols;
|
const Nd4jLong ncolsWi = 3*ncols;
|
||||||
|
|
||||||
|
@ -140,42 +141,39 @@ static void sruBI_(NDArray* x, const NDArray* w, const NDArray* b, const NDArray
|
||||||
T* pHt = ht->bufferAsT<T>();
|
T* pHt = ht->bufferAsT<T>();
|
||||||
T* pCt = ct->bufferAsT<T>();
|
T* pCt = ct->bufferAsT<T>();
|
||||||
|
|
||||||
Nd4jLong ncolsRev, ncolsWiRev; // for reverse direction
|
PRAGMA_OMP_PARALLEL_FOR
|
||||||
T maskVal, cur, bF, bR, ft, rt, val;
|
|
||||||
T *pIVal(nullptr), *pWiVal(nullptr), *pHtVal(nullptr), *pCtVal(nullptr);
|
|
||||||
bool flip = false;
|
|
||||||
|
|
||||||
for (Nd4jLong col = 0; col < ncols; ++col) {
|
for (Nd4jLong col = 0; col < ncols; ++col) {
|
||||||
|
|
||||||
const auto colNum = col % d2;
|
const auto colNum = col % d2;
|
||||||
flip = colNum >= inSize;
|
bool flip = colNum >= K;
|
||||||
maskVal = mask ? *(pMask + col) : T(1);
|
T maskVal = mask ? *(pMask + col) : T(1);
|
||||||
cur = *(pInit + col);
|
T cur = *(pInit + col);
|
||||||
bF = *(pBias + colNum);
|
T bF = *(pBias + colNum);
|
||||||
bR = *(pBias + colNum + d2);
|
T bR = *(pBias + colNum + d2);
|
||||||
pWiVal = pWi + 3*col;
|
T* pWiVal = pWi + 3*col;
|
||||||
pIVal = pI + col;
|
T* pIVal = pI + col;
|
||||||
pHtVal = pHt + col;
|
T* pHtVal = pHt + col;
|
||||||
pCtVal = pCt + col;
|
T* pCtVal = pCt + col;
|
||||||
|
|
||||||
if (flip) {
|
if (flip) {
|
||||||
pIVal += (time-1)*ncols;
|
const auto step = (time - 1) * ncols;
|
||||||
pWiVal += (time-1)*ncolsWi;
|
pIVal += step;
|
||||||
pHtVal += (time-1)*ncols;
|
pHtVal += step;
|
||||||
pCtVal += (time-1)*ncols;
|
pCtVal += step;
|
||||||
|
pWiVal += (time - 1) * ncolsWi;
|
||||||
}
|
}
|
||||||
|
|
||||||
ncolsRev = flip ? -ncols : ncols;
|
auto ncolsRev = flip ? -ncols : ncols;
|
||||||
ncolsWiRev = flip ? -ncolsWi : ncolsWi;
|
auto ncolsWiRev = flip ? -ncolsWi : ncolsWi;
|
||||||
|
|
||||||
for (Nd4jLong t = 0; t < time; ++t) {
|
for (Nd4jLong t = 0; t < time; ++t) {
|
||||||
// evaluate sigmoids
|
// evaluate sigmoids
|
||||||
ft = (1.)/(1. + nd4j::math::nd4j_exp<T, T>(-(*(pWiVal + 1) + bF)));
|
T ft = (1.)/(1. + nd4j::math::nd4j_exp<T, T>(-(pWiVal[1] + bF)));
|
||||||
rt = (1.)/(1. + nd4j::math::nd4j_exp<T, T>(-(*(pWiVal + 2) + bR)));
|
T rt = (1.)/(1. + nd4j::math::nd4j_exp<T, T>(-(pWiVal[2] + bR)));
|
||||||
|
|
||||||
cur = (cur - *pWiVal)*ft + *pWiVal;
|
cur = (cur - *pWiVal)*ft + *pWiVal;
|
||||||
*pCtVal = cur;
|
*pCtVal = cur;
|
||||||
val = nd4j::math::nd4j_tanh<T, T>(cur);
|
T val = nd4j::math::nd4j_tanh<T, T>(cur);
|
||||||
*pHtVal = (val*maskVal - *pIVal)*rt + *pIVal;
|
*pHtVal = (val*maskVal - *pIVal)*rt + *pIVal;
|
||||||
|
|
||||||
pIVal += ncolsRev;
|
pIVal += ncolsRev;
|
||||||
|
@ -191,34 +189,34 @@ template <typename T>
|
||||||
static void sruBIBP_(NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* ct, const NDArray* inGradC0, const NDArray* inGradHt, const NDArray* mask,
|
static void sruBIBP_(NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* ct, const NDArray* inGradC0, const NDArray* inGradHt, const NDArray* mask,
|
||||||
NDArray* gradI, NDArray* gradW, NDArray* gradB, NDArray* gradC0) {
|
NDArray* gradI, NDArray* gradW, NDArray* gradB, NDArray* gradC0) {
|
||||||
|
|
||||||
// x input 3d tensor [time x bS x 2*inSize], time - number of time steps, bS - batch size, inSize - number of features
|
// x input 3d tensor [time x bS x 2*K], time - number of time steps, bS - batch size, K - number of features
|
||||||
// w 2d tensor of weights [2*inSize x 6*inSize]
|
// w 2d tensor of weights [2*K x 6*K]
|
||||||
// b row of biases with twice length [1 x 4*inSize]
|
// b row of biases with twice length 4*K]
|
||||||
// c0 2d tensor of initial state [bS x 2*inSize] at time t=0
|
// c0 2d tensor of initial state [bS x 2*K] at time t=0
|
||||||
// ct [time x bS x 2*inSize]
|
// ct [time x bS x 2*K]
|
||||||
// inGradC0 [bS x 2*inSize]
|
// inGradC0 [bS x 2*K]
|
||||||
// inGradHt [time x bS x 2*inSize]
|
// inGradHt [time x bS x 2*K]
|
||||||
// mask optional, 2d tensor of dropout mask [bS x 2*inSize]
|
// mask optional, 2d tensor of dropout mask [bS x 2*K]
|
||||||
|
|
||||||
// gradI [time x bS x 2*inSize]
|
// gradI [time x bS x 2*K]
|
||||||
// gradW [time x 2*inSize x 6*inSize]
|
// gradW [time x 2*K x 6*K]
|
||||||
// gradB [1 x 4*inSize]
|
// gradB [4*K]
|
||||||
// gradC0 [bS x 2*inSize]
|
// gradC0 [bS x 2*K]
|
||||||
|
|
||||||
const Nd4jLong time = x->sizeAt(0); // time - number of time steps
|
const Nd4jLong time = x->sizeAt(0); // time - number of time steps
|
||||||
const Nd4jLong bS = x->sizeAt(1);
|
const Nd4jLong bS = x->sizeAt(1);
|
||||||
const Nd4jLong inSize = x->sizeAt(2) / 2;
|
const Nd4jLong K = x->sizeAt(2) / 2;
|
||||||
|
|
||||||
// x = x * mask
|
// x = x * mask
|
||||||
if(mask)
|
if(mask)
|
||||||
x->applyBroadcast(broadcast::Multiply, {1, 2}, mask, x, nullptr); // apply mask
|
x->applyBroadcast(broadcast::Multiply, {1, 2}, mask, x, nullptr); // apply mask
|
||||||
|
|
||||||
// U = x * w
|
// U = x * w
|
||||||
NDArray wi = mmul(*x, *w); // [time x bS x 2*inSize] * [2*inSize x 6*inSize] = [time x bS x 6*inSize]
|
NDArray wi = mmul(*x, *w); // [time x bS x 2*K] * [2*K x 6*K] = [time x bS x 6*K]
|
||||||
NDArray gradBias(x->ordering(), {bS, 4*inSize}, x->dataType(), x->getContext());
|
NDArray gradBias(x->ordering(), {bS, 4*K}, x->dataType(), x->getContext());
|
||||||
NDArray gradWi (x->ordering(), {time, bS, 6*inSize}, x->dataType(), x->getContext());
|
NDArray gradWi (x->ordering(), {time, bS, 6*K}, x->dataType(), x->getContext());
|
||||||
|
|
||||||
const Nd4jLong d2 = 2*inSize;
|
const Nd4jLong d2 = 2*K;
|
||||||
const Nd4jLong ncols = bS*d2;
|
const Nd4jLong ncols = bS*d2;
|
||||||
const Nd4jLong ncolsWi = 3*ncols;
|
const Nd4jLong ncolsWi = 3*ncols;
|
||||||
T* pInput = x->bufferAsT<T>();
|
T* pInput = x->bufferAsT<T>();
|
||||||
|
@ -234,59 +232,61 @@ static void sruBIBP_(NDArray* x, const NDArray* w, const NDArray* b, const NDArr
|
||||||
T* pGradBias = gradBias.bufferAsT<T>();
|
T* pGradBias = gradBias.bufferAsT<T>();
|
||||||
T* pGradInit = gradC0->bufferAsT<T>();
|
T* pGradInit = gradC0->bufferAsT<T>();
|
||||||
|
|
||||||
Nd4jLong ncolsRev, ncolsWiRev; // for reverse direction
|
PRAGMA_OMP_PARALLEL_FOR
|
||||||
T gbF, gbR, cur, maskVal, bF, bR, ft, rt, val, prevVal, gft, grt, gradSateVal;
|
|
||||||
bool flip = false;
|
|
||||||
T *pInputVal(nullptr), *pWiVal(nullptr), *pStateVal(nullptr), *pInGradHtVal(nullptr), *pGradWiVal(nullptr), *pGradInputVal(nullptr);
|
|
||||||
|
|
||||||
for (Nd4jLong col = 0; col < ncols; ++col) {
|
for (Nd4jLong col = 0; col < ncols; ++col) {
|
||||||
gbF = gbR = (T)0.;
|
T gbF = 0.f;
|
||||||
|
T gbR = 0.f;
|
||||||
const auto colNum = col % d2;
|
const auto colNum = col % d2;
|
||||||
flip = colNum >= inSize;
|
const bool flip = colNum >= K;
|
||||||
maskVal = mask ? *(pMask + col) : T(1.);
|
T maskVal = mask ? *(pMask + col) : T(1.);
|
||||||
cur = *(pInGradCt + col);
|
T cur = *(pInGradCt + col);
|
||||||
bF = *(pBias + colNum);
|
T bF = *(pBias + colNum);
|
||||||
bR = *(pBias + colNum + d2);
|
T bR = *(pBias + colNum + d2);
|
||||||
pWiVal = pWi + 3*col;
|
T* pWiVal = pWi + 3*col;
|
||||||
pInputVal = pInput + col;
|
T* pInputVal = pInput + col;
|
||||||
pStateVal = pState + col;
|
T* pStateVal = pState + col;
|
||||||
pInGradHtVal = pInGradHt + col;
|
T* pInGradHtVal = pInGradHt + col;
|
||||||
pGradWiVal = pGradWi + 3*col;
|
T* pGradWiVal = pGradWi + 3*col;
|
||||||
pGradInputVal = pGradInput + col;
|
T* pGradInputVal = pGradInput + col;
|
||||||
|
|
||||||
if (!flip) {
|
if (!flip) {
|
||||||
pInputVal += (time-1)*ncols;
|
const auto stepI = (time - 1) * ncols;
|
||||||
pWiVal += (time-1)*ncolsWi;
|
const auto stepW = (time - 1) * ncolsWi;
|
||||||
pStateVal += (time-1)*ncols;
|
pInputVal += stepI;
|
||||||
pInGradHtVal += (time-1)*ncols;
|
pStateVal += stepI;
|
||||||
pGradWiVal += (time-1)*ncolsWi;
|
pInGradHtVal += stepI;
|
||||||
pGradInputVal += (time-1)*ncols;
|
pGradInputVal += stepI;
|
||||||
|
pWiVal += stepW;
|
||||||
|
pGradWiVal += stepW;
|
||||||
}
|
}
|
||||||
ncolsRev = flip ? -ncols : ncols;
|
|
||||||
ncolsWiRev = flip ? -ncolsWi : ncolsWi;
|
Nd4jLong ncolsRev = flip ? -ncols : ncols;
|
||||||
|
Nd4jLong ncolsWiRev = flip ? -ncolsWi : ncolsWi;
|
||||||
|
|
||||||
for (Nd4jLong t = 0; t < time; ++t) {
|
for (Nd4jLong t = 0; t < time; ++t) {
|
||||||
// evaluate sigmoids
|
// evaluate sigmoids
|
||||||
ft = ((T)1.)/((T)1. + nd4j::math::nd4j_exp<T,T>(-(*(pWiVal + 1) + bF)));
|
T ft = ((T)1.)/((T)1. + nd4j::math::nd4j_exp<T,T>(-(*(pWiVal + 1) + bF)));
|
||||||
rt = ((T)1.)/((T)1. + nd4j::math::nd4j_exp<T,T>(-(*(pWiVal + 2) + bR)));
|
T rt = ((T)1.)/((T)1. + nd4j::math::nd4j_exp<T,T>(-(*(pWiVal + 2) + bR)));
|
||||||
|
|
||||||
val = nd4j::math::nd4j_tanh<T,T>(*pStateVal);
|
T val = nd4j::math::nd4j_tanh<T,T>(*pStateVal);
|
||||||
prevVal = (t < time-1) ? (*(pStateVal - ncolsRev)) : (*(pInit + col));
|
T prevVal = (t < time-1) ? (*(pStateVal - ncolsRev)) : (*(pInit + col));
|
||||||
// grad wrt input
|
// grad wrt input
|
||||||
*pGradInputVal = *pInGradHtVal - (*pInGradHtVal)*rt ;
|
*pGradInputVal = *pInGradHtVal - (*pInGradHtVal)*rt ;
|
||||||
// grad wrt rt, wiR and bR
|
// grad wrt rt, wiR and bR
|
||||||
grt = (*pInGradHtVal) * (val*maskVal - *pInputVal) * (rt - rt*rt);
|
T grt = (*pInGradHtVal) * (val*maskVal - *pInputVal) * (rt - rt*rt);
|
||||||
*(pGradWiVal + 2) = grt;
|
*(pGradWiVal + 2) = grt;
|
||||||
gbR += grt;
|
gbR += grt;
|
||||||
// grad wrt state
|
// grad wrt state
|
||||||
gradSateVal = (*pInGradHtVal) * maskVal * (rt - rt*val*val) + cur;
|
T gradSateVal = (*pInGradHtVal) * maskVal * (rt - rt*val*val) + cur;
|
||||||
// grad wrt wi0
|
// grad wrt wi0
|
||||||
*pGradWiVal = gradSateVal - gradSateVal*ft;
|
*pGradWiVal = gradSateVal - gradSateVal*ft;
|
||||||
// grad wrt ft, wi1, and bF
|
// grad wrt ft, wi1, and bF
|
||||||
gft = gradSateVal * (prevVal - *pWiVal) * (ft - ft*ft);
|
T gft = gradSateVal * (prevVal - *pWiVal) * (ft - ft*ft);
|
||||||
*(pGradWiVal + 1) = gft;
|
*(pGradWiVal + 1) = gft;
|
||||||
gbF += gft;
|
gbF += gft;
|
||||||
// grad wrt c_previous
|
// grad wrt c_previous
|
||||||
cur = gradSateVal * ft;
|
cur = gradSateVal * ft;
|
||||||
|
|
||||||
pInputVal -= ncolsRev;
|
pInputVal -= ncolsRev;
|
||||||
pWiVal -= ncolsWiRev;
|
pWiVal -= ncolsWiRev;
|
||||||
pStateVal -= ncolsRev;
|
pStateVal -= ncolsRev;
|
||||||
|
@ -300,11 +300,11 @@ static void sruBIBP_(NDArray* x, const NDArray* w, const NDArray* b, const NDArr
|
||||||
}
|
}
|
||||||
|
|
||||||
// gradB
|
// gradB
|
||||||
gradBias.reduceAlongDimension(reduce::Sum, gradB, {0}, false, true); // [1 x 4*inSize]
|
gradBias.reduceAlongDimension(reduce::Sum, gradB, {0}); // [4*K]
|
||||||
|
|
||||||
// gradW
|
// gradW
|
||||||
x->permutei({0, 2, 1}); // [time x bS x 2*inSize] -> [time x 2*inSize x bS]
|
x->permutei({0, 2, 1}); // [time x bS x 2*K] -> [time x 2*K x bS]
|
||||||
*gradW = mmul(*x, gradWi); // [time x 2*inSize x bS ] * [time x bS x 6*inSize] = [time x 2*inSize x 6*inSize]
|
MmulHelper::mmul(x, &gradWi, gradW, 1., 0.); // [time x 2*K x bS ] * [time x bS x 6*K] = [time x 2*K x 6*K]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -43,8 +43,8 @@ static void triuBP_(nd4j::LaunchContext * context, const NDArray& input, const N
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_IF(dLen > Environment::getInstance()->elementwiseThreshold())
|
PRAGMA_OMP_PARALLEL_FOR_IF(dLen > Environment::getInstance()->elementwiseThreshold())
|
||||||
for(int i = 0; i < dLen; ++i) {
|
for(int i = 0; i < dLen; ++i) {
|
||||||
if(dOdI.e<T>(i) != (T)0.f)
|
if(dOdI.t<T>(i) != static_cast<T>(0.f))
|
||||||
dOdI.p(i, T(1.f));
|
dOdI.t<T>(i) = static_cast<T>(1.f);
|
||||||
}
|
}
|
||||||
|
|
||||||
// FIXME: !!!
|
// FIXME: !!!
|
||||||
|
|
|
@ -32,30 +32,48 @@ namespace helpers {
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
template<typename X, typename Y>
|
template<typename X, typename Y>
|
||||||
__global__ void preluCuda(const void *vx, const Nd4jLong *xShapeInfo,
|
__global__ void preluCuda(const void *vx, const Nd4jLong *xShapeInfo,
|
||||||
const void *vy, const Nd4jLong *yShapeInfo,
|
const void *vy, const Nd4jLong *yShapeInfo,
|
||||||
void *vz) {
|
void *vz) {
|
||||||
|
|
||||||
const auto x = reinterpret_cast<const X*>(vx);
|
const auto x = reinterpret_cast<const X*>(vx);
|
||||||
const auto y = reinterpret_cast<const Y*>(vy);
|
const auto y = reinterpret_cast<const Y*>(vy);
|
||||||
auto z = reinterpret_cast<X*>(vz);
|
auto z = reinterpret_cast<X*>(vz);
|
||||||
|
|
||||||
__shared__ Nd4jLong len;
|
__shared__ Nd4jLong xzLen, totalThreads, *sharedMem;
|
||||||
|
__shared__ int xzRank, yRank;
|
||||||
|
|
||||||
if (threadIdx.x == 0)
|
if (threadIdx.x == 0) {
|
||||||
len = shape::length(xShapeInfo);
|
extern __shared__ unsigned char shmem[];
|
||||||
|
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
|
||||||
|
|
||||||
|
xzLen = shape::length(xShapeInfo);
|
||||||
|
totalThreads = gridDim.x * blockDim.x;
|
||||||
|
|
||||||
|
xzRank = shape::rank(xShapeInfo);
|
||||||
|
yRank = shape::rank(yShapeInfo);
|
||||||
|
}
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
const auto totalThreads = gridDim.x * blockDim.x;
|
Nd4jLong* coords = sharedMem + threadIdx.x * xzRank;
|
||||||
|
|
||||||
for (int i = tid; i < len; i += totalThreads) {
|
for (int i = tid; i < xzLen; i += totalThreads) {
|
||||||
|
|
||||||
const auto xzOffset = shape::getIndexOffset(i, xShapeInfo, len);
|
shape::index2coords(xzRank, xShapeInfo + 1, i, xzLen, coords);
|
||||||
const auto xVal = x[xzOffset];
|
|
||||||
|
|
||||||
if(xVal < 0)
|
const auto xzOffset = shape::getOffset(0, xShapeInfo + 1, xShapeInfo + xzRank + 1, coords, xzRank);
|
||||||
z[xzOffset] = xVal * y[shape::subArrayOffset(i, xShapeInfo, yShapeInfo)];
|
|
||||||
|
const auto xVal = x[xzOffset];
|
||||||
|
|
||||||
|
if(xVal < 0) {
|
||||||
|
|
||||||
|
for (uint j = 0; j < yRank; ++j)
|
||||||
|
if(yShapeInfo[j + 1] == 1)
|
||||||
|
coords[j + 1] = 0;
|
||||||
|
|
||||||
|
z[xzOffset] = xVal * y[shape::getOffset(0, yShapeInfo + 1, yShapeInfo + yRank + 1, coords + 1, yRank)];
|
||||||
|
}
|
||||||
else
|
else
|
||||||
z[xzOffset] = xVal;
|
z[xzOffset] = xVal;
|
||||||
}
|
}
|
||||||
|
@ -63,28 +81,121 @@ __global__ void preluCuda(const void *vx, const Nd4jLong *xShapeInfo,
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
template<typename X, typename Y>
|
template<typename X, typename Y>
|
||||||
linkage void preluCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz) {
|
linkage void preluCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz) {
|
||||||
|
|
||||||
preluCuda<X, Y><<<blocksPerGrid, threadsPerBlock, 1024, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz);
|
preluCuda<X, Y><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz);
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
void prelu(nd4j::LaunchContext * context, const NDArray& input, const NDArray& alpha, NDArray& output) {
|
void prelu(nd4j::LaunchContext * context, const NDArray& input, const NDArray& alpha, NDArray& output) {
|
||||||
if(!input.isActualOnDeviceSide()) input.syncToDevice();
|
|
||||||
if(!alpha.isActualOnDeviceSide()) alpha.syncToDevice();
|
PointersManager manager(context, "prelu");
|
||||||
|
|
||||||
|
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||||
|
const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
const int sharedMem = input.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128;
|
||||||
|
|
||||||
const auto xType = input.dataType();
|
const auto xType = input.dataType();
|
||||||
const auto yType = alpha.dataType();
|
const auto yType = alpha.dataType();
|
||||||
int threadsPerBlock = MAX_NUM_THREADS;
|
|
||||||
int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, yType, preluCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), alpha.getSpecialBuffer(), alpha.getSpecialShapeInfo(), output.getSpecialBuffer()), LIBND4J_TYPES, FLOAT_TYPES);
|
NDArray::prepareSpecialUse({&output}, {&input, &alpha});
|
||||||
|
BUILD_DOUBLE_SELECTOR(xType, yType, preluCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), alpha.getSpecialBuffer(), alpha.getSpecialShapeInfo(), output.getSpecialBuffer()), LIBND4J_TYPES, FLOAT_TYPES);
|
||||||
|
NDArray::registerSpecialUse({&output}, {&input, &alpha});
|
||||||
|
|
||||||
input.tickReadHost();
|
manager.synchronize();
|
||||||
alpha.tickReadHost();
|
|
||||||
output.tickWriteDevice();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template<typename X, typename Y>
|
||||||
|
__global__ linkage void preluBPCuda(const void *vIn, const Nd4jLong *inShapeInfo,
|
||||||
|
const void *vAlpha, const Nd4jLong *alphaShapeInfo,
|
||||||
|
const void *vdLdO, const Nd4jLong *dLdOShapeInfo,
|
||||||
|
void *vdLdI, const Nd4jLong *dLdIShapeInfo,
|
||||||
|
void *vdLdA, const Nd4jLong *dLdAShapeInfo) {
|
||||||
|
|
||||||
|
const auto in = reinterpret_cast<const X*>(vIn);
|
||||||
|
const auto alpha = reinterpret_cast<const Y*>(vAlpha);
|
||||||
|
const auto dLdO = reinterpret_cast<const Y*>(vdLdO);
|
||||||
|
auto dLdI = reinterpret_cast<Y*>(vdLdI);
|
||||||
|
auto dLdA = reinterpret_cast<Y*>(vdLdA);
|
||||||
|
|
||||||
|
__shared__ Nd4jLong inLen, totalThreads, *sharedMem;
|
||||||
|
__shared__ int inRank, alphaRank;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
extern __shared__ unsigned char shmem[];
|
||||||
|
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
|
||||||
|
|
||||||
|
inLen = shape::length(inShapeInfo);
|
||||||
|
totalThreads = gridDim.x * blockDim.x;
|
||||||
|
|
||||||
|
inRank = shape::rank(inShapeInfo);
|
||||||
|
alphaRank = shape::rank(alphaShapeInfo);
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
Nd4jLong* coords = sharedMem + threadIdx.x * inRank;
|
||||||
|
|
||||||
|
for (int i = tid; i < inLen; i += totalThreads) {
|
||||||
|
|
||||||
|
shape::index2coords(inRank, inShapeInfo + 1, i, inLen, coords);
|
||||||
|
|
||||||
|
const auto inOffset = shape::getOffset(0, inShapeInfo + 1, inShapeInfo + inRank + 1, coords, inRank);
|
||||||
|
const auto dLdOOffset = shape::getOffset(0, dLdOShapeInfo + 1, dLdOShapeInfo + inRank + 1, coords, inRank);
|
||||||
|
const auto dLdIOffset = shape::getOffset(0, dLdIShapeInfo + 1, dLdIShapeInfo + inRank + 1, coords, inRank);
|
||||||
|
|
||||||
|
const auto xVal = in[inOffset];
|
||||||
|
const auto grO = dLdO[dLdOOffset];
|
||||||
|
|
||||||
|
if(xVal < 0) {
|
||||||
|
|
||||||
|
for (uint j = 0; j < alphaRank; ++j)
|
||||||
|
if(alphaShapeInfo[j + 1] == 1)
|
||||||
|
coords[j + 1] = 0;
|
||||||
|
|
||||||
|
const auto alphaOffset = shape::getOffset(0, alphaShapeInfo + 1, alphaShapeInfo + alphaRank + 1, coords + 1, alphaRank);
|
||||||
|
const auto dLdAOffset = shape::getOffset(0, dLdAShapeInfo + 1, dLdAShapeInfo + alphaRank + 1, coords + 1, alphaRank);
|
||||||
|
|
||||||
|
dLdI[dLdIOffset] = grO * alpha[alphaOffset];
|
||||||
|
|
||||||
|
nd4j::math::atomics::nd4j_atomicAdd<Y>(&dLdA[dLdAOffset], static_cast<Y>(grO * xVal));
|
||||||
|
}
|
||||||
|
else
|
||||||
|
dLdI[dLdIOffset] = grO;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template<typename X, typename Y>
|
||||||
|
__host__ linkage void preluBPCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vIn, const Nd4jLong *inShapeInfo, const void *vAlpha, const Nd4jLong *alphaShapeInfo, const void *vdLdO, const Nd4jLong *dLdOShapeInfo, void *vdLdI, const Nd4jLong *dLdIShapeInfo, void *vdLdA, const Nd4jLong *dLdAShapeInfo) {
|
||||||
|
|
||||||
|
preluBPCuda<X, Y><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vIn, inShapeInfo, vAlpha, alphaShapeInfo, vdLdO, dLdOShapeInfo, vdLdI, dLdIShapeInfo, vdLdA, dLdAShapeInfo);
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
void preluBP(nd4j::LaunchContext* context, const NDArray& input, const NDArray& alpha, const NDArray& dLdO, NDArray& dLdI, NDArray& dLdA) {
|
||||||
|
|
||||||
|
dLdA.nullify();
|
||||||
|
|
||||||
|
PointersManager manager(context, "preluBP");
|
||||||
|
|
||||||
|
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||||
|
const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
const int sharedMem = input.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128;
|
||||||
|
|
||||||
|
const auto xType = input.dataType();
|
||||||
|
const auto zType = alpha.dataType();
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&dLdI, &dLdA}, {&input, &alpha, &dLdO});
|
||||||
|
BUILD_DOUBLE_SELECTOR(xType, zType, preluBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), alpha.getSpecialBuffer(), alpha.getSpecialShapeInfo(), dLdO.getSpecialBuffer(), dLdO.getSpecialShapeInfo(), dLdI.getSpecialBuffer(), dLdI.getSpecialShapeInfo(), dLdA.getSpecialBuffer(), dLdA.getSpecialShapeInfo()), LIBND4J_TYPES, FLOAT_TYPES);
|
||||||
|
NDArray::registerSpecialUse({&dLdI, &dLdA}, {&input, &alpha, &dLdO});
|
||||||
|
|
||||||
|
manager.synchronize();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
template<typename T>
|
template<typename T>
|
||||||
__global__ void softMaxForVectorCuda(const void *vx, const Nd4jLong *xzShapeInfo, void *vz) {
|
__global__ void softMaxForVectorCuda(const void *vx, const Nd4jLong *xzShapeInfo, void *vz) {
|
||||||
|
@ -439,88 +550,6 @@ void softmaxDerivative(nd4j::LaunchContext * context, const NDArray& input, NDAr
|
||||||
output.tickWriteDevice();
|
output.tickWriteDevice();
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
|
||||||
template<typename X, typename Y>
|
|
||||||
__global__ linkage void preluBPCuda(const void *vIn, const Nd4jLong *inShapeInfo,
|
|
||||||
const void *vAlpha, const Nd4jLong *alphaShapeInfo,
|
|
||||||
const void *vdLdO, const Nd4jLong *dLdOShapeInfo,
|
|
||||||
void *vdLdI, const Nd4jLong *dLdIShapeInfo,
|
|
||||||
void *vdLdA, const Nd4jLong *dLdAShapeInfo) {
|
|
||||||
|
|
||||||
const auto in = reinterpret_cast<const X*>(vIn);
|
|
||||||
const auto alpha = reinterpret_cast<const Y*>(vAlpha);
|
|
||||||
const auto dLdO = reinterpret_cast<const Y*>(vdLdO);
|
|
||||||
auto dLdI = reinterpret_cast<Y*>(vdLdI);
|
|
||||||
auto dLdA = reinterpret_cast<Y*>(vdLdA);
|
|
||||||
|
|
||||||
__shared__ Nd4jLong alphaLen;
|
|
||||||
|
|
||||||
if (threadIdx.x == 0)
|
|
||||||
alphaLen = shape::length(alphaShapeInfo);
|
|
||||||
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
const auto i = blockIdx.x * blockDim.x + threadIdx.x;
|
|
||||||
if (i >= alphaLen) return;
|
|
||||||
|
|
||||||
Nd4jLong inputIdxs[MAX_RANK*2];
|
|
||||||
int numIdxs = shape::outerArrayOffsets(inputIdxs, i, inShapeInfo, alphaShapeInfo);
|
|
||||||
Nd4jLong dLdOIdxs[MAX_RANK*2];
|
|
||||||
shape::outerArrayOffsets(dLdOIdxs, i, dLdOShapeInfo, alphaShapeInfo);
|
|
||||||
Nd4jLong dLdIIdxs[MAX_RANK*2];
|
|
||||||
shape::outerArrayOffsets(dLdIIdxs, i, dLdIShapeInfo, alphaShapeInfo);
|
|
||||||
|
|
||||||
const auto alphaOffset = shape::getIndexOffset(i, alphaShapeInfo, alphaLen);
|
|
||||||
const auto dLdAOffset = shape::getIndexOffset(i, dLdAShapeInfo, alphaLen);
|
|
||||||
|
|
||||||
for(Nd4jLong j = 0; j < numIdxs; ++j) {
|
|
||||||
|
|
||||||
const auto inInd = inputIdxs[j];
|
|
||||||
const auto dLdOInd = dLdOIdxs[j];
|
|
||||||
const auto dLdIInd = dLdIIdxs[j];
|
|
||||||
|
|
||||||
if(in[inInd] < 0) {
|
|
||||||
dLdI[dLdIInd] = dLdO[dLdOInd] * alpha[alphaOffset];
|
|
||||||
auto prevVal = dLdA[dLdAOffset];
|
|
||||||
prevVal = prevVal + dLdO[dLdOInd] * in[inInd];
|
|
||||||
dLdA[dLdAOffset] = prevVal;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
dLdI[dLdIInd] = dLdO[dLdOInd];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
template<typename X, typename Y>
|
|
||||||
__host__ linkage void preluBPCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void *vIn, const Nd4jLong *inShapeInfo, const void *vAlpha, const Nd4jLong *alphaShapeInfo, const void *vdLdO, const Nd4jLong *dLdOShapeInfo, void *vdLdI, const Nd4jLong *dLdIShapeInfo, void *vdLdA, const Nd4jLong *dLdAShapeInfo) {
|
|
||||||
|
|
||||||
preluBPCuda<X, Y><<<blocksPerGrid, threadsPerBlock, 1024, *stream>>>(vIn, inShapeInfo, vAlpha, alphaShapeInfo, vdLdO, dLdOShapeInfo, vdLdI, dLdIShapeInfo, vdLdA, dLdAShapeInfo);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
|
||||||
void preluBP(nd4j::LaunchContext * context, const NDArray& input, const NDArray& alpha, const NDArray& dLdO, NDArray& dLdI, NDArray& dLdA) {
|
|
||||||
|
|
||||||
if(!input.isActualOnDeviceSide()) input.syncToDevice();
|
|
||||||
if(!alpha.isActualOnDeviceSide()) alpha.syncToDevice();
|
|
||||||
if(!dLdO.isActualOnDeviceSide()) dLdO.syncToDevice();
|
|
||||||
|
|
||||||
const auto xType = input.dataType();
|
|
||||||
const auto zType = dLdO.dataType();
|
|
||||||
|
|
||||||
int threadsPerBlock = MAX_NUM_THREADS;
|
|
||||||
int blocksPerGrid = (alpha.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, preluBPCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), alpha.getSpecialBuffer(), alpha.getSpecialShapeInfo(), dLdO.getSpecialBuffer(), dLdO.getSpecialShapeInfo(), dLdI.getSpecialBuffer(), dLdI.getSpecialShapeInfo(), dLdA.getSpecialBuffer(), dLdA.getSpecialShapeInfo()), LIBND4J_TYPES, FLOAT_TYPES);
|
|
||||||
|
|
||||||
input.tickReadHost();
|
|
||||||
alpha.tickReadHost();
|
|
||||||
dLdO.tickReadHost();
|
|
||||||
dLdI.tickWriteDevice();
|
|
||||||
dLdA.tickWriteDevice();
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
linkage void thresholdRelu_(NDArray const& input, double threshold, NDArray& output) {
|
linkage void thresholdRelu_(NDArray const& input, double threshold, NDArray& output) {
|
||||||
|
@ -545,8 +574,8 @@ __host__ linkage void preluBPCudaLauncher(const int blocksPerGrid, const int thr
|
||||||
|
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void thresholdReluDerivative_, (NDArray* input, double threshold, NDArray* dLdO, NDArray* output), FLOAT_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void thresholdReluDerivative_, (NDArray* input, double threshold, NDArray* dLdO, NDArray* output), FLOAT_TYPES);
|
||||||
BUILD_DOUBLE_TEMPLATE(template void preluCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz), LIBND4J_TYPES, FLOAT_TYPES);
|
BUILD_DOUBLE_TEMPLATE(template void preluCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz), LIBND4J_TYPES, FLOAT_TYPES);
|
||||||
BUILD_DOUBLE_TEMPLATE(template void preluBPCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void *vIn, const Nd4jLong *inShapeInfo, const void *vAlpha, const Nd4jLong *alphaShapeInfo, const void *vdLdO, const Nd4jLong *dLdOShapeInfo, void *vdLdI, const Nd4jLong *dLdIShapeInfo, void *vdLdA, const Nd4jLong *dLdAShapeInfo), LIBND4J_TYPES, FLOAT_TYPES);
|
BUILD_DOUBLE_TEMPLATE(template void preluBPCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vIn, const Nd4jLong *inShapeInfo, const void *vAlpha, const Nd4jLong *alphaShapeInfo, const void *vdLdO, const Nd4jLong *dLdOShapeInfo, void *vdLdI, const Nd4jLong *dLdIShapeInfo, void *vdLdA, const Nd4jLong *dLdAShapeInfo), LIBND4J_TYPES, FLOAT_TYPES);
|
||||||
BUILD_SINGLE_TEMPLATE(template void softMaxForVectorCudaLauncher, (const cudaStream_t* stream, const void *vx, const Nd4jLong *xzShapeInfo, void *vz), FLOAT_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void softMaxForVectorCudaLauncher, (const cudaStream_t* stream, const void *vx, const Nd4jLong *xzShapeInfo, void *vz), FLOAT_TYPES);
|
||||||
BUILD_SINGLE_TEMPLATE(template void softMaxDerivForVectorCudaLauncher, (const cudaStream_t* stream, const void *vx, const Nd4jLong *xzShapeInfo, void *vz), FLOAT_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void softMaxDerivForVectorCudaLauncher, (const cudaStream_t* stream, const void *vx, const Nd4jLong *xzShapeInfo, void *vz), FLOAT_TYPES);
|
||||||
|
|
||||||
|
|
|
@ -803,8 +803,12 @@ __global__ static void pooling3dCuda(const void* vx, const Nd4jLong* xShapeInfo,
|
||||||
for (coords[4] = wstart; coords[4] < wend; coords[4] += dW)
|
for (coords[4] = wstart; coords[4] < wend; coords[4] += dW)
|
||||||
sum += x[shape::getOffset(0, xShapeInfo + 1, xShapeInfo + rank + 1, coords, rank)];
|
sum += x[shape::getOffset(0, xShapeInfo + 1, xShapeInfo + rank + 1, coords, rank)];
|
||||||
|
|
||||||
if (extraParam0 == 0) //Exclude padding
|
if (extraParam0 == 0) { //Exclude padding
|
||||||
sum /= nd4j::math::nd4j_ceil<double,T>(static_cast<double>(dend - dstart) / static_cast<double>(dD)) * nd4j::math::nd4j_ceil<double,T>(static_cast<double>(hend - hstart) / static_cast<double>(dH)) * nd4j::math::nd4j_ceil<double,T>(static_cast<double>(wend - wstart) / static_cast<double>(dW)); //Accounts for dilation
|
uint a = (dend - dstart) / dD + ((dend - dstart) % dD == 0 ? 0 : 1);
|
||||||
|
uint b = (hend - hstart) / dH + ((hend - hstart) % dH == 0 ? 0 : 1);
|
||||||
|
uint c = (wend - wstart) / dW + ((wend - wstart) % dW == 0 ? 0 : 1);
|
||||||
|
sum /= static_cast<T>(a * b * c); // /= nd4j::math::nd4j_ceil<double,T>(static_cast<double>(dend - dstart) / static_cast<double>(dD)) * nd4j::math::nd4j_ceil<double,T>(static_cast<double>(hend - hstart) / static_cast<double>(dH)) * nd4j::math::nd4j_ceil<double,T>(static_cast<double>(wend - wstart) / static_cast<double>(dW)); //Accounts for dilation
|
||||||
|
}
|
||||||
else if (extraParam0 == 1) //Include padding
|
else if (extraParam0 == 1) //Include padding
|
||||||
sum /= kProd;
|
sum /= kProd;
|
||||||
|
|
||||||
|
|
|
@ -15,26 +15,123 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// @author raver119@gmail.com
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <ops/declarable/helpers/dilation2d.h>
|
#include <ops/declarable/helpers/dilation2d.h>
|
||||||
#include <array/DataTypeUtils.h>
|
#include <array/DataTypeUtils.h>
|
||||||
|
#include <PointersManager.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
template <typename X, typename Y>
|
|
||||||
static void dilation2d_(NDArray *input, NDArray *weights, NDArray *output, int stride_rows, int stride_cols, int rate_rows, int rate_cols, int pad_top, int pad_left) {
|
|
||||||
|
|
||||||
};
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename X, typename Z>
|
||||||
|
__global__ static void dilation2dCuda(const void* vx, const Nd4jLong* xShapeInfo,
|
||||||
|
const void* vy, const Nd4jLong* yShapeInfo,
|
||||||
|
void* vz, const Nd4jLong* zShapeInfo,
|
||||||
|
const int sH, const int sW,
|
||||||
|
const int pH, const int pW,
|
||||||
|
const int dH, const int dW) {
|
||||||
|
|
||||||
void dilation2d(nd4j::LaunchContext * context, NDArray *input, NDArray *weights, NDArray *output, int stride_rows, int stride_cols, int rate_rows, int rate_cols, int pad_top, int pad_left) {
|
// x [bS, iH, iW, iC]
|
||||||
BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), dilation2d_, (input, weights, output, stride_rows, stride_cols, rate_rows, rate_cols, pad_top, pad_left), LIBND4J_TYPES, FLOAT_TYPES);
|
// y [kH, kW, iC]
|
||||||
|
// z [bS, oH, oW, iC]
|
||||||
|
|
||||||
|
const X* x = reinterpret_cast<const X*>(vx);
|
||||||
|
const X* y = reinterpret_cast<const X*>(vy);
|
||||||
|
Z* z = reinterpret_cast<Z*>(vz);
|
||||||
|
|
||||||
|
__shared__ int xzRank, yRank;
|
||||||
|
__shared__ uint iH, iW, kH, kW;
|
||||||
|
__shared__ Nd4jLong *sharedMem, zLen;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
extern __shared__ unsigned char shmem[];
|
||||||
|
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
|
||||||
|
|
||||||
|
zLen = shape::length(zShapeInfo);
|
||||||
|
|
||||||
|
xzRank = shape::rank(xShapeInfo);
|
||||||
|
yRank = shape::rank(yShapeInfo);
|
||||||
|
|
||||||
|
iH = xShapeInfo[2];
|
||||||
|
iW = xShapeInfo[3];
|
||||||
|
|
||||||
|
kH = yShapeInfo[1];
|
||||||
|
kW = yShapeInfo[2];
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_DOUBLE_TEMPLATE(template void dilation2d_, (NDArray *input, NDArray *weights, NDArray *output, int stride_rows, int stride_cols, int rate_rows, int rate_cols, int pad_top, int pad_left), LIBND4J_TYPES, FLOAT_TYPES);
|
__syncthreads();
|
||||||
|
|
||||||
|
const auto zInd = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
|
||||||
|
if(zInd >= zLen)
|
||||||
|
return;
|
||||||
|
|
||||||
|
auto xzCoords = sharedMem + threadIdx.x * (xzRank + yRank);
|
||||||
|
auto yCoords = xzCoords + xzRank;
|
||||||
|
|
||||||
|
shape::index2coords(xzRank, zShapeInfo + 1, zInd, zLen, xzCoords);
|
||||||
|
|
||||||
|
const auto zOffset = shape::getOffset(zShapeInfo, xzCoords);
|
||||||
|
|
||||||
|
yCoords[2] = xzCoords[3]; // iC coordinate is same for x, y and z
|
||||||
|
|
||||||
|
const auto oh = xzCoords[1];
|
||||||
|
const auto ow = xzCoords[2];
|
||||||
|
|
||||||
|
X max = -DataTypeUtils::max<X>();
|
||||||
|
|
||||||
|
for (yCoords[0] = 0; yCoords[0] < kH; ++yCoords[0]) {
|
||||||
|
xzCoords[1] = oh * sH - pH + yCoords[0] * dH;
|
||||||
|
if (xzCoords[1] < 0 || xzCoords[1] >= iH) continue;
|
||||||
|
|
||||||
|
for (yCoords[1] = 0; yCoords[1] < kW; ++yCoords[1]) {
|
||||||
|
xzCoords[2] = ow * sW - pW + yCoords[1] * dW;
|
||||||
|
if(xzCoords[2] < 0 || xzCoords[2] >= iW) continue;
|
||||||
|
|
||||||
|
const X val = x[shape::getOffset(xShapeInfo, xzCoords)] + y[shape::getOffset(yShapeInfo, yCoords)];
|
||||||
|
if (val > max)
|
||||||
|
max = val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
z[zOffset] = static_cast<Z>(max);
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename X, typename Z>
|
||||||
|
static void dilation2dCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream,
|
||||||
|
const void* vx, const Nd4jLong* xShapeInfo,
|
||||||
|
const void* vy, const Nd4jLong* yShapeInfo,
|
||||||
|
void* vz, const Nd4jLong* zShapeInfo,
|
||||||
|
const int sH, const int sW,
|
||||||
|
const int pH, const int pW,
|
||||||
|
const int dH, const int dW) {
|
||||||
|
|
||||||
|
dilation2dCuda<X,Z><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, sH, sW, pH, pW, dH, dW);
|
||||||
|
}
|
||||||
|
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template void dilation2dCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW), LIBND4J_TYPES, FLOAT_TYPES);
|
||||||
|
|
||||||
|
void dilation2d(nd4j::LaunchContext* context, NDArray *input, NDArray *weights, NDArray *output, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW) {
|
||||||
|
|
||||||
|
PointersManager manager(context, "dilation2d");
|
||||||
|
|
||||||
|
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||||
|
const int blocksPerGrid = (output->lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
const int sharedMem = (weights->rankOf() + output->rankOf()) * sizeof(Nd4jLong) * threadsPerBlock + 128;
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({output}, {input, weights});
|
||||||
|
BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), dilation2dCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), weights->getSpecialBuffer(), weights->getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), sH, sW, pH, pW, dH, dW), LIBND4J_TYPES, FLOAT_TYPES);
|
||||||
|
NDArray::registerSpecialUse({output}, {input, weights});
|
||||||
|
|
||||||
|
manager.synchronize();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -107,6 +107,102 @@ void gruTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray*
|
||||||
void gruCellBP(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh, const NDArray* dLdWx0,
|
void gruCellBP(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh, const NDArray* dLdWx0,
|
||||||
const NDArray* dLdWh0, const NDArray* dLdb0, NDArray* dLdx, NDArray* dLdh0, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb) {
|
const NDArray* dLdWh0, const NDArray* dLdb0, NDArray* dLdx, NDArray* dLdh0, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb) {
|
||||||
|
|
||||||
|
// x input [bS, iS]
|
||||||
|
// h0 previous cell output [bS, nU], that is at previous time step t-1
|
||||||
|
// Wx input-to-hidden weights, [iS, 3*nU]
|
||||||
|
// Wh hidden-to-hidden weights, [nU, 3*nU]
|
||||||
|
// b biases, [3*nU]
|
||||||
|
// dLdh gradient wrt output, [bS,nU], that is epsilon_next
|
||||||
|
// dLdWx0 gradient wrt Wx at previous time step, [iS, 3*nU]
|
||||||
|
// dLdWh0 gradient wrt Wh at previous time step, [nU, 3*nU]
|
||||||
|
// dLdb0 gradient wrt b at previous time step, [3*nU]
|
||||||
|
|
||||||
|
// dLdx gradient wrt x, [bS, iS], that is epsilon
|
||||||
|
// dLdh0 gradient wrt h0, [bS, nU]
|
||||||
|
// dLdWx gradient wrt Wx, [iS, 3*nU]
|
||||||
|
// dLdWh gradient wrt Wh, [nU, 3*nU]
|
||||||
|
// dLdb gradient wrt b at previous time step, [3*nU]
|
||||||
|
|
||||||
|
// h is current cell output [bS, nU], that is at current time step t
|
||||||
|
|
||||||
|
const int nU = h0->sizeAt(1);
|
||||||
|
|
||||||
|
// ***** feed forward step ***** //
|
||||||
|
// gates = sigmoid(x*Wx + h0*Wh + b)
|
||||||
|
auto gates = sigmoid(mmul(*x, (*Wx)({0,0, 0,2*nU})) + mmul(*h0, (*Wh)({0,0, 0,2*nU})) + (*b)({0,2*nU})); // [bS, 2*nU] + [bS, 2*nU] + [1, 2*nU] = [bS, 2*nU]
|
||||||
|
// reset gate
|
||||||
|
auto r = gates({0,0, 0, nU}); // [bS, nU]
|
||||||
|
// update gate
|
||||||
|
auto u = gates({0,0, nU, 2*nU}); // [bS, nU]
|
||||||
|
// ◦ means element-wise product or so called Hadamard product
|
||||||
|
// n = tanh(x*Wx + (r◦h0)*Wh + b)
|
||||||
|
auto n = tanh(mmul(*x, (*Wx)({0,0, 2*nU,3*nU})) + mmul((*h0)*r, (*Wh)({0,0, 2*nU,3*nU})) + (*b)({2*nU,3*nU})); // [bS, nU]
|
||||||
|
|
||||||
|
// ***** back prop step ***** //
|
||||||
|
auto Wxr = (*Wx)({0,0, 0, nU});
|
||||||
|
auto Wxu = (*Wx)({0,0, nU, 2*nU});
|
||||||
|
auto Wxn = (*Wx)({0,0, 2*nU,3*nU});
|
||||||
|
auto Whr = (*Wh)({0,0, 0, nU});
|
||||||
|
auto Whu = (*Wh)({0,0, nU, 2*nU});
|
||||||
|
auto Whn = (*Wh)({0,0, 2*nU,3*nU});
|
||||||
|
auto WxrT = Wxr.transpose();
|
||||||
|
auto WxuT = Wxu.transpose();
|
||||||
|
auto WxnT = Wxn.transpose();
|
||||||
|
auto WhrT = Whr.transpose();
|
||||||
|
auto WhuT = Whu.transpose();
|
||||||
|
auto WhnT = Whn.transpose();
|
||||||
|
auto xT = x->transpose();
|
||||||
|
auto h0T = h0->transpose();
|
||||||
|
|
||||||
|
auto dLdWxr = (*dLdWx)({0,0, 0, nU});
|
||||||
|
auto dLdWxu = (*dLdWx)({0,0, nU, 2*nU});
|
||||||
|
auto dLdWxn = (*dLdWx)({0,0, 2*nU,3*nU});
|
||||||
|
|
||||||
|
auto dLdWhr = (*dLdWh)({0,0, 0, nU});
|
||||||
|
auto dLdWhu = (*dLdWh)({0,0, nU, 2*nU});
|
||||||
|
auto dLdWhn = (*dLdWh)({0,0, 2*nU,3*nU});
|
||||||
|
|
||||||
|
auto dLdbr = (*dLdb)({0, nU});
|
||||||
|
auto dLdbu = (*dLdb)({nU, 2*nU});
|
||||||
|
auto dLdbn = (*dLdb)({2*nU,3*nU});
|
||||||
|
|
||||||
|
auto dhdu = *h0 - n; // [bS, nU]
|
||||||
|
auto dhdn = 1.f - u; // [bS, nU]
|
||||||
|
auto dSigdu = u * (1.f - u); // [bS, nU]
|
||||||
|
auto dSigdr = r * (1.f - r); // [bS, nU]
|
||||||
|
auto dActdn = 1.f - n * n; // [bS, nU]
|
||||||
|
auto dndr = mmul(dActdn * (*h0), WhnT);
|
||||||
|
auto drdh0 = mmul(dSigdr, WhrT);
|
||||||
|
|
||||||
|
auto dLdn = (*dLdh) * dhdn;
|
||||||
|
auto dLdu = (*dLdh) * dhdu;
|
||||||
|
auto dLdr = dLdn * dndr;
|
||||||
|
|
||||||
|
dLdx->assign( mmul(dLdu * dSigdu, WxuT) + mmul(dLdr * dSigdr, WxrT) + mmul(dLdn * dActdn, WxnT) ); // [bS,iS]
|
||||||
|
dLdh0->assign( mmul(dLdu * dSigdu, WhuT) + mmul(dLdn * dActdn * (r + drdh0), WhnT) + (*dLdh)*u ); // [bS,nU]
|
||||||
|
|
||||||
|
dLdWxr.assign( mmul(xT, dSigdr * dLdr) ); // [iS,nU]
|
||||||
|
dLdWhr.assign( mmul(h0T, dSigdr * dLdr) ); // [nU,nU]
|
||||||
|
|
||||||
|
dLdWxu.assign( mmul(xT, dSigdu * dLdu) ); // [iS,nU]
|
||||||
|
dLdWhu.assign( mmul(h0T, dSigdu * dLdu) ); // [nU,nU]
|
||||||
|
|
||||||
|
dLdWxn.assign( mmul(xT, dActdn * dLdn) ); // [iS,nU]
|
||||||
|
dLdWhn.assign( mmul((r*(*h0)).transpose(), dActdn * dLdn) ); // [nU,nU]
|
||||||
|
|
||||||
|
dLdbr.assign( (dSigdr * dLdr).reduceAlongDims(reduce::Sum, {0})); // [nU]
|
||||||
|
dLdbu.assign( (dSigdu * dLdu).reduceAlongDims(reduce::Sum, {0})); // [nU]
|
||||||
|
dLdbn.assign( (dActdn * dLdn).reduceAlongDims(reduce::Sum, {0})); // [nU]
|
||||||
|
|
||||||
|
if(dLdWx0 != nullptr)
|
||||||
|
*dLdWx += *dLdWx0;
|
||||||
|
|
||||||
|
if(dLdWh0 != nullptr)
|
||||||
|
*dLdWh += *dLdWh0;
|
||||||
|
|
||||||
|
if(dLdb0 != nullptr)
|
||||||
|
*dLdb += *dLdb0;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -23,13 +23,17 @@
|
||||||
#include <NDArrayFactory.h>
|
#include <NDArrayFactory.h>
|
||||||
#include <Status.h>
|
#include <Status.h>
|
||||||
#include <ConstantTadHelper.h>
|
#include <ConstantTadHelper.h>
|
||||||
|
#include <ShapeUtils.h>
|
||||||
|
|
||||||
|
#include <cusolverDn.h>
|
||||||
|
#include <cuda_exception.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static __device__ void _swapRows(T* matrix, Nd4jLong* shape, int theFirst, int theSecond, Nd4jLong N) {
|
static __device__ void swapRows_(T* matrix, Nd4jLong* shape, int theFirst, int theSecond, Nd4jLong N) {
|
||||||
if (theFirst != theSecond) {
|
if (theFirst != theSecond) {
|
||||||
auto start = threadIdx.x + blockIdx.x * blockDim.x;
|
auto start = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
auto step = blockDim.x * gridDim.x;
|
auto step = blockDim.x * gridDim.x;
|
||||||
|
@ -46,32 +50,180 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// BUILD_SINGLE_TEMPLATE(template void _swapRows, (NDArray* matrix, int theFirst, int theSecond), FLOAT_TYPES);
|
// BUILD_SINGLE_TEMPLATE(template void swapRows_, (NDArray* matrix, int theFirst, int theSecond), FLOAT_TYPES);
|
||||||
//
|
//
|
||||||
// void swapRows(NDArray* matrix, int theFirst, int theSecond) {
|
// void swapRows(NDArray* matrix, int theFirst, int theSecond) {
|
||||||
// BUILD_SINGLE_SELECTOR(matrix->dataType(), _swapRows, (matrix, theFirst, theSecond), FLOAT_TYPES);
|
// BUILD_SINGLE_SELECTOR(matrix->dataType(), swapRows_, (matrix, theFirst, theSecond), FLOAT_TYPES);
|
||||||
// }
|
// }
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static void _invertLowerMatrix(NDArray* inputMatrix, NDArray* invertedMatrix) {
|
static __global__ void invertKernelLow(void* invertedBuf, Nd4jLong* invertedShape, void* inputBuf, Nd4jLong* inputShape, Nd4jLong n) {
|
||||||
|
__shared__ T* inverted;
|
||||||
|
__shared__ T* input;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
inverted = reinterpret_cast<T*>(invertedBuf);
|
||||||
|
input = reinterpret_cast<T*>(inputBuf);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
auto start = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
auto step = blockDim.x * gridDim.x;
|
||||||
|
|
||||||
|
for (int i = start + 1; i < n; i += step) {
|
||||||
|
Nd4jLong pos[] = {i, i - 1};
|
||||||
|
auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2);
|
||||||
|
auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), pos, 2);
|
||||||
|
inverted[zIndex] = -input[xIndex];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void _invertLowerMatrix, (NDArray* inputMatrix, NDArray* invertedMatrix);, FLOAT_TYPES);
|
template <typename T>
|
||||||
|
static __global__ void upvertKernel(void* invertedBuf, Nd4jLong* invertedShape, void* inputBuf, Nd4jLong* inputShape, Nd4jLong n) {
|
||||||
|
__shared__ T* inverted;
|
||||||
|
__shared__ T* input;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
inverted = reinterpret_cast<T*>(invertedBuf);
|
||||||
|
input = reinterpret_cast<T*>(inputBuf);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
auto start = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
auto step = blockDim.x * gridDim.x;
|
||||||
|
|
||||||
|
for (int i = start + 1; i < n; i += step) {
|
||||||
|
Nd4jLong pos[] = {i, i};
|
||||||
|
auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2);
|
||||||
|
auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), pos, 2);
|
||||||
|
inverted[zIndex] /= input[xIndex];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static __global__ void upvertKernelUp(void* invertedBuf, Nd4jLong* invertedShape, void* inputBuf, Nd4jLong* inputShape, Nd4jLong n) {
|
||||||
|
__shared__ T* inverted;
|
||||||
|
__shared__ T* input;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
inverted = reinterpret_cast<T*>(invertedBuf);
|
||||||
|
input = reinterpret_cast<T*>(inputBuf);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
auto start = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
auto step = blockDim.x * gridDim.x;
|
||||||
|
|
||||||
|
for (int i = start + 1; i < n - 1; i += step) {
|
||||||
|
Nd4jLong pos[] = {i, i + 1};
|
||||||
|
Nd4jLong posY[] = {i, i};
|
||||||
|
Nd4jLong posX[] = {i + 1, i + 1};
|
||||||
|
auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2);
|
||||||
|
auto yIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2);
|
||||||
|
// auto yIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), pos, 2);
|
||||||
|
auto iIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posX, 2);
|
||||||
|
auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), pos, 2);
|
||||||
|
inverted[zIndex] -= input[xIndex] * inverted[iIndex] / input[yIndex];
|
||||||
|
//inputMatrix->t<T>(i, i + 1) * invertedMatrix->t<T>(i + 1, i + 1) / inputMatrix->t<T>(i, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static __global__ void invertLowKernel(void* invertedBuf, Nd4jLong* invertedShape, void* inputBuf, Nd4jLong* inputShape, Nd4jLong n) {
|
||||||
|
__shared__ T* inverted;
|
||||||
|
__shared__ T* input;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
inverted = reinterpret_cast<T*>(invertedBuf);
|
||||||
|
input = reinterpret_cast<T*>(inputBuf);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// auto start = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
// auto step = blockDim.x * gridDim.x;
|
||||||
|
|
||||||
|
for (int i = blockIdx.x + 2; i < n; i += gridDim.x) {
|
||||||
|
for (int j = i - 2; j > -1; --j)
|
||||||
|
for (int k = threadIdx.x; k < i; k+= blockDim.x) {
|
||||||
|
Nd4jLong posZ[] = {i, j};
|
||||||
|
Nd4jLong posX[] = {k, j};
|
||||||
|
Nd4jLong posY[] = {i, k};
|
||||||
|
|
||||||
|
auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posX, 2);
|
||||||
|
auto yIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posY, 2);
|
||||||
|
auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posZ, 2);
|
||||||
|
inverted[zIndex] -= inverted[yIndex] * input[xIndex];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static __global__ void invertUpKernel(void* invertedBuf, Nd4jLong* invertedShape, void* inputBuf, Nd4jLong* inputShape, Nd4jLong n) {
|
||||||
|
__shared__ T* inverted;
|
||||||
|
__shared__ T* input;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
inverted = reinterpret_cast<T*>(invertedBuf);
|
||||||
|
input = reinterpret_cast<T*>(inputBuf);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// auto start = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
// auto step = blockDim.x * gridDim.x;
|
||||||
|
|
||||||
|
for (int i = n - blockIdx.x - 2; i >= 0; i -= gridDim.x) {
|
||||||
|
for (int j = i + 2; j < n; j++)
|
||||||
|
for (int k = i + threadIdx.x; k < n; k+= blockDim.x) {
|
||||||
|
Nd4jLong posZ[] = {i, j};
|
||||||
|
Nd4jLong posY[] = {k, j};
|
||||||
|
Nd4jLong posX[] = {i, k};
|
||||||
|
Nd4jLong posD[] = {i, i};
|
||||||
|
|
||||||
|
auto xIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posX, 2);
|
||||||
|
auto yIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posY, 2);
|
||||||
|
auto dIndex = shape::getOffset(0, shape::shapeOf(inputShape), shape::stride(inputShape), posD, 2);
|
||||||
|
auto zIndex = shape::getOffset(0, shape::shapeOf(invertedShape), shape::stride(invertedShape), posZ, 2);
|
||||||
|
inverted[zIndex] -= inverted[yIndex] * input[xIndex] / input[dIndex];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static void invertLowerMatrix_(NDArray* inputMatrix, NDArray* invertedMatrix) {
|
||||||
|
int n = inputMatrix->rows();
|
||||||
|
invertedMatrix->setIdentity();
|
||||||
|
|
||||||
|
if (inputMatrix->isIdentityMatrix()) return;
|
||||||
|
LaunchContext* context = inputMatrix->getContext();
|
||||||
|
auto stream = context->getCudaStream();
|
||||||
|
|
||||||
|
invertKernelLow<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||||
|
invertLowKernel<T><<<n, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||||
|
}
|
||||||
|
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void invertLowerMatrix_, (NDArray* inputMatrix, NDArray* invertedMatrix);, FLOAT_TYPES);
|
||||||
|
|
||||||
void invertLowerMatrix(NDArray* inputMatrix, NDArray* invertedMatrix) {
|
void invertLowerMatrix(NDArray* inputMatrix, NDArray* invertedMatrix) {
|
||||||
BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), _invertLowerMatrix, (inputMatrix, invertedMatrix), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertLowerMatrix_, (inputMatrix, invertedMatrix), FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static void _invertUpperMatrix(NDArray* inputMatrix, NDArray* invertedMatrix) {
|
static void invertUpperMatrix_(NDArray* inputMatrix, NDArray* invertedMatrix) {
|
||||||
|
int n = inputMatrix->rows();
|
||||||
|
invertedMatrix->setIdentity();
|
||||||
|
auto stream = inputMatrix->getContext()->getCudaStream();
|
||||||
|
if (inputMatrix->isIdentityMatrix()) { // the inverse for I is I
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
upvertKernel<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||||
|
upvertKernelUp<T><<<1, n, 128, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||||
|
invertUpKernel<T><<<n, n, 256, *stream>>>(invertedMatrix->specialBuffer(), invertedMatrix->specialShapeInfo(), inputMatrix->specialBuffer(), inputMatrix->specialShapeInfo(), n);
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void _invertUpperMatrix, (NDArray* inputMatrix, NDArray* invertedMatrix);, FLOAT_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void invertUpperMatrix_, (NDArray* inputMatrix, NDArray* invertedMatrix);, FLOAT_TYPES);
|
||||||
|
|
||||||
void invertUpperMatrix(NDArray* inputMatrix, NDArray* invertedMatrix) {
|
void invertUpperMatrix(NDArray* inputMatrix, NDArray* invertedMatrix) {
|
||||||
BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), _invertUpperMatrix, (inputMatrix, invertedMatrix), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(inputMatrix->dataType(), invertUpperMatrix_, (inputMatrix, invertedMatrix), FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@ -91,8 +243,8 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
|
|
||||||
if( pivotValue != T(0.0) ) {
|
if( pivotValue != T(0.0) ) {
|
||||||
_swapRows<T>(compound, compoundShape, pivot, i, rowNum);
|
swapRows_<T>(compound, compoundShape, pivot, i, rowNum);
|
||||||
_swapRows<T>(permutation, permutationShape, pivot, i, rowNum);
|
swapRows_<T>(permutation, permutationShape, pivot, i, rowNum);
|
||||||
if (pivot != i)
|
if (pivot != i)
|
||||||
swapCount++;
|
swapCount++;
|
||||||
|
|
||||||
|
@ -115,124 +267,582 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
template <typename T>
|
|
||||||
static __global__ void determinantKernel(T* compound, Nd4jLong* shape, T* result) {
|
|
||||||
__shared__ Nd4jLong len;
|
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
template <typename T, typename F>
|
||||||
len = shape::length(shape);
|
static __global__ void determinantKernel(T* compound, T* result, Nd4jLong len) {
|
||||||
|
__shared__ F tempRes;
|
||||||
|
if (blockIdx.x == 0) {
|
||||||
|
tempRes = (F)result[0];
|
||||||
}
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
auto start = blockIdx.x * blockDim.x + threadIdx.x;
|
auto start = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
auto step = blockDim.x * gridDim.x;
|
auto step = blockDim.x * gridDim.x;
|
||||||
for (auto i = start; i < len; i += step) {
|
for (auto i = start; i < len; i += step) {
|
||||||
Nd4jLong di[] = {i, i};
|
auto pos = i * len + i; //shape::getOffset(0, shape::shapeOf(shape), shape::stride(shape), di, 2);
|
||||||
auto pos = shape::getOffset(0, shape::shapeOf(shape), shape::stride(shape), di, 2);
|
math::atomics::nd4j_atomicMul<F>(&tempRes, (F)compound[pos]);
|
||||||
math::atomics::nd4j_atomicMul(result, compound[pos]);
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
if (blockIdx.x == 0) {
|
||||||
|
result[0] = (T)tempRes;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
template <typename T>
|
|
||||||
static __global__ void determinantFullKernel(T* input, Nd4jLong* inputShape, T* output, Nd4jLong* outputShape, Nd4jLong* tadShape, Nd4jLong* tadOffsets) {
|
|
||||||
|
|
||||||
|
template <typename T, typename F>
|
||||||
|
static __global__ void determinantLogKernel(T* compound, T* result, Nd4jLong len) {
|
||||||
|
__shared__ F tempRes;
|
||||||
|
if (blockIdx.x == 0) {
|
||||||
|
tempRes = (F)result[0];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
auto start = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
auto step = blockDim.x * gridDim.x;
|
||||||
|
for (auto i = start; i < len; i += step) {
|
||||||
|
auto pos = i * len + i; //shape::getOffset(0, shape::shapeOf(shape), shape::stride(shape), di, 2);
|
||||||
|
math::atomics::nd4j_atomicMul<F>(&tempRes, (F)compound[pos]);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
if (blockIdx.x == 0) {
|
||||||
|
result[0] = (T)math::nd4j_log<F,F>(math::nd4j_abs(tempRes));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename F>
|
||||||
|
static __global__ void fillMatrix(void* output, Nd4jLong* outShape, void* input, Nd4jLong* inputShape, Nd4jLong pos, Nd4jLong rowLen) {
|
||||||
|
__shared__ F* matrix;
|
||||||
|
__shared__ T* inputBuf;
|
||||||
|
__shared__ Nd4jLong inputLen;
|
||||||
|
__shared__ Nd4jLong n2;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
matrix = reinterpret_cast<F*>(output);
|
||||||
|
inputBuf = reinterpret_cast<T*>(input);
|
||||||
|
inputLen = shape::length(inputShape);
|
||||||
|
n2 = rowLen * rowLen;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
auto start = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
auto step = blockDim.x * gridDim.x;
|
||||||
|
|
||||||
|
for (int k = pos + start, j = start; j < n2; k += step, j += step) {
|
||||||
|
auto xIndex = shape::getIndexOffset(k, inputShape, inputLen);
|
||||||
|
matrix[j] = (F)inputBuf[xIndex];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
template <typename F>
|
||||||
|
static __global__ void fillUpPermutation(void* output, Nd4jLong* shape, int* source, int rowNum) {
|
||||||
|
__shared__ F* permutation;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
permutation = reinterpret_cast<F*>(output);
|
||||||
|
}
|
||||||
|
auto start = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
auto step = blockDim.x * gridDim.x;
|
||||||
|
for (auto i = start; i < rowNum; i += step) {
|
||||||
|
int val = source[i] - 1;
|
||||||
|
Nd4jLong posF[] = {i, val};
|
||||||
|
auto pos = shape::getOffset(0, shape::shapeOf(shape), shape::stride(shape), posF, 2);
|
||||||
|
permutation[pos] = F(1.f);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static NDArray _lup(LaunchContext* context, NDArray* input, NDArray* compound, NDArray* permutation) {
|
static void lup_(LaunchContext* context, NDArray* input, NDArray* compound, NDArray* permutation) {
|
||||||
NDArray determinant = NDArrayFactory::create<T>(1.f);
|
|
||||||
auto rowNum = input->rows();
|
|
||||||
auto columnNum = input->columns();
|
|
||||||
|
|
||||||
NDArray compoundMatrix = *input; // copy
|
|
||||||
NDArray permutationMatrix(input, false, input->getContext()); // has same shape as input and contiguous strides
|
|
||||||
permutationMatrix.setIdentity();
|
|
||||||
|
|
||||||
T pivotValue; // = T(0.0);
|
|
||||||
int pivot; // = -1;
|
|
||||||
int swapCount = 0;
|
|
||||||
T* compoundBuf = reinterpret_cast<T*>(compoundMatrix.specialBuffer());
|
|
||||||
T* permutationBuf = reinterpret_cast<T*>(permutationMatrix.specialBuffer());
|
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
lupKernel<T><<<256, 256, 1024, *stream>>>(compoundBuf, compoundMatrix.specialShapeInfo(), permutationBuf, permutationMatrix.specialShapeInfo(), rowNum);
|
auto n = input->rows();
|
||||||
determinantKernel<T><<<256, 256, 1024, *stream>>>(compoundBuf, compoundMatrix.specialShapeInfo(), reinterpret_cast<T*>(determinant.specialBuffer()));
|
cusolverDnHandle_t cusolverH = nullptr;
|
||||||
// for (int e = 0; e < rowNum; e++) {
|
cusolverStatus_t status = cusolverDnCreate(&cusolverH);
|
||||||
// // nd4j_printf("Compound matrix diag %i %f.\n", e, (*compoundMatrix)(e, e));
|
if (CUSOLVER_STATUS_SUCCESS != status) {
|
||||||
// determinant *= compoundMatrix.e<T>(e, e);
|
throw cuda_exception::build("Cannot create cuSolver handle", status);
|
||||||
// }
|
}
|
||||||
if (swapCount % 2) determinant = -determinant;
|
status = cusolverDnSetStream(cusolverH, *stream);
|
||||||
if (compound != nullptr)
|
if (CUSOLVER_STATUS_SUCCESS != status) {
|
||||||
compound->assign(compoundMatrix);
|
throw cuda_exception::build("Cannot set up stream for cuda solver", status);
|
||||||
if (permutation != nullptr)
|
}
|
||||||
permutation->assign(permutationMatrix);
|
int lwork = 0;
|
||||||
return determinant;
|
int *d_info = nullptr;
|
||||||
|
|
||||||
|
auto err = cudaMalloc((void **) &d_info, sizeof(int));
|
||||||
|
if (err) {
|
||||||
|
throw cuda_exception::build("helpers::lup_: Cannot allocate memory for solver info buffer", err);
|
||||||
|
}
|
||||||
|
|
||||||
|
DataType dtype = input->dataType();
|
||||||
|
switch(dtype) {
|
||||||
|
|
||||||
|
case DataType::DOUBLE: {
|
||||||
|
double *d_work = nullptr;
|
||||||
|
err = cudaMalloc((void **) &d_work, sizeof(float) * lwork);
|
||||||
|
if (err) {
|
||||||
|
throw cuda_exception::build("helpers::lup_: Cannot allocate memory for solver data buffer", err);
|
||||||
|
}
|
||||||
|
double *matrix = reinterpret_cast<double*>(input->specialBuffer());
|
||||||
|
status = cusolverDnDgetrf_bufferSize(
|
||||||
|
cusolverH,
|
||||||
|
n,
|
||||||
|
n,
|
||||||
|
matrix,
|
||||||
|
n,
|
||||||
|
&lwork);
|
||||||
|
if (CUSOLVER_STATUS_SUCCESS != status) {
|
||||||
|
throw cuda_exception::build("helpers::lup_: Cannot create cuSolver handle", status);
|
||||||
|
}
|
||||||
|
if (permutation == nullptr)
|
||||||
|
status = cusolverDnDgetrf(
|
||||||
|
cusolverH,
|
||||||
|
n,
|
||||||
|
n,
|
||||||
|
matrix,
|
||||||
|
n,
|
||||||
|
d_work,
|
||||||
|
nullptr,
|
||||||
|
d_info);
|
||||||
|
else {
|
||||||
|
NDArray permutVector('c', {n}, nd4j::DataType::INT32, context);
|
||||||
|
int *permutationBuf = reinterpret_cast<int *>(permutVector.specialBuffer());
|
||||||
|
status = cusolverDnDgetrf(
|
||||||
|
cusolverH,
|
||||||
|
n,
|
||||||
|
n,
|
||||||
|
matrix,
|
||||||
|
n,
|
||||||
|
d_work,
|
||||||
|
permutationBuf,
|
||||||
|
d_info);
|
||||||
|
fillUpPermutation<double><<<n, n, 128, *stream>>>(permutation->specialBuffer(), permutation->specialShapeInfo(), permutationBuf, n);
|
||||||
|
permutation->tickWriteDevice();
|
||||||
|
}
|
||||||
|
err = cudaFree(d_work);
|
||||||
|
if (err) {
|
||||||
|
throw cuda_exception::build("helpers::lup_: Cannot deallocate memory for solver data buffer", err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case DataType::FLOAT32: {
|
||||||
|
float *matrix = reinterpret_cast<float*>(input->specialBuffer());
|
||||||
|
float *d_work = nullptr;
|
||||||
|
err = cudaMalloc((void **) &d_work, sizeof(float) * lwork);
|
||||||
|
if (err) {
|
||||||
|
throw cuda_exception::build("helpers::lup_: Cannot allocate memory for solver data buffer", err);
|
||||||
|
}
|
||||||
|
|
||||||
|
status = cusolverDnSgetrf_bufferSize(
|
||||||
|
cusolverH,
|
||||||
|
n,
|
||||||
|
n,
|
||||||
|
matrix,
|
||||||
|
n,
|
||||||
|
&lwork);
|
||||||
|
if (CUSOLVER_STATUS_SUCCESS != status) {
|
||||||
|
throw cuda_exception::build("helpers::lup_: Cannot create cuSolver handle", status);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (permutation == nullptr)
|
||||||
|
status = cusolverDnSgetrf(
|
||||||
|
cusolverH,
|
||||||
|
n,
|
||||||
|
n,
|
||||||
|
matrix,
|
||||||
|
n,
|
||||||
|
d_work,
|
||||||
|
nullptr,
|
||||||
|
d_info);
|
||||||
|
else {
|
||||||
|
NDArray permutVector('c', {n}, nd4j::DataType::INT32, context);
|
||||||
|
int *permutationBuf = reinterpret_cast<int *>(permutVector.specialBuffer());
|
||||||
|
status = cusolverDnSgetrf(
|
||||||
|
cusolverH,
|
||||||
|
n,
|
||||||
|
n,
|
||||||
|
matrix,
|
||||||
|
n,
|
||||||
|
d_work,
|
||||||
|
permutationBuf,
|
||||||
|
d_info);
|
||||||
|
fillUpPermutation<float><<<n, n, 128, *stream>>>(permutation->specialBuffer(), permutation->specialShapeInfo(), permutationBuf, n);
|
||||||
|
permutation->tickWriteDevice();
|
||||||
|
}
|
||||||
|
err = cudaFree(d_work);
|
||||||
|
if (err) {
|
||||||
|
throw cuda_exception::build("helpers::lup_: Cannot deallocate memory for solver data buffer", err);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (CUSOLVER_STATUS_SUCCESS != status) {
|
||||||
|
throw cuda_exception::build("helpers::lup_: Cannot make LU decomposition", status);
|
||||||
|
}
|
||||||
|
err = cudaFree(d_info);
|
||||||
|
if (err) {
|
||||||
|
throw cuda_exception::build("helpers::lup_: Cannot deallocate memory for solver info buffer", err);
|
||||||
|
}
|
||||||
|
cusolverDnDestroy(cusolverH);
|
||||||
|
// NDArray::registerSpecialUse({input}, {input});
|
||||||
|
input->tickWriteDevice();
|
||||||
}
|
}
|
||||||
BUILD_SINGLE_TEMPLATE(template NDArray _lup, (LaunchContext* context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void lup_, (LaunchContext* context, NDArray* input, NDArray* output, NDArray* permutation), FLOAT_TYPES);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static int _determinant(nd4j::LaunchContext* context, NDArray* input, NDArray* output) {
|
static int determinant_(nd4j::LaunchContext* context, NDArray* input, NDArray* output) {
|
||||||
Nd4jLong n = input->sizeAt(-1);
|
Nd4jLong n = input->sizeAt(-1);
|
||||||
Nd4jLong n2 = n * n;
|
Nd4jLong n2 = n * n;
|
||||||
std::vector<int> dims();
|
std::vector<int> dims();
|
||||||
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {input->rankOf() - 2, input->rankOf() - 1});
|
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {input->rankOf() - 2, input->rankOf() - 1});
|
||||||
//auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1});
|
//auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1});
|
||||||
|
DataType dtype = input->dataType();
|
||||||
|
if (dtype != DataType::DOUBLE)
|
||||||
|
dtype = DataType::FLOAT32;
|
||||||
|
|
||||||
//auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, input->dataType(), input->getContext()); //, block.getWorkspace());
|
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype, input->getContext()); //, block.getWorkspace());
|
||||||
|
auto det = NDArrayFactory::create<T>(1);
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
auto inputBuf = reinterpret_cast<T*>(input->specialBuffer());
|
NDArray::prepareSpecialUse({output}, {input});
|
||||||
auto outputBuf = reinterpret_cast<T*>(output->specialBuffer());
|
|
||||||
dim3 launchDims(256, 256, 1024);
|
dim3 launchDims(256, 256, 1024);
|
||||||
determinantFullKernel<T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(inputBuf, input->specialShapeInfo(), outputBuf, output->specialShapeInfo(), packX.specialShapeInfo(), packX.specialOffsets());
|
output->assign(1.f);
|
||||||
// for (int e = 0; e < output->lengthOf(); e++) {
|
for (int e = 0; e < output->lengthOf(); e++) {
|
||||||
// for (int k = e * n2, row = 0; k < (e + 1) * n2; ++k, ++row)
|
Nd4jLong pos = e * n2;
|
||||||
// matrix.p(row, input->e<T>(k));
|
if (matrix.dataType() == input->dataType())
|
||||||
//// output->p(e, lup_<T>(&matrix, (NDArray*)nullptr, (NDArray*)nullptr));
|
fillMatrix<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
||||||
// }
|
else
|
||||||
|
fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
||||||
|
|
||||||
|
if (matrix.dataType() == input->dataType())
|
||||||
|
lup_<T>(context, &matrix, nullptr, nullptr);
|
||||||
|
else
|
||||||
|
lup_<float>(context, &matrix, nullptr, nullptr);
|
||||||
|
auto offset = shape::getIndexOffset(e, output->shapeInfo(), output->lengthOf());
|
||||||
|
auto inputBuf = reinterpret_cast<T*>(matrix.specialBuffer());
|
||||||
|
auto outputBuf = reinterpret_cast<T*>(output->specialBuffer()) + offset;
|
||||||
|
if (matrix.dataType() == input->dataType())
|
||||||
|
determinantKernel<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n);
|
||||||
|
else
|
||||||
|
determinantKernel<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n);
|
||||||
|
}
|
||||||
|
NDArray::registerSpecialUse({output}, {input});
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template int _determinant, (nd4j::LaunchContext* context, NDArray* input, NDArray* output), FLOAT_TYPES);
|
BUILD_SINGLE_TEMPLATE(template int determinant_, (nd4j::LaunchContext* context, NDArray* input, NDArray* output), FLOAT_TYPES);
|
||||||
|
|
||||||
int determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
int determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), return _determinant, (context, input, output), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(input->dataType(), return determinant_, (context, input, output), FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
int log_abs_determinant_(NDArray* input, NDArray* output) {
|
int logAbsDeterminant_(LaunchContext* context, NDArray* input, NDArray* output) {
|
||||||
|
|
||||||
|
Nd4jLong n = input->sizeAt(-1);
|
||||||
|
Nd4jLong n2 = n * n;
|
||||||
|
std::vector<int> dims();
|
||||||
|
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {input->rankOf() - 2, input->rankOf() - 1});
|
||||||
|
//auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {output->rankOf() - 1});
|
||||||
|
DataType dtype = input->dataType();
|
||||||
|
if (dtype != DataType::DOUBLE)
|
||||||
|
dtype = DataType::FLOAT32;
|
||||||
|
|
||||||
|
auto matrix = NDArrayFactory::create(input->ordering(), {n, n}, dtype, input->getContext()); //, block.getWorkspace());
|
||||||
|
auto det = NDArrayFactory::create<T>(1);
|
||||||
|
auto stream = context->getCudaStream();
|
||||||
|
NDArray::prepareSpecialUse({output}, {input});
|
||||||
|
dim3 launchDims(256, 256, 1024);
|
||||||
|
output->assign(1.f);
|
||||||
|
for (int e = 0; e < output->lengthOf(); e++) {
|
||||||
|
Nd4jLong pos = e * n2;
|
||||||
|
if (matrix.dataType() == input->dataType())
|
||||||
|
fillMatrix<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
||||||
|
else
|
||||||
|
fillMatrix<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), pos, n);
|
||||||
|
|
||||||
|
if (matrix.dataType() == input->dataType())
|
||||||
|
lup_<T>(context, &matrix, nullptr, nullptr);
|
||||||
|
else
|
||||||
|
lup_<float>(context, &matrix, nullptr, nullptr);
|
||||||
|
auto offset = shape::getIndexOffset(e, output->shapeInfo(), output->lengthOf());
|
||||||
|
auto inputBuf = reinterpret_cast<T*>(matrix.specialBuffer());
|
||||||
|
auto outputBuf = reinterpret_cast<T*>(output->specialBuffer()) + offset;
|
||||||
|
if (matrix.dataType() == input->dataType())
|
||||||
|
determinantLogKernel<T, T><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n);
|
||||||
|
else
|
||||||
|
determinantLogKernel<T, float><<<launchDims.x, launchDims.y, launchDims.z, *stream >>> (inputBuf, outputBuf, n);
|
||||||
|
}
|
||||||
|
NDArray::registerSpecialUse({output}, {input});
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
|
||||||
return ND4J_STATUS_OK;
|
return ND4J_STATUS_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template int log_abs_determinant_, (NDArray* input, NDArray* output), FLOAT_TYPES);
|
BUILD_SINGLE_TEMPLATE(template int logAbsDeterminant_, (LaunchContext* context, NDArray* input, NDArray* output), FLOAT_TYPES);
|
||||||
|
|
||||||
int log_abs_determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
int logAbsDeterminant(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), return log_abs_determinant_, (input, output), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(input->dataType(), return logAbsDeterminant_, (context, input, output), FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static int _inverse(NDArray* input, NDArray* output) {
|
static __global__ void fillLowerUpperKernel(void* lowerBuf, Nd4jLong* lowerShape, void* upperBuf, Nd4jLong* upperShape, void* matrixBuf, Nd4jLong* matrixShape, Nd4jLong n) {
|
||||||
|
|
||||||
|
__shared__ Nd4jLong* xShapeOf;
|
||||||
|
__shared__ Nd4jLong* yShapeOf;
|
||||||
|
__shared__ Nd4jLong* zShapeOf;
|
||||||
|
__shared__ Nd4jLong* xStrideOf;
|
||||||
|
__shared__ Nd4jLong* yStrideOf;
|
||||||
|
__shared__ Nd4jLong* zStrideOf;
|
||||||
|
__shared__ T* lowerMatrix;
|
||||||
|
__shared__ T* upperMatrix;
|
||||||
|
__shared__ T* matrix;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
xShapeOf = shape::shapeOf(lowerShape);
|
||||||
|
yShapeOf = shape::shapeOf(upperShape);
|
||||||
|
zShapeOf = shape::shapeOf(matrixShape);
|
||||||
|
xStrideOf = shape::stride(lowerShape);
|
||||||
|
yStrideOf = shape::stride(upperShape);
|
||||||
|
zStrideOf = shape::stride(matrixShape);
|
||||||
|
lowerMatrix = reinterpret_cast<T*>(lowerBuf);
|
||||||
|
upperMatrix = reinterpret_cast<T*>(upperBuf);
|
||||||
|
matrix = reinterpret_cast<T*>(matrixBuf);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int k = blockIdx.x; k < n; k += gridDim.x) { // and then put all values under main diagonal on to it
|
||||||
|
for (int j = threadIdx.x; j < n; j += blockDim.x) {
|
||||||
|
Nd4jLong posX[] = {j, k};
|
||||||
|
|
||||||
|
auto xPos = shape::getOffset(0, xShapeOf, xStrideOf, posX, 2);
|
||||||
|
auto yPos = shape::getOffset(0, yShapeOf, yStrideOf, posX, 2);
|
||||||
|
auto pos = shape::getOffset(0, zShapeOf, zStrideOf, posX, 2);
|
||||||
|
if (k <= j)
|
||||||
|
lowerMatrix[xPos] = matrix[pos];//(k, j);
|
||||||
|
else
|
||||||
|
upperMatrix[yPos] = matrix[pos]; //k, j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static int inverse_(nd4j::LaunchContext* context, NDArray* input, NDArray* output) {
|
||||||
|
auto n = input->sizeAt(-1);
|
||||||
|
auto n2 = n * n;
|
||||||
|
auto dtype = input->dataType();
|
||||||
|
if (dtype != DataType::DOUBLE)
|
||||||
|
dtype = DataType::FLOAT32;
|
||||||
|
NDArray matrix = NDArrayFactory::create('c', {n, n}, dtype, input->getContext());
|
||||||
|
NDArray upper = NDArrayFactory::create('c', {n, n}, dtype, input->getContext());
|
||||||
|
NDArray lower = NDArrayFactory::create('c', {n, n}, dtype, input->getContext());
|
||||||
|
NDArray compound = NDArrayFactory::create('c', {n, n}, dtype, input->getContext());
|
||||||
|
NDArray permutation = NDArrayFactory::create('c', {n, n}, dtype, input->getContext());
|
||||||
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {input->rankOf() - 2, input->rankOf() - 1});
|
||||||
|
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {output->rankOf() - 2, output->rankOf() - 1});
|
||||||
|
auto stream = context->getCudaStream();
|
||||||
|
|
||||||
|
for (auto i = 0LL; i < packX.numberOfTads(); i++) {
|
||||||
|
fillMatrix<T, float><<<1, n2, 128, *stream>>>(matrix.specialBuffer(), matrix.specialShapeInfo(), input->specialBuffer(), input->specialShapeInfo(), i * n2, n);
|
||||||
|
permutation.assign(0.f);
|
||||||
|
lup_<float>(context, &matrix, &compound, &permutation);
|
||||||
|
matrix.tickWriteDevice();
|
||||||
|
permutation.tickWriteDevice();
|
||||||
|
permutation.printIndexedBuffer("PERMUTE");
|
||||||
|
lower.setIdentity(); // set up U to identity matrix
|
||||||
|
upper.setIdentity();
|
||||||
|
fillLowerUpperKernel<float><<<1, n2, 128>>>(lower.specialBuffer(), lower.specialShapeInfo(), upper.specialBuffer(), upper.specialShapeInfo(), matrix.specialBuffer(), matrix.specialShapeInfo(), n);
|
||||||
|
lower.tickWriteDevice();
|
||||||
|
upper.tickWriteDevice();
|
||||||
|
invertUpperMatrix(&upper, &matrix);
|
||||||
|
invertLowerMatrix(&lower, &upper);
|
||||||
|
lower.tickWriteDevice();
|
||||||
|
upper.tickWriteDevice();
|
||||||
|
lower.printIndexedBuffer("LOWER");
|
||||||
|
upper.printIndexedBuffer("UPPER");
|
||||||
|
|
||||||
|
nd4j::MmulHelper::mmul(&matrix, &upper, &compound, 1.0, 0.0);
|
||||||
|
nd4j::MmulHelper::mmul(&compound, &permutation, &matrix, 1.0, 0.0);
|
||||||
|
// for (int k = e * n2, row = 0; k < (e + 1) * n2; k++) {
|
||||||
|
// output->t<T>(k) = matrix.template t<T>(row++);
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
int inverse(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
int inverse(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), return _inverse, (input, output), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(input->dataType(), return inverse_, (context, input, output), FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool checkCholeskyInput(nd4j::LaunchContext * context, NDArray const* input) {
|
bool checkCholeskyInput(nd4j::LaunchContext * context, NDArray const* input) {
|
||||||
return false;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename F>
|
||||||
int cholesky_(NDArray* input, NDArray* output, bool inplace) {
|
__global__ void fillBatchKernel(F** dArrayBatch, F* buf, Nd4jLong* offsets, Nd4jLong batchSize) {
|
||||||
|
auto start = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
auto step = blockDim.x * gridDim.x;
|
||||||
|
|
||||||
|
for (auto i = start; i < batchSize; i += step) {
|
||||||
|
dArrayBatch[i] = buf + offsets[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename F>
|
||||||
|
__global__ void adjustResultsKernel(F* dArray, Nd4jLong* shape, Nd4jLong* offsets, Nd4jLong batchSize, Nd4jLong n) {
|
||||||
|
//auto i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
__shared__ Nd4jLong* shapeOf;
|
||||||
|
__shared__ Nd4jLong* strideOf;
|
||||||
|
if (blockIdx.x == 0 && threadIdx.x == 0) {
|
||||||
|
shapeOf = shape::shapeOf(shape);
|
||||||
|
strideOf = shape::stride(shape);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (auto i = blockIdx.x; i < batchSize; i+= gridDim.x) {
|
||||||
|
auto current = dArray + offsets[i];
|
||||||
|
for (auto r = threadIdx.x; r < n; r += blockDim.x) {
|
||||||
|
for (auto c = r + 1; c < n; c++) {
|
||||||
|
Nd4jLong posRC[] = {r, c};
|
||||||
|
auto pos = r * n + c; //shape::getOffset(0, shapeOf, strideOf, posRC, 2);
|
||||||
|
current[pos] = 0.;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename F>
|
||||||
|
int cholesky__(LaunchContext* context, NDArray* input, NDArray* output, bool inplace) {
|
||||||
|
if (!inplace)
|
||||||
|
output->assign(input);
|
||||||
|
std::unique_ptr<NDArray> tempOutput(output->dup());
|
||||||
|
cusolverDnHandle_t handle = nullptr;
|
||||||
|
auto n = input->sizeAt(-1);
|
||||||
|
auto n2 = n * n;
|
||||||
|
NDArray::prepareSpecialUse({output}, {input});
|
||||||
|
auto status = cusolverDnCreate(&handle);
|
||||||
|
if (CUSOLVER_STATUS_SUCCESS != status) {
|
||||||
|
throw cuda_exception::build("helpers::cholesky_: Cannot create solver handle", status);
|
||||||
|
}
|
||||||
|
F** dArrayBatch = nullptr;
|
||||||
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempOutput->getShapeInfo(), {tempOutput->rankOf() - 2, tempOutput->rankOf() - 1});
|
||||||
|
const Nd4jLong batchSize = packX.numberOfTads();
|
||||||
|
int* dInfoArray = nullptr;
|
||||||
|
auto err = cudaMalloc((void**)&dArrayBatch, sizeof(F*) * batchSize);
|
||||||
|
if (err) {
|
||||||
|
throw cuda_exception::build("helpers::cholesky_: Cannot allocate memory for solver batch data buffer", err);
|
||||||
|
}
|
||||||
|
err = cudaMalloc ((void**)&dInfoArray, sizeof(int) * batchSize);
|
||||||
|
if (err) {
|
||||||
|
throw cuda_exception::build("helpers::cholesky_: Cannot allocate memory for solver errors buffer", err);
|
||||||
|
}
|
||||||
|
auto stream = context->getCudaStream();
|
||||||
|
fillBatchKernel<F><<<1, batchSize, 128, *stream>>>(dArrayBatch, reinterpret_cast<F*>(tempOutput->specialBuffer()), packX.specialOffsets(), batchSize);
|
||||||
|
|
||||||
|
status = cusolverDnSetStream(handle, *stream);
|
||||||
|
if (CUSOLVER_STATUS_SUCCESS != status) {
|
||||||
|
throw cuda_exception::build("helpers::cholesky_: Cannot set stream to solver handle", status);
|
||||||
|
}
|
||||||
|
const cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER;
|
||||||
|
if (input->dataType() == DataType::DOUBLE)
|
||||||
|
status = cusolverDnDpotrfBatched(
|
||||||
|
handle,
|
||||||
|
uplo,
|
||||||
|
n,
|
||||||
|
(double**)dArrayBatch,
|
||||||
|
n,
|
||||||
|
dInfoArray,
|
||||||
|
batchSize);
|
||||||
|
else
|
||||||
|
status = cusolverDnSpotrfBatched(
|
||||||
|
handle,
|
||||||
|
uplo,
|
||||||
|
n,
|
||||||
|
(float**)dArrayBatch,
|
||||||
|
n,
|
||||||
|
dInfoArray,
|
||||||
|
batchSize);
|
||||||
|
|
||||||
|
if (CUSOLVER_STATUS_SUCCESS != status) {
|
||||||
|
throw cuda_exception::build("helpers::cholesky_: Cholesky factorization failed for batch", status);
|
||||||
|
}
|
||||||
|
adjustResultsKernel<F><<<batchSize, n2, 128, *stream>>>(reinterpret_cast<F*>(tempOutput->specialBuffer()), packX.specialShapeInfo(), packX.specialOffsets(), batchSize, n);
|
||||||
|
|
||||||
|
err = cudaFree(dArrayBatch);
|
||||||
|
if (err) {
|
||||||
|
throw cuda_exception::build("helpers::cholesky_: Cannot deallocate memory for solver batch data buffer", err);
|
||||||
|
}
|
||||||
|
err = cudaFree(dInfoArray);
|
||||||
|
if (err) {
|
||||||
|
throw cuda_exception::build("helpers::cholesky_: Cannot allocate memory for solver errors buffer", err);
|
||||||
|
}
|
||||||
|
|
||||||
|
if(!inplace)
|
||||||
|
output->assign(tempOutput.get());
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({output}, {input});
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
int cholesky(nd4j::LaunchContext * context, NDArray* input, NDArray* output, bool inplace) {
|
// template <typename T>
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (input, output, inplace), FLOAT_TYPES);
|
int cholesky_(LaunchContext* context, NDArray* input, NDArray* output, bool inplace) {
|
||||||
}
|
if (input->dataType() == DataType::DOUBLE)
|
||||||
BUILD_SINGLE_TEMPLATE(template int cholesky_, (NDArray* input, NDArray* output, bool inplace), FLOAT_TYPES);
|
cholesky__<double>(context, input, output, inplace);
|
||||||
BUILD_SINGLE_TEMPLATE(template int _inverse, (NDArray* input, NDArray* output), FLOAT_TYPES);
|
else if (input->dataType() == DataType::FLOAT32)
|
||||||
|
cholesky__<float>(context, input, output, inplace);
|
||||||
|
else {
|
||||||
|
std::unique_ptr<NDArray> tempOutput(NDArrayFactory::create_('c', input->getShapeAsVector(), DataType::FLOAT32, input->getContext()));
|
||||||
|
tempOutput->assign(input);
|
||||||
|
cholesky__<float>(context, tempOutput.get(), tempOutput.get(), true);
|
||||||
|
output->assign(tempOutput.get());
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
int cholesky(nd4j::LaunchContext* context, NDArray* input, NDArray* output, bool inplace) {
|
||||||
|
// BUILD_SINGLE_SELECTOR(input->dataType(), return cholesky_, (context, input, output, inplace), FLOAT_TYPES);
|
||||||
|
return cholesky_(context, input, output, inplace);
|
||||||
|
}
|
||||||
|
// BUILD_SINGLE_TEMPLATE(template int cholesky_, (LaunchContext* context, NDArray* input, NDArray* output, bool inplace), FLOAT_TYPES);
|
||||||
|
BUILD_SINGLE_TEMPLATE(template int inverse_, (nd4j::LaunchContext* context, NDArray* input, NDArray* output), FLOAT_TYPES);
|
||||||
|
|
||||||
int logdetFunctor(nd4j::LaunchContext * context, NDArray* input, NDArray* output) {
|
__global__ void logDetKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong batchNum, Nd4jLong* tadShape, Nd4jLong* tadOffsets, void* outputBuf, Nd4jLong* outputShape) {
|
||||||
return 119;
|
__shared__ double* output;
|
||||||
|
__shared__ double* input;
|
||||||
|
__shared__ int n2;
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
output = reinterpret_cast<double*>(outputBuf);
|
||||||
|
input = reinterpret_cast<double*>(inputBuf);
|
||||||
|
n2 = shape::sizeAt(inputShape, -1) * shape::sizeAt(inputShape, -1);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (Nd4jLong i = blockIdx.x; i < batchNum; i += gridDim.x) {
|
||||||
|
double* current = input + tadOffsets[i];
|
||||||
|
Nd4jLong* shapeOf = shape::shapeOf(tadShape);
|
||||||
|
Nd4jLong* strideOf = shape::stride(tadShape);
|
||||||
|
auto zIndex = shape::getIndexOffset(i, outputShape, batchNum);
|
||||||
|
for (Nd4jLong e = threadIdx.x; e < n2; e += blockDim.x) {
|
||||||
|
Nd4jLong diag[] = {e, e};
|
||||||
|
auto xIndex = shape::getOffset(0, shapeOf, strideOf, diag, 2);
|
||||||
|
math::atomics::nd4j_atomicAdd(&output[zIndex], math::nd4j_log<double,double>(current[xIndex] * current[xIndex]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int logdetFunctor(nd4j::LaunchContext* context, NDArray* input, NDArray* output) {
|
||||||
|
NDArray::prepareSpecialUse({output}, {input});
|
||||||
|
auto tempOutput = input->dup('c');
|
||||||
|
auto n2 = input->sizeAt(-1) * input->sizeAt(-2);
|
||||||
|
auto stream = context->getCudaStream();
|
||||||
|
cholesky(context, tempOutput, tempOutput, true);
|
||||||
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempOutput->getShapeInfo(), {tempOutput->rankOf() - 2, tempOutput->rankOf() - 1});
|
||||||
|
//for (Nd4jLong e = 0; e < output->lengthOf(); e++) {
|
||||||
|
auto outputBuf = reinterpret_cast<double*>(output->specialBuffer()); // + e * n2;
|
||||||
|
logDetKernel<<<packX.numberOfTads(), n2, 128, *stream>>>(tempOutput->specialBuffer(), tempOutput->specialShapeInfo(), packX.numberOfTads(), packX.specialShapeInfo(), packX.specialOffsets(), outputBuf, output->specialShapeInfo());
|
||||||
|
//}
|
||||||
|
NDArray::registerSpecialUse({output}, {input});
|
||||||
|
delete tempOutput;
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -732,10 +732,81 @@ void scatterND(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& i
|
||||||
manager.synchronize();
|
manager.synchronize();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template<typename X, typename Z>
|
||||||
|
__global__ void scatterForLossCuda(const void *vx, const Nd4jLong *xShapeInfo,
|
||||||
|
void *vy, const Nd4jLong *yShapeInfo,
|
||||||
|
void *vz, const Nd4jLong *zShapeInfo) {
|
||||||
|
|
||||||
|
const auto x = reinterpret_cast<const X*>(vx);
|
||||||
|
auto y = reinterpret_cast<Z*>(vy);
|
||||||
|
auto z = reinterpret_cast<Z*>(vz);
|
||||||
|
|
||||||
void scatterForLoss(nd4j::LaunchContext *context, const NDArray& indices, const NDArray& updates, NDArray& output, const bool calcGrad) {
|
__shared__ Nd4jLong xLen, *sharedMem;
|
||||||
|
__shared__ int xRank; // xRank = zRank, yRank = xRank + 1
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
extern __shared__ unsigned char shmem[];
|
||||||
|
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
|
||||||
|
|
||||||
|
xLen = shape::length(xShapeInfo);
|
||||||
|
xRank = shape::rank(xShapeInfo);
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
const auto xInd = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
|
||||||
|
if(xInd >= xLen)
|
||||||
|
return;
|
||||||
|
|
||||||
|
auto coords = sharedMem + threadIdx.x * (xRank + 1);
|
||||||
|
|
||||||
|
shape::index2coords(xRank, xShapeInfo + 1, xInd, xLen, coords);
|
||||||
|
|
||||||
|
// y last coordinate
|
||||||
|
coords[xRank] = x[shape::getOffset(0, xShapeInfo + 1, xShapeInfo + xRank + 1, coords, xRank)];
|
||||||
|
|
||||||
|
const auto yOffset = shape::getOffset(0, yShapeInfo + 1, yShapeInfo + xRank + 2, coords, xRank + 1);
|
||||||
|
|
||||||
|
if(z == nullptr) { // gradient calculation
|
||||||
|
y[yOffset] -= 1.f;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
z[shape::getOffset(0, zShapeInfo + 1, zShapeInfo + xRank + 1, coords, xRank)] = y[yOffset];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template<typename X, typename Z>
|
||||||
|
static void scatterForLossCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vx, const Nd4jLong* xShapeInfo, void *vy, const Nd4jLong* yShapeInfo, void *vz, const Nd4jLong* zShapeInfo) {
|
||||||
|
|
||||||
|
scatterForLossCuda<X, Z><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo);
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
void scatterForLoss(nd4j::LaunchContext* context, const NDArray& indices, NDArray& updates, NDArray& output, const bool calcGrad) {
|
||||||
|
// shapes of indices and output must be the same
|
||||||
|
// shape of indices should be the same as updates shape with last dimension excluded, for example if updates is {a,b,c} then indices should be {a,b}
|
||||||
|
|
||||||
|
PointersManager manager(context, "scatterForLoss");
|
||||||
|
|
||||||
|
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||||
|
const int blocksPerGrid = (indices.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
const int sharedMem = updates.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128;
|
||||||
|
|
||||||
|
if(calcGrad) {
|
||||||
|
NDArray::prepareSpecialUse({&updates}, {&indices});
|
||||||
|
BUILD_DOUBLE_SELECTOR(indices.dataType(), updates.dataType(), scatterForLossCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), updates.specialBuffer(), updates.specialShapeInfo(), nullptr, nullptr), INTEGER_TYPES, FLOAT_TYPES);
|
||||||
|
NDArray::registerSpecialUse({&updates}, {&indices});
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
NDArray::prepareSpecialUse({&output}, {&indices, &updates});
|
||||||
|
BUILD_DOUBLE_SELECTOR(indices.dataType(), updates.dataType(), scatterForLossCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), updates.getSpecialBuffer(), updates.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo()), INTEGER_TYPES, FLOAT_TYPES);
|
||||||
|
NDArray::registerSpecialUse({&output}, {&indices, &updates});
|
||||||
|
}
|
||||||
|
|
||||||
|
manager.synchronize();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,8 @@
|
||||||
|
|
||||||
#include<ops/declarable/helpers/sru.h>
|
#include<ops/declarable/helpers/sru.h>
|
||||||
#include <NDArrayFactory.h>
|
#include <NDArrayFactory.h>
|
||||||
|
#include <PointersManager.h>
|
||||||
|
#include <MmulHelper.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -103,30 +105,432 @@ void sruTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray*
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static void sruBI_(NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* mask, NDArray* ht, NDArray* ct) {
|
__global__ static void sruBICuda(const void* vx, const Nd4jLong* xShapeInfo,
|
||||||
|
const void* vwi, const Nd4jLong* wiShapeInfo,
|
||||||
|
const void* vb, const Nd4jLong* bShapeInfo,
|
||||||
|
const void* vc0, const Nd4jLong* c0ShapeInfo,
|
||||||
|
const void* vmask, const Nd4jLong* maskShapeInfo,
|
||||||
|
void* vht, const Nd4jLong* htShapeInfo,
|
||||||
|
void* vct, const Nd4jLong* ctShapeInfo) {
|
||||||
|
// inputs:
|
||||||
|
// x [time, bS, 2*K]
|
||||||
|
// wi [time, bS, 6*K], wi = mmul(x, weights);
|
||||||
|
// b [4*K]
|
||||||
|
// c0 [bS, 2*K]
|
||||||
|
// mask [bS, 2*K], optional
|
||||||
|
|
||||||
|
// outputs
|
||||||
|
// ht [time, bS, 2*K]
|
||||||
|
// ct [time, bS, 2*K]
|
||||||
|
|
||||||
|
const auto x = reinterpret_cast<const T*>(vx);
|
||||||
|
const auto wi = reinterpret_cast<const T*>(vwi);
|
||||||
|
const auto b = reinterpret_cast<const T*>(vb);
|
||||||
|
const auto c0 = reinterpret_cast<const T*>(vc0);
|
||||||
|
const auto mask = reinterpret_cast<const T*>(vmask);
|
||||||
|
auto ht = reinterpret_cast<T*>(vht);
|
||||||
|
auto ct = reinterpret_cast<T*>(vct);
|
||||||
|
|
||||||
|
const int rank = 3;
|
||||||
|
|
||||||
|
__shared__ int time, K;
|
||||||
|
__shared__ Nd4jLong len, totalThreads, *sharedMem;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
|
||||||
|
extern __shared__ unsigned char shmem[];
|
||||||
|
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
|
||||||
|
|
||||||
|
time = xShapeInfo[1];
|
||||||
|
K = xShapeInfo[3] / 2;
|
||||||
|
len = xShapeInfo[2] * xShapeInfo[3]; // 2*K*bS
|
||||||
|
|
||||||
|
totalThreads = gridDim.x * blockDim.x;
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
__syncthreads();
|
||||||
template <typename T>
|
|
||||||
static void sruBIBP_(NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* ct, const NDArray* inGradC0, const NDArray* inGradHt, const NDArray* mask,
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
NDArray* gradI, NDArray* gradW, NDArray* gradB, NDArray* gradC0) {
|
Nd4jLong* coords = sharedMem + threadIdx.x * rank;
|
||||||
|
|
||||||
|
if(tid >= len)
|
||||||
|
return;
|
||||||
|
|
||||||
|
shape::index2coords(rank, xShapeInfo + 2, tid, len, coords + 1); // loop through last two dimensions of x : {bS, 2*K}
|
||||||
|
|
||||||
|
const auto maskOffst = mask ? shape::getOffset(0, maskShapeInfo + 1, maskShapeInfo + rank, coords + 1, rank - 1) : 0;
|
||||||
|
const auto c0Offset = shape::getOffset(0, c0ShapeInfo + 1, c0ShapeInfo + rank, coords + 1, rank - 1);
|
||||||
|
const auto bFOffset = shape::getOffset(0, bShapeInfo + 1, bShapeInfo + rank - 1, coords + 2, rank - 2);
|
||||||
|
const auto bROffset = bFOffset + 2 * K * bShapeInfo[2]; // 2*K*b_stride
|
||||||
|
|
||||||
|
const T maskVal = mask ? mask[maskOffst] : static_cast<T>(1);
|
||||||
|
const T bF = b[bFOffset];
|
||||||
|
const T bR = b[bROffset];
|
||||||
|
T c0Val = c0[c0Offset];
|
||||||
|
|
||||||
|
const bool flip = coords[2] >= K;
|
||||||
|
|
||||||
|
if(flip)
|
||||||
|
coords[0] = time - 1;
|
||||||
|
else
|
||||||
|
coords[0] = 0;
|
||||||
|
|
||||||
|
auto xOffset = shape::getOffset(0, xShapeInfo + 1, xShapeInfo + rank + 1, coords, rank);
|
||||||
|
auto htOffset = shape::getOffset(0, htShapeInfo + 1, htShapeInfo + rank + 1, coords, rank);
|
||||||
|
auto ctOffset = shape::getOffset(0, ctShapeInfo + 1, ctShapeInfo + rank + 1, coords, rank);
|
||||||
|
|
||||||
|
coords[2] *= 3;
|
||||||
|
auto wiOffset0 = shape::getOffset(0, wiShapeInfo + 1, wiShapeInfo + rank + 1, coords, rank);
|
||||||
|
auto wiOffset1 = wiOffset0 + wiShapeInfo[rank + 3]; // add last stride
|
||||||
|
auto wiOffset2 = wiOffset1 + wiShapeInfo[rank + 3]; // add last stride
|
||||||
|
|
||||||
|
// time loop
|
||||||
|
for (uint t = 0; t < time; ++t) {
|
||||||
|
|
||||||
|
// evaluate sigmoids
|
||||||
|
T ft = (1.f)/(1.f + nd4j::math::nd4j_exp<T, T>(-(wi[wiOffset1] + bF)));
|
||||||
|
T rt = (1.f)/(1.f + nd4j::math::nd4j_exp<T, T>(-(wi[wiOffset2] + bR)));
|
||||||
|
|
||||||
|
c0Val = (c0Val - wi[wiOffset0]) * ft + wi[wiOffset0];
|
||||||
|
ct[ctOffset] = c0Val;
|
||||||
|
T val = nd4j::math::nd4j_tanh<T, T>(c0Val);
|
||||||
|
T xVal = x[xOffset];
|
||||||
|
ht[htOffset] = (val * maskVal - xVal) * rt + xVal;
|
||||||
|
|
||||||
|
if(flip) {
|
||||||
|
xOffset -= xShapeInfo[rank + 1]; // first stride, corresponds to time step
|
||||||
|
htOffset -= htShapeInfo[rank + 1];
|
||||||
|
ctOffset -= htShapeInfo[rank + 1];
|
||||||
|
wiOffset0 -= wiShapeInfo[rank + 1];
|
||||||
|
wiOffset1 -= wiShapeInfo[rank + 1];
|
||||||
|
wiOffset2 -= wiShapeInfo[rank + 1];
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
xOffset += xShapeInfo[rank + 1]; // first stride, corresponds to time step
|
||||||
|
htOffset += htShapeInfo[rank + 1];
|
||||||
|
ctOffset += htShapeInfo[rank + 1];
|
||||||
|
wiOffset0 += wiShapeInfo[rank + 1];
|
||||||
|
wiOffset1 += wiShapeInfo[rank + 1];
|
||||||
|
wiOffset2 += wiShapeInfo[rank + 1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename T>
|
||||||
|
static void sruBICudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream,
|
||||||
|
const void* vx, const Nd4jLong* xShapeInfo,
|
||||||
|
const void* vwi, const Nd4jLong* wiShapeInfo,
|
||||||
|
const void* vb, const Nd4jLong* bShapeInfo,
|
||||||
|
const void* vc0, const Nd4jLong* c0ShapeInfo,
|
||||||
|
const void* vmask, const Nd4jLong* maskShapeInfo,
|
||||||
|
void* vht, const Nd4jLong* htShapeInfo,
|
||||||
|
void* vct, const Nd4jLong* ctShapeInfo) {
|
||||||
|
|
||||||
|
sruBICuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vx, xShapeInfo, vwi, wiShapeInfo, vb, bShapeInfo, vc0, c0ShapeInfo, vmask, maskShapeInfo, vht, htShapeInfo, vct, ctShapeInfo);
|
||||||
|
}
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void sruBICudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const void* vwi, const Nd4jLong* wiShapeInfo, const void* vb, const Nd4jLong* bShapeInfo, const void* vc0, const Nd4jLong* c0ShapeInfo, const void* vmask, const Nd4jLong* maskShapeInfo, void* vht, const Nd4jLong* htShapeInfo, void* vct, const Nd4jLong* ctShapeInfo), FLOAT_TYPES);
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
void sruBI(nd4j::LaunchContext * context, NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* mask, NDArray* ht, NDArray* ct) {
|
||||||
|
|
||||||
|
// x = x * mask
|
||||||
|
if(mask)
|
||||||
|
x->applyBroadcast(broadcast::Multiply, {1, 2}, mask, x, nullptr); // apply mask
|
||||||
|
|
||||||
|
// U = x * w
|
||||||
|
NDArray wi = mmul(*x, *w); // U [time x bS x 6*K]
|
||||||
|
|
||||||
|
PointersManager manager(context, "sru_bi");
|
||||||
|
|
||||||
|
const int threadsPerBlock = MAX_NUM_THREADS / 4;
|
||||||
|
const int blocksPerGrid = (x->sizeAt(1) * x->sizeAt(2) + threadsPerBlock - 1) / threadsPerBlock; // loop through last two dimensions of x array -> bS, 2*K
|
||||||
|
const int sharedMem = threadsPerBlock * sizeof(Nd4jLong) * x->rankOf() + 128;
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({ht, ct}, {x, &wi, b, c0, mask});
|
||||||
|
BUILD_SINGLE_SELECTOR(x->dataType(), sruBICudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), x->getSpecialBuffer(), x->getSpecialShapeInfo(), wi.getSpecialBuffer(), wi.getSpecialShapeInfo(), b->getSpecialBuffer(), b->getSpecialShapeInfo(), c0->getSpecialBuffer(), c0->getSpecialShapeInfo(), mask ? mask->getSpecialBuffer() : nullptr, mask ? mask->getSpecialShapeInfo() : nullptr, ht->specialBuffer(), ht->specialShapeInfo(), ct->specialBuffer(), ct->specialShapeInfo()), FLOAT_TYPES);
|
||||||
|
NDArray::registerSpecialUse({ht, ct}, {x, &wi, b, c0, mask});
|
||||||
|
|
||||||
|
manager.synchronize();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename T>
|
||||||
|
__global__ static void sruBIBPCuda(const void* vx, const Nd4jLong* xShapeInfo,
|
||||||
|
const void* vwi, const Nd4jLong* wiShapeInfo,
|
||||||
|
const void* vb, const Nd4jLong* bShapeInfo,
|
||||||
|
const void* vc0, const Nd4jLong* c0ShapeInfo,
|
||||||
|
const void* vmask, const Nd4jLong* maskShapeInfo,
|
||||||
|
const void* vct, const Nd4jLong* ctShapeInfo,
|
||||||
|
const void* vgradHt, const Nd4jLong* gradHtShapeInfo,
|
||||||
|
const void* vgradCt, const Nd4jLong* gradCtShapeInfo,
|
||||||
|
void* vgradI, const Nd4jLong* gradIShapeInfo,
|
||||||
|
void* vgradWi, const Nd4jLong* gradWiShapeInfo,
|
||||||
|
void* vgradB, const Nd4jLong* gradBShapeInfo,
|
||||||
|
void* vgradC0, const Nd4jLong* gradC0ShapeInfo) {
|
||||||
|
// inputs:
|
||||||
|
// x [time, bS, 2*K]
|
||||||
|
// wi [time, bS, 6*K], wi = mmul(x, weights);
|
||||||
|
// b [4*K]
|
||||||
|
// c0 [bS, 2*K]
|
||||||
|
// mask [bS, 2*K], optional
|
||||||
|
// ct [time, bS, 2*K]
|
||||||
|
// gradHt [time, bS, 2*K]
|
||||||
|
// gradCt [bS, 2*K]
|
||||||
|
|
||||||
|
// outputs
|
||||||
|
// gradI [time, bS, 2*K]
|
||||||
|
// gradWi [time, 2*K, 6*K]
|
||||||
|
// gradB [bS, 4*K]
|
||||||
|
// gradC0 [bS, 2*K]
|
||||||
|
|
||||||
|
const auto x = reinterpret_cast<const T*>(vx);
|
||||||
|
const auto wi = reinterpret_cast<const T*>(vwi);
|
||||||
|
const auto b = reinterpret_cast<const T*>(vb);
|
||||||
|
const auto c0 = reinterpret_cast<const T*>(vc0);
|
||||||
|
const auto mask = reinterpret_cast<const T*>(vmask);
|
||||||
|
const auto ct = reinterpret_cast<const T*>(vct);
|
||||||
|
const auto gradHt = reinterpret_cast<const T*>(vgradHt);
|
||||||
|
const auto gradCt = reinterpret_cast<const T*>(vgradCt);
|
||||||
|
|
||||||
|
auto gradI = reinterpret_cast<T*>(vgradI);
|
||||||
|
auto gradWi = reinterpret_cast<T*>(vgradWi);
|
||||||
|
auto gradB = reinterpret_cast<T*>(vgradB);
|
||||||
|
auto gradC0 = reinterpret_cast<T*>(vgradC0);
|
||||||
|
|
||||||
|
const int rank = 3;
|
||||||
|
|
||||||
|
__shared__ int time, K;
|
||||||
|
__shared__ Nd4jLong len, totalThreads, *sharedMem;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
|
||||||
|
extern __shared__ unsigned char shmem[];
|
||||||
|
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
|
||||||
|
|
||||||
|
time = xShapeInfo[1];
|
||||||
|
K = xShapeInfo[3] / 2;
|
||||||
|
len = xShapeInfo[2] * xShapeInfo[3]; // 2*K*bS
|
||||||
|
|
||||||
|
totalThreads = gridDim.x * blockDim.x;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
void sruBI(nd4j::LaunchContext * context, NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* mask, NDArray* ht, NDArray* ct) {
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
BUILD_SINGLE_SELECTOR(x->dataType(), sruBI_, (x, w, b, c0, mask, ht, ct), FLOAT_TYPES);
|
Nd4jLong* coords = sharedMem + threadIdx.x * rank;
|
||||||
|
|
||||||
|
if(tid >= len)
|
||||||
|
return;
|
||||||
|
|
||||||
|
shape::index2coords(rank, xShapeInfo + 2, tid, len, coords + 1); // loop through last two dimensions of x : {bS, 2*K}
|
||||||
|
|
||||||
|
const auto maskOffst = mask ? shape::getOffset(0, maskShapeInfo + 1, maskShapeInfo + rank, coords + 1, rank - 1) : 0;
|
||||||
|
const auto c0Offset = shape::getOffset(0, c0ShapeInfo + 1, c0ShapeInfo + rank, coords + 1, rank - 1);
|
||||||
|
const auto gradCtOffset = shape::getOffset(0, gradCtShapeInfo + 1, gradCtShapeInfo + rank, coords + 1, rank - 1);
|
||||||
|
const auto gradC0Offset = shape::getOffset(0, gradC0ShapeInfo + 1, gradC0ShapeInfo + rank, coords + 1, rank - 1);
|
||||||
|
const auto bFOffset = shape::getOffset(0, bShapeInfo + 1, bShapeInfo + rank - 1, coords + 2, rank - 2);
|
||||||
|
const auto bROffset = bFOffset + 2 * K * bShapeInfo[2]; // 2*K*b_stride
|
||||||
|
// const auto gradBFOffset = shape::getOffset(0, gradBShapeInfo + 1, gradBShapeInfo + rank, coords + 1, rank - 1);
|
||||||
|
const auto gradBFOffset = coords[1] * gradBShapeInfo[3] / 2 + coords[2] * gradBShapeInfo[4];
|
||||||
|
const auto gradBROffset = gradBFOffset + gradBShapeInfo[3];
|
||||||
|
|
||||||
|
const bool flip = coords[2] >= K;
|
||||||
|
|
||||||
|
if(flip)
|
||||||
|
coords[0] = 0;
|
||||||
|
else
|
||||||
|
coords[0] = time - 1;
|
||||||
|
|
||||||
|
auto xOffset = shape::getOffset(0, xShapeInfo + 1, xShapeInfo + rank + 1, coords, rank);
|
||||||
|
auto ctOffset = shape::getOffset(0, ctShapeInfo + 1, ctShapeInfo + rank + 1, coords, rank);
|
||||||
|
auto gradIOffset = shape::getOffset(0, gradIShapeInfo + 1, gradIShapeInfo + rank + 1, coords, rank);
|
||||||
|
auto gradHtOffset = shape::getOffset(0, gradHtShapeInfo + 1, gradHtShapeInfo + rank + 1, coords, rank);
|
||||||
|
|
||||||
|
coords[2] *= 3;
|
||||||
|
auto gradWiOffset0 = shape::getOffset(0, gradWiShapeInfo + 1, gradWiShapeInfo + rank + 1, coords, rank);
|
||||||
|
auto gradWiOffset1 = gradWiOffset0 + gradWiShapeInfo[rank + 3]; // add last stride
|
||||||
|
auto gradWiOffset2 = gradWiOffset1 + gradWiShapeInfo[rank + 3]; // add last stride
|
||||||
|
auto wiOffset0 = shape::getOffset(0, wiShapeInfo + 1, wiShapeInfo + rank + 1, coords, rank);
|
||||||
|
auto wiOffset1 = wiOffset0 + wiShapeInfo[rank + 3]; // add last stride
|
||||||
|
auto wiOffset2 = wiOffset1 + wiShapeInfo[rank + 3]; // add last stride
|
||||||
|
|
||||||
|
const T xVal = x[xOffset];
|
||||||
|
const T maskVal = mask ? mask[maskOffst] : static_cast<T>(1);
|
||||||
|
const T c0Val = c0[c0Offset];
|
||||||
|
const T bF = b[bFOffset];
|
||||||
|
const T bR = b[bROffset];
|
||||||
|
T gradCtVal = gradCt[gradCtOffset];
|
||||||
|
T gbF = 0.f;
|
||||||
|
T gbR = 0.f;
|
||||||
|
|
||||||
|
// time loop
|
||||||
|
for (uint t = 0; t < time; ++t) {
|
||||||
|
|
||||||
|
// evaluate sigmoids
|
||||||
|
T ft = (1.f)/(1.f + nd4j::math::nd4j_exp<T, T>(-(wi[wiOffset1] + bF)));
|
||||||
|
T rt = (1.f)/(1.f + nd4j::math::nd4j_exp<T, T>(-(wi[wiOffset2] + bR)));
|
||||||
|
|
||||||
|
T val = nd4j::math::nd4j_tanh<T,T>(ct[ctOffset]);
|
||||||
|
|
||||||
|
T prevVal;
|
||||||
|
if(t < time-1)
|
||||||
|
prevVal = ct[ctOffset += flip ? ctShapeInfo[rank + 1] : -ctShapeInfo[rank + 1]];
|
||||||
|
else
|
||||||
|
prevVal = c0Val;
|
||||||
|
|
||||||
|
// grad wrt input
|
||||||
|
gradI[gradIOffset] = gradHt[gradHtOffset] - gradHt[gradHtOffset] * rt ;
|
||||||
|
|
||||||
|
// grad wrt rt, wiR and bR
|
||||||
|
T grt = gradHt[gradHtOffset] * (val * maskVal - x[xOffset]) * (rt - rt * rt);
|
||||||
|
gradWi[gradWiOffset2] = grt;
|
||||||
|
gbR += grt;
|
||||||
|
|
||||||
|
// grad wrt state
|
||||||
|
T gradC0Val = gradHt[gradHtOffset] * maskVal * (rt - rt * val * val) + gradCtVal;
|
||||||
|
|
||||||
|
// grad wrt wi0
|
||||||
|
gradWi[gradWiOffset0] = gradC0Val - gradC0Val * ft;
|
||||||
|
|
||||||
|
// grad wrt ft, wi1, and bF
|
||||||
|
T gft = gradC0Val * (prevVal - wi[wiOffset0]) * (ft - ft * ft);
|
||||||
|
gradWi[gradWiOffset1] = gft;
|
||||||
|
gbF += gft;
|
||||||
|
|
||||||
|
// grad wrt c_previous
|
||||||
|
gradCtVal = gradC0Val * ft;
|
||||||
|
|
||||||
|
if(flip) {
|
||||||
|
xOffset += xShapeInfo[rank + 1]; // first stride, corresponds to time step
|
||||||
|
gradHtOffset += gradHtShapeInfo[rank + 1];
|
||||||
|
gradIOffset += gradIShapeInfo[rank + 1];
|
||||||
|
wiOffset0 += wiShapeInfo[rank + 1];
|
||||||
|
wiOffset1 += wiShapeInfo[rank + 1];
|
||||||
|
wiOffset2 += wiShapeInfo[rank + 1];
|
||||||
|
gradWiOffset0 += gradWiShapeInfo[rank + 1];
|
||||||
|
gradWiOffset1 += gradWiShapeInfo[rank + 1];
|
||||||
|
gradWiOffset2 += gradWiShapeInfo[rank + 1];
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
xOffset -= xShapeInfo[rank + 1]; // first stride, corresponds to time step
|
||||||
|
gradHtOffset -= gradHtShapeInfo[rank + 1];
|
||||||
|
gradIOffset -= gradIShapeInfo[rank + 1];
|
||||||
|
wiOffset0 -= wiShapeInfo[rank + 1];
|
||||||
|
wiOffset1 -= wiShapeInfo[rank + 1];
|
||||||
|
wiOffset2 -= wiShapeInfo[rank + 1];
|
||||||
|
gradWiOffset0 -= gradWiShapeInfo[rank + 1];
|
||||||
|
gradWiOffset1 -= gradWiShapeInfo[rank + 1];
|
||||||
|
gradWiOffset2 -= gradWiShapeInfo[rank + 1];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void sruBIBP(nd4j::LaunchContext * context, NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* ct, const NDArray* inGradC0, const NDArray* inGradH, const NDArray* mask, NDArray* gradI, NDArray* gradW, NDArray* gradB, NDArray* gradC0) {
|
gradB[gradBFOffset] = gbF;
|
||||||
BUILD_SINGLE_SELECTOR(x->dataType(), sruBIBP_, (x, w, b, c0, ct, inGradC0, inGradH, mask, gradI, gradW, gradB, gradC0), FLOAT_TYPES);
|
gradB[gradBROffset] = gbR;
|
||||||
}
|
gradC0[gradC0Offset] = gradCtVal;
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename T>
|
||||||
|
static void sruBIBPCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream,
|
||||||
|
const void* vx, const Nd4jLong* xShapeInfo,
|
||||||
|
const void* vwi, const Nd4jLong* wiShapeInfo,
|
||||||
|
const void* vb, const Nd4jLong* bShapeInfo,
|
||||||
|
const void* vc0, const Nd4jLong* c0ShapeInfo,
|
||||||
|
const void* vmask, const Nd4jLong* maskShapeInfo,
|
||||||
|
const void* vct, const Nd4jLong* ctShapeInfo,
|
||||||
|
const void* vgradHt, const Nd4jLong* gradHtShapeInfo,
|
||||||
|
const void* vgradCt, const Nd4jLong* gradCtShapeInfo,
|
||||||
|
void* vgradI, const Nd4jLong* gradIShapeInfo,
|
||||||
|
void* vgradWi, const Nd4jLong* gradWiShapeInfo,
|
||||||
|
void* vgradB, const Nd4jLong* gradBShapeInfo,
|
||||||
|
void* vgradC0, const Nd4jLong* gradC0ShapeInfo) {
|
||||||
|
|
||||||
|
sruBIBPCuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vx, xShapeInfo, vwi, wiShapeInfo, vb, bShapeInfo, vc0, c0ShapeInfo, vmask, maskShapeInfo, vct, ctShapeInfo, vgradHt, gradHtShapeInfo, vgradCt, gradCtShapeInfo, vgradI, gradIShapeInfo, vgradWi, gradWiShapeInfo, vgradB, gradBShapeInfo, vgradC0, gradC0ShapeInfo);
|
||||||
|
}
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void sruBIBPCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, const void* vwi, const Nd4jLong* wiShapeInfo, const void* vb, const Nd4jLong* bShapeInfo, const void* vc0, const Nd4jLong* c0ShapeInfo, const void* vmask, const Nd4jLong* maskShapeInfo, const void* vct, const Nd4jLong* ctShapeInfo, const void* vgradHt, const Nd4jLong* gradHtShapeInfo, const void* vgradCt, const Nd4jLong* gradCtShapeInfo, void* vgradI, const Nd4jLong* gradIShapeInfo, void* vgradWi, const Nd4jLong* gradWiShapeInfo, void* vgradB, const Nd4jLong* gradBShapeInfo, void* vgradC0, const Nd4jLong* gradC0ShapeInfo), FLOAT_TYPES);
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
void sruBIBP(nd4j::LaunchContext* context, NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* ct,
|
||||||
|
const NDArray* gradCt, const NDArray* gradHt, const NDArray* mask,
|
||||||
|
NDArray* gradI, NDArray* gradW, NDArray* gradB, NDArray* gradC0) {
|
||||||
|
|
||||||
|
// x = x * mask
|
||||||
|
if(mask)
|
||||||
|
x->applyBroadcast(broadcast::Multiply, {1, 2}, mask, x, nullptr); // apply mask
|
||||||
|
|
||||||
|
// U = x * w
|
||||||
|
NDArray wi = mmul(*x, *w); // U [time x bS x 6*K]
|
||||||
|
|
||||||
|
const int time = x->sizeAt(0);
|
||||||
|
const int bS = x->sizeAt(1);
|
||||||
|
const int K = x->sizeAt(2) / 2;
|
||||||
|
|
||||||
|
NDArray gradBias(x->ordering(), {bS, 4*K}, x->dataType(), context);
|
||||||
|
NDArray gradWi (x->ordering(), {time, bS, 6*K}, x->dataType(), context);
|
||||||
|
|
||||||
|
PointersManager manager(context, "sru_bi_bp");
|
||||||
|
|
||||||
|
const int threadsPerBlock = MAX_NUM_THREADS / 4;
|
||||||
|
const int blocksPerGrid = (x->sizeAt(1) * x->sizeAt(2) + threadsPerBlock - 1) / threadsPerBlock; // loop through last two dimensions of x array -> bS, 2*K
|
||||||
|
const int sharedMem = threadsPerBlock * sizeof(Nd4jLong) * x->rankOf() + 128;
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({gradI, &gradWi, &gradBias, gradC0}, {x, &wi, b, c0, ct, gradCt, gradHt, mask});
|
||||||
|
BUILD_SINGLE_SELECTOR(x->dataType(), sruBIBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), x->getSpecialBuffer(), x->getSpecialShapeInfo(), wi.getSpecialBuffer(), wi.getSpecialShapeInfo(), b->getSpecialBuffer(), b->getSpecialShapeInfo(), c0->getSpecialBuffer(), c0->getSpecialShapeInfo(), mask ? mask->getSpecialBuffer() : nullptr, mask ? mask->getSpecialShapeInfo() : nullptr, ct->getSpecialBuffer(), ct->getSpecialShapeInfo(), gradHt->getSpecialBuffer(), gradHt->getSpecialShapeInfo(), gradCt->getSpecialBuffer(), gradCt->getSpecialShapeInfo(), gradI->specialBuffer(), gradI->specialShapeInfo(), gradWi.specialBuffer(), gradWi.specialShapeInfo(), gradBias.specialBuffer(), gradBias.specialShapeInfo(), gradC0->specialBuffer(), gradC0->specialShapeInfo()), FLOAT_TYPES);
|
||||||
|
NDArray::registerSpecialUse({gradI, &gradWi, &gradBias, gradC0}, {x, &wi, b, c0, ct, gradCt, gradHt, mask});
|
||||||
|
|
||||||
|
manager.synchronize();
|
||||||
|
|
||||||
|
// gradB
|
||||||
|
gradBias.reduceAlongDimension(reduce::Sum, gradB, {0}); // [4*K]
|
||||||
|
|
||||||
|
// gradW
|
||||||
|
x->permutei({0, 2, 1}); // [time, bS, 2*K] -> [time, 2*K, bS]
|
||||||
|
MmulHelper::mmul(x, &gradWi, gradW, 1., 0.); // [time, 2*K, bS] x [time, bS , 6*K] = [time, 2*K, 6*K]
|
||||||
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void sruBI_, (NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* mask, NDArray* ht, NDArray* ct), FLOAT_TYPES);
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void sruBIBP_, (NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* ct, const NDArray* inGradC0, const NDArray* inGradH, const NDArray* mask, NDArray* gradI, NDArray* gradW, NDArray* gradB, NDArray* gradC0), FLOAT_TYPES);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,639 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <svd.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <cublas_v2.h>
|
||||||
|
#include <cusolverDn.h>
|
||||||
|
#include <exceptions/cuda_exception.h>
|
||||||
|
#include <PointersManager.h>
|
||||||
|
#include <ShapeUtils.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
|
||||||
|
// FIXME -> we should optimize these helpers for the case when input matrices have c order (perform transpositions appropriately)
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ static void inverseColumnSignCuda(void* vu, const Nd4jLong* uShapeInfo, void* vv, const Nd4jLong* vShapeInfo) {
|
||||||
|
|
||||||
|
T* u = reinterpret_cast<T*>(vu);
|
||||||
|
T* v = reinterpret_cast<T*>(vv);
|
||||||
|
|
||||||
|
__shared__ int rank, uLastButOneColumn, vLastButOneColumn; // uRank = vRank
|
||||||
|
__shared__ Nd4jLong uLen, vLen;
|
||||||
|
__shared__ Nd4jLong *sharedMem;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
|
||||||
|
extern __shared__ unsigned char shmem[];
|
||||||
|
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
|
||||||
|
|
||||||
|
rank = shape::rank(uShapeInfo);
|
||||||
|
uLen = shape::length(uShapeInfo);
|
||||||
|
vLen = shape::length(vShapeInfo);
|
||||||
|
|
||||||
|
uLastButOneColumn = uShapeInfo[rank] - 2;
|
||||||
|
vLastButOneColumn = vShapeInfo[rank - 1] - 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
const auto ind = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
|
||||||
|
auto coords = sharedMem + threadIdx.x * rank;
|
||||||
|
|
||||||
|
// u
|
||||||
|
for (Nd4jLong i = ind; i < uLen; i += gridDim.x * blockDim.x) {
|
||||||
|
|
||||||
|
shape::index2coords(rank, uShapeInfo + 1, i, uLen, coords);
|
||||||
|
|
||||||
|
if(coords[rank - 1] == 0 || coords[rank - 1] == uLastButOneColumn) // do not change sign in first and last but one columns
|
||||||
|
continue;
|
||||||
|
|
||||||
|
const auto uOffset = shape::getOffset(0, uShapeInfo + 1, uShapeInfo + rank + 1, coords, rank);
|
||||||
|
|
||||||
|
u[uOffset] = -u[uOffset];
|
||||||
|
}
|
||||||
|
|
||||||
|
// v
|
||||||
|
for (Nd4jLong i = ind; i < vLen; i += gridDim.x * blockDim.x) {
|
||||||
|
|
||||||
|
shape::index2coords(rank, vShapeInfo + 1, i, vLen, coords);
|
||||||
|
|
||||||
|
if(coords[rank - 2] == 0 || coords[rank - 2] == vLastButOneColumn) // do not change sign in first and last but one columns
|
||||||
|
continue;
|
||||||
|
|
||||||
|
const auto vOffset = shape::getOffset(0, vShapeInfo + 1, vShapeInfo + rank + 1, coords, rank);
|
||||||
|
|
||||||
|
v[vOffset] = -v[vOffset];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename T>
|
||||||
|
static void inverseColumnSignCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream,
|
||||||
|
void* vu, const Nd4jLong* uShapeInfo,
|
||||||
|
void* vv, const Nd4jLong* vShapeInfo) {
|
||||||
|
|
||||||
|
inverseColumnSignCuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vu, uShapeInfo, vv, vShapeInfo);
|
||||||
|
}
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void inverseColumnSignCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t* stream, void* vu, const Nd4jLong* uShapeInfo, void* vv, const Nd4jLong* vShapeInfo), FLOAT_TYPES);
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
static void svdQR(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, NDArray& U, NDArray& VT, const bool fullUV, const bool calcUV) {
|
||||||
|
|
||||||
|
// since cusa api cusolverDnDgesvd/cusolverDnSgesvd have following constrain on input matrix A: A_rows >= A_columns && A_order = 'f'
|
||||||
|
// we make this function to have deal with 2 valid cases only:
|
||||||
|
// 1) A_rows >= A_columns and A_corder = 'f'
|
||||||
|
// 2) A_rows <= A_columns and A_corder = 'c' - int this case perform transposition to get f order
|
||||||
|
// if 1) or 2) are not met then throw exception
|
||||||
|
|
||||||
|
// A [m, n]
|
||||||
|
// S [n]
|
||||||
|
// U [m, m] or [m, n] if fullUV = false and m > n
|
||||||
|
// VT [n, n] or [m, n] if fullUV = false and m < n
|
||||||
|
|
||||||
|
if(A.rankOf() != 2)
|
||||||
|
throw std::runtime_error("svdQR: rank of A array is not equal 2 !");
|
||||||
|
|
||||||
|
auto m = A.sizeAt(0);
|
||||||
|
auto n = A.sizeAt(1);
|
||||||
|
const int minDim = m < n ? m : n;
|
||||||
|
const char orderA = A.ordering();
|
||||||
|
|
||||||
|
if(m < n)
|
||||||
|
throw std::runtime_error("svdQR: due to cuda api input constrains given shape of A array are not valid !");
|
||||||
|
|
||||||
|
if(ShapeUtils::shapeAsString({minDim}) != ShapeUtils::shapeAsString(&S))
|
||||||
|
throw std::runtime_error("svdQR: wrong shape of S array !");
|
||||||
|
|
||||||
|
if(calcUV) {
|
||||||
|
|
||||||
|
if(fullUV && ShapeUtils::shapeAsString({m,m}) != ShapeUtils::shapeAsString(&U))
|
||||||
|
throw std::runtime_error("svdQR: wrong shape of U array !");
|
||||||
|
else if(!fullUV && ShapeUtils::shapeAsString({m,minDim}) != ShapeUtils::shapeAsString(&U))
|
||||||
|
throw std::runtime_error("svdQR: wrong shape of U array !");
|
||||||
|
|
||||||
|
if(fullUV && ShapeUtils::shapeAsString({n,n}) != ShapeUtils::shapeAsString(&VT))
|
||||||
|
throw std::runtime_error("svdQR: wrong shape of VT array !");
|
||||||
|
else if(!fullUV && ShapeUtils::shapeAsString({minDim,n}) != ShapeUtils::shapeAsString(&VT))
|
||||||
|
throw std::runtime_error("svdQR: wrong shape of VT array !");
|
||||||
|
}
|
||||||
|
|
||||||
|
NDArray* pA = const_cast<NDArray*>(&A);
|
||||||
|
NDArray* pS = &S;
|
||||||
|
NDArray* pU = &U;
|
||||||
|
NDArray* pVT = &VT;
|
||||||
|
|
||||||
|
std::vector<NDArray*> toDelete;
|
||||||
|
|
||||||
|
if(pA->ews() != 1 || pA->ordering() == 'c') {
|
||||||
|
pA = A.dup('f');
|
||||||
|
toDelete.push_back(pA);
|
||||||
|
}
|
||||||
|
|
||||||
|
if(S.ews() != 1) {
|
||||||
|
pS = S.dup('f');
|
||||||
|
toDelete.push_back(pS);
|
||||||
|
}
|
||||||
|
|
||||||
|
if(calcUV) {
|
||||||
|
|
||||||
|
if(pU->ews() != 1 || pU->ordering() == 'c') {
|
||||||
|
pU = U.dup('f');
|
||||||
|
toDelete.push_back(pU);
|
||||||
|
}
|
||||||
|
|
||||||
|
if(pVT->ews() != 1 || pVT->ordering() == 'c') {
|
||||||
|
pVT = VT.dup('f');
|
||||||
|
toDelete.push_back(pVT);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// create cusolverDn handle
|
||||||
|
cusolverDnHandle_t handle = nullptr;
|
||||||
|
cusolverStatus_t status = cusolverDnCreate(&handle);
|
||||||
|
if(status != CUSOLVER_STATUS_SUCCESS)
|
||||||
|
throw cuda_exception::build("svdQR: cuda failed !", status);
|
||||||
|
|
||||||
|
// stream
|
||||||
|
status = cusolverDnSetStream(handle, *context->getCudaStream());
|
||||||
|
if(status != CUSOLVER_STATUS_SUCCESS)
|
||||||
|
throw cuda_exception::build("svdQR: cuda failed !", status);
|
||||||
|
|
||||||
|
// query working space of SVD
|
||||||
|
int lwork = 0;
|
||||||
|
if(A.dataType() == DataType::DOUBLE)
|
||||||
|
status = cusolverDnDgesvd_bufferSize(handle, m, n, &lwork);
|
||||||
|
else if(A.dataType() == DataType::FLOAT32)
|
||||||
|
status = cusolverDnSgesvd_bufferSize(handle, m, n, &lwork);
|
||||||
|
else
|
||||||
|
throw std::invalid_argument("svdQR: given data type is unsupported !");
|
||||||
|
|
||||||
|
if(status != CUSOLVER_STATUS_SUCCESS)
|
||||||
|
throw cuda_exception::build("svdQR: cuda failed !", status);
|
||||||
|
|
||||||
|
// allocate memory for dWork
|
||||||
|
void* dWork = nullptr;
|
||||||
|
cudaError_t status2 = cudaMalloc((void**)&dWork , A.sizeOfT() * lwork);
|
||||||
|
if(status2 != cudaSuccess)
|
||||||
|
throw cuda_exception::build("svdQR: cuda failed !", status2);
|
||||||
|
|
||||||
|
signed char jobu, jobvt;
|
||||||
|
|
||||||
|
if(calcUV) {
|
||||||
|
if(fullUV)
|
||||||
|
jobu = jobvt = 'A';
|
||||||
|
else
|
||||||
|
jobu = jobvt = 'S';
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
jobu = jobvt = 'N';
|
||||||
|
}
|
||||||
|
|
||||||
|
int *devInfo = nullptr;
|
||||||
|
void* rWork = nullptr;
|
||||||
|
|
||||||
|
int lda(m), ldu, ldvt;
|
||||||
|
|
||||||
|
if(calcUV) {
|
||||||
|
ldu = pU->sizeAt(0);
|
||||||
|
ldvt = pVT->sizeAt(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
PointersManager manager(context, "svdQR");
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({pS, pU, pVT}, {pA});
|
||||||
|
|
||||||
|
// choose appropriate cuda gemm api depending on data types
|
||||||
|
if(A.dataType() == DataType::DOUBLE) {
|
||||||
|
status = cusolverDnDgesvd(handle, jobu, jobvt, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), reinterpret_cast<double*>(pU->getSpecialBuffer()), ldu, reinterpret_cast<double*>(pVT->getSpecialBuffer()), ldvt, reinterpret_cast<double*>(dWork), lwork, reinterpret_cast<double*>(rWork), devInfo);
|
||||||
|
}
|
||||||
|
else if(A.dataType() == DataType::FLOAT32) {
|
||||||
|
status = cusolverDnSgesvd(handle, jobu, jobvt, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), reinterpret_cast<float*>(pU->getSpecialBuffer()), ldu, reinterpret_cast<float*>(pVT->getSpecialBuffer()), ldvt, reinterpret_cast<float*>(dWork), lwork, reinterpret_cast<float*>(rWork), devInfo);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
throw std::invalid_argument("svdQR: given data type is unsupported !");
|
||||||
|
|
||||||
|
if(status != CUSOLVER_STATUS_SUCCESS)
|
||||||
|
throw cuda_exception::build("svdQR: cuda failed !", status);
|
||||||
|
|
||||||
|
manager.synchronize();
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({pS, pU, pVT}, {pA});
|
||||||
|
|
||||||
|
S.assign(pS);
|
||||||
|
|
||||||
|
if(calcUV) {
|
||||||
|
U.assign(pU);
|
||||||
|
VT.assign(pVT);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = toDelete.size() - 1; i >= 0; --i)
|
||||||
|
delete toDelete[i];
|
||||||
|
|
||||||
|
if (devInfo)
|
||||||
|
cudaFree(devInfo);
|
||||||
|
if (dWork )
|
||||||
|
cudaFree(dWork);
|
||||||
|
if (rWork)
|
||||||
|
cudaFree(rWork);
|
||||||
|
|
||||||
|
if(handle)
|
||||||
|
cusolverDnDestroy(handle);
|
||||||
|
|
||||||
|
// cudaDeviceReset();
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
static void svdJcb(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, NDArray& U, NDArray& V, const bool fullUV, const bool calcUV) {
|
||||||
|
|
||||||
|
// A [m, n]
|
||||||
|
// S [n]
|
||||||
|
// U [m, m] or [m, n] if fullUV = false and m > n
|
||||||
|
// V [n, n] or [n, m] if fullUV = false and m < n
|
||||||
|
|
||||||
|
if(A.rankOf() != 2)
|
||||||
|
throw std::runtime_error("svdJcb: rank of A array is not equal 2 !");
|
||||||
|
|
||||||
|
auto m = A.sizeAt(0);
|
||||||
|
auto n = A.sizeAt(1);
|
||||||
|
const int minDim = m < n ? m : n;
|
||||||
|
|
||||||
|
if(ShapeUtils::shapeAsString({minDim}) != ShapeUtils::shapeAsString(&S))
|
||||||
|
throw std::runtime_error("svdJcb: wrong shape of S array !");
|
||||||
|
|
||||||
|
if(calcUV) {
|
||||||
|
|
||||||
|
if(fullUV && ShapeUtils::shapeAsString({m,m}) != ShapeUtils::shapeAsString(&U))
|
||||||
|
throw std::runtime_error("svdJcb: wrong shape of U array !");
|
||||||
|
else if(!fullUV && ShapeUtils::shapeAsString({m,minDim}) != ShapeUtils::shapeAsString(&U))
|
||||||
|
throw std::runtime_error("svdJcb: wrong shape of U array !");
|
||||||
|
|
||||||
|
if(fullUV && ShapeUtils::shapeAsString({n,n}) != ShapeUtils::shapeAsString(&V))
|
||||||
|
throw std::runtime_error("svdJcb: wrong shape of V array !");
|
||||||
|
else if(!fullUV && ShapeUtils::shapeAsString({n,minDim}) != ShapeUtils::shapeAsString(&V))
|
||||||
|
throw std::runtime_error("svdJcb: wrong shape of V array !");
|
||||||
|
}
|
||||||
|
|
||||||
|
NDArray* pA = const_cast<NDArray*>(&A);
|
||||||
|
NDArray* pS = &S;
|
||||||
|
NDArray* pU = &U;
|
||||||
|
NDArray* pV = &V;
|
||||||
|
|
||||||
|
std::vector<NDArray*> toDelete;
|
||||||
|
|
||||||
|
if(pA->ews() != 1 || pA->ordering() == 'c') {
|
||||||
|
pA = A.dup('f');
|
||||||
|
toDelete.push_back(pA);
|
||||||
|
}
|
||||||
|
|
||||||
|
if(S.ews() != 1) {
|
||||||
|
pS = S.dup('f');
|
||||||
|
toDelete.push_back(pS);
|
||||||
|
}
|
||||||
|
|
||||||
|
if(calcUV) {
|
||||||
|
|
||||||
|
if(pU->ews() != 1 || pU->ordering() == 'c') {
|
||||||
|
pU = U.dup('f');
|
||||||
|
toDelete.push_back(pU);
|
||||||
|
}
|
||||||
|
|
||||||
|
if(pV->ews() != 1 || pV->ordering() == 'c') {
|
||||||
|
pV = V.dup('f');
|
||||||
|
toDelete.push_back(pV);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// create cusolverDn handle
|
||||||
|
cusolverDnHandle_t handle = nullptr;
|
||||||
|
cusolverStatus_t status = cusolverDnCreate(&handle);
|
||||||
|
if(status != CUSOLVER_STATUS_SUCCESS)
|
||||||
|
throw cuda_exception::build("svdJcb: cuda failed !", status);
|
||||||
|
|
||||||
|
// stream
|
||||||
|
status = cusolverDnSetStream(handle, *context->getCudaStream());
|
||||||
|
if(status != CUSOLVER_STATUS_SUCCESS)
|
||||||
|
throw cuda_exception::build("svdJcb: cuda failed !", status);
|
||||||
|
|
||||||
|
// set parameters
|
||||||
|
gesvdjInfo_t gesvdjParams = nullptr;
|
||||||
|
status = cusolverDnCreateGesvdjInfo(&gesvdjParams);
|
||||||
|
if(status != CUSOLVER_STATUS_SUCCESS)
|
||||||
|
throw cuda_exception::build("svdJcb: cuda failed !", status);
|
||||||
|
status = cusolverDnXgesvdjSetTolerance(gesvdjParams, 1.e-7); // tolerance
|
||||||
|
if(status != CUSOLVER_STATUS_SUCCESS)
|
||||||
|
throw cuda_exception::build("svdJcb: cuda failed !", status);
|
||||||
|
status = cusolverDnXgesvdjSetMaxSweeps(gesvdjParams, 15); // max_sweeps
|
||||||
|
if(status != CUSOLVER_STATUS_SUCCESS)
|
||||||
|
throw cuda_exception::build("svdJcb: cuda failed !", status);
|
||||||
|
|
||||||
|
int *devInfo = nullptr;
|
||||||
|
const cusolverEigMode_t jobz = calcUV ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR;
|
||||||
|
const int econ = !fullUV;
|
||||||
|
|
||||||
|
int lda(m), ldu(m), ldv(m);
|
||||||
|
|
||||||
|
if(calcUV) {
|
||||||
|
ldu = pU->sizeAt(0);
|
||||||
|
ldv = pV->sizeAt(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// query working space of SVD
|
||||||
|
int lwork = 0;
|
||||||
|
if(A.dataType() == DataType::DOUBLE)
|
||||||
|
status = cusolverDnDgesvdj_bufferSize(handle, jobz, econ, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), reinterpret_cast<double*>(pU->getSpecialBuffer()), ldu, reinterpret_cast<double*>(pV->getSpecialBuffer()), ldv, &lwork, gesvdjParams);
|
||||||
|
else if(A.dataType() == DataType::FLOAT32)
|
||||||
|
status = cusolverDnSgesvdj_bufferSize(handle, jobz, econ, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), reinterpret_cast<float*>(pU->getSpecialBuffer()), ldu, reinterpret_cast<float*>(pV->getSpecialBuffer()), ldv, &lwork, gesvdjParams);
|
||||||
|
else
|
||||||
|
throw std::invalid_argument("svdJcb: given data type is unsupported !");
|
||||||
|
|
||||||
|
if(status != CUSOLVER_STATUS_SUCCESS)
|
||||||
|
throw cuda_exception::build("svdJcb: cuda failed !", status);
|
||||||
|
|
||||||
|
// allocate memory dWork
|
||||||
|
void* dWork = nullptr;
|
||||||
|
auto status2 = cudaMalloc((void**)&dWork , A.sizeOfT() * lwork);
|
||||||
|
if(status2 != cudaSuccess)
|
||||||
|
throw cuda_exception::build("svdJcb: cuda failed !", status2);
|
||||||
|
|
||||||
|
PointersManager manager(context, "svdJcb");
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({pS, pU, pV}, {pA});
|
||||||
|
|
||||||
|
// choose appropriate cuda gemm api depending on data types
|
||||||
|
if(A.dataType() == DataType::DOUBLE) {
|
||||||
|
status = cusolverDnDgesvdj(handle, jobz, econ, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), reinterpret_cast<double*>(pU->getSpecialBuffer()), ldu, reinterpret_cast<double*>(pV->getSpecialBuffer()), ldv, reinterpret_cast<double*>(dWork), lwork, devInfo, gesvdjParams);
|
||||||
|
}
|
||||||
|
else if(A.dataType() == DataType::FLOAT32) {
|
||||||
|
status = cusolverDnSgesvdj(handle, jobz, econ, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), reinterpret_cast<float*>(pU->getSpecialBuffer()), ldu, reinterpret_cast<float*>(pV->getSpecialBuffer()), ldv, reinterpret_cast<float*>(dWork), lwork, devInfo, gesvdjParams);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
throw std::invalid_argument("svdJcb: given data type is unsupported !");
|
||||||
|
|
||||||
|
if(status != CUSOLVER_STATUS_SUCCESS)
|
||||||
|
throw cuda_exception::build("svdJcb: cuda failed !", status);
|
||||||
|
|
||||||
|
manager.synchronize();
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({pS, pU, pV}, {pA});
|
||||||
|
|
||||||
|
S.assign(pS);
|
||||||
|
|
||||||
|
if(calcUV) {
|
||||||
|
U.assign(pU);
|
||||||
|
V.assign(pV);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = toDelete.size() - 1; i >= 0; --i)
|
||||||
|
delete toDelete[i];
|
||||||
|
|
||||||
|
if (devInfo)
|
||||||
|
cudaFree(devInfo);
|
||||||
|
if (dWork )
|
||||||
|
cudaFree(dWork);
|
||||||
|
if(handle)
|
||||||
|
cusolverDnDestroy(handle);
|
||||||
|
if(gesvdjParams)
|
||||||
|
cusolverDnDestroyGesvdjInfo(gesvdjParams);
|
||||||
|
|
||||||
|
// cudaDeviceReset();
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
static void svdBatched(nd4j::LaunchContext* context, const NDArray& A, NDArray& S, NDArray& U, NDArray& V, const bool fullUV, const bool calcUV) {
|
||||||
|
|
||||||
|
// A [..., m, n]
|
||||||
|
// S [..., n]
|
||||||
|
// U [..., m, m] or [..., m, n] if fullUV = false and m > n
|
||||||
|
// V [..., n, n] or [..., n, m] if fullUV = false and m < n
|
||||||
|
|
||||||
|
auto m = A.sizeAt(-2);
|
||||||
|
auto n = A.sizeAt(-1);
|
||||||
|
const int minDim = m < n ? m : n;
|
||||||
|
const Nd4jLong bS = A.lengthOf() / (m * n);
|
||||||
|
|
||||||
|
if(m > 32 || n > 32)
|
||||||
|
throw std::runtime_error("svdBatched: numbers of rows and columns should be <= 32 !");
|
||||||
|
|
||||||
|
if(minDim != S.sizeAt(-1))
|
||||||
|
throw std::runtime_error("svdBatched: wrong shape of S array !");
|
||||||
|
|
||||||
|
if(calcUV) {
|
||||||
|
|
||||||
|
if(U.sizeAt(-2) != m)
|
||||||
|
throw std::runtime_error("svdBatched: wrong shape of U array !");
|
||||||
|
if(U.sizeAt(-1) != (fullUV ? m : minDim))
|
||||||
|
throw std::runtime_error("svdBatched: wrong shape of U array !");
|
||||||
|
if(U.lengthOf() / (U.sizeAt(-2) * U.sizeAt(-1)) != bS)
|
||||||
|
throw std::runtime_error("svdBatched: wrong shape of U array !");
|
||||||
|
|
||||||
|
if(V.sizeAt(-2) != n)
|
||||||
|
throw std::runtime_error("svdBatched: wrong shape of V array !");
|
||||||
|
if(V.sizeAt(-1) != (fullUV ? n : minDim))
|
||||||
|
throw std::runtime_error("svdBatched: wrong shape of V array !");
|
||||||
|
if(V.lengthOf() / (V.sizeAt(-2) * V.sizeAt(-1)) != bS)
|
||||||
|
throw std::runtime_error("svdBatched: wrong shape of V array !");
|
||||||
|
}
|
||||||
|
|
||||||
|
NDArray* pA = const_cast<NDArray*>(&A);
|
||||||
|
NDArray* pS = &S;
|
||||||
|
NDArray* pU = &U;
|
||||||
|
NDArray* pV = &V;
|
||||||
|
|
||||||
|
std::vector<NDArray*> toDelete;
|
||||||
|
|
||||||
|
if(pA->ews() != 1 || pA->ordering() == 'c') {
|
||||||
|
pA = A.dup('f');
|
||||||
|
toDelete.push_back(pA);
|
||||||
|
}
|
||||||
|
|
||||||
|
if(S.ews() != 1) {
|
||||||
|
pS = S.dup('f');
|
||||||
|
toDelete.push_back(pS);
|
||||||
|
}
|
||||||
|
|
||||||
|
if(calcUV) {
|
||||||
|
|
||||||
|
if(pU->ews() != 1 || pU->ordering() == 'c') {
|
||||||
|
pU = U.dup('f');
|
||||||
|
toDelete.push_back(pU);
|
||||||
|
}
|
||||||
|
|
||||||
|
if(pV->ews() != 1 || pV->ordering() == 'c') {
|
||||||
|
pV = V.dup('f');
|
||||||
|
toDelete.push_back(pV);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// create cusolverDn handle
|
||||||
|
cusolverDnHandle_t handle = nullptr;
|
||||||
|
cusolverStatus_t status = cusolverDnCreate(&handle);
|
||||||
|
if(status != CUSOLVER_STATUS_SUCCESS)
|
||||||
|
throw cuda_exception::build("svdBatched: cuda failed !", status);
|
||||||
|
|
||||||
|
// stream
|
||||||
|
status = cusolverDnSetStream(handle, *context->getCudaStream());
|
||||||
|
if(status != CUSOLVER_STATUS_SUCCESS)
|
||||||
|
throw cuda_exception::build("svdBatched: cuda failed !", status);
|
||||||
|
|
||||||
|
// set parameters
|
||||||
|
gesvdjInfo_t gesvdjParams = nullptr;
|
||||||
|
status = cusolverDnCreateGesvdjInfo(&gesvdjParams);
|
||||||
|
if(status != CUSOLVER_STATUS_SUCCESS)
|
||||||
|
throw cuda_exception::build("svdBatched: cuda failed !", status);
|
||||||
|
status = cusolverDnXgesvdjSetTolerance(gesvdjParams, 1.e-7); // tolerance
|
||||||
|
if(status != CUSOLVER_STATUS_SUCCESS)
|
||||||
|
throw cuda_exception::build("svdBatched: cuda failed !", status);
|
||||||
|
status = cusolverDnXgesvdjSetMaxSweeps(gesvdjParams, 15); // max_sweeps
|
||||||
|
if(status != CUSOLVER_STATUS_SUCCESS)
|
||||||
|
throw cuda_exception::build("svdBatched: cuda failed !", status);
|
||||||
|
|
||||||
|
// devInfo
|
||||||
|
int *devInfo = nullptr;
|
||||||
|
auto status2 = cudaMalloc((void**)&devInfo, sizeof(int) * bS);
|
||||||
|
if(status2 != cudaSuccess)
|
||||||
|
throw cuda_exception::build("svdBatched: cuda failed !", status2);
|
||||||
|
status2 = cudaDeviceSynchronize();
|
||||||
|
if(status2 != cudaSuccess)
|
||||||
|
throw cuda_exception::build("svdJcb: cuda failed !", status2);
|
||||||
|
|
||||||
|
const cusolverEigMode_t jobz = calcUV ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR;
|
||||||
|
|
||||||
|
int lda(m), ldu, ldv;
|
||||||
|
|
||||||
|
if(calcUV) {
|
||||||
|
ldu = pU->sizeAt(-2);
|
||||||
|
ldv = pV->sizeAt(-2);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ak (i,j) = A[i + 5*j + 25*k]
|
||||||
|
|
||||||
|
// query working space of SVD
|
||||||
|
int lwork = 0;
|
||||||
|
if(A.dataType() == DataType::DOUBLE)
|
||||||
|
status = cusolverDnDgesvdjBatched_bufferSize(handle, jobz, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), reinterpret_cast<double*>(pU->getSpecialBuffer()), ldu, reinterpret_cast<double*>(pV->getSpecialBuffer()), ldv, &lwork, gesvdjParams, bS);
|
||||||
|
else if(A.dataType() == DataType::FLOAT32)
|
||||||
|
status = cusolverDnSgesvdjBatched_bufferSize(handle, jobz, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), reinterpret_cast<float*>(pU->getSpecialBuffer()), ldu, reinterpret_cast<float*>(pV->getSpecialBuffer()), ldv, &lwork, gesvdjParams, bS);
|
||||||
|
else
|
||||||
|
throw std::invalid_argument("svdBatched: given data type is unsupported !");
|
||||||
|
|
||||||
|
if(status != CUSOLVER_STATUS_SUCCESS)
|
||||||
|
throw cuda_exception::build("svdBatched: cuda failed !", status);
|
||||||
|
|
||||||
|
// allocate memory dWork
|
||||||
|
void* dWork = nullptr;
|
||||||
|
status2 = cudaMalloc((void**)&dWork , A.sizeOfT() * lwork);
|
||||||
|
if(status2 != cudaSuccess)
|
||||||
|
throw cuda_exception::build("svdBatched: cuda failed !", status2);
|
||||||
|
status2 = cudaDeviceSynchronize();
|
||||||
|
if(status2 != cudaSuccess)
|
||||||
|
throw cuda_exception::build("svdBatched: cuda failed !", status2);
|
||||||
|
|
||||||
|
PointersManager manager(context, "svdBatched");
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({pS, pU, pV}, {pA});
|
||||||
|
|
||||||
|
// choose appropriate cuda gemm api depending on data types
|
||||||
|
if(A.dataType() == DataType::DOUBLE) {
|
||||||
|
status = cusolverDnDgesvdjBatched(handle, jobz, m, n, reinterpret_cast<double*>(pA->getSpecialBuffer()), lda, reinterpret_cast<double*>(pS->getSpecialBuffer()), reinterpret_cast<double*>(pU->getSpecialBuffer()), ldu, reinterpret_cast<double*>(pV->getSpecialBuffer()), ldv, reinterpret_cast<double*>(dWork), lwork, devInfo, gesvdjParams, bS);
|
||||||
|
}
|
||||||
|
else if(A.dataType() == DataType::FLOAT32) {
|
||||||
|
status = cusolverDnSgesvdjBatched(handle, jobz, m, n, reinterpret_cast<float*>(pA->getSpecialBuffer()), lda, reinterpret_cast<float*>(pS->getSpecialBuffer()), reinterpret_cast<float*>(pU->getSpecialBuffer()), ldu, reinterpret_cast<float*>(pV->getSpecialBuffer()), ldv, reinterpret_cast<float*>(dWork), lwork, devInfo, gesvdjParams, bS);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
throw std::invalid_argument("svdBatched: given data type is unsupported !");
|
||||||
|
|
||||||
|
if(status != CUSOLVER_STATUS_SUCCESS)
|
||||||
|
throw cuda_exception::build("svdBatched: cuda failed !", status);
|
||||||
|
|
||||||
|
manager.synchronize();
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({pS, pU, pV}, {pA});
|
||||||
|
|
||||||
|
S.assign(pS);
|
||||||
|
|
||||||
|
if(calcUV) {
|
||||||
|
U.assign(pU);
|
||||||
|
V.assign(pV);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = toDelete.size() - 1; i >= 0; --i)
|
||||||
|
delete toDelete[i];
|
||||||
|
|
||||||
|
if (devInfo)
|
||||||
|
cudaFree(devInfo);
|
||||||
|
if (dWork )
|
||||||
|
cudaFree(dWork);
|
||||||
|
if(handle)
|
||||||
|
cusolverDnDestroy(handle);
|
||||||
|
if(gesvdjParams)
|
||||||
|
cusolverDnDestroyGesvdjInfo(gesvdjParams);
|
||||||
|
|
||||||
|
// cudaDeviceReset();
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
void svd(nd4j::LaunchContext* context, const NDArray* x, const std::vector<NDArray*>& outArrs, const bool fullUV, const bool calcUV, const int switchNum) {
|
||||||
|
|
||||||
|
NDArray* S = outArrs[0];
|
||||||
|
NDArray* U = outArrs[1];
|
||||||
|
// NDArray VT = outArrs[2]->transpose();
|
||||||
|
NDArray* V = outArrs[2];
|
||||||
|
|
||||||
|
if(x->rankOf() == 2) {
|
||||||
|
// svdQR(context, *x, *S, *U, VT, fullUV, calcUV);
|
||||||
|
svdJcb(context, *x, *S, *U, *V, fullUV, calcUV);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
|
||||||
|
// svdBatched(context, *x, *S, *U, *V, fullUV, calcUV);
|
||||||
|
|
||||||
|
ResultSet *tadsU(nullptr), *tadsV(nullptr);
|
||||||
|
|
||||||
|
auto tadsX = x->allTensorsAlongDimension({x->rankOf() - 2, x->rankOf() - 1});
|
||||||
|
auto tadsS = S->allTensorsAlongDimension({S->rankOf() - 1});
|
||||||
|
|
||||||
|
if(calcUV) {
|
||||||
|
tadsU = U->allTensorsAlongDimension({U->rankOf() - 2, U->rankOf() - 1});
|
||||||
|
tadsV = V->allTensorsAlongDimension({V->rankOf() - 2, V->rankOf() - 1});
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < tadsX->size(); ++i)
|
||||||
|
svdJcb(context, *tadsX->at(i), *tadsS->at(i), calcUV ? *tadsU->at(i) : *S, calcUV ? *tadsV->at(i) : *S, fullUV, calcUV);
|
||||||
|
|
||||||
|
delete tadsX;
|
||||||
|
delete tadsS;
|
||||||
|
|
||||||
|
if(calcUV) {
|
||||||
|
delete tadsU;
|
||||||
|
delete tadsV;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -323,40 +323,148 @@ void trace(nd4j::LaunchContext* context, const NDArray& input, NDArray& output)
|
||||||
manager.synchronize();
|
manager.synchronize();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
__global__ static void triuBPCuda(const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int diag) {
|
||||||
|
|
||||||
|
// x and z have same shapes
|
||||||
|
const auto x = reinterpret_cast<const T*>(vx); // gradO
|
||||||
|
auto z = reinterpret_cast<T*>(vz); // gradI
|
||||||
|
|
||||||
|
__shared__ int rank, areSameOffsets; // xRank = zRank
|
||||||
|
__shared__ Nd4jLong len, totalThreads, *sharedMem; // xLen = zLen
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
|
||||||
|
extern __shared__ unsigned char shmem[];
|
||||||
|
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
|
||||||
|
areSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
|
||||||
|
rank = shape::rank(xShapeInfo);
|
||||||
|
len = shape::length(zShapeInfo);
|
||||||
|
totalThreads = gridDim.x * blockDim.x;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
|
||||||
template <typename T>
|
|
||||||
static void triuBP_(nd4j::LaunchContext * context, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int diagonal) {
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void triuBP(nd4j::LaunchContext * context, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int diagonal) {
|
__syncthreads();
|
||||||
BUILD_SINGLE_SELECTOR(gradO.dataType(), triuBP_, (context, input, gradO, gradI, diagonal), LIBND4J_TYPES);
|
|
||||||
|
auto coords = sharedMem + threadIdx.x * rank;
|
||||||
|
|
||||||
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
|
for (Nd4jLong i = tid; i < len; i += totalThreads) {
|
||||||
|
|
||||||
|
shape::index2coords(rank, zShapeInfo + 1, i, len, coords);
|
||||||
|
|
||||||
|
const auto zOffset = shape::getOffset(0, zShapeInfo + 1, zShapeInfo + rank + 1, coords, rank);
|
||||||
|
|
||||||
|
if((coords[rank - 2] + diag > coords[rank - 1])) // row + diag > col
|
||||||
|
z[zOffset] = 0;
|
||||||
|
else
|
||||||
|
z[zOffset] = x[areSameOffsets ? zOffset : shape::getOffset(0, xShapeInfo + 1, xShapeInfo + rank + 1, coords, rank)];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
static void triuBPCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int diag) {
|
||||||
|
|
||||||
|
triuBPCuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vx, xShapeInfo, vz, zShapeInfo, diag);
|
||||||
|
}
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void triuBPCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int diag), LIBND4J_TYPES);
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
void triuBP(nd4j::LaunchContext* context, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int diagonal) {
|
||||||
|
|
||||||
|
const int threadsPerBlock = MAX_NUM_THREADS / 4;
|
||||||
|
const int blocksPerGrid = (gradO.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
const int sharedMem = threadsPerBlock * sizeof(Nd4jLong) * gradO.rankOf() + 128;
|
||||||
|
|
||||||
|
PointersManager manager(context, "triuBP");
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&gradI}, {&gradO});
|
||||||
|
BUILD_SINGLE_SELECTOR(gradI.dataType(), triuBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), gradO.getSpecialBuffer(), gradO.getSpecialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), diagonal), LIBND4J_TYPES);
|
||||||
|
NDArray::registerSpecialUse({&gradI}, {&gradO});
|
||||||
|
|
||||||
|
manager.synchronize();
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
__global__ static void tileBPCuda(const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, Nd4jLong* globMem) {
|
||||||
|
|
||||||
|
// x and z have same shapes
|
||||||
|
const auto x = reinterpret_cast<const T*>(vx); // gradO
|
||||||
|
auto z = reinterpret_cast<T*>(vz); // gradI
|
||||||
|
|
||||||
|
__shared__ int xRank, zRank; // xRank >= zRank
|
||||||
|
__shared__ Nd4jLong numOfXOffsets, zLen, totalThreads, *sharedMem; // xLen >= zLen
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
|
||||||
|
extern __shared__ unsigned char shmem[];
|
||||||
|
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
|
||||||
|
|
||||||
|
xRank = shape::rank(zShapeInfo);
|
||||||
|
zLen = shape::length(zShapeInfo);
|
||||||
|
numOfXOffsets = shape::length(xShapeInfo) / zLen;
|
||||||
|
|
||||||
|
totalThreads = gridDim.x * blockDim.x;
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void triuBP_, (nd4j::LaunchContext * context, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int diagonal), LIBND4J_TYPES);
|
__syncthreads();
|
||||||
|
|
||||||
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
|
auto memBuff = sharedMem + threadIdx.x * 2 * xRank;
|
||||||
|
auto xOffsets = globMem + tid * numOfXOffsets;
|
||||||
|
|
||||||
|
for (Nd4jLong i = tid; i < zLen; i += totalThreads) {
|
||||||
|
|
||||||
|
const auto zOffset = shape::getIndexOffset(i, zShapeInfo, zLen);
|
||||||
|
|
||||||
|
shape::outerArrayOffsets(xOffsets, i, xShapeInfo, zShapeInfo, memBuff);
|
||||||
|
|
||||||
|
z[zOffset] = x[xOffsets[0]]; // first offset
|
||||||
|
for (Nd4jLong j = 1; j < numOfXOffsets; ++j) // rest offsets
|
||||||
|
z[zOffset] += x[xOffsets[j]];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
static void tileBPCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, Nd4jLong* globMem) {
|
||||||
|
|
||||||
|
tileBPCuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vx, xShapeInfo, vz, zShapeInfo, globMem);
|
||||||
|
}
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void tileBPCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, Nd4jLong* globMem), FLOAT_TYPES);
|
||||||
|
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
void tileBP(nd4j::LaunchContext * context, const NDArray& gradO /*input*/, NDArray& gradI /*output*/, const std::vector<Nd4jLong> reps) {
|
||||||
|
|
||||||
|
NDArray memBuff('c', gradO.getShapeAsVector(), nd4j::DataType::INT64, context); // empty auxiliary array for storing device memory which will be used in kernel calculations
|
||||||
|
|
||||||
|
const int threadsPerBlock = MAX_NUM_THREADS / 4;
|
||||||
|
const int blocksPerGrid = (gradI.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
const int sharedMem = threadsPerBlock * sizeof(Nd4jLong) * 2 * gradO.rankOf() + 128;
|
||||||
|
|
||||||
|
PointersManager manager(context, "tileBP");
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&gradI}, {&gradO, &memBuff});
|
||||||
|
BUILD_SINGLE_SELECTOR(gradI.dataType(), tileBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), gradO.getSpecialBuffer(), gradO.getSpecialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), reinterpret_cast<Nd4jLong*>(memBuff.specialBuffer())), FLOAT_TYPES);
|
||||||
|
NDArray::registerSpecialUse({&gradI}, {&gradO, &memBuff});
|
||||||
|
|
||||||
|
manager.synchronize();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@ -1036,18 +1144,6 @@ void concat(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs,
|
||||||
scatterSimpleKernel<X,Y><<<256, 256, 1024, *context->getCudaStream()>>>(input.getSpecialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), xLength, packX.numberOfTads(), indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), iLength, updates.getSpecialBuffer(), updates.getSpecialShapeInfo(), uLength);
|
scatterSimpleKernel<X,Y><<<256, 256, 1024, *context->getCudaStream()>>>(input.getSpecialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), xLength, packX.numberOfTads(), indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), iLength, updates.getSpecialBuffer(), updates.getSpecialShapeInfo(), uLength);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
|
||||||
template <typename T>
|
|
||||||
static void tileBP_(nd4j::LaunchContext * context, const NDArray& gradO /*input*/, NDArray& gradI /*output*/, const std::vector<Nd4jLong> reps) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
void tileBP(nd4j::LaunchContext * context, const NDArray& gradO /*input*/, NDArray& gradI /*output*/, const std::vector<Nd4jLong> reps) {
|
|
||||||
BUILD_SINGLE_SELECTOR(gradI.dataType(), tileBP_, (context, gradO, gradI, reps), FLOAT_TYPES);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void tileBP_, (nd4j::LaunchContext * context, const NDArray& gradO /*input*/, NDArray& gradI /*output*/, const std::vector<Nd4jLong> reps), FLOAT_TYPES);
|
|
||||||
|
|
||||||
void scatterSimple(nd4j::LaunchContext * context, const int opId, NDArray& input, const NDArray& updates, const NDArray& indices, const std::vector<int>& dimensions) {
|
void scatterSimple(nd4j::LaunchContext * context, const int opId, NDArray& input, const NDArray& updates, const NDArray& indices, const std::vector<int>& dimensions) {
|
||||||
auto xType = input.dataType();
|
auto xType = input.dataType();
|
||||||
|
|
|
@ -20,63 +20,68 @@
|
||||||
|
|
||||||
#include <ops/declarable/helpers/helpers.h>
|
#include <ops/declarable/helpers/helpers.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
|
|
||||||
void dilation2d(nd4j::LaunchContext * context, NDArray *input, NDArray *weights, NDArray *output, int stride_rows, int stride_cols, int rate_rows, int rate_cols, int pad_top, int pad_left);
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
void dilation2d(nd4j::LaunchContext* context, NDArray *input, NDArray *weights, NDArray *output, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW);
|
||||||
|
|
||||||
FORCEINLINE Nd4jStatus outputSize(nd4j::LaunchContext * context, int input_size, int filter_size, int dilation_rate, int stride, bool isSameMode, int *output_size, int *padding_before, int *padding_after) {
|
//////////////////////////////////////////////////////////////////////
|
||||||
if (stride <= 0)
|
FORCEINLINE Nd4jStatus outputSize(nd4j::LaunchContext * context, const int inSize, const int k, const int d, const int s, bool isSameMode, int *outSize, int *padding_before, int *padding_after) {
|
||||||
return Status::THROW("Dilation2D: Stride must be > 0");
|
if (s <= 0)
|
||||||
|
return Status::THROW("Dilation2D: Stride must be > 0");
|
||||||
|
|
||||||
if (dilation_rate < 1)
|
if (d < 1)
|
||||||
return Status::THROW("Dilation2D: Dilation rate must be >= 1");
|
return Status::THROW("Dilation2D: Dilation rate must be >= 1");
|
||||||
|
|
||||||
int effective_filter_size = (filter_size - 1) * dilation_rate + 1;
|
int kEff = (k - 1) * d + 1;
|
||||||
if (isSameMode) {
|
if (isSameMode) {
|
||||||
*output_size = (input_size + stride - 1) / stride;
|
*outSize = (inSize + s - 1) / s;
|
||||||
const int padding_needed = nd4j::math::nd4j_max<int>(0, (*output_size - 1) * stride + effective_filter_size -input_size);
|
const int padding_needed = nd4j::math::nd4j_max<int>(0, (*outSize - 1) * s + kEff -inSize);
|
||||||
|
|
||||||
*padding_before = padding_needed / 2;
|
*padding_before = padding_needed / 2;
|
||||||
*padding_after = padding_needed - *padding_before;
|
*padding_after = padding_needed - *padding_before;
|
||||||
} else {
|
} else {
|
||||||
*output_size = (input_size - effective_filter_size + stride) / stride;
|
*outSize = (inSize - kEff + s) / s;
|
||||||
*padding_before = *padding_after = 0;
|
*padding_before = *padding_after = 0;
|
||||||
}
|
|
||||||
|
|
||||||
if (*output_size < 0)
|
|
||||||
return Status::THROW("Dilation2D: output_size has negative value");
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (*outSize < 0)
|
||||||
|
return Status::THROW("Dilation2D: outSize has negative value");
|
||||||
|
|
||||||
FORCEINLINE Nd4jStatus dilation_hw(nd4j::LaunchContext * context, Nd4jLong *in, Nd4jLong *wh, std::vector<int> &strides, std::vector<int> &rates, bool isSameMode, int *stride_rows, int *stride_cols, int *rate_rows, int *rate_cols, int *pad_top, int *pad_left, int *out_rows, int *out_cols) {
|
return Status::OK();
|
||||||
const int input_rows = shape::sizeAt(in, 1);
|
}
|
||||||
const int input_cols = shape::sizeAt(in, 2);
|
|
||||||
const int depth = shape::sizeAt(in, 3);
|
|
||||||
|
|
||||||
*stride_rows = strides[1];
|
//////////////////////////////////////////////////////////////////////
|
||||||
*stride_cols = strides[2];
|
FORCEINLINE Nd4jStatus dilation_hw(nd4j::LaunchContext * context, Nd4jLong *in, Nd4jLong *wh, std::vector<int> &strides, std::vector<int> &rates, bool isSameMode, int *sH, int *sW, int *pH, int *pW, int *dH, int *dW, int *oH, int *oW) {
|
||||||
*rate_rows = rates[1];
|
const int iH = shape::sizeAt(in, 1);
|
||||||
*rate_cols = rates[2];
|
const int iW = shape::sizeAt(in, 2);
|
||||||
|
const int iC = shape::sizeAt(in, 3);
|
||||||
|
|
||||||
const int filter_rows = shape::sizeAt(wh, 0);
|
*sH = strides[1];
|
||||||
const int filter_cols = shape::sizeAt(wh, 1);
|
*sW = strides[2];
|
||||||
|
*dH = rates[1];
|
||||||
|
*dW = rates[2];
|
||||||
|
|
||||||
const int filter_rows_eff = filter_rows + (filter_rows - 1) * (*rate_rows - 1);
|
const int kH = shape::sizeAt(wh, 0);
|
||||||
const int filter_cols_eff = filter_cols + (filter_cols - 1) * (*rate_cols - 1);
|
const int kW = shape::sizeAt(wh, 1);
|
||||||
|
|
||||||
|
const int kHeff = kH + (kH - 1) * (*dH - 1);
|
||||||
|
const int kWeff = kW + (kW - 1) * (*dW - 1);
|
||||||
|
|
||||||
|
int padding_after_unusedA, padding_after_unusedB;
|
||||||
|
if (outputSize(context, iH, kHeff, 1, *sH, isSameMode, oH, pH, &padding_after_unusedA) != Status::OK())
|
||||||
|
return Status::THROW("Dilation2D: bad height");
|
||||||
|
|
||||||
|
if (outputSize(context, iW, kWeff, 1, *sW, isSameMode, oW, pW, &padding_after_unusedA) != Status::OK())
|
||||||
|
return Status::THROW("Dilation2D: bad width");
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
int padding_after_unusedA, padding_after_unusedB;
|
|
||||||
if (outputSize(context, input_rows, filter_rows_eff, 1, *stride_rows, isSameMode, out_rows, pad_top, &padding_after_unusedA) != Status::OK())
|
|
||||||
return Status::THROW("Dilation2D: bad height");
|
|
||||||
|
|
||||||
if (outputSize(context, input_cols, filter_cols_eff, 1, *stride_cols, isSameMode, out_cols, pad_left, &padding_after_unusedA) != Status::OK())
|
|
||||||
return Status::THROW("Dilation2D: bad width");
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -30,7 +30,7 @@ namespace helpers {
|
||||||
T lup(nd4j::LaunchContext * context, NDArray* input, NDArray* compound, NDArray* permutation);
|
T lup(nd4j::LaunchContext * context, NDArray* input, NDArray* compound, NDArray* permutation);
|
||||||
|
|
||||||
int determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output);
|
int determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output);
|
||||||
int log_abs_determinant(nd4j::LaunchContext * context, NDArray* input, NDArray* output);
|
int logAbsDeterminant(nd4j::LaunchContext * context, NDArray* input, NDArray* output);
|
||||||
|
|
||||||
int inverse(nd4j::LaunchContext * context, NDArray* input, NDArray* output);
|
int inverse(nd4j::LaunchContext * context, NDArray* input, NDArray* output);
|
||||||
|
|
||||||
|
|
|
@ -26,11 +26,11 @@
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
void scatter(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock);
|
void scatter(nd4j::LaunchContext* context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock);
|
||||||
|
|
||||||
void scatterND(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock);
|
void scatterND(nd4j::LaunchContext* context, pairwise::Ops op, const NDArray& indices, const NDArray& updates, NDArray& output, const bool lock);
|
||||||
|
|
||||||
void scatterForLoss(nd4j::LaunchContext *context, const NDArray& indices, const NDArray& updates, NDArray& output, const bool calcGrad);
|
void scatterForLoss(nd4j::LaunchContext* context, const NDArray& indices, NDArray& updates, NDArray& output, const bool calcGrad);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,7 +30,7 @@ namespace helpers {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
// svd operation, this function is not method of SVD class, it is standalone function
|
// svd operation, this function is not method of SVD class, it is standalone function
|
||||||
void svd(nd4j::LaunchContext * context, const NDArray* x, const std::vector<NDArray*>& outArrs, const bool fullUV, const bool calcUV, const int switchNum);
|
void svd(nd4j::LaunchContext* context, const NDArray* x, const std::vector<NDArray*>& outArrs, const bool fullUV, const bool calcUV, const int switchNum);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1197,7 +1197,7 @@ inline __device__ float nd4j_atomicMul<float>(float* address, float val) {
|
||||||
do {
|
do {
|
||||||
assumed = old;
|
assumed = old;
|
||||||
old = atomicCAS(address_as_ull, assumed, __float_as_int(val *
|
old = atomicCAS(address_as_ull, assumed, __float_as_int(val *
|
||||||
__float_as_int(assumed)));
|
__int_as_float(assumed)));
|
||||||
} while (assumed != old);
|
} while (assumed != old);
|
||||||
return __int_as_float(old);
|
return __int_as_float(old);
|
||||||
}
|
}
|
||||||
|
|
|
@ -2595,7 +2595,7 @@ TEST_F(DeclarableOpsTests1, sru_bi_1) {
|
||||||
|
|
||||||
NDArray input('c', {N,bS,2*K}, nd4j::DataType::DOUBLE);
|
NDArray input('c', {N,bS,2*K}, nd4j::DataType::DOUBLE);
|
||||||
NDArray weights('c', {2*K,6*K}, nd4j::DataType::DOUBLE);
|
NDArray weights('c', {2*K,6*K}, nd4j::DataType::DOUBLE);
|
||||||
NDArray bias('c', {1,4*K}, nd4j::DataType::DOUBLE);
|
NDArray bias('c', {4*K}, nd4j::DataType::DOUBLE);
|
||||||
NDArray init('c', {bS,2*K}, nd4j::DataType::DOUBLE);
|
NDArray init('c', {bS,2*K}, nd4j::DataType::DOUBLE);
|
||||||
NDArray mask('c', {bS,2*K}, nd4j::DataType::DOUBLE);
|
NDArray mask('c', {bS,2*K}, nd4j::DataType::DOUBLE);
|
||||||
NDArray expState('c', {N,bS,2*K}, {1.02857, 1.02857, 1.02857, 1.11288, 1.11288, 1.11288, 1.02857, 1.02857, 1.02857, 1.11288, 1.11288, 1.11288, 1.0569, 1.0569, 1.0569, 1.08501, 1.08501, 1.08501, 1.0569, 1.0569, 1.0569, 1.08501, 1.08501, 1.08501, 1.08501, 1.08501, 1.08501, 1.0569, 1.0569, 1.0569, 1.08501, 1.08501, 1.08501, 1.0569, 1.0569, 1.0569, 1.11288, 1.11288, 1.11288, 1.02857, 1.02857, 1.02857, 1.11288, 1.11288, 1.11288, 1.02857, 1.02857, 1.02857});
|
NDArray expState('c', {N,bS,2*K}, {1.02857, 1.02857, 1.02857, 1.11288, 1.11288, 1.11288, 1.02857, 1.02857, 1.02857, 1.11288, 1.11288, 1.11288, 1.0569, 1.0569, 1.0569, 1.08501, 1.08501, 1.08501, 1.0569, 1.0569, 1.0569, 1.08501, 1.08501, 1.08501, 1.08501, 1.08501, 1.08501, 1.0569, 1.0569, 1.0569, 1.08501, 1.08501, 1.08501, 1.0569, 1.0569, 1.0569, 1.11288, 1.11288, 1.11288, 1.02857, 1.02857, 1.02857, 1.11288, 1.11288, 1.11288, 1.02857, 1.02857, 1.02857});
|
||||||
|
@ -2635,7 +2635,7 @@ TEST_F(DeclarableOpsTests1, sru_bi_bp_1) {
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<double>('c', {N,bS,2*K});
|
auto input = NDArrayFactory::create<double>('c', {N,bS,2*K});
|
||||||
auto weights = NDArrayFactory::create<double>('c', {2*K,6*K});
|
auto weights = NDArrayFactory::create<double>('c', {2*K,6*K});
|
||||||
auto bias = NDArrayFactory::create<double>('c', {1,4*K});
|
auto bias = NDArrayFactory::create<double>('c', {4*K});
|
||||||
auto init = NDArrayFactory::create<double>('c', {bS,2*K});
|
auto init = NDArrayFactory::create<double>('c', {bS,2*K});
|
||||||
auto mask = NDArrayFactory::create<double>('c', {bS,2*K});
|
auto mask = NDArrayFactory::create<double>('c', {bS,2*K});
|
||||||
NDArray state('c', {N,bS,2*K}, stateBuff);
|
NDArray state('c', {N,bS,2*K}, stateBuff);
|
||||||
|
@ -2646,8 +2646,8 @@ TEST_F(DeclarableOpsTests1, sru_bi_bp_1) {
|
||||||
|
|
||||||
NDArray expGradX('c', {N,bS,2*K}, expGradXBuff);
|
NDArray expGradX('c', {N,bS,2*K}, expGradXBuff);
|
||||||
NDArray expGradW('c', {N,2*K,6*K}, expGradWBuff);
|
NDArray expGradW('c', {N,2*K,6*K}, expGradWBuff);
|
||||||
auto expGradB = NDArrayFactory::create<double>('c', {1,4*K});
|
auto expGradB = NDArrayFactory::create<double>('c', {4*K});
|
||||||
gradBias.reduceAlongDimension(reduce::Sum, &expGradB, {0}, false, true); // [bS x 4K] -> [1 x 4K]
|
gradBias.reduceAlongDimension(reduce::Sum, &expGradB, {0}); // [bS, 4K] -> [4K]
|
||||||
NDArray expGradInit('c', {bS,2*K}, expGradInitBuff);
|
NDArray expGradInit('c', {bS,2*K}, expGradInitBuff);
|
||||||
|
|
||||||
input.assign(1.5);
|
input.assign(1.5);
|
||||||
|
|
|
@ -391,30 +391,6 @@ TEST_F(DeclarableOpsTests10, CosineDistance_SGO_Test_2) {
|
||||||
delete res;
|
delete res;
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests10, svd_test11) {
|
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {3,3}, {1.,2.,3.,4.,5.,6.,7.,8.,9.});
|
|
||||||
auto expS = NDArrayFactory::create<double>('c', {3});
|
|
||||||
auto expU = NDArrayFactory::create<double>('c', {3,3});
|
|
||||||
auto expV = NDArrayFactory::create<double>('c', {3,3});
|
|
||||||
|
|
||||||
nd4j::ops::svd op;
|
|
||||||
auto results = op.execute({&x}, {}, {0, 1, 16});
|
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
||||||
|
|
||||||
auto s = results->at(0);
|
|
||||||
auto u = results->at(1);
|
|
||||||
auto v = results->at(2);
|
|
||||||
|
|
||||||
ASSERT_TRUE(expS.isSameShape(s));
|
|
||||||
ASSERT_TRUE(expU.isSameShape(u));
|
|
||||||
ASSERT_TRUE(expV.isSameShape(v));
|
|
||||||
|
|
||||||
delete results;
|
|
||||||
}
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests10, TestMarixBandPart_Test_1) {
|
TEST_F(DeclarableOpsTests10, TestMarixBandPart_Test_1) {
|
||||||
|
|
||||||
|
|
|
@ -2093,21 +2093,30 @@ TEST_F(DeclarableOpsTests3, svd_test1) {
|
||||||
auto *u = results->at(1);
|
auto *u = results->at(1);
|
||||||
auto *v = results->at(2);
|
auto *v = results->at(2);
|
||||||
|
|
||||||
ASSERT_TRUE(expS.equalsTo(s));
|
|
||||||
ASSERT_TRUE(expU.equalsTo(u));
|
|
||||||
ASSERT_TRUE(expV.equalsTo(v));
|
|
||||||
|
|
||||||
ASSERT_TRUE(expS.isSameShape(s));
|
ASSERT_TRUE(expS.isSameShape(s));
|
||||||
ASSERT_TRUE(expU.isSameShape(u));
|
ASSERT_TRUE(expU.isSameShape(u));
|
||||||
ASSERT_TRUE(expV.isSameShape(v));
|
ASSERT_TRUE(expV.isSameShape(v));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expS.equalsTo(s));
|
||||||
|
|
||||||
|
if(nd4j::Environment::getInstance()->isCPU()) {
|
||||||
|
ASSERT_TRUE(expU.equalsTo(u));
|
||||||
|
ASSERT_TRUE(expV.equalsTo(v));
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
for(uint i = 0; i < expU.lengthOf(); ++i)
|
||||||
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.t<float>(i)), nd4j::math::nd4j_abs(u->t<float>(i)), 1e-5);
|
||||||
|
for(uint i = 0; i < expV.lengthOf(); ++i)
|
||||||
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.t<float>(i)), nd4j::math::nd4j_abs(v->t<float>(i)), 1e-5);
|
||||||
|
}
|
||||||
|
|
||||||
delete results;
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests3, svd_test2) {
|
TEST_F(DeclarableOpsTests3, svd_test2) {
|
||||||
|
|
||||||
auto x= NDArrayFactory::create<float>('c', {7,6}, {0. ,-9. ,-6 ,9 ,-10 ,-12 ,2 ,13 ,5 ,-11 ,20 ,-17 ,1 ,-2 ,-11 ,3 ,-8 ,3 ,-14 ,19 ,-20 ,20 ,-17 ,-5 ,6 ,-16 ,0 ,-1 ,-16 ,11 ,7 ,-19 ,2 ,-17 ,17 ,-16, 4, -9, 1, -15, 7, -2});
|
auto x = NDArrayFactory::create<float>('c', {7,6}, {0. ,-9. ,-6 ,9 ,-10 ,-12 ,2 ,13 ,5 ,-11 ,20 ,-17 ,1 ,-2 ,-11 ,3 ,-8 ,3 ,-14 ,19 ,-20 ,20 ,-17 ,-5 ,6 ,-16 ,0 ,-1 ,-16 ,11 ,7 ,-19 ,2 ,-17 ,17 ,-16, 4, -9, 1, -15, 7, -2});
|
||||||
auto expS= NDArrayFactory::create<float>('c', {6}, {56.76573, 39.11776, 26.00713, 11.83606, 6.16578, 3.99672});
|
auto expS= NDArrayFactory::create<float>('c', {6}, {56.76573, 39.11776, 26.00713, 11.83606, 6.16578, 3.99672});
|
||||||
auto expU= NDArrayFactory::create<float>('c', {7,7}, {-0.13417,-0.12443, -0.68854, 0.5196 , 0.21706, 0.03974, 0.41683, 0.347 , 0.62666, -0.04964, -0.01912, 0.66932, 0.1457 , -0.12183,-0.17329,-0.14666, -0.19639, -0.55355, 0.0614 , 0.75729, 0.1619 ,-0.64703, 0.37056, -0.37398, -0.32922, -0.0186 , -0.35656, -0.26134,-0.08027,-0.64405, -0.0127 , -0.06934, 0.59287, -0.14956, -0.44712, 0.55906,-0.06235, -0.58017, -0.12911, -0.359 , -0.00393, -0.44877, 0.30645,-0.11953, -0.09083, -0.54163, 0.14283, -0.50417, 0.56178});
|
auto expU= NDArrayFactory::create<float>('c', {7,7}, {-0.13417,-0.12443, -0.68854, 0.5196 , 0.21706, 0.03974, 0.41683, 0.347 , 0.62666, -0.04964, -0.01912, 0.66932, 0.1457 , -0.12183,-0.17329,-0.14666, -0.19639, -0.55355, 0.0614 , 0.75729, 0.1619 ,-0.64703, 0.37056, -0.37398, -0.32922, -0.0186 , -0.35656, -0.26134,-0.08027,-0.64405, -0.0127 , -0.06934, 0.59287, -0.14956, -0.44712, 0.55906,-0.06235, -0.58017, -0.12911, -0.359 , -0.00393, -0.44877, 0.30645,-0.11953, -0.09083, -0.54163, 0.14283, -0.50417, 0.56178});
|
||||||
auto expV= NDArrayFactory::create<float>('c', {6,6}, {0.2508 ,-0.2265 , 0.01689, 0.04486, 0.53132, 0.77537,-0.32281, 0.74559, 0.41845, -0.13821, 0.37642, 0.06315, 0.33139,-0.05528, 0.47186, 0.73171, 0.18905, -0.3055 ,-0.57263, 0.06276,-0.09542, 0.59396, -0.36152, 0.419 , 0.59193, 0.4361 , 0.13557, -0.03632, -0.5755 , 0.32944,-0.21165,-0.44227, 0.75794, -0.29895, -0.27993, 0.13187});
|
auto expV= NDArrayFactory::create<float>('c', {6,6}, {0.2508 ,-0.2265 , 0.01689, 0.04486, 0.53132, 0.77537,-0.32281, 0.74559, 0.41845, -0.13821, 0.37642, 0.06315, 0.33139,-0.05528, 0.47186, 0.73171, 0.18905, -0.3055 ,-0.57263, 0.06276,-0.09542, 0.59396, -0.36152, 0.419 , 0.59193, 0.4361 , 0.13557, -0.03632, -0.5755 , 0.32944,-0.21165,-0.44227, 0.75794, -0.29895, -0.27993, 0.13187});
|
||||||
|
@ -2121,14 +2130,23 @@ TEST_F(DeclarableOpsTests3, svd_test2) {
|
||||||
auto *u = results->at(1);
|
auto *u = results->at(1);
|
||||||
auto *v = results->at(2);
|
auto *v = results->at(2);
|
||||||
|
|
||||||
ASSERT_TRUE(expS.equalsTo(s));
|
|
||||||
ASSERT_TRUE(expU.equalsTo(u));
|
|
||||||
ASSERT_TRUE(expV.equalsTo(v));
|
|
||||||
|
|
||||||
ASSERT_TRUE(expS.isSameShape(s));
|
ASSERT_TRUE(expS.isSameShape(s));
|
||||||
ASSERT_TRUE(expU.isSameShape(u));
|
ASSERT_TRUE(expU.isSameShape(u));
|
||||||
ASSERT_TRUE(expV.isSameShape(v));
|
ASSERT_TRUE(expV.isSameShape(v));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expS.equalsTo(s));
|
||||||
|
|
||||||
|
if(nd4j::Environment::getInstance()->isCPU()) {
|
||||||
|
ASSERT_TRUE(expU.equalsTo(u));
|
||||||
|
ASSERT_TRUE(expV.equalsTo(v));
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
for(uint i = 0; i < expU.lengthOf(); ++i)
|
||||||
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.t<float>(i)), nd4j::math::nd4j_abs(u->t<float>(i)), 1e-5);
|
||||||
|
for(uint i = 0; i < expV.lengthOf(); ++i)
|
||||||
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.t<float>(i)), nd4j::math::nd4j_abs(v->t<float>(i)), 1e-5);
|
||||||
|
}
|
||||||
|
|
||||||
delete results;
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2149,14 +2167,23 @@ TEST_F(DeclarableOpsTests3, svd_test3) {
|
||||||
auto *u = results->at(1);
|
auto *u = results->at(1);
|
||||||
auto *v = results->at(2);
|
auto *v = results->at(2);
|
||||||
|
|
||||||
ASSERT_TRUE(expS.equalsTo(s));
|
|
||||||
ASSERT_TRUE(expU.equalsTo(u));
|
|
||||||
ASSERT_TRUE(expV.equalsTo(v));
|
|
||||||
|
|
||||||
ASSERT_TRUE(expS.isSameShape(s));
|
ASSERT_TRUE(expS.isSameShape(s));
|
||||||
ASSERT_TRUE(expU.isSameShape(u));
|
ASSERT_TRUE(expU.isSameShape(u));
|
||||||
ASSERT_TRUE(expV.isSameShape(v));
|
ASSERT_TRUE(expV.isSameShape(v));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expS.equalsTo(s));
|
||||||
|
|
||||||
|
if(nd4j::Environment::getInstance()->isCPU()) {
|
||||||
|
ASSERT_TRUE(expU.equalsTo(u));
|
||||||
|
ASSERT_TRUE(expV.equalsTo(v));
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
for(uint i = 0; i < expU.lengthOf(); ++i)
|
||||||
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.t<float>(i)), nd4j::math::nd4j_abs(u->t<float>(i)), 1e-5);
|
||||||
|
for(uint i = 0; i < expV.lengthOf(); ++i)
|
||||||
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.t<float>(i)), nd4j::math::nd4j_abs(v->t<float>(i)), 1e-5);
|
||||||
|
}
|
||||||
|
|
||||||
delete results;
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2177,14 +2204,23 @@ TEST_F(DeclarableOpsTests3, svd_test4) {
|
||||||
auto *u = results->at(1);
|
auto *u = results->at(1);
|
||||||
auto *v = results->at(2);
|
auto *v = results->at(2);
|
||||||
|
|
||||||
ASSERT_TRUE(expS.equalsTo(s));
|
|
||||||
ASSERT_TRUE(expU.equalsTo(u));
|
|
||||||
ASSERT_TRUE(expV.equalsTo(v));
|
|
||||||
|
|
||||||
ASSERT_TRUE(expS.isSameShape(s));
|
ASSERT_TRUE(expS.isSameShape(s));
|
||||||
ASSERT_TRUE(expU.isSameShape(u));
|
ASSERT_TRUE(expU.isSameShape(u));
|
||||||
ASSERT_TRUE(expV.isSameShape(v));
|
ASSERT_TRUE(expV.isSameShape(v));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expS.equalsTo(s));
|
||||||
|
|
||||||
|
if(nd4j::Environment::getInstance()->isCPU()) {
|
||||||
|
ASSERT_TRUE(expU.equalsTo(u));
|
||||||
|
ASSERT_TRUE(expV.equalsTo(v));
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
for(uint i = 0; i < expU.lengthOf(); ++i)
|
||||||
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.t<float>(i)), nd4j::math::nd4j_abs(u->t<float>(i)), 1e-5);
|
||||||
|
for(uint i = 0; i < expV.lengthOf(); ++i)
|
||||||
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.t<float>(i)), nd4j::math::nd4j_abs(v->t<float>(i)), 1e-5);
|
||||||
|
}
|
||||||
|
|
||||||
delete results;
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2205,14 +2241,23 @@ TEST_F(DeclarableOpsTests3, svd_test5) {
|
||||||
auto *u = results->at(1);
|
auto *u = results->at(1);
|
||||||
auto *v = results->at(2);
|
auto *v = results->at(2);
|
||||||
|
|
||||||
ASSERT_TRUE(expS.equalsTo(s));
|
|
||||||
ASSERT_TRUE(expU.equalsTo(u));
|
|
||||||
ASSERT_TRUE(expV.equalsTo(v));
|
|
||||||
|
|
||||||
ASSERT_TRUE(expS.isSameShape(s));
|
ASSERT_TRUE(expS.isSameShape(s));
|
||||||
ASSERT_TRUE(expU.isSameShape(u));
|
ASSERT_TRUE(expU.isSameShape(u));
|
||||||
ASSERT_TRUE(expV.isSameShape(v));
|
ASSERT_TRUE(expV.isSameShape(v));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expS.equalsTo(s));
|
||||||
|
|
||||||
|
if(nd4j::Environment::getInstance()->isCPU()) {
|
||||||
|
ASSERT_TRUE(expU.equalsTo(u));
|
||||||
|
ASSERT_TRUE(expV.equalsTo(v));
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
for(uint i = 0; i < expU.lengthOf(); ++i)
|
||||||
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.t<float>(i)), nd4j::math::nd4j_abs(u->t<float>(i)), 1e-5);
|
||||||
|
for(uint i = 0; i < expV.lengthOf(); ++i)
|
||||||
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.t<float>(i)), nd4j::math::nd4j_abs(v->t<float>(i)), 1e-5);
|
||||||
|
}
|
||||||
|
|
||||||
delete results;
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2220,56 +2265,27 @@ TEST_F(DeclarableOpsTests3, svd_test5) {
|
||||||
TEST_F(DeclarableOpsTests3, svd_test6) {
|
TEST_F(DeclarableOpsTests3, svd_test6) {
|
||||||
|
|
||||||
auto x= NDArrayFactory::create<float>('c', {2,2,5,5}, {-7. ,17 ,4 ,-10 ,5 ,1 ,-5 ,-19 ,13 ,-8 ,9 ,13 ,19 ,13 ,-2
|
auto x= NDArrayFactory::create<float>('c', {2,2,5,5}, {-7. ,17 ,4 ,-10 ,5 ,1 ,-5 ,-19 ,13 ,-8 ,9 ,13 ,19 ,13 ,-2
|
||||||
,-8 ,10 ,-9 ,0 ,-20 ,-2 ,14 ,19 ,5 ,-18 ,4 ,-13 ,12 ,-10
|
,-8 ,10 ,-9 ,0 ,-20 ,-2 ,14 ,19 ,5 ,-18 ,4 ,-13 ,12 ,-10 ,5 ,-10 ,-10 ,17 ,-5 ,-2 ,10 ,5 ,-4 ,-11 ,15 ,-3 ,15 ,-17
|
||||||
,5 ,-10 ,-10 ,17 ,-5 ,-2 ,10 ,5 ,-4 ,-11 ,15 ,-3 ,15 ,-17
|
,-20 ,-10 ,-4 ,12 ,-9 ,16 ,13 ,10 ,-19 ,2 ,-9 ,-10 ,8 ,-2 ,-4 ,3 ,7 ,10 ,-19 ,-11 ,-4 ,-6 ,2 ,-12 ,6 ,-4 ,-14 ,14
|
||||||
,-20 ,-10 ,-4 ,12 ,-9 ,16 ,13 ,10 ,-19 ,2 ,-9 ,-10 ,8 ,-2
|
,16 ,7 ,19 ,-17 ,2 ,-14 ,5 ,-1 ,16 ,19 ,-11 ,-14 ,-16 ,-19 ,15 ,-18 ,-12 ,-16 ,16 ,1 ,5 ,7 ,8 ,2 ,13 ,-3 ,6 ,2 ,-5});
|
||||||
,-4 ,3 ,7 ,10 ,-19 ,-11 ,-4 ,-6 ,2 ,-12 ,6 ,-4 ,-14 ,14
|
|
||||||
,16 ,7 ,19 ,-17 ,2 ,-14 ,5 ,-1 ,16 ,19 ,-11 ,-14 ,-16
|
|
||||||
,-19 ,15 ,-18 ,-12 ,-16 ,16 ,1 ,5 ,7 ,8 ,2 ,13 ,-3 ,6 ,2 ,-5});
|
|
||||||
auto expS= NDArrayFactory::create<float>('c', {2,2,5}, {40.95395, 31.46869, 24.79993, 12.33768, 1.80031,
|
auto expS= NDArrayFactory::create<float>('c', {2,2,5}, {40.95395, 31.46869, 24.79993, 12.33768, 1.80031,
|
||||||
38.18412, 31.52287, 23.52755, 11.79484, 1.90195,
|
38.18412, 31.52287, 23.52755, 11.79484, 1.90195,
|
||||||
39.34498, 32.54861, 17.52492, 7.03003, 2.2399,
|
39.34498, 32.54861, 17.52492, 7.03003, 2.2399,
|
||||||
44.72126, 32.3164 , 16.60139, 6.88783, 0.78122});
|
44.72126, 32.3164 , 16.60139, 6.88783, 0.78122});
|
||||||
auto expU= NDArrayFactory::create<float>('c', {2,2,5,5}, {0.25441, 0.16908, -0.68564, 0.58844, -0.30054,
|
auto expU= NDArrayFactory::create<float>('c', {2,2,5,5}, {0.25441, 0.16908, -0.68564, 0.58844, -0.30054,
|
||||||
-0.32285, -0.58332, 0.3451 , 0.4746 , -0.45953,
|
-0.32285, -0.58332, 0.3451 , 0.4746 , -0.45953,0.58332, 0.10605, 0.51533, 0.50234, 0.36136,0.12588, -0.73123, -0.37812, -0.00215, 0.55361,
|
||||||
0.58332, 0.10605, 0.51533, 0.50234, 0.36136,
|
0.68915, -0.2919 , 0.04767, -0.4197 , -0.51132,0.44464, -0.25326, -0.42493, -0.01712, -0.74653,0.516 , -0.16688, 0.1854 , -0.77155, 0.27611,
|
||||||
0.12588, -0.73123, -0.37812, -0.00215, 0.55361,
|
-0.19321, -0.14317, -0.85886, -0.15224, 0.42585,-0.60155, -0.68323, 0.18819, -0.29053, -0.22696,-0.36993, 0.64862, -0.10956, -0.54483, -0.36552,
|
||||||
0.68915, -0.2919 , 0.04767, -0.4197 , -0.51132,
|
-0.57697, -0.32277, 0.11229, 0.55495, 0.4923 ,-0.02937, 0.01689, -0.63257, 0.57075, -0.52245,-0.56002, -0.2036 , -0.53119, -0.6022 , 0.01017,
|
||||||
0.44464, -0.25326, -0.42493, -0.01712, -0.74653,
|
-0.33605, -0.35257, 0.53215, -0.04936, -0.69075,0.48958, -0.85427, -0.14796, -0.03449, 0.08633,0.15008, 0.60996, 0.31071, -0.67721, 0.22421,
|
||||||
0.516 , -0.16688, 0.1854 , -0.77155, 0.27611,
|
0.67717, -0.59857, 0.04372, -0.2565 , 0.33979,0.68116, 0.49852, -0.13441, 0.51374, -0.07421,-0.20066, 0.04504, 0.42865, 0.44418, 0.75939,0.12113, -0.13826, 0.83651, 0.11988, -0.50209});
|
||||||
-0.19321, -0.14317, -0.85886, -0.15224, 0.42585,
|
|
||||||
-0.60155, -0.68323, 0.18819, -0.29053, -0.22696,
|
|
||||||
-0.36993, 0.64862, -0.10956, -0.54483, -0.36552,
|
|
||||||
-0.57697, -0.32277, 0.11229, 0.55495, 0.4923 ,
|
|
||||||
-0.02937, 0.01689, -0.63257, 0.57075, -0.52245,
|
|
||||||
-0.56002, -0.2036 , -0.53119, -0.6022 , 0.01017,
|
|
||||||
-0.33605, -0.35257, 0.53215, -0.04936, -0.69075,
|
|
||||||
0.48958, -0.85427, -0.14796, -0.03449, 0.08633,
|
|
||||||
0.15008, 0.60996, 0.31071, -0.67721, 0.22421,
|
|
||||||
0.67717, -0.59857, 0.04372, -0.2565 , 0.33979,
|
|
||||||
0.68116, 0.49852, -0.13441, 0.51374, -0.07421,
|
|
||||||
-0.20066, 0.04504, 0.42865, 0.44418, 0.75939,
|
|
||||||
0.12113, -0.13826, 0.83651, 0.11988, -0.50209});
|
|
||||||
auto expV= NDArrayFactory::create<float>('c', {2,2,5,5}, {0.01858, 0.17863, 0.51259, 0.14048, 0.82781,
|
auto expV= NDArrayFactory::create<float>('c', {2,2,5,5}, {0.01858, 0.17863, 0.51259, 0.14048, 0.82781,
|
||||||
0.59651, -0.13439, -0.395 , 0.66979, 0.14654,
|
0.59651, -0.13439, -0.395 , 0.66979, 0.14654,0.73731, 0.47061, 0.19357, -0.41127, -0.16817,0.1047 , -0.29727, 0.73711, 0.38235, -0.45951,
|
||||||
0.73731, 0.47061, 0.19357, -0.41127, -0.16817,
|
-0.29873, 0.80012, -0.02078, 0.4651 , -0.23201,-0.05314, -0.0419 , -0.52146, 0.77792, 0.344 ,-0.66438, 0.05648, 0.03756, -0.31531, 0.67422,
|
||||||
0.1047 , -0.29727, 0.73711, 0.38235, -0.45951,
|
0.74471, 0.01504, -0.03081, -0.24335, 0.62049,0.03172, 0.91947, 0.30828, 0.23713, 0.04796,-0.01311, 0.38652, -0.79415, -0.42423, -0.19945,
|
||||||
-0.29873, 0.80012, -0.02078, 0.4651 , -0.23201,
|
-0.13783, -0.54667, -0.58527, 0.49955, 0.3001 ,0.85214, 0.01628, 0.02688, -0.02891, 0.52157,0.16608, -0.20181, 0.61371, 0.69894, -0.25794,
|
||||||
-0.05314, -0.0419 , -0.52146, 0.77792, 0.344 ,
|
0.45726, -0.33952, -0.32659, -0.18938, -0.73015,0.13486, 0.73816, -0.41646, 0.47458, -0.1956 ,0.5536 , -0.137 , 0.64688, 0.50536, 0.03017,
|
||||||
-0.66438, 0.05648, 0.03756, -0.31531, 0.67422,
|
-0.51827, -0.31837, -0.16732, 0.71378, -0.30425,-0.39314, 0.15266, 0.63693, -0.30945, -0.5663 ,-0.51981, 0.03325, 0.37603, 0.05147, 0.76462,-0.01282, 0.92491, -0.08042, 0.36977, -0.03428});
|
||||||
0.74471, 0.01504, -0.03081, -0.24335, 0.62049,
|
|
||||||
0.03172, 0.91947, 0.30828, 0.23713, 0.04796,
|
|
||||||
-0.01311, 0.38652, -0.79415, -0.42423, -0.19945,
|
|
||||||
-0.13783, -0.54667, -0.58527, 0.49955, 0.3001 ,
|
|
||||||
0.85214, 0.01628, 0.02688, -0.02891, 0.52157,
|
|
||||||
0.16608, -0.20181, 0.61371, 0.69894, -0.25794,
|
|
||||||
0.45726, -0.33952, -0.32659, -0.18938, -0.73015,
|
|
||||||
0.13486, 0.73816, -0.41646, 0.47458, -0.1956 ,
|
|
||||||
0.5536 , -0.137 , 0.64688, 0.50536, 0.03017,
|
|
||||||
-0.51827, -0.31837, -0.16732, 0.71378, -0.30425,
|
|
||||||
-0.39314, 0.15266, 0.63693, -0.30945, -0.5663 ,
|
|
||||||
-0.51981, 0.03325, 0.37603, 0.05147, 0.76462,
|
|
||||||
-0.01282, 0.92491, -0.08042, 0.36977, -0.03428});
|
|
||||||
|
|
||||||
nd4j::ops::svd op;
|
nd4j::ops::svd op;
|
||||||
auto results = op.execute({&x}, {}, {1, 1, 16});
|
auto results = op.execute({&x}, {}, {1, 1, 16});
|
||||||
|
@ -2280,14 +2296,23 @@ TEST_F(DeclarableOpsTests3, svd_test6) {
|
||||||
auto *u = results->at(1);
|
auto *u = results->at(1);
|
||||||
auto *v = results->at(2);
|
auto *v = results->at(2);
|
||||||
|
|
||||||
ASSERT_TRUE(expS.equalsTo(s));
|
|
||||||
ASSERT_TRUE(expU.equalsTo(u));
|
|
||||||
ASSERT_TRUE(expV.equalsTo(v));
|
|
||||||
|
|
||||||
ASSERT_TRUE(expS.isSameShape(s));
|
ASSERT_TRUE(expS.isSameShape(s));
|
||||||
ASSERT_TRUE(expU.isSameShape(u));
|
ASSERT_TRUE(expU.isSameShape(u));
|
||||||
ASSERT_TRUE(expV.isSameShape(v));
|
ASSERT_TRUE(expV.isSameShape(v));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expS.equalsTo(s));
|
||||||
|
|
||||||
|
if(nd4j::Environment::getInstance()->isCPU()) {
|
||||||
|
ASSERT_TRUE(expU.equalsTo(u));
|
||||||
|
ASSERT_TRUE(expV.equalsTo(v));
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
for(uint i = 0; i < expU.lengthOf(); ++i)
|
||||||
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.t<float>(i)), nd4j::math::nd4j_abs(u->t<float>(i)), 1e-5);
|
||||||
|
for(uint i = 0; i < expV.lengthOf(); ++i)
|
||||||
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.t<float>(i)), nd4j::math::nd4j_abs(v->t<float>(i)), 1e-5);
|
||||||
|
}
|
||||||
|
|
||||||
delete results;
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2451,13 +2476,22 @@ TEST_F(DeclarableOpsTests3, svd_test7) {
|
||||||
// auto *u = results->at(1);
|
// auto *u = results->at(1);
|
||||||
// auto *v = results->at(2);
|
// auto *v = results->at(2);
|
||||||
|
|
||||||
// ASSERT_TRUE(expS.isSameShape(s));
|
// ASSERT_TRUE(expS.isSameShape(s));
|
||||||
// ASSERT_TRUE(expU.isSameShape(u));
|
// ASSERT_TRUE(expU.isSameShape(u));
|
||||||
// ASSERT_TRUE(expV.isSameShape(v));
|
// ASSERT_TRUE(expV.isSameShape(v));
|
||||||
|
|
||||||
// ASSERT_TRUE(expS.equalsTo(s));
|
// ASSERT_TRUE(expS.equalsTo(s));
|
||||||
// ASSERT_TRUE(expU.equalsTo(u));
|
|
||||||
// ASSERT_TRUE(expV.equalsTo(v));
|
// if(nd4j::Environment::getInstance()->isCPU()) {
|
||||||
|
// ASSERT_TRUE(expU.equalsTo(u));
|
||||||
|
// ASSERT_TRUE(expV.equalsTo(v));
|
||||||
|
// }
|
||||||
|
// else {
|
||||||
|
// for(uint i = 0; i < expU.lengthOf(); ++i)
|
||||||
|
// ASSERT_NEAR(nd4j::math::nd4j_abs(expU.t<float>(i)), nd4j::math::nd4j_abs(u->t<float>(i)), 1e-5);
|
||||||
|
// for(uint i = 0; i < expV.lengthOf(); ++i)
|
||||||
|
// ASSERT_NEAR(nd4j::math::nd4j_abs(expV.t<float>(i)), nd4j::math::nd4j_abs(v->t<float>(i)), 1e-5);
|
||||||
|
// }
|
||||||
|
|
||||||
// delete results;
|
// delete results;
|
||||||
// }
|
// }
|
||||||
|
@ -2555,14 +2589,23 @@ TEST_F(DeclarableOpsTests3, svd_test9) {
|
||||||
auto *u = results->at(1);
|
auto *u = results->at(1);
|
||||||
auto *v = results->at(2);
|
auto *v = results->at(2);
|
||||||
|
|
||||||
ASSERT_TRUE(expS.equalsTo(s));
|
|
||||||
ASSERT_TRUE(expU.equalsTo(u));
|
|
||||||
ASSERT_TRUE(expV.equalsTo(v));
|
|
||||||
|
|
||||||
ASSERT_TRUE(expS.isSameShape(s));
|
ASSERT_TRUE(expS.isSameShape(s));
|
||||||
ASSERT_TRUE(expU.isSameShape(u));
|
ASSERT_TRUE(expU.isSameShape(u));
|
||||||
ASSERT_TRUE(expV.isSameShape(v));
|
ASSERT_TRUE(expV.isSameShape(v));
|
||||||
|
|
||||||
|
ASSERT_TRUE(expS.equalsTo(s));
|
||||||
|
|
||||||
|
if(nd4j::Environment::getInstance()->isCPU()) {
|
||||||
|
ASSERT_TRUE(expU.equalsTo(u));
|
||||||
|
ASSERT_TRUE(expV.equalsTo(v));
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
for(uint i = 0; i < expU.lengthOf(); ++i)
|
||||||
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.t<float>(i)), nd4j::math::nd4j_abs(u->t<float>(i)), 1e-5);
|
||||||
|
for(uint i = 0; i < expV.lengthOf(); ++i)
|
||||||
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.t<float>(i)), nd4j::math::nd4j_abs(v->t<float>(i)), 1e-5);
|
||||||
|
}
|
||||||
|
|
||||||
delete results;
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2659,9 +2702,42 @@ TEST_F(DeclarableOpsTests3, svd_test10) {
|
||||||
auto *u = results->at(1);
|
auto *u = results->at(1);
|
||||||
auto *v = results->at(2);
|
auto *v = results->at(2);
|
||||||
|
|
||||||
|
ASSERT_TRUE(expS.isSameShape(s));
|
||||||
|
ASSERT_TRUE(expU.isSameShape(u));
|
||||||
|
ASSERT_TRUE(expV.isSameShape(v));
|
||||||
|
|
||||||
ASSERT_TRUE(expS.equalsTo(s));
|
ASSERT_TRUE(expS.equalsTo(s));
|
||||||
ASSERT_TRUE(expU.equalsTo(u));
|
|
||||||
ASSERT_TRUE(expV.equalsTo(v));
|
if(nd4j::Environment::getInstance()->isCPU()) {
|
||||||
|
ASSERT_TRUE(expU.equalsTo(u));
|
||||||
|
ASSERT_TRUE(expV.equalsTo(v));
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
for(uint i = 0; i < expU.lengthOf(); ++i)
|
||||||
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.t<float>(i)), nd4j::math::nd4j_abs(u->t<float>(i)), 1e-5);
|
||||||
|
for(uint i = 0; i < expV.lengthOf(); ++i)
|
||||||
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.t<float>(i)), nd4j::math::nd4j_abs(v->t<float>(i)), 1e-5);
|
||||||
|
}
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests3, svd_test11) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {3,3}, {1.,2.,3.,4.,5.,6.,7.,8.,9.});
|
||||||
|
auto expS = NDArrayFactory::create<double>('c', {3});
|
||||||
|
auto expU = NDArrayFactory::create<double>('c', {3,3});
|
||||||
|
auto expV = NDArrayFactory::create<double>('c', {3,3});
|
||||||
|
|
||||||
|
nd4j::ops::svd op;
|
||||||
|
auto results = op.execute({&x}, {}, {0, 1, 16});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto s = results->at(0);
|
||||||
|
auto u = results->at(1);
|
||||||
|
auto v = results->at(2);
|
||||||
|
|
||||||
ASSERT_TRUE(expS.isSameShape(s));
|
ASSERT_TRUE(expS.isSameShape(s));
|
||||||
ASSERT_TRUE(expU.isSameShape(u));
|
ASSERT_TRUE(expU.isSameShape(u));
|
||||||
|
@ -2679,4 +2755,3 @@ TEST_F(DeclarableOpsTests3, svd_test10) {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1933,7 +1933,7 @@ TEST_F(DeclarableOpsTests4, triu_bp_test1) {
|
||||||
auto gradO = NDArrayFactory::create<double>('c', {2, 3, 2});
|
auto gradO = NDArrayFactory::create<double>('c', {2, 3, 2});
|
||||||
gradO = 0.5;
|
gradO = 0.5;
|
||||||
|
|
||||||
auto expected = NDArrayFactory::create<double>('c', {2, 3, 2}, {0.,0.5,0.,0. ,0.,0. ,0.,0.5,0.,0. ,0.,0.});
|
auto expected = NDArrayFactory::create<double>('c', {2, 3, 2}, {0.,0.5,0.,0. ,0.,0. ,0.,0.5,0.,0. ,0.,0.});
|
||||||
|
|
||||||
nd4j::ops::triu_bp op;
|
nd4j::ops::triu_bp op;
|
||||||
auto results = op.execute({&input, &gradO}, {}, {1});
|
auto results = op.execute({&input, &gradO}, {}, {1});
|
||||||
|
|
|
@ -177,9 +177,9 @@ TEST_F(DeclarableOpsTests5, Test_TTS_bp_1) {
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
z->printShapeInfo("RES shape");
|
// z->printShapeInfo("RES shape");
|
||||||
x.printShapeInfo("EXP shape");
|
// x.printShapeInfo("EXP shape");
|
||||||
z->printIndexedBuffer("RES output");
|
// z->printIndexedBuffer("RES output");
|
||||||
ASSERT_TRUE(x.isSameShape(z));
|
ASSERT_TRUE(x.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
@ -1335,11 +1335,11 @@ TEST_F(DeclarableOpsTests5, trace_test1) {
|
||||||
auto results = op.execute({&input}, {}, {});
|
auto results = op.execute({&input}, {}, {});
|
||||||
auto output = results->at(0);
|
auto output = results->at(0);
|
||||||
double traceM = matrix.getTrace();
|
double traceM = matrix.getTrace();
|
||||||
nd4j_printf("Trace for matrix is %f\n", traceM);
|
// nd4j_printf("Trace for matrix is %f\n", traceM);
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
exp.printIndexedBuffer("EXP TRACE");
|
// exp.printIndexedBuffer("EXP TRACE");
|
||||||
output->printIndexedBuffer("OUT TRACE");
|
// output->printIndexedBuffer("OUT TRACE");
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
delete results;
|
delete results;
|
||||||
|
|
|
@ -1209,6 +1209,51 @@ TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_7) {
|
||||||
|
|
||||||
delete res;
|
delete res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_8) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<int>('c', {1}, {1});
|
||||||
|
|
||||||
|
auto y = NDArrayFactory::create<int>('c', {1}, {4});
|
||||||
|
|
||||||
|
// ------------------------------------
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<int>('c', {1}, {4});
|
||||||
|
|
||||||
|
nd4j::ops::broadcast_dynamic_shape op;
|
||||||
|
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT32);
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, res->status());
|
||||||
|
// res->at(0)->printIndexedBuffer("Output SGO 8");
|
||||||
|
// exp.printIndexedBuffer("Expect");
|
||||||
|
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
||||||
|
|
||||||
|
delete res;
|
||||||
|
}
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_9) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<int>('c', {2}, {2,2});
|
||||||
|
|
||||||
|
auto y = NDArrayFactory::create<int>('c', {1}, {1});
|
||||||
|
|
||||||
|
// ------------------------------------
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<int>('c', {2}, {2,2});
|
||||||
|
|
||||||
|
nd4j::ops::broadcast_dynamic_shape op;
|
||||||
|
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT32);
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, res->status());
|
||||||
|
res->at(0)->printIndexedBuffer("Output SGO 9");
|
||||||
|
exp.printIndexedBuffer("Expect9");
|
||||||
|
ASSERT_TRUE(exp.equalsTo(res->at(0)));
|
||||||
|
|
||||||
|
delete res;
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests6, ClipByGlobalNorm_1) {
|
TEST_F(DeclarableOpsTests6, ClipByGlobalNorm_1) {
|
||||||
|
|
||||||
|
@ -1496,10 +1541,33 @@ TEST_F(DeclarableOpsTests6, LogDet_1) {
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests6, LogDet_2) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {1, 3, 3}, {4,12,-16,12,37,-43,-16,-43,98});
|
||||||
|
auto exp = NDArrayFactory::create<double>('c', {1}, { 3.5835189});
|
||||||
|
|
||||||
|
//x.printIndexedBuffer("Input");
|
||||||
|
nd4j::ops::logdet op;
|
||||||
|
auto result = op.execute({&x}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
// z->printIndexedBuffer("Output ");
|
||||||
|
// z->printShapeInfo("Shape");
|
||||||
|
//exp.printIndexedBuffer("Expected ");
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests6, MatrixInverse_1) {
|
TEST_F(DeclarableOpsTests6, MatrixInverse_1) {
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {2, 5, 5}, {
|
auto x = NDArrayFactory::create<float>('c', {2, 5, 5}, {
|
||||||
2., 4., 60., 8., 10.,
|
2., 4., 60., 8., 10.,
|
||||||
0., 1., 2., 3., 4.,
|
0., 1., 2., 3., 4.,
|
||||||
0., 0., 2., 4., 6.,
|
0., 0., 2., 4., 6.,
|
||||||
|
@ -1513,7 +1581,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_1) {
|
||||||
5., 4., 3., 2., 1.,
|
5., 4., 3., 2., 1.,
|
||||||
});
|
});
|
||||||
|
|
||||||
auto exp = NDArrayFactory::create<double>('c', {2, 5, 5}, {
|
auto exp = NDArrayFactory::create<float>('c', {2, 5, 5}, {
|
||||||
0.5, -2.0, -13.0, 54.0, -6.75,
|
0.5, -2.0, -13.0, 54.0, -6.75,
|
||||||
0.0, 1.0, -1.0, 1.0, 0.0,
|
0.0, 1.0, -1.0, 1.0, 0.0,
|
||||||
0, 0, 0.5, -2.0, 0.25,
|
0, 0, 0.5, -2.0, 0.25,
|
||||||
|
@ -1528,7 +1596,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_1) {
|
||||||
});
|
});
|
||||||
|
|
||||||
nd4j::ops::matrix_inverse op;
|
nd4j::ops::matrix_inverse op;
|
||||||
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
|
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
|
|
@ -575,7 +575,7 @@ TEST_F(DeclarableOpsTests7, Test_Dynamic_Stitch_119_1) {
|
||||||
z->printShapeInfo("Stitch Shape");
|
z->printShapeInfo("Stitch Shape");
|
||||||
ASSERT_TRUE(z->isSameShape(exp));
|
ASSERT_TRUE(z->isSameShape(exp));
|
||||||
ASSERT_TRUE(z->equalsTo(exp));
|
ASSERT_TRUE(z->equalsTo(exp));
|
||||||
|
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6242,23 +6242,18 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_Dot_BP_3) {
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests7, Test_CumSum_BP_1) {
|
TEST_F(DeclarableOpsTests7, cumsum_bp_1) {
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {3, 4});
|
auto x = NDArrayFactory::create<double>('c', {3, 4});
|
||||||
// auto y = NDArrayFactory::create<double>('c', {3, 4});
|
|
||||||
// auto z; // = NDArrayFactory::create<double>('c', {4});
|
|
||||||
auto eps = NDArrayFactory::create<double>('c', {3, 4});
|
auto eps = NDArrayFactory::create<double>('c', {3, 4});
|
||||||
auto exp = NDArrayFactory::create<double>('c', {3, 4}, {12.f, 11.f, 10.f, 9.f, 8.f, 7.f,
|
auto exp = NDArrayFactory::create<double>('c', {3, 4}, {12.f, 11.f, 10.f, 9.f, 8.f, 7.f,
|
||||||
6.f, 5.f, 4.f, 3.f, 2.f, 1.f});
|
6.f, 5.f, 4.f, 3.f, 2.f, 1.f});
|
||||||
x.linspace(1);
|
x.linspace(1);
|
||||||
eps.assign(1.f);
|
eps.assign(1.f);
|
||||||
|
|
||||||
// z = x.applyReduce3<simdOps::Dot<float>>(&y, {0}, nullptr);
|
|
||||||
nd4j::ops::cumsum_bp op;
|
nd4j::ops::cumsum_bp op;
|
||||||
auto result = op.execute({&x, &eps}, {}, {0,0});
|
auto result = op.execute({&x, &eps}, {}, {0,0});
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
// output->printIndexedBuffer("Result is");
|
|
||||||
// output->printShapeInfo("Result shape is");
|
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -6266,27 +6261,21 @@ TEST_F(DeclarableOpsTests7, Test_CumSum_BP_1) {
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
delete result;
|
delete result;
|
||||||
// delete z;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests7, Test_CumSum_BP_2) {
|
TEST_F(DeclarableOpsTests7, cumsum_bp_2) {
|
||||||
auto x = NDArrayFactory::create<double>('c', {3, 4});
|
auto x = NDArrayFactory::create<double>('c', {3, 4});
|
||||||
// auto y = NDArrayFactory::create<double>('c', {3, 4});
|
|
||||||
// auto z; // = NDArrayFactory::create<double>('c', {4});
|
|
||||||
auto eps = NDArrayFactory::create<double>('c', {3, 4});
|
auto eps = NDArrayFactory::create<double>('c', {3, 4});
|
||||||
auto exp = NDArrayFactory::create<double>('c', {3, 4}, { 11.f, 10.f, 9.f, 8.f, 7.f, 6.f,
|
auto exp = NDArrayFactory::create<double>('c', {3, 4}, { 11.f, 10.f, 9.f, 8.f, 7.f, 6.f,
|
||||||
5.f, 4.f, 3.f, 2.f, 1.f, 0.f});
|
5.f, 4.f, 3.f, 2.f, 1.f, 0.f});
|
||||||
x.linspace(1);
|
x.linspace(1);
|
||||||
// exp.linspace(1);
|
|
||||||
eps.assign(1.f);
|
eps.assign(1.f);
|
||||||
|
|
||||||
// z = x.applyReduce3<simdOps::Dot<float>>(&y, {0}, nullptr);
|
|
||||||
nd4j::ops::cumsum_bp op;
|
nd4j::ops::cumsum_bp op;
|
||||||
auto result = op.execute({&x, &eps}, {}, {1,0});
|
auto result = op.execute({&x, &eps}, {}, {1,0});
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
// output->printIndexedBuffer("Result is");
|
|
||||||
// output->printShapeInfo("Result shape is");
|
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
@ -6294,14 +6283,11 @@ TEST_F(DeclarableOpsTests7, Test_CumSum_BP_2) {
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
delete result;
|
delete result;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests7, Test_Reduce_CumSum_BP_3) {
|
TEST_F(DeclarableOpsTests7, cumsum_bp_3) {
|
||||||
auto x = NDArrayFactory::create<double>('c', {3, 4});
|
auto x = NDArrayFactory::create<double>('c', {3, 4});
|
||||||
// auto y = NDArrayFactory::create<double>('c', {3, 4});
|
|
||||||
// auto z; // = NDArrayFactory::create<double>('c', {4});
|
|
||||||
auto eps = NDArrayFactory::create<double>('c', {3, 4});
|
auto eps = NDArrayFactory::create<double>('c', {3, 4});
|
||||||
auto exp = NDArrayFactory::create<double>('c', {3, 4});
|
auto exp = NDArrayFactory::create<double>('c', {3, 4});
|
||||||
|
|
||||||
|
@ -6309,16 +6295,11 @@ TEST_F(DeclarableOpsTests7, Test_Reduce_CumSum_BP_3) {
|
||||||
exp.linspace(0);
|
exp.linspace(0);
|
||||||
eps.assign(1.f);
|
eps.assign(1.f);
|
||||||
|
|
||||||
// z = x.applyReduce3<simdOps::Dot<float>>(&y, {0}, nullptr);
|
|
||||||
nd4j::ops::cumsum_bp op;
|
nd4j::ops::cumsum_bp op;
|
||||||
auto result = op.execute({&x, &eps}, {}, {1,1});
|
auto result = op.execute({&x, &eps}, {}, {1,1});
|
||||||
auto output = result->at(0);
|
auto output = result->at(0);
|
||||||
// output->printIndexedBuffer("Result is");
|
|
||||||
// output->printShapeInfo("Result shape is");
|
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
// ASSERT_TRUE(exp.isSameShape(output));
|
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
delete result;
|
delete result;
|
||||||
|
|
|
@ -236,45 +236,6 @@ TEST_F(DeclarableOpsTests9, ScalarOpTest_MixedOrders_1) {
|
||||||
ASSERT_EQ(e, z);
|
ASSERT_EQ(e, z);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests9, tile_bp_test1) {
|
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<double>('c', {2, 3}, {1.,2.,3.,4.,5.,6.});
|
|
||||||
auto gradO = NDArrayFactory::create<double>('c', {4, 9});
|
|
||||||
auto gradIExp = NDArrayFactory::create<double>('c', {2, 3}, {0.78, 0.84, 0.9,1.32, 1.38, 1.44});
|
|
||||||
|
|
||||||
gradO.linspace(0.01, 0.01);
|
|
||||||
|
|
||||||
nd4j::ops::tile_bp op;
|
|
||||||
auto results = op.execute({&input, &gradO}, {}, {2, 3});
|
|
||||||
auto gradI = results->at(0);
|
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
|
||||||
ASSERT_TRUE(gradIExp.isSameShape(gradI));
|
|
||||||
ASSERT_TRUE(gradIExp.equalsTo(gradI));
|
|
||||||
|
|
||||||
delete results;
|
|
||||||
}
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests9, tile_bp_test2) {
|
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<double>('c', {2, 3}, {1.,2.,3.,4.,5.,6.});
|
|
||||||
auto gradO = NDArrayFactory::create<double>('c', {2, 9});
|
|
||||||
auto gradIExp = NDArrayFactory::create<double>('c', {2, 3}, {0.12, 0.15, 0.18, 0.39, 0.42, 0.45});
|
|
||||||
|
|
||||||
gradO.linspace(0.01, 0.01);
|
|
||||||
|
|
||||||
nd4j::ops::tile_bp op;
|
|
||||||
auto results = op.execute({&input, &gradO}, {}, {1, 3});
|
|
||||||
auto gradI = results->at(0);
|
|
||||||
ASSERT_EQ(Status::OK(), results->status());
|
|
||||||
ASSERT_TRUE(gradIExp.isSameShape(gradI));
|
|
||||||
ASSERT_TRUE(gradIExp.equalsTo(gradI));
|
|
||||||
|
|
||||||
delete results;
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests9, concat_test1) {
|
TEST_F(DeclarableOpsTests9, concat_test1) {
|
||||||
|
|
||||||
|
@ -626,6 +587,45 @@ TEST_F(DeclarableOpsTests9, concat_test16) {
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests9, tile_bp_test1) {
|
||||||
|
|
||||||
|
auto input = NDArrayFactory::create<double>('c', {2, 3}, {1.,2.,3.,4.,5.,6.});
|
||||||
|
auto gradO = NDArrayFactory::create<double>('c', {4, 9});
|
||||||
|
auto gradIExp = NDArrayFactory::create<double>('c', {2, 3}, {0.78, 0.84, 0.9,1.32, 1.38, 1.44});
|
||||||
|
|
||||||
|
gradO.linspace(0.01, 0.01);
|
||||||
|
|
||||||
|
nd4j::ops::tile_bp op;
|
||||||
|
auto results = op.execute({&input, &gradO}, {}, {2, 3});
|
||||||
|
auto gradI = results->at(0);
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
ASSERT_TRUE(gradIExp.isSameShape(gradI));
|
||||||
|
ASSERT_TRUE(gradIExp.equalsTo(gradI));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests9, tile_bp_test2) {
|
||||||
|
|
||||||
|
auto input = NDArrayFactory::create<double>('c', {2, 3}, {1.,2.,3.,4.,5.,6.});
|
||||||
|
auto gradO = NDArrayFactory::create<double>('c', {2, 9});
|
||||||
|
auto gradIExp = NDArrayFactory::create<double>('c', {2, 3}, {0.12, 0.15, 0.18, 0.39, 0.42, 0.45});
|
||||||
|
|
||||||
|
gradO.linspace(0.01, 0.01);
|
||||||
|
|
||||||
|
nd4j::ops::tile_bp op;
|
||||||
|
auto results = op.execute({&input, &gradO}, {}, {1, 3});
|
||||||
|
auto gradI = results->at(0);
|
||||||
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
ASSERT_TRUE(gradIExp.isSameShape(gradI));
|
||||||
|
ASSERT_TRUE(gradIExp.equalsTo(gradI));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests9, tile_bp_test3) {
|
TEST_F(DeclarableOpsTests9, tile_bp_test3) {
|
||||||
|
|
||||||
|
@ -2623,21 +2623,52 @@ TEST_F(DeclarableOpsTests9, Dynamic_Partition_BP_1) {
|
||||||
auto dLdzX = NDArrayFactory::create<double>('c', {2, 4});
|
auto dLdzX = NDArrayFactory::create<double>('c', {2, 4});
|
||||||
auto dLdzY = NDArrayFactory::create<double>('c', {2, 4});
|
auto dLdzY = NDArrayFactory::create<double>('c', {2, 4});
|
||||||
auto dLdzZ = NDArrayFactory::create<double>('c', {2, 4});
|
auto dLdzZ = NDArrayFactory::create<double>('c', {2, 4});
|
||||||
|
auto exp = NDArrayFactory::create<double>('c', {2,3,4}, {1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 3, 3, 3, 3});
|
||||||
x.linspace(1);
|
x.linspace(1);
|
||||||
dLdzX.linspace(1);
|
// dLdzX.linspace(1);
|
||||||
dLdzY.linspace(2);
|
// dLdzY.linspace(2);
|
||||||
dLdzZ.linspace(3);
|
// dLdzZ.linspace(3);
|
||||||
|
dLdzX.assign(1);
|
||||||
|
dLdzY.assign(2);
|
||||||
|
dLdzZ.assign(3);
|
||||||
|
|
||||||
nd4j::ops::dynamic_partition op1;
|
nd4j::ops::dynamic_partition op1;
|
||||||
auto res1 = op1.execute({&x, &y}, {}, {3});
|
auto res1 = op1.execute({&x, &y}, {}, {3});
|
||||||
|
|
||||||
nd4j::ops::dynamic_partition_bp op2;
|
nd4j::ops::dynamic_partition_bp op2;
|
||||||
auto res2 = op2.execute({&x, &y, res1->at(0), res1->at(1), res1->at(2)}, {}, {3});
|
auto res2 = op2.execute({&x, &y, &dLdzX, &dLdzY, &dLdzZ}, {}, {3});
|
||||||
ASSERT_TRUE(res2->status() == ND4J_STATUS_OK);
|
ASSERT_TRUE(res2->status() == ND4J_STATUS_OK);
|
||||||
ASSERT_TRUE(res2->size() == 2);
|
ASSERT_TRUE(res2->size() == 2);
|
||||||
|
// printf("How many: %ul\n", res2->size());
|
||||||
|
// res2->at(0)->printBuffer("Ouputput0");
|
||||||
|
// res2->at(1)->printBuffer("Ouputput1");
|
||||||
|
ASSERT_TRUE(res2->at(0)->equalsTo(exp));
|
||||||
delete res1;
|
delete res1;
|
||||||
delete res2;
|
delete res2;
|
||||||
}
|
}
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
//TEST_F(DeclarableOpsTests9, Dynamic_Partition_BP_2) {
|
||||||
|
//
|
||||||
|
// auto x = NDArrayFactory::create<double>('c', {2, 3, 4});
|
||||||
|
// auto y = NDArrayFactory::create<int>('c', {2, 3}, {0, 1, 2, 1, 0, 2});
|
||||||
|
// auto dLdzX = NDArrayFactory::create<double>('c', {2, 4});
|
||||||
|
// auto dLdzY = NDArrayFactory::create<double>('c', {2, 4});
|
||||||
|
// auto dLdzZ = NDArrayFactory::create<double>('c', {2, 4});
|
||||||
|
// x.linspace(1);
|
||||||
|
// dLdzX.linspace(1);
|
||||||
|
// dLdzY.linspace(1);
|
||||||
|
// dLdzZ.linspace(1);
|
||||||
|
//
|
||||||
|
// const OpArgsHolder argsHolderFF({&x, &y}, {}, {3});
|
||||||
|
// const OpArgsHolder argsHolderBP({&x, &y, &dLdzX, &dLdzY, &dLdzZ}, {}, {3});
|
||||||
|
//
|
||||||
|
// nd4j::ops::dynamic_partition opFF;
|
||||||
|
// nd4j::ops::dynamic_partition_bp opBP;
|
||||||
|
//
|
||||||
|
// const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
||||||
|
//
|
||||||
|
// ASSERT_TRUE(isGradCorrect);
|
||||||
|
//}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests9, Floormod_BP_Test_4) {
|
TEST_F(DeclarableOpsTests9, Floormod_BP_Test_4) {
|
||||||
|
@ -2914,7 +2945,7 @@ TEST_F(DeclarableOpsTests9, Cholesky_Test_1) {
|
||||||
auto result = op.execute({&x}, {}, {});
|
auto result = op.execute({&x}, {}, {});
|
||||||
ASSERT_EQ(result->status(), ND4J_STATUS_OK);
|
ASSERT_EQ(result->status(), ND4J_STATUS_OK);
|
||||||
auto res = result->at(0);
|
auto res = result->at(0);
|
||||||
//res->printIndexedBuffer("Output for Cholesky");
|
// res->printIndexedBuffer("Output for Cholesky1");
|
||||||
ASSERT_TRUE(exp.equalsTo(res));
|
ASSERT_TRUE(exp.equalsTo(res));
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
@ -2922,6 +2953,22 @@ TEST_F(DeclarableOpsTests9, Cholesky_Test_1) {
|
||||||
////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests9, Cholesky_Test_2) {
|
TEST_F(DeclarableOpsTests9, Cholesky_Test_2) {
|
||||||
|
|
||||||
|
NDArray x = NDArrayFactory::create<double>('c', {2, 3, 3}, {4, 12,-16, 12 ,37,-43, -16, -43, 98, 1, 1, 1, 1, 2, 2, 1, 2., 6});
|
||||||
|
NDArray exp = NDArrayFactory::create<double>('c', {2, 3, 3}, {2., 0., 0., 6., 1., 0., -8., 5., 3., 1., 0., 0., 1., 1., 0,1., 1., 2.});
|
||||||
|
|
||||||
|
nd4j::ops::cholesky op;
|
||||||
|
|
||||||
|
auto result = op.execute({&x}, {}, {});
|
||||||
|
ASSERT_EQ(result->status(), ND4J_STATUS_OK);
|
||||||
|
auto res = result->at(0);
|
||||||
|
// res->printIndexedBuffer("Output for Cholesky 2");
|
||||||
|
ASSERT_TRUE(exp.equalsTo(res));
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests9, Cholesky_Test_3) {
|
||||||
|
|
||||||
NDArray x = NDArrayFactory::create<float>('c', {2, 3, 3}, {4, 12,-16, 12 ,37,-43, -16, -43, 98, 1, 1, 1, 1, 2, 2, 1, 2., 6});
|
NDArray x = NDArrayFactory::create<float>('c', {2, 3, 3}, {4, 12,-16, 12 ,37,-43, -16, -43, 98, 1, 1, 1, 1, 2, 2, 1, 2., 6});
|
||||||
NDArray exp = NDArrayFactory::create<float>('c', {2, 3, 3}, {2., 0., 0., 6., 1., 0., -8., 5., 3., 1., 0., 0., 1., 1., 0,1., 1., 2.});
|
NDArray exp = NDArrayFactory::create<float>('c', {2, 3, 3}, {2., 0., 0., 6., 1., 0., -8., 5., 3., 1., 0., 0., 1., 1., 0,1., 1., 2.});
|
||||||
|
|
||||||
|
@ -2930,7 +2977,7 @@ TEST_F(DeclarableOpsTests9, Cholesky_Test_2) {
|
||||||
auto result = op.execute({&x}, {}, {});
|
auto result = op.execute({&x}, {}, {});
|
||||||
ASSERT_EQ(result->status(), ND4J_STATUS_OK);
|
ASSERT_EQ(result->status(), ND4J_STATUS_OK);
|
||||||
auto res = result->at(0);
|
auto res = result->at(0);
|
||||||
//res->printIndexedBuffer("Output for Cholesky");
|
// res->printIndexedBuffer("Output for Cholesky 3");
|
||||||
ASSERT_TRUE(exp.equalsTo(res));
|
ASSERT_TRUE(exp.equalsTo(res));
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1215,10 +1215,14 @@ public class Nd4j {
|
||||||
|
|
||||||
protected static Indexer getIndexerByType(Pointer pointer, DataType dataType) {
|
protected static Indexer getIndexerByType(Pointer pointer, DataType dataType) {
|
||||||
switch (dataType) {
|
switch (dataType) {
|
||||||
|
case UINT64:
|
||||||
case LONG:
|
case LONG:
|
||||||
return LongIndexer.create((LongPointer) pointer);
|
return LongIndexer.create((LongPointer) pointer);
|
||||||
|
case UINT32:
|
||||||
case INT:
|
case INT:
|
||||||
return IntIndexer.create((IntPointer) pointer);
|
return IntIndexer.create((IntPointer) pointer);
|
||||||
|
case UINT16:
|
||||||
|
return UShortIndexer.create((ShortPointer) pointer);
|
||||||
case SHORT:
|
case SHORT:
|
||||||
return ShortIndexer.create((ShortPointer) pointer);
|
return ShortIndexer.create((ShortPointer) pointer);
|
||||||
case BYTE:
|
case BYTE:
|
||||||
|
@ -1229,6 +1233,8 @@ public class Nd4j {
|
||||||
return BooleanIndexer.create((BooleanPointer) pointer);
|
return BooleanIndexer.create((BooleanPointer) pointer);
|
||||||
case FLOAT:
|
case FLOAT:
|
||||||
return FloatIndexer.create((FloatPointer) pointer);
|
return FloatIndexer.create((FloatPointer) pointer);
|
||||||
|
case BFLOAT16:
|
||||||
|
return Bfloat16Indexer.create((ShortPointer) pointer);
|
||||||
case HALF:
|
case HALF:
|
||||||
return HalfIndexer.create((ShortPointer) pointer);
|
return HalfIndexer.create((ShortPointer) pointer);
|
||||||
case DOUBLE:
|
case DOUBLE:
|
||||||
|
@ -1297,12 +1303,15 @@ public class Nd4j {
|
||||||
public static DataBuffer createBuffer(@NonNull Pointer pointer, @NonNull Pointer devicePointer, long length, @NonNull DataType dataType) {
|
public static DataBuffer createBuffer(@NonNull Pointer pointer, @NonNull Pointer devicePointer, long length, @NonNull DataType dataType) {
|
||||||
Pointer nPointer = null;
|
Pointer nPointer = null;
|
||||||
switch (dataType) {
|
switch (dataType) {
|
||||||
|
case UINT64:
|
||||||
case LONG:
|
case LONG:
|
||||||
nPointer = new LongPointer(pointer);
|
nPointer = new LongPointer(pointer);
|
||||||
break;
|
break;
|
||||||
|
case UINT32:
|
||||||
case INT:
|
case INT:
|
||||||
nPointer = new IntPointer(pointer);
|
nPointer = new IntPointer(pointer);
|
||||||
break;
|
break;
|
||||||
|
case UINT16:
|
||||||
case SHORT:
|
case SHORT:
|
||||||
nPointer = new ShortPointer(pointer);
|
nPointer = new ShortPointer(pointer);
|
||||||
break;
|
break;
|
||||||
|
@ -1315,12 +1324,13 @@ public class Nd4j {
|
||||||
case BOOL:
|
case BOOL:
|
||||||
nPointer = new BooleanPointer(pointer);
|
nPointer = new BooleanPointer(pointer);
|
||||||
break;
|
break;
|
||||||
case FLOAT:
|
case BFLOAT16:
|
||||||
nPointer = new FloatPointer(pointer);
|
|
||||||
break;
|
|
||||||
case HALF:
|
case HALF:
|
||||||
nPointer = new ShortPointer(pointer);
|
nPointer = new ShortPointer(pointer);
|
||||||
break;
|
break;
|
||||||
|
case FLOAT:
|
||||||
|
nPointer = new FloatPointer(pointer);
|
||||||
|
break;
|
||||||
case DOUBLE:
|
case DOUBLE:
|
||||||
nPointer = new DoublePointer(pointer);
|
nPointer = new DoublePointer(pointer);
|
||||||
break;
|
break;
|
||||||
|
|
|
@ -93,18 +93,7 @@ public class BasicContextPool implements ContextPool {
|
||||||
try {
|
try {
|
||||||
// this is lockable thing, but since it locks once per thread initialization, performance impact won't be big
|
// this is lockable thing, but since it locks once per thread initialization, performance impact won't be big
|
||||||
lock.acquire();
|
lock.acquire();
|
||||||
// we create 1 CUcontext per device, which will be shared for all threads/streams on this device
|
|
||||||
/*
|
|
||||||
if (!cuPool.containsKey(deviceId)) {
|
|
||||||
CUcontext cuContext = createNewContext(deviceId);
|
|
||||||
cuPool.put(deviceId, cuContext);
|
|
||||||
}
|
|
||||||
|
|
||||||
int result = JCudaDriver.cuCtxSetCurrent(cuPool.get(deviceId));
|
|
||||||
if (result != CUresult.CUDA_SUCCESS) {
|
|
||||||
throw new RuntimeException("Failed to set context on assigner");
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
if (!contextsForDevices.containsKey(deviceId)) {
|
if (!contextsForDevices.containsKey(deviceId)) {
|
||||||
contextsForDevices.put(deviceId, new ConcurrentHashMap<Integer, CudaContext>());
|
contextsForDevices.put(deviceId, new ConcurrentHashMap<Integer, CudaContext>());
|
||||||
}
|
}
|
||||||
|
@ -120,11 +109,11 @@ public class BasicContextPool implements ContextPool {
|
||||||
// if we have no contexts created - it's just awesome time to attach cuBLAS handle here
|
// if we have no contexts created - it's just awesome time to attach cuBLAS handle here
|
||||||
log.debug("Creating new cuBLAS handle for device [{}]...", deviceId);
|
log.debug("Creating new cuBLAS handle for device [{}]...", deviceId);
|
||||||
|
|
||||||
cudaStream_t cublasStream = createNewStream(deviceId).getOldStream();
|
//cudaStream_t cublasStream = createNewStream(deviceId).getOldStream();
|
||||||
|
|
||||||
cublasHandle_t handle = createNewCublasHandle(cublasStream);
|
cublasHandle_t handle = createNewCublasHandle(context.getOldStream());
|
||||||
context.setHandle(handle);
|
context.setHandle(handle);
|
||||||
context.setCublasStream(cublasStream);
|
//context.setCublasStream(cublasStream);
|
||||||
|
|
||||||
cublasPool.put(deviceId, handle);
|
cublasPool.put(deviceId, handle);
|
||||||
|
|
||||||
|
|
|
@ -62,11 +62,11 @@ public class PackedContextPool extends BasicContextPool implements ContextPool {
|
||||||
// if we have no contexts created - it's just awesome time to attach cuBLAS handle here
|
// if we have no contexts created - it's just awesome time to attach cuBLAS handle here
|
||||||
log.debug("Creating new cuBLAS handle for device [{}]", deviceId);
|
log.debug("Creating new cuBLAS handle for device [{}]", deviceId);
|
||||||
|
|
||||||
cudaStream_t cublasStream = createNewStream(deviceId).getOldStream();
|
//cudaStream_t cublasStream = createNewStream(deviceId).getOldStream();
|
||||||
|
|
||||||
cublasHandle_t handle = createNewCublasHandle(cublasStream);
|
cublasHandle_t handle = createNewCublasHandle(context.getOldStream());
|
||||||
context.setHandle(handle);
|
context.setHandle(handle);
|
||||||
context.setCublasStream(cublasStream);
|
//context.setCublasStream(cublasStream);
|
||||||
|
|
||||||
cublasPool.put(deviceId, handle);
|
cublasPool.put(deviceId, handle);
|
||||||
|
|
||||||
|
|
|
@ -31,6 +31,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.compression.CompressedDataBuffer;
|
import org.nd4j.linalg.compression.CompressedDataBuffer;
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer;
|
||||||
import org.nd4j.linalg.jcublas.context.CudaContext;
|
import org.nd4j.linalg.jcublas.context.CudaContext;
|
||||||
import org.nd4j.linalg.jcublas.ops.executioner.CudaGridExecutioner;
|
import org.nd4j.linalg.jcublas.ops.executioner.CudaGridExecutioner;
|
||||||
import org.nd4j.linalg.memory.BasicMemoryManager;
|
import org.nd4j.linalg.memory.BasicMemoryManager;
|
||||||
|
@ -151,6 +152,14 @@ public class CudaMemoryManager extends BasicMemoryManager {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected void allocateHostPointers(DataBuffer... dataBuffers) {
|
||||||
|
for (val v:dataBuffers) {
|
||||||
|
if (v != null && v instanceof BaseCudaDataBuffer) {
|
||||||
|
((BaseCudaDataBuffer) v).lazyAllocateHostPointer();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method provides basic memcpy functionality with respect to target environment
|
* This method provides basic memcpy functionality with respect to target environment
|
||||||
*
|
*
|
||||||
|
@ -161,9 +170,13 @@ public class CudaMemoryManager extends BasicMemoryManager {
|
||||||
public void memcpy(DataBuffer dstBuffer, DataBuffer srcBuffer) {
|
public void memcpy(DataBuffer dstBuffer, DataBuffer srcBuffer) {
|
||||||
CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
|
CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
|
||||||
|
|
||||||
|
|
||||||
if (dstBuffer instanceof CompressedDataBuffer && !(srcBuffer instanceof CompressedDataBuffer)) {
|
if (dstBuffer instanceof CompressedDataBuffer && !(srcBuffer instanceof CompressedDataBuffer)) {
|
||||||
// destination is compressed, source isn't
|
// destination is compressed, source isn't
|
||||||
AllocationPoint srcPoint = AtomicAllocator.getInstance().getAllocationPoint(srcBuffer);
|
AllocationPoint srcPoint = AtomicAllocator.getInstance().getAllocationPoint(srcBuffer);
|
||||||
|
|
||||||
|
allocateHostPointers(dstBuffer, srcBuffer);
|
||||||
|
|
||||||
long size = srcBuffer.getElementSize() * srcBuffer.length();
|
long size = srcBuffer.getElementSize() * srcBuffer.length();
|
||||||
if (!srcPoint.isActualOnHostSide()) {
|
if (!srcPoint.isActualOnHostSide()) {
|
||||||
// copying device -> host
|
// copying device -> host
|
||||||
|
@ -177,12 +190,14 @@ public class CudaMemoryManager extends BasicMemoryManager {
|
||||||
|
|
||||||
} // else {
|
} // else {
|
||||||
// copying host -> host
|
// copying host -> host
|
||||||
Pointer src = AtomicAllocator.getInstance().getHostPointer(srcBuffer);
|
val src = AtomicAllocator.getInstance().getHostPointer(srcBuffer);
|
||||||
|
|
||||||
Pointer.memcpy(dstBuffer.addressPointer(), src, size);
|
Pointer.memcpy(dstBuffer.addressPointer(), src, size);
|
||||||
// }
|
// }
|
||||||
|
|
||||||
} else if (!(dstBuffer instanceof CompressedDataBuffer) && srcBuffer instanceof CompressedDataBuffer) {
|
} else if (!(dstBuffer instanceof CompressedDataBuffer) && srcBuffer instanceof CompressedDataBuffer) {
|
||||||
|
allocateHostPointers(dstBuffer, srcBuffer);
|
||||||
|
|
||||||
// destination is NOT compressed, source is compressed
|
// destination is NOT compressed, source is compressed
|
||||||
AllocationPoint dstPoint = AtomicAllocator.getInstance().getAllocationPoint(dstBuffer);
|
AllocationPoint dstPoint = AtomicAllocator.getInstance().getAllocationPoint(dstBuffer);
|
||||||
long size = srcBuffer.getElementSize() * srcBuffer.length();
|
long size = srcBuffer.getElementSize() * srcBuffer.length();
|
||||||
|
@ -193,6 +208,7 @@ public class CudaMemoryManager extends BasicMemoryManager {
|
||||||
} else if (dstBuffer instanceof CompressedDataBuffer && srcBuffer instanceof CompressedDataBuffer) {
|
} else if (dstBuffer instanceof CompressedDataBuffer && srcBuffer instanceof CompressedDataBuffer) {
|
||||||
// both buffers are compressed, just fire memcpy
|
// both buffers are compressed, just fire memcpy
|
||||||
|
|
||||||
|
allocateHostPointers(dstBuffer, srcBuffer);
|
||||||
|
|
||||||
Pointer.memcpy(dstBuffer.addressPointer(), srcBuffer.addressPointer(),
|
Pointer.memcpy(dstBuffer.addressPointer(), srcBuffer.addressPointer(),
|
||||||
srcBuffer.length() * srcBuffer.getElementSize());
|
srcBuffer.length() * srcBuffer.getElementSize());
|
||||||
|
|
|
@ -124,65 +124,6 @@ public class CublasPointer implements AutoCloseable {
|
||||||
this.cudaContext = context;
|
this.cudaContext = context;
|
||||||
this.devicePointer = AtomicAllocator.getInstance().getPointer(array, context);
|
this.devicePointer = AtomicAllocator.getInstance().getPointer(array, context);
|
||||||
|
|
||||||
/*
|
|
||||||
if(array instanceof IComplexNDArray) {
|
|
||||||
if(array.length() * 2 < array.data().length() && !array.isVector()) {
|
|
||||||
array = Shape.toOffsetZero(array);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
buffer = (JCudaBuffer) array.data();
|
|
||||||
|
|
||||||
//the opName of this thread for knowing whether to copy data or not
|
|
||||||
//String opName = Thread.currentThread().getName();
|
|
||||||
this.arr = array;
|
|
||||||
if(array.elementWiseStride() < 0) {
|
|
||||||
this.arr = array.dup();
|
|
||||||
buffer = (JCudaBuffer) this.arr.data();
|
|
||||||
if(this.arr.elementWiseStride() < 0)
|
|
||||||
throw new IllegalStateException("Unable to iterate over buffer");
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
//int compLength = arr instanceof IComplexNDArray ? arr.length() * 2 : arr.length();
|
|
||||||
////int stride = arr instanceof IComplexNDArray ? BlasBufferUtil.getBlasStride(arr) / 2 : BlasBufferUtil.getBlasStride(arr);
|
|
||||||
//no striding for upload if we are using the whole buffer
|
|
||||||
// System.out.println("Allocation offset: ["+array.offset()+"], length: ["+compLength+"], stride: ["+ stride+"]");
|
|
||||||
|
|
||||||
/*
|
|
||||||
buffer.getPointer(
|
|
||||||
this.arr,
|
|
||||||
stride
|
|
||||||
,this.arr.offset()
|
|
||||||
,compLength);
|
|
||||||
*/
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Neat edge case here.
|
|
||||||
*
|
|
||||||
* The striding will overshoot the original array
|
|
||||||
* when the offset is zero (the case being when offset is zero
|
|
||||||
* sayon a getRow(0) operation.
|
|
||||||
*
|
|
||||||
* We need to allocate the data differently here
|
|
||||||
* due to how the striding works out.
|
|
||||||
*/
|
|
||||||
// Copy the data to the device iff the whole buffer hasn't been copied
|
|
||||||
/*
|
|
||||||
|
|
||||||
//Data is already copied into CUDA buffer during allocation at getPointer
|
|
||||||
|
|
||||||
if(!buffer.copied(opName)) {
|
|
||||||
ContextHolder.getInstance().getMemoryStrategy().setData(buffer,0,1,buffer.length());
|
|
||||||
//mark the buffer copied
|
|
||||||
buffer.setCopied(opName);
|
|
||||||
|
|
||||||
}*/
|
|
||||||
|
|
||||||
/*
|
|
||||||
DevicePointerInfo info = buffer.getPointersToContexts().get(Thread.currentThread().getName(), Triple.of(0, buffer.length(), 1));
|
|
||||||
hostPointer = info.getPointers().getHostPointer();
|
|
||||||
*/
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -360,12 +360,16 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
||||||
|
|
||||||
|
|
||||||
for (INDArray m : matrices) {
|
for (INDArray m : matrices) {
|
||||||
|
if (m.isEmpty())
|
||||||
|
continue;
|
||||||
|
|
||||||
CudaContext context = allocator.getFlowController().prepareAction(ret, m);
|
CudaContext context = allocator.getFlowController().prepareAction(ret, m);
|
||||||
|
|
||||||
if (m.ordering() == order && ret.elementWiseStride() == m.elementWiseStride()
|
if (m.ordering() == order && ret.elementWiseStride() == m.elementWiseStride()
|
||||||
&& ret.elementWiseStride() == 1) {
|
&& ret.elementWiseStride() == 1) {
|
||||||
// do memcpy in proper direction and forget about that
|
// do memcpy in proper direction and forget about that
|
||||||
|
// FIXME: get rid of this
|
||||||
|
((BaseCudaDataBuffer) m.data()).lazyAllocateHostPointer();
|
||||||
allocator.memcpyAsync(ret.data(), new CudaPointer(allocator.getHostPointer(m).address()),
|
allocator.memcpyAsync(ret.data(), new CudaPointer(allocator.getHostPointer(m).address()),
|
||||||
AllocationUtils.getRequiredMemory(AllocationUtils.buildAllocationShape(m)),
|
AllocationUtils.getRequiredMemory(AllocationUtils.buildAllocationShape(m)),
|
||||||
linearIndex * (m.data().dataType() == DataType.DOUBLE ? 8
|
linearIndex * (m.data().dataType() == DataType.DOUBLE ? 8
|
||||||
|
@ -560,6 +564,8 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
||||||
|
|
||||||
|
|
||||||
for (int i = 0; i < toConcat.length; i++) {
|
for (int i = 0; i < toConcat.length; i++) {
|
||||||
|
((BaseCudaDataBuffer) toConcat[i].data()).lazyAllocateHostPointer();
|
||||||
|
|
||||||
if (toConcat[i].isCompressed())
|
if (toConcat[i].isCompressed())
|
||||||
Nd4j.getCompressor().decompressi(toConcat[i]);
|
Nd4j.getCompressor().decompressi(toConcat[i]);
|
||||||
|
|
||||||
|
@ -577,15 +583,15 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
||||||
|
|
||||||
outputShape[dimension] = sumAlongDim;
|
outputShape[dimension] = sumAlongDim;
|
||||||
|
|
||||||
val dummy = new PointerPointer(new Pointer[] {null});
|
|
||||||
|
|
||||||
val ret = Nd4j.createUninitialized(toConcat[0].dataType(), outputShape, Nd4j.order());
|
val ret = Nd4j.createUninitialized(toConcat[0].dataType(), outputShape, Nd4j.order());
|
||||||
|
|
||||||
|
((BaseCudaDataBuffer) ret.data()).lazyAllocateHostPointer();
|
||||||
|
|
||||||
nativeOps.specialConcat(dummy, dimension, toConcat.length, dataPointers, shapeInfoPointers,
|
nativeOps.specialConcat(null, dimension, toConcat.length, dataPointers, shapeInfoPointers,
|
||||||
ret.data().addressPointer(),
|
ret.data().addressPointer(),
|
||||||
(LongPointer) ret.shapeInfoDataBuffer().addressPointer(),
|
(LongPointer) ret.shapeInfoDataBuffer().addressPointer(),
|
||||||
new PointerPointer(new Pointer[] {null}), new PointerPointer(new Pointer[] {null}));
|
null, null);
|
||||||
|
|
||||||
|
|
||||||
AllocationPoint point = allocator.getAllocationPoint(ret);
|
AllocationPoint point = allocator.getAllocationPoint(ret);
|
||||||
|
@ -780,8 +786,6 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
||||||
|
|
||||||
allocator.getFlowController().registerAction(context, target, arrays);
|
allocator.getFlowController().registerAction(context, target, arrays);
|
||||||
|
|
||||||
tempX.address();
|
|
||||||
|
|
||||||
return target;
|
return target;
|
||||||
} else {
|
} else {
|
||||||
long len = target.lengthLong();
|
long len = target.lengthLong();
|
||||||
|
@ -803,9 +807,13 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
||||||
if (arrays[i].lengthLong() != len)
|
if (arrays[i].lengthLong() != len)
|
||||||
throw new ND4JIllegalStateException("All arrays should have equal length for averaging");
|
throw new ND4JIllegalStateException("All arrays should have equal length for averaging");
|
||||||
|
|
||||||
|
((BaseCudaDataBuffer) arrays[i].data()).lazyAllocateHostPointer();
|
||||||
|
|
||||||
dataPointers.put(i, AtomicAllocator.getInstance().getHostPointer(arrays[i]));
|
dataPointers.put(i, AtomicAllocator.getInstance().getHostPointer(arrays[i]));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (target != null)
|
||||||
|
((BaseCudaDataBuffer) target.data()).lazyAllocateHostPointer();
|
||||||
|
|
||||||
nativeOps.accumulate(extras,
|
nativeOps.accumulate(extras,
|
||||||
dataPointers,
|
dataPointers,
|
||||||
|
@ -823,7 +831,6 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
||||||
AtomicAllocator.getInstance().getAllocationPoint(target).tickHostWrite();
|
AtomicAllocator.getInstance().getAllocationPoint(target).tickHostWrite();
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
return target;
|
return target;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -893,8 +900,6 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
||||||
|
|
||||||
allocator.getFlowController().registerAction(context, target, arrays);
|
allocator.getFlowController().registerAction(context, target, arrays);
|
||||||
|
|
||||||
tempX.address();
|
|
||||||
|
|
||||||
return target;
|
return target;
|
||||||
} else {
|
} else {
|
||||||
// otherwise we do averging on CPU side
|
// otherwise we do averging on CPU side
|
||||||
|
@ -918,9 +923,14 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
||||||
if (arrays[i].lengthLong() != len)
|
if (arrays[i].lengthLong() != len)
|
||||||
throw new ND4JIllegalStateException("All arrays should have equal length for averaging");
|
throw new ND4JIllegalStateException("All arrays should have equal length for averaging");
|
||||||
|
|
||||||
|
((BaseCudaDataBuffer) arrays[i].data()).lazyAllocateHostPointer();
|
||||||
|
|
||||||
dataPointers.put(i, AtomicAllocator.getInstance().getHostPointer(arrays[i]));
|
dataPointers.put(i, AtomicAllocator.getInstance().getHostPointer(arrays[i]));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (target != null)
|
||||||
|
((BaseCudaDataBuffer) target.data()).lazyAllocateHostPointer();
|
||||||
|
|
||||||
nativeOps.average(extras,
|
nativeOps.average(extras,
|
||||||
dataPointers,
|
dataPointers,
|
||||||
(LongPointer) arrays[0].shapeInfoDataBuffer().addressPointer(),
|
(LongPointer) arrays[0].shapeInfoDataBuffer().addressPointer(),
|
||||||
|
@ -1114,8 +1124,8 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
||||||
|
|
||||||
|
|
||||||
// just to keep reference
|
// just to keep reference
|
||||||
shuffle.address();
|
//shuffle.address();
|
||||||
hostPointers.address();
|
//hostPointers.address();
|
||||||
|
|
||||||
tempX.dataType();
|
tempX.dataType();
|
||||||
tempShapes.dataType();
|
tempShapes.dataType();
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.nd4j.linalg.jcublas.blas;
|
package org.nd4j.linalg.jcublas.blas;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import lombok.val;
|
||||||
import org.bytedeco.javacpp.DoublePointer;
|
import org.bytedeco.javacpp.DoublePointer;
|
||||||
import org.bytedeco.javacpp.FloatPointer;
|
import org.bytedeco.javacpp.FloatPointer;
|
||||||
import org.bytedeco.javacpp.IntPointer;
|
import org.bytedeco.javacpp.IntPointer;
|
||||||
|
@ -35,6 +36,7 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.indexing.INDArrayIndex;
|
import org.nd4j.linalg.indexing.INDArrayIndex;
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||||
import org.nd4j.linalg.jcublas.CublasPointer;
|
import org.nd4j.linalg.jcublas.CublasPointer;
|
||||||
|
import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer;
|
||||||
import org.nd4j.linalg.jcublas.context.CudaContext;
|
import org.nd4j.linalg.jcublas.context.CudaContext;
|
||||||
import org.nd4j.nativeblas.NativeOps;
|
import org.nd4j.nativeblas.NativeOps;
|
||||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||||
|
@ -81,7 +83,7 @@ public class JcublasLapack extends BaseLapack {
|
||||||
|
|
||||||
// synchronized on the solver
|
// synchronized on the solver
|
||||||
synchronized (handle) {
|
synchronized (handle) {
|
||||||
int result = cusolverDnSetStream(new cusolverDnContext(handle), new CUstream_st(ctx.getOldStream()));
|
int result = cusolverDnSetStream(new cusolverDnContext(handle), new CUstream_st(ctx.getCublasStream()));
|
||||||
if (result != 0)
|
if (result != 0)
|
||||||
throw new BlasException("solverSetStream failed");
|
throw new BlasException("solverSetStream failed");
|
||||||
|
|
||||||
|
@ -89,7 +91,8 @@ public class JcublasLapack extends BaseLapack {
|
||||||
CublasPointer xAPointer = new CublasPointer(a, ctx);
|
CublasPointer xAPointer = new CublasPointer(a, ctx);
|
||||||
|
|
||||||
// this output - indicates how much memory we'll need for the real operation
|
// this output - indicates how much memory we'll need for the real operation
|
||||||
DataBuffer worksizeBuffer = Nd4j.getDataBufferFactory().createInt(1);
|
val worksizeBuffer = (BaseCudaDataBuffer) Nd4j.getDataBufferFactory().createInt(1);
|
||||||
|
worksizeBuffer.lazyAllocateHostPointer();
|
||||||
|
|
||||||
int stat = cusolverDnSgetrf_bufferSize(solverDn, M, N, (FloatPointer) xAPointer.getDevicePointer(), M,
|
int stat = cusolverDnSgetrf_bufferSize(solverDn, M, N, (FloatPointer) xAPointer.getDevicePointer(), M,
|
||||||
(IntPointer) worksizeBuffer.addressPointer() // we intentionally use host pointer here
|
(IntPointer) worksizeBuffer.addressPointer() // we intentionally use host pointer here
|
||||||
|
@ -147,15 +150,16 @@ public class JcublasLapack extends BaseLapack {
|
||||||
|
|
||||||
// synchronized on the solver
|
// synchronized on the solver
|
||||||
synchronized (handle) {
|
synchronized (handle) {
|
||||||
int result = cusolverDnSetStream(new cusolverDnContext(handle), new CUstream_st(ctx.getOldStream()));
|
int result = cusolverDnSetStream(new cusolverDnContext(handle), new CUstream_st(ctx.getCublasStream()));
|
||||||
if (result != 0)
|
if (result != 0)
|
||||||
throw new BlasException("solverSetStream failed");
|
throw new BlasException("solverSetStream failed");
|
||||||
|
|
||||||
// transfer the INDArray into GPU memory
|
// transfer the INDArray into GPU memory
|
||||||
CublasPointer xAPointer = new CublasPointer(a, ctx);
|
val xAPointer = new CublasPointer(a, ctx);
|
||||||
|
|
||||||
// this output - indicates how much memory we'll need for the real operation
|
// this output - indicates how much memory we'll need for the real operation
|
||||||
DataBuffer worksizeBuffer = Nd4j.getDataBufferFactory().createInt(1);
|
val worksizeBuffer = (BaseCudaDataBuffer) Nd4j.getDataBufferFactory().createInt(1);
|
||||||
|
worksizeBuffer.lazyAllocateHostPointer();
|
||||||
|
|
||||||
int stat = cusolverDnDgetrf_bufferSize(solverDn, M, N, (DoublePointer) xAPointer.getDevicePointer(), M,
|
int stat = cusolverDnDgetrf_bufferSize(solverDn, M, N, (DoublePointer) xAPointer.getDevicePointer(), M,
|
||||||
(IntPointer) worksizeBuffer.addressPointer() // we intentionally use host pointer here
|
(IntPointer) worksizeBuffer.addressPointer() // we intentionally use host pointer here
|
||||||
|
@ -167,7 +171,7 @@ public class JcublasLapack extends BaseLapack {
|
||||||
int worksize = worksizeBuffer.getInt(0);
|
int worksize = worksizeBuffer.getInt(0);
|
||||||
|
|
||||||
// Now allocate memory for the workspace, the permutation matrix and a return code
|
// Now allocate memory for the workspace, the permutation matrix and a return code
|
||||||
Pointer workspace = new Workspace(worksize * Nd4j.sizeOfDataType());
|
val workspace = new Workspace(worksize * Nd4j.sizeOfDataType());
|
||||||
|
|
||||||
// Do the actual LU decomp
|
// Do the actual LU decomp
|
||||||
stat = cusolverDnDgetrf(solverDn, M, N, (DoublePointer) xAPointer.getDevicePointer(), M,
|
stat = cusolverDnDgetrf(solverDn, M, N, (DoublePointer) xAPointer.getDevicePointer(), M,
|
||||||
|
@ -218,7 +222,7 @@ public class JcublasLapack extends BaseLapack {
|
||||||
|
|
||||||
// synchronized on the solver
|
// synchronized on the solver
|
||||||
synchronized (handle) {
|
synchronized (handle) {
|
||||||
int result = cusolverDnSetStream(new cusolverDnContext(handle), new CUstream_st(ctx.getOldStream()));
|
int result = cusolverDnSetStream(new cusolverDnContext(handle), new CUstream_st(ctx.getCublasStream()));
|
||||||
if (result != 0)
|
if (result != 0)
|
||||||
throw new IllegalStateException("solverSetStream failed");
|
throw new IllegalStateException("solverSetStream failed");
|
||||||
|
|
||||||
|
@ -227,7 +231,8 @@ public class JcublasLapack extends BaseLapack {
|
||||||
CublasPointer xTauPointer = new CublasPointer(tau, ctx);
|
CublasPointer xTauPointer = new CublasPointer(tau, ctx);
|
||||||
|
|
||||||
// this output - indicates how much memory we'll need for the real operation
|
// this output - indicates how much memory we'll need for the real operation
|
||||||
DataBuffer worksizeBuffer = Nd4j.getDataBufferFactory().createInt(1);
|
val worksizeBuffer = (BaseCudaDataBuffer) Nd4j.getDataBufferFactory().createInt(1);
|
||||||
|
worksizeBuffer.lazyAllocateHostPointer();
|
||||||
|
|
||||||
int stat = cusolverDnSgeqrf_bufferSize(solverDn, M, N,
|
int stat = cusolverDnSgeqrf_bufferSize(solverDn, M, N,
|
||||||
(FloatPointer) xAPointer.getDevicePointer(), M,
|
(FloatPointer) xAPointer.getDevicePointer(), M,
|
||||||
|
@ -333,7 +338,7 @@ public class JcublasLapack extends BaseLapack {
|
||||||
|
|
||||||
// synchronized on the solver
|
// synchronized on the solver
|
||||||
synchronized (handle) {
|
synchronized (handle) {
|
||||||
int result = cusolverDnSetStream(new cusolverDnContext(handle), new CUstream_st(ctx.getOldStream()));
|
int result = cusolverDnSetStream(new cusolverDnContext(handle), new CUstream_st(ctx.getCublasStream()));
|
||||||
if (result != 0)
|
if (result != 0)
|
||||||
throw new BlasException("solverSetStream failed");
|
throw new BlasException("solverSetStream failed");
|
||||||
|
|
||||||
|
@ -342,7 +347,8 @@ public class JcublasLapack extends BaseLapack {
|
||||||
CublasPointer xTauPointer = new CublasPointer(tau, ctx);
|
CublasPointer xTauPointer = new CublasPointer(tau, ctx);
|
||||||
|
|
||||||
// this output - indicates how much memory we'll need for the real operation
|
// this output - indicates how much memory we'll need for the real operation
|
||||||
DataBuffer worksizeBuffer = Nd4j.getDataBufferFactory().createInt(1);
|
val worksizeBuffer = (BaseCudaDataBuffer) Nd4j.getDataBufferFactory().createInt(1);
|
||||||
|
worksizeBuffer.lazyAllocateHostPointer();
|
||||||
|
|
||||||
int stat = cusolverDnDgeqrf_bufferSize(solverDn, M, N,
|
int stat = cusolverDnDgeqrf_bufferSize(solverDn, M, N,
|
||||||
(DoublePointer) xAPointer.getDevicePointer(), M,
|
(DoublePointer) xAPointer.getDevicePointer(), M,
|
||||||
|
@ -441,7 +447,7 @@ public class JcublasLapack extends BaseLapack {
|
||||||
|
|
||||||
// synchronized on the solver
|
// synchronized on the solver
|
||||||
synchronized (handle) {
|
synchronized (handle) {
|
||||||
int result = cusolverDnSetStream(new cusolverDnContext(handle), new CUstream_st(ctx.getOldStream()));
|
int result = cusolverDnSetStream(new cusolverDnContext(handle), new CUstream_st(ctx.getCublasStream()));
|
||||||
if (result != 0)
|
if (result != 0)
|
||||||
throw new BlasException("solverSetStream failed");
|
throw new BlasException("solverSetStream failed");
|
||||||
|
|
||||||
|
@ -449,7 +455,8 @@ public class JcublasLapack extends BaseLapack {
|
||||||
CublasPointer xAPointer = new CublasPointer(a, ctx);
|
CublasPointer xAPointer = new CublasPointer(a, ctx);
|
||||||
|
|
||||||
// this output - indicates how much memory we'll need for the real operation
|
// this output - indicates how much memory we'll need for the real operation
|
||||||
DataBuffer worksizeBuffer = Nd4j.getDataBufferFactory().createInt(1);
|
val worksizeBuffer = (BaseCudaDataBuffer) Nd4j.getDataBufferFactory().createInt(1);
|
||||||
|
worksizeBuffer.lazyAllocateHostPointer();
|
||||||
|
|
||||||
int stat = cusolverDnSpotrf_bufferSize(solverDn, uplo, N,
|
int stat = cusolverDnSpotrf_bufferSize(solverDn, uplo, N,
|
||||||
(FloatPointer) xAPointer.getDevicePointer(), N,
|
(FloatPointer) xAPointer.getDevicePointer(), N,
|
||||||
|
@ -524,7 +531,7 @@ public class JcublasLapack extends BaseLapack {
|
||||||
|
|
||||||
// synchronized on the solver
|
// synchronized on the solver
|
||||||
synchronized (handle) {
|
synchronized (handle) {
|
||||||
int result = cusolverDnSetStream(solverDn, new CUstream_st(ctx.getOldStream()));
|
int result = cusolverDnSetStream(solverDn, new CUstream_st(ctx.getCublasStream()));
|
||||||
if (result != 0)
|
if (result != 0)
|
||||||
throw new BlasException("solverSetStream failed");
|
throw new BlasException("solverSetStream failed");
|
||||||
|
|
||||||
|
@ -532,7 +539,8 @@ public class JcublasLapack extends BaseLapack {
|
||||||
CublasPointer xAPointer = new CublasPointer(a, ctx);
|
CublasPointer xAPointer = new CublasPointer(a, ctx);
|
||||||
|
|
||||||
// this output - indicates how much memory we'll need for the real operation
|
// this output - indicates how much memory we'll need for the real operation
|
||||||
DataBuffer worksizeBuffer = Nd4j.getDataBufferFactory().createInt(1);
|
val worksizeBuffer = (BaseCudaDataBuffer) Nd4j.getDataBufferFactory().createInt(1);
|
||||||
|
worksizeBuffer.lazyAllocateHostPointer();
|
||||||
|
|
||||||
int stat = cusolverDnDpotrf_bufferSize(solverDn, uplo, N,
|
int stat = cusolverDnDpotrf_bufferSize(solverDn, uplo, N,
|
||||||
(DoublePointer) xAPointer.getDevicePointer(), N,
|
(DoublePointer) xAPointer.getDevicePointer(), N,
|
||||||
|
@ -656,7 +664,7 @@ public class JcublasLapack extends BaseLapack {
|
||||||
|
|
||||||
// synchronized on the solver
|
// synchronized on the solver
|
||||||
synchronized (handle) {
|
synchronized (handle) {
|
||||||
int result = cusolverDnSetStream(new cusolverDnContext(handle), new CUstream_st(ctx.getOldStream()));
|
int result = cusolverDnSetStream(new cusolverDnContext(handle), new CUstream_st(ctx.getCublasStream()));
|
||||||
if (result != 0)
|
if (result != 0)
|
||||||
throw new BlasException("solverSetStream failed");
|
throw new BlasException("solverSetStream failed");
|
||||||
|
|
||||||
|
@ -664,7 +672,8 @@ public class JcublasLapack extends BaseLapack {
|
||||||
CublasPointer xAPointer = new CublasPointer(a, ctx);
|
CublasPointer xAPointer = new CublasPointer(a, ctx);
|
||||||
|
|
||||||
// this output - indicates how much memory we'll need for the real operation
|
// this output - indicates how much memory we'll need for the real operation
|
||||||
DataBuffer worksizeBuffer = Nd4j.getDataBufferFactory().createInt(1);
|
val worksizeBuffer = (BaseCudaDataBuffer) Nd4j.getDataBufferFactory().createInt(1);
|
||||||
|
worksizeBuffer.lazyAllocateHostPointer();
|
||||||
|
|
||||||
int stat = cusolverDnSgesvd_bufferSize(solverDn, M, N, (IntPointer) worksizeBuffer.addressPointer() // we intentionally use host pointer here
|
int stat = cusolverDnSgesvd_bufferSize(solverDn, M, N, (IntPointer) worksizeBuffer.addressPointer() // we intentionally use host pointer here
|
||||||
);
|
);
|
||||||
|
@ -765,7 +774,7 @@ public class JcublasLapack extends BaseLapack {
|
||||||
|
|
||||||
// synchronized on the solver
|
// synchronized on the solver
|
||||||
synchronized (handle) {
|
synchronized (handle) {
|
||||||
int result = cusolverDnSetStream(new cusolverDnContext(handle), new CUstream_st(ctx.getOldStream()));
|
int result = cusolverDnSetStream(new cusolverDnContext(handle), new CUstream_st(ctx.getCublasStream()));
|
||||||
if (result != 0)
|
if (result != 0)
|
||||||
throw new BlasException("solverSetStream failed");
|
throw new BlasException("solverSetStream failed");
|
||||||
|
|
||||||
|
@ -773,7 +782,8 @@ public class JcublasLapack extends BaseLapack {
|
||||||
CublasPointer xAPointer = new CublasPointer(a, ctx);
|
CublasPointer xAPointer = new CublasPointer(a, ctx);
|
||||||
|
|
||||||
// this output - indicates how much memory we'll need for the real operation
|
// this output - indicates how much memory we'll need for the real operation
|
||||||
DataBuffer worksizeBuffer = Nd4j.getDataBufferFactory().createInt(1);
|
val worksizeBuffer = (BaseCudaDataBuffer) Nd4j.getDataBufferFactory().createInt(1);
|
||||||
|
worksizeBuffer.lazyAllocateHostPointer();
|
||||||
|
|
||||||
int stat = cusolverDnSgesvd_bufferSize(solverDn, M, N, (IntPointer) worksizeBuffer.addressPointer() // we intentionally use host pointer here
|
int stat = cusolverDnSgesvd_bufferSize(solverDn, M, N, (IntPointer) worksizeBuffer.addressPointer() // we intentionally use host pointer here
|
||||||
);
|
);
|
||||||
|
@ -851,14 +861,16 @@ public class JcublasLapack extends BaseLapack {
|
||||||
|
|
||||||
// synchronized on the solver
|
// synchronized on the solver
|
||||||
synchronized (handle) {
|
synchronized (handle) {
|
||||||
status = cusolverDnSetStream(new cusolverDnContext(handle), new CUstream_st(ctx.getOldStream()));
|
status = cusolverDnSetStream(new cusolverDnContext(handle), new CUstream_st(ctx.getCublasStream()));
|
||||||
if (status == 0) {
|
if (status == 0) {
|
||||||
// transfer the INDArray into GPU memory
|
// transfer the INDArray into GPU memory
|
||||||
CublasPointer xAPointer = new CublasPointer(a, ctx);
|
CublasPointer xAPointer = new CublasPointer(a, ctx);
|
||||||
CublasPointer xRPointer = new CublasPointer(R, ctx);
|
CublasPointer xRPointer = new CublasPointer(R, ctx);
|
||||||
|
|
||||||
// this output - indicates how much memory we'll need for the real operation
|
// this output - indicates how much memory we'll need for the real operation
|
||||||
DataBuffer worksizeBuffer = Nd4j.getDataBufferFactory().createInt(1);
|
val worksizeBuffer = (BaseCudaDataBuffer) Nd4j.getDataBufferFactory().createInt(1);
|
||||||
|
worksizeBuffer.lazyAllocateHostPointer();
|
||||||
|
|
||||||
status = cusolverDnSsyevd_bufferSize(
|
status = cusolverDnSsyevd_bufferSize(
|
||||||
solverDn, jobz, uplo, M,
|
solverDn, jobz, uplo, M,
|
||||||
(FloatPointer) xAPointer.getDevicePointer(), M,
|
(FloatPointer) xAPointer.getDevicePointer(), M,
|
||||||
|
@ -869,7 +881,7 @@ public class JcublasLapack extends BaseLapack {
|
||||||
int worksize = worksizeBuffer.getInt(0);
|
int worksize = worksizeBuffer.getInt(0);
|
||||||
|
|
||||||
// allocate memory for the workspace, the non-converging row buffer and a return code
|
// allocate memory for the workspace, the non-converging row buffer and a return code
|
||||||
Pointer workspace = new Workspace(worksize * 4); //4 = float width
|
val workspace = new Workspace(worksize * 4); //4 = float width
|
||||||
|
|
||||||
INDArray INFO = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt(1),
|
INDArray INFO = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt(1),
|
||||||
Nd4j.getShapeInfoProvider().createShapeInformation(new long[]{1, 1}, A.dataType()));
|
Nd4j.getShapeInfoProvider().createShapeInformation(new long[]{1, 1}, A.dataType()));
|
||||||
|
@ -924,14 +936,16 @@ public class JcublasLapack extends BaseLapack {
|
||||||
|
|
||||||
// synchronized on the solver
|
// synchronized on the solver
|
||||||
synchronized (handle) {
|
synchronized (handle) {
|
||||||
status = cusolverDnSetStream(new cusolverDnContext(handle), new CUstream_st(ctx.getOldStream()));
|
status = cusolverDnSetStream(new cusolverDnContext(handle), new CUstream_st(ctx.getCublasStream()));
|
||||||
if (status == 0) {
|
if (status == 0) {
|
||||||
// transfer the INDArray into GPU memory
|
// transfer the INDArray into GPU memory
|
||||||
CublasPointer xAPointer = new CublasPointer(a, ctx);
|
CublasPointer xAPointer = new CublasPointer(a, ctx);
|
||||||
CublasPointer xRPointer = new CublasPointer(R, ctx);
|
CublasPointer xRPointer = new CublasPointer(R, ctx);
|
||||||
|
|
||||||
// this output - indicates how much memory we'll need for the real operation
|
// this output - indicates how much memory we'll need for the real operation
|
||||||
DataBuffer worksizeBuffer = Nd4j.getDataBufferFactory().createInt(1);
|
val worksizeBuffer = (BaseCudaDataBuffer) Nd4j.getDataBufferFactory().createInt(1);
|
||||||
|
worksizeBuffer.lazyAllocateHostPointer();
|
||||||
|
|
||||||
status = cusolverDnDsyevd_bufferSize(
|
status = cusolverDnDsyevd_bufferSize(
|
||||||
solverDn, jobz, uplo, M,
|
solverDn, jobz, uplo, M,
|
||||||
(DoublePointer) xAPointer.getDevicePointer(), M,
|
(DoublePointer) xAPointer.getDevicePointer(), M,
|
||||||
|
|
|
@ -17,12 +17,11 @@
|
||||||
package org.nd4j.linalg.jcublas.blas;
|
package org.nd4j.linalg.jcublas.blas;
|
||||||
|
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.bytedeco.javacpp.DoublePointer;
|
import org.bytedeco.javacpp.*;
|
||||||
import org.bytedeco.javacpp.FloatPointer;
|
|
||||||
import org.bytedeco.javacpp.IntPointer;
|
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.jita.allocator.Allocator;
|
import org.nd4j.jita.allocator.Allocator;
|
||||||
import org.nd4j.jita.allocator.impl.AtomicAllocator;
|
import org.nd4j.jita.allocator.impl.AtomicAllocator;
|
||||||
|
import org.nd4j.jita.allocator.pointers.CudaPointer;
|
||||||
import org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t;
|
import org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t;
|
||||||
import org.nd4j.linalg.api.blas.impl.BaseLevel1;
|
import org.nd4j.linalg.api.blas.impl.BaseLevel1;
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||||
|
@ -37,6 +36,7 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.jcublas.CublasPointer;
|
import org.nd4j.linalg.jcublas.CublasPointer;
|
||||||
import org.nd4j.linalg.jcublas.context.CudaContext;
|
import org.nd4j.linalg.jcublas.context.CudaContext;
|
||||||
import org.nd4j.linalg.jcublas.ops.executioner.CudaExecutioner;
|
import org.nd4j.linalg.jcublas.ops.executioner.CudaExecutioner;
|
||||||
|
import org.nd4j.nativeblas.LongPointerWrapper;
|
||||||
import org.nd4j.nativeblas.NativeOps;
|
import org.nd4j.nativeblas.NativeOps;
|
||||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||||
import org.nd4j.nativeblas.Nd4jBlas;
|
import org.nd4j.nativeblas.Nd4jBlas;
|
||||||
|
@ -82,30 +82,9 @@ public class JcublasLevel1 extends BaseLevel1 {
|
||||||
Nd4j.getExecutioner().exec(dot);
|
Nd4j.getExecutioner().exec(dot);
|
||||||
|
|
||||||
ret = dot.getFinalResult().floatValue();
|
ret = dot.getFinalResult().floatValue();
|
||||||
/*
|
|
||||||
cublasHandle_t handle = ctx.getHandle();
|
|
||||||
synchronized (handle) {
|
|
||||||
long result = cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream()));
|
|
||||||
if (result != 0)
|
|
||||||
throw new IllegalStateException("cublasSetStream failed");
|
|
||||||
|
|
||||||
FloatPointer resultPointer = new FloatPointer(0.0f);
|
|
||||||
cuBlasSdot_v2(new cublasContext(handle),
|
|
||||||
N,
|
|
||||||
xCPointer.getDevicePointer(),
|
|
||||||
incX,
|
|
||||||
yCPointer.getDevicePointer(),
|
|
||||||
incY, resultPointer);
|
|
||||||
ret = resultPointer.get();
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
// allocator.registerAction(ctx, null, X, Y);
|
|
||||||
|
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected float sdot(long N, INDArray X, int incX, INDArray Y, int incY) {
|
protected float sdot(long N, INDArray X, int incX, INDArray Y, int incY) {
|
||||||
Preconditions.checkArgument(X.dataType() == DataType.FLOAT, "Float dtype expected");
|
Preconditions.checkArgument(X.dataType() == DataType.FLOAT, "Float dtype expected");
|
||||||
|
@ -114,22 +93,27 @@ public class JcublasLevel1 extends BaseLevel1 {
|
||||||
|
|
||||||
Nd4j.getExecutioner().push();
|
Nd4j.getExecutioner().push();
|
||||||
|
|
||||||
CudaContext ctx = allocator.getFlowController().prepareAction(null, X, Y);
|
val ctx = allocator.getFlowController().prepareAction(null, X, Y);
|
||||||
|
|
||||||
float ret = 1f;
|
float ret = 1f;
|
||||||
|
|
||||||
CublasPointer xCPointer = new CublasPointer(X, ctx);
|
val xCPointer = new CublasPointer(X, ctx);
|
||||||
CublasPointer yCPointer = new CublasPointer(Y, ctx);
|
val yCPointer = new CublasPointer(Y, ctx);
|
||||||
|
|
||||||
cublasHandle_t handle = ctx.getHandle();
|
val handle = ctx.getHandle();
|
||||||
|
|
||||||
|
val cctx = new cublasContext(handle);
|
||||||
synchronized (handle) {
|
synchronized (handle) {
|
||||||
long result = cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream()));
|
long result = cublasSetStream_v2(cctx, new CUstream_st(ctx.getCublasStream()));
|
||||||
if (result != 0)
|
if (result != 0)
|
||||||
throw new IllegalStateException("cublasSetStream failed");
|
throw new IllegalStateException("cublasSetStream failed");
|
||||||
|
|
||||||
FloatPointer resultPointer = new FloatPointer(0.0f);
|
val resultPointer = new FloatPointer(0.0f);
|
||||||
result = cublasSdot_v2(new cublasContext(handle), (int) N, (FloatPointer) xCPointer.getDevicePointer(), incX,
|
result = cublasSdot_v2(cctx, (int) N, (FloatPointer) xCPointer.getDevicePointer(), incX, (FloatPointer) yCPointer.getDevicePointer(), incY, resultPointer);
|
||||||
(FloatPointer) yCPointer.getDevicePointer(), incY, resultPointer);
|
|
||||||
|
if (result != 0)
|
||||||
|
throw new IllegalStateException("cublasSdot_v2 failed. Error code: " + result);
|
||||||
|
|
||||||
ret = resultPointer.get();
|
ret = resultPointer.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -155,17 +139,18 @@ public class JcublasLevel1 extends BaseLevel1 {
|
||||||
Nd4j.getExecutioner().push();
|
Nd4j.getExecutioner().push();
|
||||||
|
|
||||||
double ret;
|
double ret;
|
||||||
CudaContext ctx = allocator.getFlowController().prepareAction(null, X, Y);
|
val ctx = allocator.getFlowController().prepareAction(null, X, Y);
|
||||||
|
|
||||||
CublasPointer xCPointer = new CublasPointer(X, ctx);
|
val xCPointer = new CublasPointer(X, ctx);
|
||||||
CublasPointer yCPointer = new CublasPointer(Y, ctx);
|
val yCPointer = new CublasPointer(Y, ctx);
|
||||||
|
|
||||||
cublasHandle_t handle = ctx.getHandle();
|
val handle = ctx.getHandle();
|
||||||
synchronized (handle) {
|
synchronized (handle) {
|
||||||
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream()));
|
val cctx = new cublasContext(handle);
|
||||||
|
cublasSetStream_v2(cctx, new CUstream_st(ctx.getCublasStream()));
|
||||||
|
|
||||||
DoublePointer resultPointer = new DoublePointer(0.0);
|
val resultPointer = new DoublePointer(0.0);
|
||||||
cublasDdot_v2(new cublasContext(handle), (int) N, (DoublePointer) xCPointer.getDevicePointer(), incX,
|
cublasDdot_v2(cctx, (int) N, (DoublePointer) xCPointer.getDevicePointer(), incX,
|
||||||
(DoublePointer) yCPointer.getDevicePointer(), incY, resultPointer);
|
(DoublePointer) yCPointer.getDevicePointer(), incY, resultPointer);
|
||||||
ret = resultPointer.get();
|
ret = resultPointer.get();
|
||||||
}
|
}
|
||||||
|
@ -194,7 +179,7 @@ public class JcublasLevel1 extends BaseLevel1 {
|
||||||
|
|
||||||
cublasHandle_t handle = ctx.getHandle();
|
cublasHandle_t handle = ctx.getHandle();
|
||||||
synchronized (handle) {
|
synchronized (handle) {
|
||||||
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream()));
|
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
|
||||||
|
|
||||||
FloatPointer resultPointer = new FloatPointer(0.0f);
|
FloatPointer resultPointer = new FloatPointer(0.0f);
|
||||||
cublasSnrm2_v2(new cublasContext(handle), (int) N, (FloatPointer) cAPointer.getDevicePointer(), incX,
|
cublasSnrm2_v2(new cublasContext(handle), (int) N, (FloatPointer) cAPointer.getDevicePointer(), incX,
|
||||||
|
@ -226,28 +211,6 @@ public class JcublasLevel1 extends BaseLevel1 {
|
||||||
float ret = asum.getFinalResult().floatValue();
|
float ret = asum.getFinalResult().floatValue();
|
||||||
|
|
||||||
return ret;
|
return ret;
|
||||||
/*
|
|
||||||
CudaContext ctx = allocator.getFlowController().prepareAction(null, X);
|
|
||||||
float ret;
|
|
||||||
|
|
||||||
CublasPointer xCPointer = new CublasPointer(X, ctx);
|
|
||||||
|
|
||||||
cublasHandle_t handle = ctx.getHandle();
|
|
||||||
synchronized (handle) {
|
|
||||||
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream()));
|
|
||||||
|
|
||||||
FloatPointer resultPointer = new FloatPointer(0.0f);
|
|
||||||
cublasSasum_v2(new cublasContext(handle),
|
|
||||||
N,
|
|
||||||
(FloatPointer) xCPointer.getDevicePointer(),
|
|
||||||
incX, resultPointer);
|
|
||||||
ret = resultPointer.get();
|
|
||||||
}
|
|
||||||
|
|
||||||
allocator.registerAction(ctx, null, X);
|
|
||||||
|
|
||||||
return ret;
|
|
||||||
*/
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -274,7 +237,7 @@ public class JcublasLevel1 extends BaseLevel1 {
|
||||||
|
|
||||||
cublasHandle_t handle = ctx.getHandle();
|
cublasHandle_t handle = ctx.getHandle();
|
||||||
synchronized (handle) {
|
synchronized (handle) {
|
||||||
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream()));
|
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
|
||||||
|
|
||||||
DoublePointer resultPointer = new DoublePointer(0.0f);
|
DoublePointer resultPointer = new DoublePointer(0.0f);
|
||||||
cublasDnrm2_v2(new cublasContext(handle), (int) N, (DoublePointer) cAPointer.getDevicePointer(), incX,
|
cublasDnrm2_v2(new cublasContext(handle), (int) N, (DoublePointer) cAPointer.getDevicePointer(), incX,
|
||||||
|
@ -295,26 +258,6 @@ public class JcublasLevel1 extends BaseLevel1 {
|
||||||
double ret = asum.getFinalResult().doubleValue();
|
double ret = asum.getFinalResult().doubleValue();
|
||||||
|
|
||||||
return ret;
|
return ret;
|
||||||
/*CudaContext ctx = allocator.getFlowController().prepareAction(null, X);
|
|
||||||
double ret;
|
|
||||||
|
|
||||||
CublasPointer xCPointer = new CublasPointer(X, ctx);
|
|
||||||
|
|
||||||
cublasHandle_t handle = ctx.getHandle();
|
|
||||||
synchronized (handle) {
|
|
||||||
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream()));
|
|
||||||
|
|
||||||
DoublePointer resultPointer = new DoublePointer(0.0);
|
|
||||||
cublasDasum_v2(new cublasContext(handle),
|
|
||||||
N,
|
|
||||||
xCPointer.getDevicePointer(),
|
|
||||||
incX, resultPointer);
|
|
||||||
ret = resultPointer.get();
|
|
||||||
}
|
|
||||||
allocator.registerAction(ctx, null, X);
|
|
||||||
|
|
||||||
return ret;
|
|
||||||
*/
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -335,7 +278,7 @@ public class JcublasLevel1 extends BaseLevel1 {
|
||||||
|
|
||||||
cublasHandle_t handle = ctx.getHandle();
|
cublasHandle_t handle = ctx.getHandle();
|
||||||
synchronized (handle) {
|
synchronized (handle) {
|
||||||
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream()));
|
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
|
||||||
|
|
||||||
IntPointer resultPointer = new IntPointer(new int[] {0});
|
IntPointer resultPointer = new IntPointer(new int[] {0});
|
||||||
cublasIsamax_v2(new cublasContext(handle), (int) N, (FloatPointer) xCPointer.getDevicePointer(), incX,
|
cublasIsamax_v2(new cublasContext(handle), (int) N, (FloatPointer) xCPointer.getDevicePointer(), incX,
|
||||||
|
@ -365,7 +308,7 @@ public class JcublasLevel1 extends BaseLevel1 {
|
||||||
|
|
||||||
cublasHandle_t handle = ctx.getHandle();
|
cublasHandle_t handle = ctx.getHandle();
|
||||||
synchronized (handle) {
|
synchronized (handle) {
|
||||||
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream()));
|
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
|
||||||
|
|
||||||
IntPointer resultPointer = new IntPointer(new int[] {0});
|
IntPointer resultPointer = new IntPointer(new int[] {0});
|
||||||
cublasIdamax_v2(new cublasContext(handle), (int) N, (DoublePointer) xCPointer.getDevicePointer(), incX,
|
cublasIdamax_v2(new cublasContext(handle), (int) N, (DoublePointer) xCPointer.getDevicePointer(), incX,
|
||||||
|
@ -396,7 +339,7 @@ public class JcublasLevel1 extends BaseLevel1 {
|
||||||
|
|
||||||
cublasHandle_t handle = ctx.getHandle();
|
cublasHandle_t handle = ctx.getHandle();
|
||||||
synchronized (handle) {
|
synchronized (handle) {
|
||||||
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream()));
|
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
|
||||||
|
|
||||||
cublasSswap_v2(new cublasContext(handle), (int) N, (FloatPointer) xCPointer.getDevicePointer(), incX,
|
cublasSswap_v2(new cublasContext(handle), (int) N, (FloatPointer) xCPointer.getDevicePointer(), incX,
|
||||||
(FloatPointer) yCPointer.getDevicePointer(), incY);
|
(FloatPointer) yCPointer.getDevicePointer(), incY);
|
||||||
|
@ -420,7 +363,7 @@ public class JcublasLevel1 extends BaseLevel1 {
|
||||||
|
|
||||||
cublasHandle_t handle = ctx.getHandle();
|
cublasHandle_t handle = ctx.getHandle();
|
||||||
synchronized (handle) {
|
synchronized (handle) {
|
||||||
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream()));
|
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
|
||||||
|
|
||||||
cublasScopy_v2(new cublasContext(handle), (int) N, (FloatPointer) xCPointer.getDevicePointer(), incX,
|
cublasScopy_v2(new cublasContext(handle), (int) N, (FloatPointer) xCPointer.getDevicePointer(), incX,
|
||||||
(FloatPointer) yCPointer.getDevicePointer(), incY);
|
(FloatPointer) yCPointer.getDevicePointer(), incY);
|
||||||
|
@ -443,28 +386,6 @@ public class JcublasLevel1 extends BaseLevel1 {
|
||||||
Nd4j.getExecutioner().exec(new Axpy(X, Y, Y, alpha));
|
Nd4j.getExecutioner().exec(new Axpy(X, Y, Y, alpha));
|
||||||
|
|
||||||
OpExecutionerUtil.checkForAny(Y);
|
OpExecutionerUtil.checkForAny(Y);
|
||||||
/*
|
|
||||||
CublasPointer xAPointer = new CublasPointer(X, ctx);
|
|
||||||
CublasPointer xBPointer = new CublasPointer(Y, ctx);
|
|
||||||
|
|
||||||
cublasHandle_t handle = ctx.getHandle();
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
synchronized (handle) {
|
|
||||||
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream()));
|
|
||||||
|
|
||||||
PointerPointer p = new cublasContext(handle);
|
|
||||||
cublasSaxpy_v2(p,
|
|
||||||
N,
|
|
||||||
alpha,
|
|
||||||
xAPointer.getDevicePointer(),
|
|
||||||
incX,
|
|
||||||
xBPointer.getDevicePointer(),
|
|
||||||
incY);
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
// allocator.registerAction(ctx, Y, X);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -479,21 +400,6 @@ public class JcublasLevel1 extends BaseLevel1 {
|
||||||
((CudaExecutioner) Nd4j.getExecutioner()).exec(new Axpy(X, Y, Y, alpha));
|
((CudaExecutioner) Nd4j.getExecutioner()).exec(new Axpy(X, Y, Y, alpha));
|
||||||
|
|
||||||
OpExecutionerUtil.checkForAny(Y);
|
OpExecutionerUtil.checkForAny(Y);
|
||||||
|
|
||||||
/* synchronized (handle) {
|
|
||||||
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream()));
|
|
||||||
|
|
||||||
PointerPointer p = new cublasContext(handle);
|
|
||||||
cublasSaxpy_v2(p,
|
|
||||||
N,
|
|
||||||
alpha,
|
|
||||||
xAPointer.getDevicePointer(),
|
|
||||||
incX,
|
|
||||||
xBPointer.getDevicePointer(),
|
|
||||||
incY);
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
// allocator.registerAction(ctx, Y, X);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -520,7 +426,7 @@ public class JcublasLevel1 extends BaseLevel1 {
|
||||||
|
|
||||||
cublasHandle_t handle = ctx.getHandle();
|
cublasHandle_t handle = ctx.getHandle();
|
||||||
synchronized (handle) {
|
synchronized (handle) {
|
||||||
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream()));
|
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
|
||||||
|
|
||||||
cublasDswap_v2(new cublasContext(handle), (int) N, (DoublePointer) xCPointer.getDevicePointer(), incX,
|
cublasDswap_v2(new cublasContext(handle), (int) N, (DoublePointer) xCPointer.getDevicePointer(), incX,
|
||||||
(DoublePointer) yCPointer.getDevicePointer(), incY);
|
(DoublePointer) yCPointer.getDevicePointer(), incY);
|
||||||
|
@ -542,7 +448,7 @@ public class JcublasLevel1 extends BaseLevel1 {
|
||||||
|
|
||||||
cublasHandle_t handle = ctx.getHandle();
|
cublasHandle_t handle = ctx.getHandle();
|
||||||
synchronized (handle) {
|
synchronized (handle) {
|
||||||
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream()));
|
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
|
||||||
|
|
||||||
cublasDcopy_v2(new cublasContext(handle), (int) N, (DoublePointer) xCPointer.getDevicePointer(), incX,
|
cublasDcopy_v2(new cublasContext(handle), (int) N, (DoublePointer) xCPointer.getDevicePointer(), incX,
|
||||||
(DoublePointer) yCPointer.getDevicePointer(), incY);
|
(DoublePointer) yCPointer.getDevicePointer(), incY);
|
||||||
|
@ -568,22 +474,6 @@ public class JcublasLevel1 extends BaseLevel1 {
|
||||||
Nd4j.getExecutioner().exec(new Axpy(X, Y, Y, alpha));
|
Nd4j.getExecutioner().exec(new Axpy(X, Y, Y, alpha));
|
||||||
|
|
||||||
OpExecutionerUtil.checkForAny(Y);
|
OpExecutionerUtil.checkForAny(Y);
|
||||||
|
|
||||||
/*
|
|
||||||
CublasPointer xAPointer = new CublasPointer(X, ctx);
|
|
||||||
CublasPointer xBPointer = new CublasPointer(Y, ctx);
|
|
||||||
|
|
||||||
cublasHandle_t handle = ctx.getHandle();
|
|
||||||
synchronized (handle) {
|
|
||||||
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream()));
|
|
||||||
|
|
||||||
cublasDaxpy_v2(new cublasContext(handle),
|
|
||||||
N, alpha, xAPointer.getDevicePointer(),
|
|
||||||
incX, xBPointer.getDevicePointer(),
|
|
||||||
incY);
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
// allocator.registerAction(ctx, Y, X);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -652,7 +542,7 @@ public class JcublasLevel1 extends BaseLevel1 {
|
||||||
|
|
||||||
cublasHandle_t handle = ctx.getHandle();
|
cublasHandle_t handle = ctx.getHandle();
|
||||||
synchronized (handle) {
|
synchronized (handle) {
|
||||||
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream()));
|
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
|
||||||
|
|
||||||
cublasSscal_v2(new cublasContext(handle),(int) N, new FloatPointer(alpha),
|
cublasSscal_v2(new cublasContext(handle),(int) N, new FloatPointer(alpha),
|
||||||
(FloatPointer) xCPointer.getDevicePointer(), incX);
|
(FloatPointer) xCPointer.getDevicePointer(), incX);
|
||||||
|
@ -675,7 +565,7 @@ public class JcublasLevel1 extends BaseLevel1 {
|
||||||
|
|
||||||
cublasHandle_t handle = ctx.getHandle();
|
cublasHandle_t handle = ctx.getHandle();
|
||||||
synchronized (handle) {
|
synchronized (handle) {
|
||||||
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream()));
|
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
|
||||||
|
|
||||||
cublasDscal_v2(new cublasContext(handle), (int) N, new DoublePointer(alpha),
|
cublasDscal_v2(new cublasContext(handle), (int) N, new DoublePointer(alpha),
|
||||||
(DoublePointer) xCPointer.getDevicePointer(), incX);
|
(DoublePointer) xCPointer.getDevicePointer(), incX);
|
||||||
|
|
|
@ -64,7 +64,7 @@ public class JcublasLevel2 extends BaseLevel2 {
|
||||||
|
|
||||||
cublasHandle_t handle = ctx.getHandle();
|
cublasHandle_t handle = ctx.getHandle();
|
||||||
synchronized (handle) {
|
synchronized (handle) {
|
||||||
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream()));
|
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
|
||||||
|
|
||||||
cublasSgemv_v2(new cublasContext(handle), convertTranspose(TransA), M, N, new FloatPointer(alpha),
|
cublasSgemv_v2(new cublasContext(handle), convertTranspose(TransA), M, N, new FloatPointer(alpha),
|
||||||
(FloatPointer) cAPointer.getDevicePointer(), lda,
|
(FloatPointer) cAPointer.getDevicePointer(), lda,
|
||||||
|
@ -136,7 +136,7 @@ public class JcublasLevel2 extends BaseLevel2 {
|
||||||
|
|
||||||
cublasHandle_t handle = ctx.getHandle();
|
cublasHandle_t handle = ctx.getHandle();
|
||||||
synchronized (handle) {
|
synchronized (handle) {
|
||||||
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream()));
|
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
|
||||||
|
|
||||||
cublasDgemv_v2(new cublasContext(handle), convertTranspose(TransA), M, N, new DoublePointer(alpha),
|
cublasDgemv_v2(new cublasContext(handle), convertTranspose(TransA), M, N, new DoublePointer(alpha),
|
||||||
(DoublePointer) cAPointer.getDevicePointer(), lda,
|
(DoublePointer) cAPointer.getDevicePointer(), lda,
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.nd4j.linalg.jcublas.blas;
|
package org.nd4j.linalg.jcublas.blas;
|
||||||
|
|
||||||
|
|
||||||
|
import lombok.val;
|
||||||
import org.bytedeco.javacpp.DoublePointer;
|
import org.bytedeco.javacpp.DoublePointer;
|
||||||
import org.bytedeco.javacpp.FloatPointer;
|
import org.bytedeco.javacpp.FloatPointer;
|
||||||
import org.bytedeco.javacpp.ShortPointer;
|
import org.bytedeco.javacpp.ShortPointer;
|
||||||
|
@ -73,7 +74,7 @@ public class JcublasLevel3 extends BaseLevel3 {
|
||||||
|
|
||||||
cublasHandle_t handle = ctx.getHandle();
|
cublasHandle_t handle = ctx.getHandle();
|
||||||
synchronized (handle) {
|
synchronized (handle) {
|
||||||
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream()));
|
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
|
||||||
|
|
||||||
int arch = CudaEnvironment.getInstance().getCurrentDeviceArchitecture();
|
int arch = CudaEnvironment.getInstance().getCurrentDeviceArchitecture();
|
||||||
|
|
||||||
|
@ -111,15 +112,15 @@ public class JcublasLevel3 extends BaseLevel3 {
|
||||||
|
|
||||||
Nd4j.getExecutioner().push();
|
Nd4j.getExecutioner().push();
|
||||||
|
|
||||||
CudaContext ctx = allocator.getFlowController().prepareAction(C, A, B);
|
val ctx = allocator.getFlowController().prepareAction(C, A, B);
|
||||||
|
|
||||||
CublasPointer cAPointer = new CublasPointer(A, ctx);
|
val cAPointer = new CublasPointer(A, ctx);
|
||||||
CublasPointer cBPointer = new CublasPointer(B, ctx);
|
val cBPointer = new CublasPointer(B, ctx);
|
||||||
CublasPointer cCPointer = new CublasPointer(C, ctx);
|
val cCPointer = new CublasPointer(C, ctx);
|
||||||
|
|
||||||
cublasHandle_t handle = ctx.getHandle();
|
val handle = ctx.getHandle();
|
||||||
synchronized (handle) {
|
synchronized (handle) {
|
||||||
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream()));
|
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
|
||||||
|
|
||||||
cublasSgemm_v2(new cublasContext(handle), convertTranspose(TransA), convertTranspose(TransB), M, N, K,
|
cublasSgemm_v2(new cublasContext(handle), convertTranspose(TransA), convertTranspose(TransB), M, N, K,
|
||||||
new FloatPointer(alpha), (FloatPointer) cAPointer.getDevicePointer(), lda,
|
new FloatPointer(alpha), (FloatPointer) cAPointer.getDevicePointer(), lda,
|
||||||
|
@ -145,7 +146,7 @@ public class JcublasLevel3 extends BaseLevel3 {
|
||||||
|
|
||||||
cublasHandle_t handle = ctx.getHandle();
|
cublasHandle_t handle = ctx.getHandle();
|
||||||
synchronized (handle) {
|
synchronized (handle) {
|
||||||
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream()));
|
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
|
||||||
|
|
||||||
cublasSsymm_v2(new cublasContext(handle), convertSideMode(Side), convertUplo(Uplo), M, N,
|
cublasSsymm_v2(new cublasContext(handle), convertSideMode(Side), convertUplo(Uplo), M, N,
|
||||||
new FloatPointer(alpha), (FloatPointer) aPointer.getDevicePointer(), lda,
|
new FloatPointer(alpha), (FloatPointer) aPointer.getDevicePointer(), lda,
|
||||||
|
@ -170,7 +171,7 @@ public class JcublasLevel3 extends BaseLevel3 {
|
||||||
|
|
||||||
cublasHandle_t handle = ctx.getHandle();
|
cublasHandle_t handle = ctx.getHandle();
|
||||||
synchronized (handle) {
|
synchronized (handle) {
|
||||||
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream()));
|
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
|
||||||
|
|
||||||
cublasSsyrk_v2(new cublasContext(handle), convertUplo(Uplo), convertTranspose(Trans), N, K,
|
cublasSsyrk_v2(new cublasContext(handle), convertUplo(Uplo), convertTranspose(Trans), N, K,
|
||||||
new FloatPointer(alpha), (FloatPointer) aPointer.getDevicePointer(), lda,
|
new FloatPointer(alpha), (FloatPointer) aPointer.getDevicePointer(), lda,
|
||||||
|
@ -207,7 +208,7 @@ public class JcublasLevel3 extends BaseLevel3 {
|
||||||
|
|
||||||
cublasHandle_t handle = ctx.getHandle();
|
cublasHandle_t handle = ctx.getHandle();
|
||||||
synchronized (handle) {
|
synchronized (handle) {
|
||||||
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream()));
|
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
|
||||||
|
|
||||||
cublasStrsm_v2(new cublasContext(handle), convertSideMode(Side), convertUplo(Uplo),
|
cublasStrsm_v2(new cublasContext(handle), convertSideMode(Side), convertUplo(Uplo),
|
||||||
convertTranspose(TransA), convertDiag(Diag), M, N, new FloatPointer(alpha),
|
convertTranspose(TransA), convertDiag(Diag), M, N, new FloatPointer(alpha),
|
||||||
|
@ -227,17 +228,17 @@ public class JcublasLevel3 extends BaseLevel3 {
|
||||||
|
|
||||||
Nd4j.getExecutioner().push();
|
Nd4j.getExecutioner().push();
|
||||||
|
|
||||||
CudaContext ctx = allocator.getFlowController().prepareAction(C, A, B);
|
val ctx = allocator.getFlowController().prepareAction(C, A, B);
|
||||||
|
|
||||||
DataTypeValidation.assertDouble(A, B, C);
|
DataTypeValidation.assertDouble(A, B, C);
|
||||||
|
|
||||||
CublasPointer cAPointer = new CublasPointer(A, ctx);
|
val cAPointer = new CublasPointer(A, ctx);
|
||||||
CublasPointer cBPointer = new CublasPointer(B, ctx);
|
val cBPointer = new CublasPointer(B, ctx);
|
||||||
CublasPointer cCPointer = new CublasPointer(C, ctx);
|
val cCPointer = new CublasPointer(C, ctx);
|
||||||
|
|
||||||
cublasHandle_t handle = ctx.getHandle();
|
val handle = ctx.getHandle();
|
||||||
synchronized (handle) {
|
synchronized (handle) {
|
||||||
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream()));
|
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
|
||||||
|
|
||||||
cublasDgemm_v2(new cublasContext(handle), convertTranspose(TransA), convertTranspose(TransB), M, N, K,
|
cublasDgemm_v2(new cublasContext(handle), convertTranspose(TransA), convertTranspose(TransB), M, N, K,
|
||||||
new DoublePointer(alpha), (DoublePointer) cAPointer.getDevicePointer(), lda,
|
new DoublePointer(alpha), (DoublePointer) cAPointer.getDevicePointer(), lda,
|
||||||
|
@ -262,7 +263,7 @@ public class JcublasLevel3 extends BaseLevel3 {
|
||||||
|
|
||||||
cublasHandle_t handle = ctx.getHandle();
|
cublasHandle_t handle = ctx.getHandle();
|
||||||
synchronized (handle) {
|
synchronized (handle) {
|
||||||
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream()));
|
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
|
||||||
|
|
||||||
cublasDsymm_v2(new cublasContext(handle), convertSideMode(Side), convertUplo(Uplo), M, N,
|
cublasDsymm_v2(new cublasContext(handle), convertSideMode(Side), convertUplo(Uplo), M, N,
|
||||||
new DoublePointer(alpha), (DoublePointer) aPointer.getDevicePointer(), lda,
|
new DoublePointer(alpha), (DoublePointer) aPointer.getDevicePointer(), lda,
|
||||||
|
@ -287,7 +288,7 @@ public class JcublasLevel3 extends BaseLevel3 {
|
||||||
|
|
||||||
cublasHandle_t handle = ctx.getHandle();
|
cublasHandle_t handle = ctx.getHandle();
|
||||||
synchronized (handle) {
|
synchronized (handle) {
|
||||||
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream()));
|
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
|
||||||
|
|
||||||
cublasDsyrk_v2(new cublasContext(handle), convertUplo(Uplo), Trans, N, K, new DoublePointer(alpha),
|
cublasDsyrk_v2(new cublasContext(handle), convertUplo(Uplo), Trans, N, K, new DoublePointer(alpha),
|
||||||
(DoublePointer) aPointer.getDevicePointer(), lda, new DoublePointer(beta),
|
(DoublePointer) aPointer.getDevicePointer(), lda, new DoublePointer(beta),
|
||||||
|
@ -312,7 +313,7 @@ public class JcublasLevel3 extends BaseLevel3 {
|
||||||
|
|
||||||
cublasHandle_t handle = ctx.getHandle();
|
cublasHandle_t handle = ctx.getHandle();
|
||||||
synchronized (handle) {
|
synchronized (handle) {
|
||||||
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream()));
|
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
|
||||||
|
|
||||||
cublasDsyr2k_v2(new cublasContext(handle), convertUplo(Uplo), Trans, N, K, new DoublePointer(alpha),
|
cublasDsyr2k_v2(new cublasContext(handle), convertUplo(Uplo), Trans, N, K, new DoublePointer(alpha),
|
||||||
(DoublePointer) aPointer.getDevicePointer(), lda,
|
(DoublePointer) aPointer.getDevicePointer(), lda,
|
||||||
|
@ -337,7 +338,7 @@ public class JcublasLevel3 extends BaseLevel3 {
|
||||||
|
|
||||||
cublasHandle_t handle = ctx.getHandle();
|
cublasHandle_t handle = ctx.getHandle();
|
||||||
synchronized (handle) {
|
synchronized (handle) {
|
||||||
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream()));
|
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
|
||||||
|
|
||||||
cublasDtrmm_v2(new cublasContext(handle), convertSideMode(Side), convertUplo(Uplo),
|
cublasDtrmm_v2(new cublasContext(handle), convertSideMode(Side), convertUplo(Uplo),
|
||||||
convertTranspose(TransA), convertDiag(Diag), M, N, new DoublePointer(alpha),
|
convertTranspose(TransA), convertDiag(Diag), M, N, new DoublePointer(alpha),
|
||||||
|
@ -363,7 +364,7 @@ public class JcublasLevel3 extends BaseLevel3 {
|
||||||
|
|
||||||
cublasHandle_t handle = ctx.getHandle();
|
cublasHandle_t handle = ctx.getHandle();
|
||||||
synchronized (handle) {
|
synchronized (handle) {
|
||||||
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream()));
|
cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getCublasStream()));
|
||||||
|
|
||||||
cublasDtrsm_v2(new cublasContext(handle), convertSideMode(Side), convertUplo(Uplo),
|
cublasDtrsm_v2(new cublasContext(handle), convertSideMode(Side), convertUplo(Uplo),
|
||||||
convertTranspose(TransA), convertDiag(Diag), M, N, new DoublePointer(alpha),
|
convertTranspose(TransA), convertDiag(Diag), M, N, new DoublePointer(alpha),
|
||||||
|
|
|
@ -241,7 +241,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
|
||||||
initPointers(length, Nd4j.sizeOfDataType(dtype), initialize);
|
initPointers(length, Nd4j.sizeOfDataType(dtype), initialize);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void lazyAllocateHostPointer() {
|
public void lazyAllocateHostPointer() {
|
||||||
if (allocationPoint.getPointers().getHostPointer() == null)
|
if (allocationPoint.getPointers().getHostPointer() == null)
|
||||||
initHostPointerAndIndexer();
|
initHostPointerAndIndexer();
|
||||||
}
|
}
|
||||||
|
|
|
@ -42,6 +42,10 @@ public class CudaBfloat16DataBuffer extends BaseCudaDataBuffer {
|
||||||
super(pointer, indexer, length);
|
super(pointer, indexer, length);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public CudaBfloat16DataBuffer(Pointer pointer, Pointer specialPointer, Indexer indexer, long length){
|
||||||
|
super(pointer, specialPointer, indexer, length);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Base constructor
|
* Base constructor
|
||||||
*
|
*
|
||||||
|
|
|
@ -42,6 +42,10 @@ public class CudaUInt16DataBuffer extends BaseCudaDataBuffer {
|
||||||
super(pointer, indexer, length);
|
super(pointer, indexer, length);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public CudaUInt16DataBuffer(Pointer pointer, Pointer specialPointer, Indexer indexer, long length){
|
||||||
|
super(pointer, specialPointer, indexer, length);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Base constructor
|
* Base constructor
|
||||||
*
|
*
|
||||||
|
|
|
@ -42,6 +42,10 @@ public class CudaUInt32DataBuffer extends BaseCudaDataBuffer {
|
||||||
super(pointer, indexer, length);
|
super(pointer, indexer, length);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public CudaUInt32DataBuffer(Pointer pointer, Pointer specialPointer, Indexer indexer, long length){
|
||||||
|
super(pointer, specialPointer, indexer, length);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Base constructor
|
* Base constructor
|
||||||
*
|
*
|
||||||
|
|
|
@ -100,6 +100,10 @@ public class CudaUInt64DataBuffer extends BaseCudaDataBuffer {
|
||||||
super(data, copy, offset, workspace);
|
super(data, copy, offset, workspace);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public CudaUInt64DataBuffer(Pointer pointer, Pointer specialPointer, Indexer indexer, long length){
|
||||||
|
super(pointer, specialPointer, indexer, length);
|
||||||
|
}
|
||||||
|
|
||||||
public CudaUInt64DataBuffer(double[] data) {
|
public CudaUInt64DataBuffer(double[] data) {
|
||||||
super(data);
|
super(data);
|
||||||
}
|
}
|
||||||
|
|
|
@ -72,10 +72,18 @@ public class CudaDataBufferFactory implements DataBufferFactory {
|
||||||
return new CudaFloatDataBuffer(underlyingBuffer, length, offset);
|
return new CudaFloatDataBuffer(underlyingBuffer, length, offset);
|
||||||
case HALF:
|
case HALF:
|
||||||
return new CudaHalfDataBuffer(underlyingBuffer, length, offset);
|
return new CudaHalfDataBuffer(underlyingBuffer, length, offset);
|
||||||
|
case BFLOAT16:
|
||||||
|
return new CudaBfloat16DataBuffer(underlyingBuffer, length, offset);
|
||||||
|
case UINT64:
|
||||||
|
return new CudaUInt64DataBuffer(underlyingBuffer, length, offset);
|
||||||
case LONG:
|
case LONG:
|
||||||
return new CudaLongDataBuffer(underlyingBuffer, length, offset);
|
return new CudaLongDataBuffer(underlyingBuffer, length, offset);
|
||||||
|
case UINT32:
|
||||||
|
return new CudaUInt32DataBuffer(underlyingBuffer, length, offset);
|
||||||
case INT:
|
case INT:
|
||||||
return new CudaIntDataBuffer(underlyingBuffer, length, offset);
|
return new CudaIntDataBuffer(underlyingBuffer, length, offset);
|
||||||
|
case UINT16:
|
||||||
|
return new CudaUInt16DataBuffer(underlyingBuffer, length, offset);
|
||||||
case SHORT:
|
case SHORT:
|
||||||
return new CudaShortDataBuffer(underlyingBuffer, length, offset);
|
return new CudaShortDataBuffer(underlyingBuffer, length, offset);
|
||||||
case UBYTE:
|
case UBYTE:
|
||||||
|
@ -684,16 +692,32 @@ public class CudaDataBufferFactory implements DataBufferFactory {
|
||||||
@Override
|
@Override
|
||||||
public DataBuffer create(Pointer pointer, DataType type, long length, Indexer indexer) {
|
public DataBuffer create(Pointer pointer, DataType type, long length, Indexer indexer) {
|
||||||
switch (type) {
|
switch (type) {
|
||||||
|
case UINT64:
|
||||||
|
return new CudaUInt64DataBuffer(pointer, indexer, length);
|
||||||
case LONG:
|
case LONG:
|
||||||
return new CudaLongDataBuffer(pointer, indexer, length);
|
return new CudaLongDataBuffer(pointer, indexer, length);
|
||||||
|
case UINT32:
|
||||||
|
return new CudaUInt32DataBuffer(pointer, indexer, length);
|
||||||
case INT:
|
case INT:
|
||||||
return new CudaIntDataBuffer(pointer, indexer, length);
|
return new CudaIntDataBuffer(pointer, indexer, length);
|
||||||
|
case UINT16:
|
||||||
|
return new CudaUInt16DataBuffer(pointer, indexer, length);
|
||||||
|
case SHORT:
|
||||||
|
return new CudaShortDataBuffer(pointer, indexer, length);
|
||||||
|
case UBYTE:
|
||||||
|
return new CudaUByteDataBuffer(pointer, indexer, length);
|
||||||
|
case BYTE:
|
||||||
|
return new CudaByteDataBuffer(pointer, indexer, length);
|
||||||
case DOUBLE:
|
case DOUBLE:
|
||||||
return new CudaDoubleDataBuffer(pointer, indexer, length);
|
return new CudaDoubleDataBuffer(pointer, indexer, length);
|
||||||
case FLOAT:
|
case FLOAT:
|
||||||
return new CudaFloatDataBuffer(pointer, indexer, length);
|
return new CudaFloatDataBuffer(pointer, indexer, length);
|
||||||
case HALF:
|
case HALF:
|
||||||
return new CudaHalfDataBuffer(pointer, indexer, length);
|
return new CudaHalfDataBuffer(pointer, indexer, length);
|
||||||
|
case BFLOAT16:
|
||||||
|
return new CudaBfloat16DataBuffer(pointer, indexer, length);
|
||||||
|
case BOOL:
|
||||||
|
return new CudaBoolDataBuffer(pointer, indexer, length);
|
||||||
}
|
}
|
||||||
|
|
||||||
throw new IllegalArgumentException("Illegal dtype " + type);
|
throw new IllegalArgumentException("Illegal dtype " + type);
|
||||||
|
@ -702,16 +726,32 @@ public class CudaDataBufferFactory implements DataBufferFactory {
|
||||||
@Override
|
@Override
|
||||||
public DataBuffer create(Pointer pointer, Pointer specialPointer, DataType type, long length, Indexer indexer) {
|
public DataBuffer create(Pointer pointer, Pointer specialPointer, DataType type, long length, Indexer indexer) {
|
||||||
switch (type) {
|
switch (type) {
|
||||||
|
case UINT64:
|
||||||
|
return new CudaUInt64DataBuffer(pointer, specialPointer, indexer, length);
|
||||||
case LONG:
|
case LONG:
|
||||||
return new CudaLongDataBuffer(pointer, specialPointer, indexer, length);
|
return new CudaLongDataBuffer(pointer, specialPointer, indexer, length);
|
||||||
|
case UINT32:
|
||||||
|
return new CudaUInt32DataBuffer(pointer, specialPointer, indexer, length);
|
||||||
case INT:
|
case INT:
|
||||||
return new CudaIntDataBuffer(pointer, specialPointer, indexer, length);
|
return new CudaIntDataBuffer(pointer, specialPointer, indexer, length);
|
||||||
|
case UINT16:
|
||||||
|
return new CudaUInt16DataBuffer(pointer, specialPointer, indexer, length);
|
||||||
|
case SHORT:
|
||||||
|
return new CudaShortDataBuffer(pointer, specialPointer, indexer, length);
|
||||||
|
case UBYTE:
|
||||||
|
return new CudaUByteDataBuffer(pointer, specialPointer, indexer, length);
|
||||||
|
case BYTE:
|
||||||
|
return new CudaByteDataBuffer(pointer, specialPointer, indexer, length);
|
||||||
case DOUBLE:
|
case DOUBLE:
|
||||||
return new CudaDoubleDataBuffer(pointer, specialPointer, indexer, length);
|
return new CudaDoubleDataBuffer(pointer, specialPointer, indexer, length);
|
||||||
case FLOAT:
|
case FLOAT:
|
||||||
return new CudaFloatDataBuffer(pointer, specialPointer, indexer, length);
|
return new CudaFloatDataBuffer(pointer, specialPointer, indexer, length);
|
||||||
case HALF:
|
case HALF:
|
||||||
return new CudaHalfDataBuffer(pointer, specialPointer, indexer, length);
|
return new CudaHalfDataBuffer(pointer, specialPointer, indexer, length);
|
||||||
|
case BFLOAT16:
|
||||||
|
return new CudaBfloat16DataBuffer(pointer, specialPointer, indexer, length);
|
||||||
|
case BOOL:
|
||||||
|
return new CudaBoolDataBuffer(pointer, specialPointer, indexer, length);
|
||||||
}
|
}
|
||||||
|
|
||||||
throw new IllegalArgumentException("Illegal dtype " + type);
|
throw new IllegalArgumentException("Illegal dtype " + type);
|
||||||
|
|
|
@ -17,9 +17,13 @@
|
||||||
package org.nd4j.linalg.jcublas.context;
|
package org.nd4j.linalg.jcublas.context;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.val;
|
||||||
|
import org.bytedeco.javacpp.LongPointer;
|
||||||
import org.bytedeco.javacpp.Pointer;
|
import org.bytedeco.javacpp.Pointer;
|
||||||
|
import org.bytedeco.javacpp.PointerPointer;
|
||||||
import org.nd4j.jita.allocator.garbage.GarbageResourceReference;
|
import org.nd4j.jita.allocator.garbage.GarbageResourceReference;
|
||||||
import org.nd4j.jita.allocator.impl.AtomicAllocator;
|
import org.nd4j.jita.allocator.impl.AtomicAllocator;
|
||||||
|
import org.nd4j.jita.allocator.pointers.CudaPointer;
|
||||||
import org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t;
|
import org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t;
|
||||||
import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t;
|
import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t;
|
||||||
import org.nd4j.jita.allocator.pointers.cuda.cusolverDnHandle_t;
|
import org.nd4j.jita.allocator.pointers.cuda.cusolverDnHandle_t;
|
||||||
|
@ -46,7 +50,6 @@ public class CudaContext {
|
||||||
//private CUevent cUevent;
|
//private CUevent cUevent;
|
||||||
private cudaStream_t oldStream;
|
private cudaStream_t oldStream;
|
||||||
|
|
||||||
private cudaStream_t cublasStream;
|
|
||||||
private cudaStream_t solverStream;
|
private cudaStream_t solverStream;
|
||||||
|
|
||||||
private cudaStream_t specialStream;
|
private cudaStream_t specialStream;
|
||||||
|
@ -130,17 +133,11 @@ public class CudaContext {
|
||||||
// ContextHolder.getInstance().setContext();
|
// ContextHolder.getInstance().setContext();
|
||||||
if (nativeOps.streamSynchronize(oldStream) == 0)
|
if (nativeOps.streamSynchronize(oldStream) == 0)
|
||||||
throw new ND4JIllegalStateException("CUDA stream synchronization failed");
|
throw new ND4JIllegalStateException("CUDA stream synchronization failed");
|
||||||
|
|
||||||
if (syncCuBlas)
|
|
||||||
syncCublasStream();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void syncCublasStream() {
|
public Pointer getCublasStream() {
|
||||||
if (cublasStream != null) {
|
val lptr = new PointerPointer(this.getOldStream());
|
||||||
if (nativeOps.streamSynchronize(cublasStream) == 0)
|
return lptr.get(0);
|
||||||
throw new ND4JIllegalStateException("CUDA stream synchronization failed");
|
|
||||||
} else
|
|
||||||
throw new IllegalStateException("cuBLAS stream isnt set");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1984,13 +1984,6 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
AtomicAllocator.getInstance().getAllocationPoint(encodedBuffer).tickDeviceWrite();
|
AtomicAllocator.getInstance().getAllocationPoint(encodedBuffer).tickDeviceWrite();
|
||||||
AtomicAllocator.getInstance().getAllocationPoint(buffer).tickDeviceWrite();
|
AtomicAllocator.getInstance().getAllocationPoint(buffer).tickDeviceWrite();
|
||||||
|
|
||||||
|
|
||||||
// just to ensure it's not purged
|
|
||||||
extras.address();
|
|
||||||
tempX.address();
|
|
||||||
buffers.getClass();
|
|
||||||
|
|
||||||
|
|
||||||
return Nd4j.createArrayFromShapeBuffer(encodedBuffer, input.shapeInfoDataBuffer());
|
return Nd4j.createArrayFromShapeBuffer(encodedBuffer, input.shapeInfoDataBuffer());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2171,14 +2164,17 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
return Collections.emptyList();
|
return Collections.emptyList();
|
||||||
}
|
}
|
||||||
|
|
||||||
val inputBuffers = new PointerPointer<>(op.inputArguments().length);
|
val inputBuffers = new PointerPointer<>(op.inputArguments().length * 2);
|
||||||
val inputShapes = new PointerPointer<>(op.inputArguments().length);
|
val inputShapes = new PointerPointer<>(op.inputArguments().length);
|
||||||
|
|
||||||
int cnt= 0;
|
int cnt= 0;
|
||||||
for (val in: op.inputArguments()) {
|
for (val in: op.inputArguments()) {
|
||||||
// NOT A TYPO: shape functions work on host side only
|
// NOT A TYPO: shape functions work on host side only
|
||||||
if (!in.isEmpty())
|
if (!in.isEmpty()) {
|
||||||
inputBuffers.put(cnt, in.data().addressPointer());
|
inputBuffers.put(cnt, in.data().addressPointer());
|
||||||
|
inputBuffers.put(cnt + op.inputArguments().length, AtomicAllocator.getInstance().getPointer(in.data()));
|
||||||
|
}
|
||||||
|
|
||||||
inputShapes.put(cnt++, in.shapeInfoDataBuffer().addressPointer());
|
inputShapes.put(cnt++, in.shapeInfoDataBuffer().addressPointer());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -4394,9 +4394,9 @@ public static class NativeOps extends org.nd4j.nativeblas.NativeOps {
|
||||||
* - down from main diagonal starting at subdiagonal number "lower" if direction = 'd' (down) or 'b' (both)
|
* - down from main diagonal starting at subdiagonal number "lower" if direction = 'd' (down) or 'b' (both)
|
||||||
* - up from main diagonal starting at superdiagonal number "upper"if direction = 'u' (up) or 'b' (both)
|
* - up from main diagonal starting at superdiagonal number "upper"if direction = 'u' (up) or 'b' (both)
|
||||||
* direction - in what direction to fill matrix. There are 3 possible directions:
|
* direction - in what direction to fill matrix. There are 3 possible directions:
|
||||||
* 'u' - fill up, mathematically this corresponds to lower triangular matrix, parameter "lower" is not taken into account
|
* 'u' - fill up, mathematically this corresponds to lower triangular matrix, subdiagonal "lower" unaffected
|
||||||
* 'l' - fill down, mathematically this corresponds to upper triangular matrix, parameter "upper" is not taken into account
|
* 'l' - fill down, mathematically this corresponds to upper triangular matrix, superdiagonal "upper" remains unaffected
|
||||||
* 'b' - fill in both directions, both parameters "lower" and "upper" are taken into account
|
* 'b' - fill in both directions, both "lower" and "upper" are taken into account
|
||||||
* rest of target elements are equal to this array elements
|
* rest of target elements are equal to this array elements
|
||||||
* target and this array should have same shapes, except when this_rank = 1 (in that case should be target_rank = 2)
|
* target and this array should have same shapes, except when this_rank = 1 (in that case should be target_rank = 2)
|
||||||
*/
|
*/
|
||||||
|
@ -7857,9 +7857,18 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
* @param indices the indices to iterate over
|
* @param indices the indices to iterate over
|
||||||
* @return the double at the specified index
|
* @return the double at the specified index
|
||||||
*/
|
*/
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("Nd4jLong") long baseOffset, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer stride, @Cast("const Nd4jLong*") LongPointer indices,int rank);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("Nd4jLong") long baseOffset, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer stride, @Cast("const Nd4jLong*") LongPointer indices, int rank);
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("Nd4jLong") long baseOffset, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer stride, @Cast("const Nd4jLong*") LongBuffer indices,int rank);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("Nd4jLong") long baseOffset, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer stride, @Cast("const Nd4jLong*") LongBuffer indices, int rank);
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("Nd4jLong") long baseOffset, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] stride, @Cast("const Nd4jLong*") long[] indices,int rank);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("Nd4jLong") long baseOffset, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] stride, @Cast("const Nd4jLong*") long[] indices, int rank);
|
||||||
|
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer indices, @Cast("Nd4jLong") long baseOffset/*=0*/);
|
||||||
|
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer indices);
|
||||||
|
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer indices, @Cast("Nd4jLong") long baseOffset/*=0*/);
|
||||||
|
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer indices);
|
||||||
|
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] indices, @Cast("Nd4jLong") long baseOffset/*=0*/);
|
||||||
|
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] indices);
|
||||||
|
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("uint*") @StdVector IntPointer indices);
|
||||||
|
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("uint*") @StdVector IntBuffer indices);
|
||||||
|
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("uint*") @StdVector int[] indices);
|
||||||
|
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer createShapeInfo(@Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer stride, int rank);
|
@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer createShapeInfo(@Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer stride, int rank);
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer createShapeInfo(@Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer stride, int rank);
|
@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer createShapeInfo(@Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer stride, int rank);
|
||||||
|
@ -8027,12 +8036,13 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
|
|
||||||
// calculate offsets of max-array, these output offsets correspond to one minIdx index of min-array which is sub-array of max-array
|
// calculate offsets of max-array, these output offsets correspond to one minIdx index of min-array which is sub-array of max-array
|
||||||
// dimsToExclude - should be sorted in increasing order
|
// dimsToExclude - should be sorted in increasing order
|
||||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/);
|
// memBuff - auxiliary memory buffer (size = 2 * max_rank) for coordinates and increments storing, should be passed from outside
|
||||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo);
|
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") LongPointer memBuff, @Const IntPointer dimsToExclude/*=nullptr*/);
|
||||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/);
|
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") LongPointer memBuff);
|
||||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo);
|
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") LongBuffer memBuff, @Const IntBuffer dimsToExclude/*=nullptr*/);
|
||||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/);
|
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") LongBuffer memBuff);
|
||||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo);
|
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Cast("Nd4jLong*") long[] memBuff, @Const int[] dimsToExclude/*=nullptr*/);
|
||||||
|
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Cast("Nd4jLong*") long[] memBuff);
|
||||||
|
|
||||||
// calculates offsets for entities (elements or sub-arrays), shape in context of sub-array means dimensions excluded from outer array
|
// calculates offsets for entities (elements or sub-arrays), shape in context of sub-array means dimensions excluded from outer array
|
||||||
// rank is equal to size of shape
|
// rank is equal to size of shape
|
||||||
|
|
|
@ -4394,9 +4394,9 @@ public static class NativeOps extends org.nd4j.nativeblas.NativeOps {
|
||||||
* - down from main diagonal starting at subdiagonal number "lower" if direction = 'd' (down) or 'b' (both)
|
* - down from main diagonal starting at subdiagonal number "lower" if direction = 'd' (down) or 'b' (both)
|
||||||
* - up from main diagonal starting at superdiagonal number "upper"if direction = 'u' (up) or 'b' (both)
|
* - up from main diagonal starting at superdiagonal number "upper"if direction = 'u' (up) or 'b' (both)
|
||||||
* direction - in what direction to fill matrix. There are 3 possible directions:
|
* direction - in what direction to fill matrix. There are 3 possible directions:
|
||||||
* 'u' - fill up, mathematically this corresponds to lower triangular matrix, parameter "lower" is not taken into account
|
* 'u' - fill up, mathematically this corresponds to lower triangular matrix, subdiagonal "lower" unaffected
|
||||||
* 'l' - fill down, mathematically this corresponds to upper triangular matrix, parameter "upper" is not taken into account
|
* 'l' - fill down, mathematically this corresponds to upper triangular matrix, superdiagonal "upper" remains unaffected
|
||||||
* 'b' - fill in both directions, both parameters "lower" and "upper" are taken into account
|
* 'b' - fill in both directions, both "lower" and "upper" are taken into account
|
||||||
* rest of target elements are equal to this array elements
|
* rest of target elements are equal to this array elements
|
||||||
* target and this array should have same shapes, except when this_rank = 1 (in that case should be target_rank = 2)
|
* target and this array should have same shapes, except when this_rank = 1 (in that case should be target_rank = 2)
|
||||||
*/
|
*/
|
||||||
|
@ -7857,9 +7857,18 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
* @param indices the indices to iterate over
|
* @param indices the indices to iterate over
|
||||||
* @return the double at the specified index
|
* @return the double at the specified index
|
||||||
*/
|
*/
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("Nd4jLong") long baseOffset, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer stride, @Cast("const Nd4jLong*") LongPointer indices,int rank);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("Nd4jLong") long baseOffset, @Cast("const Nd4jLong*") LongPointer shape, @Cast("const Nd4jLong*") LongPointer stride, @Cast("const Nd4jLong*") LongPointer indices, int rank);
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("Nd4jLong") long baseOffset, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer stride, @Cast("const Nd4jLong*") LongBuffer indices,int rank);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("Nd4jLong") long baseOffset, @Cast("const Nd4jLong*") LongBuffer shape, @Cast("const Nd4jLong*") LongBuffer stride, @Cast("const Nd4jLong*") LongBuffer indices, int rank);
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("Nd4jLong") long baseOffset, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] stride, @Cast("const Nd4jLong*") long[] indices,int rank);
|
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("Nd4jLong") long baseOffset, @Cast("const Nd4jLong*") long[] shape, @Cast("const Nd4jLong*") long[] stride, @Cast("const Nd4jLong*") long[] indices, int rank);
|
||||||
|
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer indices, @Cast("Nd4jLong") long baseOffset/*=0*/);
|
||||||
|
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("const Nd4jLong*") LongPointer indices);
|
||||||
|
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer indices, @Cast("Nd4jLong") long baseOffset/*=0*/);
|
||||||
|
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("const Nd4jLong*") LongBuffer indices);
|
||||||
|
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] indices, @Cast("Nd4jLong") long baseOffset/*=0*/);
|
||||||
|
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("const Nd4jLong*") long[] indices);
|
||||||
|
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongPointer shapeInfo, @Cast("uint*") @StdVector IntPointer indices);
|
||||||
|
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") LongBuffer shapeInfo, @Cast("uint*") @StdVector IntBuffer indices);
|
||||||
|
@Namespace("shape") public static native @Cast("Nd4jLong") long getOffset(@Cast("const Nd4jLong*") long[] shapeInfo, @Cast("uint*") @StdVector int[] indices);
|
||||||
|
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer createShapeInfo(@Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer stride, int rank);
|
@Namespace("shape") public static native @Cast("Nd4jLong*") LongPointer createShapeInfo(@Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer stride, int rank);
|
||||||
@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer createShapeInfo(@Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer stride, int rank);
|
@Namespace("shape") public static native @Cast("Nd4jLong*") LongBuffer createShapeInfo(@Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer stride, int rank);
|
||||||
|
@ -8027,12 +8036,13 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
|
|
||||||
// calculate offsets of max-array, these output offsets correspond to one minIdx index of min-array which is sub-array of max-array
|
// calculate offsets of max-array, these output offsets correspond to one minIdx index of min-array which is sub-array of max-array
|
||||||
// dimsToExclude - should be sorted in increasing order
|
// dimsToExclude - should be sorted in increasing order
|
||||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Const IntPointer dimsToExclude/*=nullptr*/);
|
// memBuff - auxiliary memory buffer (size = 2 * max_rank) for coordinates and increments storing, should be passed from outside
|
||||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo);
|
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") LongPointer memBuff, @Const IntPointer dimsToExclude/*=nullptr*/);
|
||||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Const IntBuffer dimsToExclude/*=nullptr*/);
|
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongPointer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongPointer maxShapeInfo, @Cast("const Nd4jLong*") LongPointer minShapeInfo, @Cast("Nd4jLong*") LongPointer memBuff);
|
||||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo);
|
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") LongBuffer memBuff, @Const IntBuffer dimsToExclude/*=nullptr*/);
|
||||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Const int[] dimsToExclude/*=nullptr*/);
|
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") LongBuffer maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") LongBuffer maxShapeInfo, @Cast("const Nd4jLong*") LongBuffer minShapeInfo, @Cast("Nd4jLong*") LongBuffer memBuff);
|
||||||
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo);
|
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Cast("Nd4jLong*") long[] memBuff, @Const int[] dimsToExclude/*=nullptr*/);
|
||||||
|
@Namespace("shape") public static native int outerArrayOffsets(@Cast("Nd4jLong*") long[] maxOffsets, @Cast("const Nd4jLong") long minIdx, @Cast("const Nd4jLong*") long[] maxShapeInfo, @Cast("const Nd4jLong*") long[] minShapeInfo, @Cast("Nd4jLong*") long[] memBuff);
|
||||||
|
|
||||||
// calculates offsets for entities (elements or sub-arrays), shape in context of sub-array means dimensions excluded from outer array
|
// calculates offsets for entities (elements or sub-arrays), shape in context of sub-array means dimensions excluded from outer array
|
||||||
// rank is equal to size of shape
|
// rank is equal to size of shape
|
||||||
|
|
|
@ -28,6 +28,7 @@ import org.junit.rules.TemporaryFolder;
|
||||||
import org.junit.runner.RunWith;
|
import org.junit.runner.RunWith;
|
||||||
import org.junit.runners.Parameterized;
|
import org.junit.runners.Parameterized;
|
||||||
import org.nd4j.imports.TFGraphs.NodeReader;
|
import org.nd4j.imports.TFGraphs.NodeReader;
|
||||||
|
import org.nd4j.linalg.api.blas.Level1;
|
||||||
import org.nd4j.linalg.api.blas.params.GemmParams;
|
import org.nd4j.linalg.api.blas.params.GemmParams;
|
||||||
import org.nd4j.linalg.api.blas.params.MMulTranspose;
|
import org.nd4j.linalg.api.blas.params.MMulTranspose;
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||||
|
@ -118,6 +119,7 @@ import static org.junit.Assert.assertArrayEquals;
|
||||||
public class Nd4jTestsC extends BaseNd4jTest {
|
public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
|
|
||||||
DataType initialType;
|
DataType initialType;
|
||||||
|
Level1 l1;
|
||||||
|
|
||||||
@Rule
|
@Rule
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
public TemporaryFolder testDir = new TemporaryFolder();
|
||||||
|
@ -125,6 +127,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
public Nd4jTestsC(Nd4jBackend backend) {
|
public Nd4jTestsC(Nd4jBackend backend) {
|
||||||
super(backend);
|
super(backend);
|
||||||
this.initialType = Nd4j.dataType();
|
this.initialType = Nd4j.dataType();
|
||||||
|
l1 = Nd4j.getBlasWrapper().level1();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -431,7 +434,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMmulOp() {
|
public void testMmulOp() throws Exception {
|
||||||
INDArray arr = Nd4j.create(new double[][] {{1, 2, 3}, {4, 5, 6}});
|
INDArray arr = Nd4j.create(new double[][] {{1, 2, 3}, {4, 5, 6}});
|
||||||
INDArray z = Nd4j.create(2, 2);
|
INDArray z = Nd4j.create(2, 2);
|
||||||
INDArray assertion = Nd4j.create(new double[][] {{14, 32}, {32, 77}});
|
INDArray assertion = Nd4j.create(new double[][] {{14, 32}, {32, 77}});
|
||||||
|
@ -2797,15 +2800,21 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testDot() {
|
public void testDot() throws Exception {
|
||||||
INDArray vec1 = Nd4j.create(new float[] {1, 2, 3, 4});
|
INDArray vec1 = Nd4j.create(new float[] {1, 2, 3, 4});
|
||||||
INDArray vec2 = Nd4j.create(new float[] {1, 2, 3, 4});
|
INDArray vec2 = Nd4j.create(new float[] {1, 2, 3, 4});
|
||||||
|
|
||||||
|
assertEquals(10.f, vec1.sumNumber().floatValue(), 1e-5);
|
||||||
|
assertEquals(10.f, vec2.sumNumber().floatValue(), 1e-5);
|
||||||
|
|
||||||
assertEquals(30, Nd4j.getBlasWrapper().dot(vec1, vec2), 1e-1);
|
assertEquals(30, Nd4j.getBlasWrapper().dot(vec1, vec2), 1e-1);
|
||||||
|
|
||||||
INDArray matrix = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
INDArray matrix = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
||||||
INDArray row = matrix.getRow(1);
|
INDArray row = matrix.getRow(1);
|
||||||
assertEquals(25, Nd4j.getBlasWrapper().dot(row, row), 1e-1);
|
|
||||||
|
|
||||||
|
assertEquals(7.0f, row.sumNumber().floatValue(), 1e-5f);
|
||||||
|
|
||||||
|
assertEquals(25, Nd4j.getBlasWrapper().dot(row, row), 1e-1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -2815,8 +2824,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
assertTrue(Arrays.equals(new long[] {5, 5}, eye.shape()));
|
assertTrue(Arrays.equals(new long[] {5, 5}, eye.shape()));
|
||||||
eye = Nd4j.eye(5);
|
eye = Nd4j.eye(5);
|
||||||
assertTrue(Arrays.equals(new long[] {5, 5}, eye.shape()));
|
assertTrue(Arrays.equals(new long[] {5, 5}, eye.shape()));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -4224,6 +4231,8 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
assertEquals(10, matrix.rows());
|
assertEquals(10, matrix.rows());
|
||||||
assertEquals(6, matrix.columns());
|
assertEquals(6, matrix.columns());
|
||||||
|
|
||||||
|
log.info("Result: {}", matrix);
|
||||||
|
|
||||||
for (int x = 0; x < 10; x++) {
|
for (int x = 0; x < 10; x++) {
|
||||||
assertEquals((double) x, matrix.getRow(x).meanNumber().doubleValue(), 0.1);
|
assertEquals((double) x, matrix.getRow(x).meanNumber().doubleValue(), 0.1);
|
||||||
assertEquals(arrays.get(x), matrix.getRow(x).reshape(1, matrix.size(1)));
|
assertEquals(arrays.get(x), matrix.getRow(x).reshape(1, matrix.size(1)));
|
||||||
|
|
|
@ -32,6 +32,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
|
|
||||||
|
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
import java.nio.ByteOrder;
|
import java.nio.ByteOrder;
|
||||||
import java.nio.ShortBuffer;
|
import java.nio.ShortBuffer;
|
||||||
|
|
Loading…
Reference in New Issue