[WIP] More of CUDA (#95)
* initial commit Signed-off-by: raver119 <raver119@gmail.com> * Implementation of hashcode cuda helper. Working edition. * Fixed parallel test input arangements. * Fixed tests for hashcode op. * Fixed shape calculation for image:crop_and_resize op and test. * NativeOps tests. Initial test suite. * Added tests for indexReduce methods. * Added test on execBroadcast with NDArray as dimensions. * Added test on execBroadcastBool with NDArray as dimensions. * Added tests on execPairwiseTransform and execPairwiseTransofrmBool. * Added tests for execReduce with scalar results. * Added reduce tests for non-empty dims array. * Added tests for reduce3. * Added tests for execScalar. * Added tests for execSummaryStats. * - provide cpu/cuda code for batch_to_space - testing it Signed-off-by: Yurii <yurii@skymind.io> * - remove old test for batch_to_space (had wrong format and numbers were not checked) Signed-off-by: Yurii <yurii@skymind.io> * Fixed complilation errors with test. * Added test for execTransformFloat. * Added test for execTransformSame. * Added test for execTransformBool. * Added test for execTransformStrict. * Added tests for execScalar/execScalarBool with TADs. * Added test for flatten. * - provide cpu/cuda code for space_to_Batch operaion Signed-off-by: Yurii <yurii@skymind.io> * Added test for concat. * comment unnecessary stuff in s_t_b Signed-off-by: Yurii <yurii@skymind.io> * Added test for specialConcat. * Added tests for memcpy/set routines. * Fixed pullRow cuda test. * Added pullRow test. * Added average test. * - correct typo in NDArray::applyPairwiseTransform(nd4j::pairwise::BoolOps op...) Signed-off-by: Yurii <yurii@skymind.io> * - debugging and fixing cuda tests in JavaInteropTests file Signed-off-by: Yurii <yurii@skymind.io> * - correct some tests Signed-off-by: Yurii <yurii@skymind.io> * Added test for shuffle. * Fixed ops declarations. * Restored omp and added shuffle test. * Added convertTypes test. * Added tests for execRandom. Eliminated usage of RandomBuffer with NativeOps. * Added sort tests. * Added tests for execCustomOp. * - further debuging and fixing tests terminated with crash Signed-off-by: Yurii <yurii@skymind.io> * Added tests for calculateOutputShapes. * Addded Benchmarks test. * Commented benchmark tests. * change assertion Signed-off-by: raver119 <raver119@gmail.com> * Added tests for apply_sgd op. Added cpu helper for that op. * Implement cuda helper for aplly_sgd op. Fixed tests for NativeOps. * Added test for assign broadcastable. * Added tests for assign_bp op. * Added tests for axpy op. * - assign/execScalar/execTransformAny signature change - minor test fix Signed-off-by: raver119 <raver119@gmail.com> * Fixed axpy op. * meh Signed-off-by: raver119 <raver119@gmail.com> * - fix tests for nativeOps::concat Signed-off-by: Yurii <yurii@skymind.io> * sequential transform/scalar Signed-off-by: raver119 <raver119@gmail.com> * allow nested parallelism Signed-off-by: raver119 <raver119@gmail.com> * assign_bp leak fix Signed-off-by: raver119 <raver119@gmail.com> * block setRNG fix Signed-off-by: raver119 <raver119@gmail.com> * enable parallelism by default Signed-off-by: raver119 <raver119@gmail.com> * enable nested parallelism by default Signed-off-by: raver119 <raver119@gmail.com> * Added cuda implementation for row_count helper. * Added implementation for tnse gains op helper. * - take into account possible situations when input arrays are empty in reduce_ cuda stuff Signed-off-by: Yurii <yurii@skymind.io> * Implemented tsne/edge_forces op cuda-based helper. Parallelized cpu-based helper for edge_forces. * Added kernel for tsne/symmetrized op heleper. * Implementation of tsne/symmetrized op cuda helper. Working edition. * Eliminated waste printfs. * Added test for broadcastgradientargs op. * host-only fallback for empty reduce float Signed-off-by: raver119 <raver119@gmail.com> * - some tests fixes Signed-off-by: Yurii <yurii@skymind.io> * - correct the rest of reduce_ stuff Signed-off-by: Yurii <yurii@skymind.io> * - further correction of reduce_ stuff Signed-off-by: Yurii <yurii@skymind.io> * Added test for Cbow op. Also added cuda implementation for cbow helpers. * - improve code of stack operation for scalar case Signed-off-by: Yurii <yurii@skymind.io> * - provide cuda kernel for gatherND operation Signed-off-by: Yurii <yurii@skymind.io> * Implementation of cbow helpers with cuda kernels. * minor tests tweaks Signed-off-by: raver119 <raver119@gmail.com> * minor tests tweaks Signed-off-by: raver119 <raver119@gmail.com> * - further correction of cuda stuff Signed-off-by: Yurii <yurii@skymind.io> * Implementatation of cbow op helper with cuda kernels. Working edition. * Skip random testing for cudablas case. * lstmBlockCell context fix Signed-off-by: raver119 <raver119@gmail.com> * Added tests for ELU and ELU_BP ops. * Added tests for eq_scalar, gt_scalar, gte_scalar and lte_scalar ops. * Added tests for neq_scalar. * Added test for noop. * - further work on clipbynorm_bp Signed-off-by: Yurii <yurii@skymind.io> * - get rid of concat op call, use instead direct concat helper call Signed-off-by: Yurii <yurii@skymind.io> * lstmBlockCell context fix Signed-off-by: raver119 <raver119@gmail.com> * Added tests for lrelu and lrelu_bp. * Added tests for selu and selu_bp. * Fixed lrelu derivative helpers. * - some corrections in lstm Signed-off-by: Yurii <yurii@skymind.io> * operator * result shape fix Signed-off-by: raver119 <raver119@gmail.com> * - correct typo in lstmCell Signed-off-by: Yurii <yurii@skymind.io> * few tests fixed Signed-off-by: raver119 <raver119@gmail.com> * CUDA inverse broadcast bool fix Signed-off-by: raver119 <raver119@gmail.com> * disable MMAP test for CUDA Signed-off-by: raver119 <raver119@gmail.com> * BooleanOp syncToDevice Signed-off-by: raver119 <raver119@gmail.com> * meh Signed-off-by: raver119 <raver119@gmail.com> * additional data types for im2col/col2im Signed-off-by: raver119 <raver119@gmail.com> * Added test for firas_sparse op. * one more RandomBuffer test excluded Signed-off-by: raver119 <raver119@gmail.com> * Added tests for flatten op. * Added test for Floor op. * bunch of tests fixed Signed-off-by: raver119 <raver119@gmail.com> * mmulDot tests fixed Signed-off-by: raver119 <raver119@gmail.com> * more tests fixed Signed-off-by: raver119 <raver119@gmail.com> * Implemented floordiv_bp op and tests. * Fixed scalar case with cuda implementation for bds. * - work on cuda kernel for clip_by_norm backprop op is completed Signed-off-by: Yurii <yurii@skymind.io> * Eliminate cbow crach. * more tests fixed Signed-off-by: raver119 <raver119@gmail.com> * more tests fixed Signed-off-by: raver119 <raver119@gmail.com> * Eliminated abortion with batched nlp test. * more tests fixed Signed-off-by: raver119 <raver119@gmail.com> * Fixed shared flag initializing. * disabled bunch of cpu workspaces tests Signed-off-by: raver119 <raver119@gmail.com> * scalar operators fix: missing registerSpecialUse call Signed-off-by: raver119 <raver119@gmail.com> * Fixed logdet for cuda and tests. * - correct clipBynorm_bp Signed-off-by: Yurii <yurii@skymind.io> * Fixed crop_and_resize shape datatype. * - correct some mmul tests Signed-off-by: Yurii <yurii@skymind.io>
This commit is contained in:
parent
e565788329
commit
3c4e959e21
@ -25,6 +25,12 @@
|
|||||||
#include "Environment.h"
|
#include "Environment.h"
|
||||||
#include <helpers/StringUtils.h>
|
#include <helpers/StringUtils.h>
|
||||||
|
|
||||||
|
#ifdef _OPENMP
|
||||||
|
|
||||||
|
#include <omp.h>
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifdef __CUDABLAS__
|
#ifdef __CUDABLAS__
|
||||||
|
|
||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
@ -77,6 +83,10 @@ namespace nd4j {
|
|||||||
|
|
||||||
cudaSetDevice(0);
|
cudaSetDevice(0);
|
||||||
delete[] devProperties;
|
delete[] devProperties;
|
||||||
|
#else
|
||||||
|
#ifdef _OPENMP
|
||||||
|
omp_set_nested(1);
|
||||||
|
#endif
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -266,11 +266,11 @@ namespace nd4j {
|
|||||||
* @param writeList
|
* @param writeList
|
||||||
* @param readList
|
* @param readList
|
||||||
*/
|
*/
|
||||||
static void registerSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList);
|
static void registerSpecialUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList);
|
||||||
static void prepareSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables = false);
|
static void prepareSpecialUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList, bool synchronizeWritables = false);
|
||||||
|
|
||||||
static void registerPrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList);
|
static void registerPrimaryUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList);
|
||||||
static void preparePrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables = false);
|
static void preparePrimaryUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList, bool synchronizeWritables = false);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method returns buffer pointer offset by given number of elements, wrt own data type
|
* This method returns buffer pointer offset by given number of elements, wrt own data type
|
||||||
@ -495,29 +495,29 @@ namespace nd4j {
|
|||||||
/**
|
/**
|
||||||
* this method assigns values of given array to this one
|
* this method assigns values of given array to this one
|
||||||
*/
|
*/
|
||||||
void assign(const NDArray* other);
|
void assign(const NDArray* other, bool allowParallelism = true);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* this method assigns values of given array to this one
|
* this method assigns values of given array to this one
|
||||||
*/
|
*/
|
||||||
void assign(const NDArray& other);
|
void assign(const NDArray& other, bool allowParallelism = true);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* this method assigns given value to all elements in array
|
* this method assigns given value to all elements in array
|
||||||
*/
|
*/
|
||||||
void assign(const double value);
|
void assign(const double value, bool allowParallelism = true);
|
||||||
void assign(const float value);
|
void assign(const float value, bool allowParallelism = true);
|
||||||
void assign(const float16 value);
|
void assign(const float16 value, bool allowParallelism = true);
|
||||||
void assign(const bfloat16& value);
|
void assign(const bfloat16& value, bool allowParallelism = true);
|
||||||
void assign(const Nd4jLong value);
|
void assign(const Nd4jLong value, bool allowParallelism = true);
|
||||||
void assign(const int value);
|
void assign(const int value, bool allowParallelism = true);
|
||||||
void assign(const int16_t value);
|
void assign(const int16_t value, bool allowParallelism = true);
|
||||||
void assign(const uint8_t value);
|
void assign(const uint8_t value, bool allowParallelism = true);
|
||||||
void assign(const uint16_t value);
|
void assign(const uint16_t value, bool allowParallelism = true);
|
||||||
void assign(const uint32_t value);
|
void assign(const uint32_t value, bool allowParallelism = true);
|
||||||
void assign(const uint64_t value);
|
void assign(const uint64_t value, bool allowParallelism = true);
|
||||||
void assign(const int8_t value);
|
void assign(const int8_t value, bool allowParallelism = true);
|
||||||
void assign(const bool value);
|
void assign(const bool value, bool allowParallelism = true);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* returns new copy of this array, optionally in different order
|
* returns new copy of this array, optionally in different order
|
||||||
|
@ -24,6 +24,7 @@
|
|||||||
#include <ConstantShapeHelper.h>
|
#include <ConstantShapeHelper.h>
|
||||||
#include <ConstantTadHelper.h>
|
#include <ConstantTadHelper.h>
|
||||||
#include <BroadcastPairwiseConverter.h>
|
#include <BroadcastPairwiseConverter.h>
|
||||||
|
#include <helpers/PointersManager.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
|
|
||||||
@ -573,13 +574,13 @@ void NDArray::copyBuffersContinuouslyFrom(const NDArray& other, size_t sizeToCop
|
|||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
// This method assigns values of given NDArray to this one, wrt order
|
// This method assigns values of given NDArray to this one, wrt order
|
||||||
void NDArray::assign(const NDArray *other) {
|
void NDArray::assign(const NDArray *other, bool allowParallelism) {
|
||||||
assign(*other);
|
assign(*other, allowParallelism);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
// This method assigns given value to all elements in this NDArray
|
// This method assigns given value to all elements in this NDArray
|
||||||
void NDArray::assign(const double value) {
|
void NDArray::assign(const double value, bool allowParallelism) {
|
||||||
// just fire scalar
|
// just fire scalar
|
||||||
auto temp = NDArrayFactory::create(dataType(), value, this->getContext());
|
auto temp = NDArrayFactory::create(dataType(), value, this->getContext());
|
||||||
|
|
||||||
@ -589,122 +590,122 @@ void NDArray::assign(const double value) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::assign(const float value) {
|
void NDArray::assign(const float value, bool allowParallelism) {
|
||||||
// just fire scalar
|
// just fire scalar
|
||||||
auto temp = NDArrayFactory::create(dataType(), value, this->getContext());
|
auto temp = NDArrayFactory::create(dataType(), value, this->getContext());
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({this}, {&temp});
|
NDArray::prepareSpecialUse({this}, {&temp});
|
||||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr, allowParallelism);
|
||||||
NDArray::registerSpecialUse({this}, {&temp});
|
NDArray::registerSpecialUse({this}, {&temp});
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::assign(const float16 value) {
|
void NDArray::assign(const float16 value, bool allowParallelism) {
|
||||||
// just fire scalar
|
// just fire scalar
|
||||||
auto temp = NDArrayFactory::create(value, this->getContext());
|
auto temp = NDArrayFactory::create(value, this->getContext());
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({this}, {&temp});
|
NDArray::prepareSpecialUse({this}, {&temp});
|
||||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr, allowParallelism);
|
||||||
NDArray::registerSpecialUse({this}, {&temp});
|
NDArray::registerSpecialUse({this}, {&temp});
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::assign(const bfloat16& value) {
|
void NDArray::assign(const bfloat16& value, bool allowParallelism) {
|
||||||
// just fire scalar
|
// just fire scalar
|
||||||
auto temp = NDArrayFactory::create(value, this->getContext());
|
auto temp = NDArrayFactory::create(value, this->getContext());
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({this}, {&temp});
|
NDArray::prepareSpecialUse({this}, {&temp});
|
||||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr, allowParallelism);
|
||||||
NDArray::registerSpecialUse({this}, {&temp});
|
NDArray::registerSpecialUse({this}, {&temp});
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::assign(const Nd4jLong value) {
|
void NDArray::assign(const Nd4jLong value, bool allowParallelism) {
|
||||||
// just fire scalar
|
// just fire scalar
|
||||||
auto temp = NDArrayFactory::create(value, this->getContext());
|
auto temp = NDArrayFactory::create(value, this->getContext());
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({this}, {&temp});
|
NDArray::prepareSpecialUse({this}, {&temp});
|
||||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr, allowParallelism);
|
||||||
NDArray::registerSpecialUse({this}, {&temp});
|
NDArray::registerSpecialUse({this}, {&temp});
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::assign(const int value) {
|
void NDArray::assign(const int value, bool allowParallelism) {
|
||||||
// just fire scalar
|
// just fire scalar
|
||||||
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
|
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({this}, {&temp});
|
NDArray::prepareSpecialUse({this}, {&temp});
|
||||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), _shapeInfo, specialBuffer(), _shapeInfoD, buffer(), _shapeInfo, specialBuffer(), _shapeInfoD, temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp._shapeInfoD, nullptr);
|
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), _shapeInfo, specialBuffer(), _shapeInfoD, buffer(), _shapeInfo, specialBuffer(), _shapeInfoD, temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp._shapeInfoD, nullptr, allowParallelism);
|
||||||
NDArray::registerSpecialUse({this}, {&temp});
|
NDArray::registerSpecialUse({this}, {&temp});
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::assign(const int16_t value) {
|
void NDArray::assign(const int16_t value, bool allowParallelism) {
|
||||||
// just fire scalar
|
// just fire scalar
|
||||||
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
|
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({this}, {&temp});
|
NDArray::prepareSpecialUse({this}, {&temp});
|
||||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), _shapeInfo, specialBuffer(), _shapeInfoD, buffer(), _shapeInfo, specialBuffer(), _shapeInfoD, temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp._shapeInfoD, nullptr);
|
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), _shapeInfo, specialBuffer(), _shapeInfoD, buffer(), _shapeInfo, specialBuffer(), _shapeInfoD, temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp._shapeInfoD, nullptr, allowParallelism);
|
||||||
NDArray::registerSpecialUse({this}, {&temp});
|
NDArray::registerSpecialUse({this}, {&temp});
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::assign(const uint8_t value) {
|
void NDArray::assign(const uint8_t value, bool allowParallelism) {
|
||||||
// just fire scalar
|
// just fire scalar
|
||||||
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
|
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({this}, {&temp});
|
NDArray::prepareSpecialUse({this}, {&temp});
|
||||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr, allowParallelism);
|
||||||
NDArray::registerSpecialUse({this}, {&temp});
|
NDArray::registerSpecialUse({this}, {&temp});
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::assign(const uint16_t value) {
|
void NDArray::assign(const uint16_t value, bool allowParallelism) {
|
||||||
// just fire scalar
|
// just fire scalar
|
||||||
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
|
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({this}, {&temp});
|
NDArray::prepareSpecialUse({this}, {&temp});
|
||||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr, allowParallelism);
|
||||||
NDArray::registerSpecialUse({this}, {&temp});
|
NDArray::registerSpecialUse({this}, {&temp});
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::assign(const uint32_t value) {
|
void NDArray::assign(const uint32_t value, bool allowParallelism) {
|
||||||
// just fire scalar
|
// just fire scalar
|
||||||
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
|
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({this}, {&temp});
|
NDArray::prepareSpecialUse({this}, {&temp});
|
||||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr, allowParallelism);
|
||||||
NDArray::registerSpecialUse({this}, {&temp});
|
NDArray::registerSpecialUse({this}, {&temp});
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::assign(const uint64_t value) {
|
void NDArray::assign(const uint64_t value, bool allowParallelism) {
|
||||||
// just fire scalar
|
// just fire scalar
|
||||||
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
|
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({this}, {&temp});
|
NDArray::prepareSpecialUse({this}, {&temp});
|
||||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr, allowParallelism);
|
||||||
NDArray::registerSpecialUse({this}, {&temp});
|
NDArray::registerSpecialUse({this}, {&temp});
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::assign(const int8_t value) {
|
void NDArray::assign(const int8_t value, bool allowParallelism) {
|
||||||
// just fire scalar
|
// just fire scalar
|
||||||
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
|
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({this}, {&temp});
|
NDArray::prepareSpecialUse({this}, {&temp});
|
||||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr, allowParallelism);
|
||||||
NDArray::registerSpecialUse({this}, {&temp});
|
NDArray::registerSpecialUse({this}, {&temp});
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::assign(const bool value) {
|
void NDArray::assign(const bool value, bool allowParallelism) {
|
||||||
// just fire scalar
|
// just fire scalar
|
||||||
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
|
auto temp = NDArrayFactory::create(this->dataType(), value, this->getContext());
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({this}, {&temp});
|
NDArray::prepareSpecialUse({this}, {&temp});
|
||||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::CopyPws, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.getSpecialShapeInfo(), nullptr, allowParallelism);
|
||||||
NDArray::registerSpecialUse({this}, {&temp});
|
NDArray::registerSpecialUse({this}, {&temp});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1167,15 +1168,15 @@ static void printFormatted(NDArray const* arr, int depth, int limit) {
|
|||||||
Nd4jLong colLimit = cols > limit?cols:limit;
|
Nd4jLong colLimit = cols > limit?cols:limit;
|
||||||
for (Nd4jLong col = 0; col < colLimit; ++col) {
|
for (Nd4jLong col = 0; col < colLimit; ++col) {
|
||||||
if (col)
|
if (col)
|
||||||
printf(" ");
|
printf(", ");
|
||||||
if (arr->isR())
|
if (arr->isR())
|
||||||
printf("%f,", arr->e<float>(row, col));
|
printf("%f", arr->e<float>(row, col));
|
||||||
else if (arr->isZ())
|
else if (arr->isZ())
|
||||||
printf("%lld,", arr->e<Nd4jLong>(row, col));
|
printf("%lld", arr->e<Nd4jLong>(row, col));
|
||||||
else if (arr->isB())
|
else if (arr->isB())
|
||||||
printf("%s,", arr->e<bool>(row, col)?"true":"false");
|
printf("%s", arr->e<bool>(row, col)?"true":"false");
|
||||||
else if (arr->isS())
|
else if (arr->isS())
|
||||||
printf("\"%s\",", arr->e<std::string>(row * cols + col).c_str());
|
printf("\"%s\"", arr->e<std::string>(row * cols + col).c_str());
|
||||||
}
|
}
|
||||||
if (row < rows - 1)
|
if (row < rows - 1)
|
||||||
printf("]\n");
|
printf("]\n");
|
||||||
@ -2190,7 +2191,12 @@ void NDArray::operator+=(const T value) {
|
|||||||
throw std::runtime_error("NDArray::operator+=: you can't use this method on String array!");
|
throw std::runtime_error("NDArray::operator+=: you can't use this method on String array!");
|
||||||
|
|
||||||
auto other = NDArrayFactory::create(this->dataType(), value, getContext());
|
auto other = NDArrayFactory::create(this->dataType(), value, getContext());
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({this}, {&other});
|
||||||
|
|
||||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Add, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Add, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({this}, {});
|
||||||
}
|
}
|
||||||
template void NDArray::operator+=(const double value);
|
template void NDArray::operator+=(const double value);
|
||||||
template void NDArray::operator+=(const float value);
|
template void NDArray::operator+=(const float value);
|
||||||
@ -2207,7 +2213,12 @@ void NDArray::operator-=(const T value) {
|
|||||||
throw std::runtime_error("NDArray::operator-=: you can't use this method on String array!");
|
throw std::runtime_error("NDArray::operator-=: you can't use this method on String array!");
|
||||||
|
|
||||||
auto other = NDArrayFactory::create(dataType(), value, getContext());
|
auto other = NDArrayFactory::create(dataType(), value, getContext());
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({this}, {&other});
|
||||||
|
|
||||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Subtract, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Subtract, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({this}, {});
|
||||||
}
|
}
|
||||||
template void NDArray::operator-=(const double value);
|
template void NDArray::operator-=(const double value);
|
||||||
template void NDArray::operator-=(const float value);
|
template void NDArray::operator-=(const float value);
|
||||||
@ -2224,7 +2235,10 @@ void NDArray::operator*=(const T scalar) {
|
|||||||
throw std::runtime_error("NDArray::operator*=: you can't use this method on String array!");
|
throw std::runtime_error("NDArray::operator*=: you can't use this method on String array!");
|
||||||
|
|
||||||
auto other = NDArrayFactory::create(this->dataType(), scalar, getContext());
|
auto other = NDArrayFactory::create(this->dataType(), scalar, getContext());
|
||||||
|
NDArray::prepareSpecialUse({this}, {&other});
|
||||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Multiply, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Multiply, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({this}, {});
|
||||||
}
|
}
|
||||||
template void NDArray::operator*=(const double scalar);
|
template void NDArray::operator*=(const double scalar);
|
||||||
template void NDArray::operator*=(const float scalar);
|
template void NDArray::operator*=(const float scalar);
|
||||||
@ -2244,7 +2258,9 @@ void NDArray::operator/=(const T scalar) {
|
|||||||
throw std::runtime_error("NDArray::operator/=: you can't use this method on String array!");
|
throw std::runtime_error("NDArray::operator/=: you can't use this method on String array!");
|
||||||
|
|
||||||
auto other = NDArrayFactory::create(this->dataType(), scalar, getContext());
|
auto other = NDArrayFactory::create(this->dataType(), scalar, getContext());
|
||||||
|
NDArray::prepareSpecialUse({this}, {&other});
|
||||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Divide, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Divide, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
|
||||||
|
NDArray::registerSpecialUse({this}, {});
|
||||||
}
|
}
|
||||||
template void NDArray::operator/=(const double scalar);
|
template void NDArray::operator/=(const double scalar);
|
||||||
template void NDArray::operator/=(const float scalar);
|
template void NDArray::operator/=(const float scalar);
|
||||||
@ -2287,11 +2303,13 @@ NDArray NDArray::operator*(const NDArray& other) const {
|
|||||||
if (!Environment::getInstance()->isExperimentalBuild() && this->dataType() != other.dataType() && (this->dataType() != DataType::BOOL || other.dataType() != BOOL))
|
if (!Environment::getInstance()->isExperimentalBuild() && this->dataType() != other.dataType() && (this->dataType() != DataType::BOOL || other.dataType() != BOOL))
|
||||||
throw nd4j::datatype_exception::build("NDArray operator*: Cannot multiply different types", this->dataType(), other.dataType());
|
throw nd4j::datatype_exception::build("NDArray operator*: Cannot multiply different types", this->dataType(), other.dataType());
|
||||||
|
|
||||||
|
PointersManager pointersManager(getContext(), "operator *");
|
||||||
if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
|
if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
|
||||||
NDArray result(getShapeInfo(), DataTypeUtils::pickPairwiseResultType(getShapeInfo(), other.getShapeInfo()), false, this->getContext());
|
NDArray result(getShapeInfo(), DataTypeUtils::pickPairwiseResultType(getShapeInfo(), other.getShapeInfo()), false, this->getContext());
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({&result}, {this, &other});
|
NDArray::prepareSpecialUse({&result}, {this, &other});
|
||||||
NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Multiply, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), getSpecialShapeInfo(), nullptr);
|
|
||||||
|
NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Multiply, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), nullptr);
|
||||||
NDArray::registerSpecialUse({&result}, {this, &other});
|
NDArray::registerSpecialUse({&result}, {this, &other});
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
@ -2973,7 +2991,7 @@ void NDArray::applyPairwiseTransform(nd4j::pairwise::BoolOps op, const NDArray *
|
|||||||
throw std::invalid_argument("NDArray::applyPairwiseTransform BoolOps method - this and other arrays must have the same type !");
|
throw std::invalid_argument("NDArray::applyPairwiseTransform BoolOps method - this and other arrays must have the same type !");
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({target}, {this, other});
|
NDArray::prepareSpecialUse({target}, {this, other});
|
||||||
NativeOpExecutioner::execPairwiseBoolTransform(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr);
|
NativeOpExecutioner::execPairwiseBoolTransform(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr);
|
||||||
NDArray::registerSpecialUse({target}, {this, other});
|
NDArray::registerSpecialUse({target}, {this, other});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3050,7 +3068,7 @@ NDArray* NDArray::varianceAlongDimension(nd4j::variance::Ops op, const bool bias
|
|||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////
|
||||||
// This method assigns values of given NDArray to this one
|
// This method assigns values of given NDArray to this one
|
||||||
void NDArray::assign(const NDArray& other) {
|
void NDArray::assign(const NDArray& other, bool allowParallelism) {
|
||||||
|
|
||||||
if (this == &other)
|
if (this == &other)
|
||||||
return;
|
return;
|
||||||
@ -3082,13 +3100,13 @@ void NDArray::assign(const NDArray& other) {
|
|||||||
if (dataType() != other.dataType()) {
|
if (dataType() != other.dataType()) {
|
||||||
auto tmp = other.cast(dataType());
|
auto tmp = other.cast(dataType());
|
||||||
NDArray::prepareSpecialUse({this}, {tmp});
|
NDArray::prepareSpecialUse({this}, {tmp});
|
||||||
NativeOpExecutioner::execScalar(getContext(), scalar::CopyPws, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), tmp->getBuffer(), tmp->getShapeInfo(), tmp->getSpecialBuffer(), tmp->getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execScalar(getContext(), scalar::CopyPws, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), tmp->getBuffer(), tmp->getShapeInfo(), tmp->getSpecialBuffer(), tmp->getSpecialShapeInfo(), nullptr, allowParallelism);
|
||||||
NDArray::registerSpecialUse({this}, {});
|
NDArray::registerSpecialUse({this}, {});
|
||||||
delete tmp;
|
delete tmp;
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
NDArray::prepareSpecialUse({this}, {&other});
|
NDArray::prepareSpecialUse({this}, {&other});
|
||||||
NativeOpExecutioner::execScalar(getContext(), scalar::CopyPws, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execScalar(getContext(), scalar::CopyPws, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr, allowParallelism);
|
||||||
NDArray::registerSpecialUse({this}, {&other});
|
NDArray::registerSpecialUse({this}, {&other});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -3106,7 +3124,7 @@ void NDArray::assign(const NDArray& other) {
|
|||||||
copyBuffersContinuouslyFrom(other, other.lengthOf() * other.sizeOfT());
|
copyBuffersContinuouslyFrom(other, other.lengthOf() * other.sizeOfT());
|
||||||
else {
|
else {
|
||||||
NDArray::prepareSpecialUse({this}, {&other});
|
NDArray::prepareSpecialUse({this}, {&other});
|
||||||
NativeOpExecutioner::execTransformAny(getContext(), transform::Assign, other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, nullptr, nullptr);
|
NativeOpExecutioner::execTransformAny(getContext(), transform::Assign, other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, nullptr, nullptr, allowParallelism);
|
||||||
NDArray::registerSpecialUse({this}, {&other});
|
NDArray::registerSpecialUse({this}, {&other});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -177,7 +177,7 @@ public:
|
|||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
void *hScalar, Nd4jLong *hSscalarShapeInfo,
|
void *hScalar, Nd4jLong *hSscalarShapeInfo,
|
||||||
void *dScalar, Nd4jLong *dSscalarShapeInfo,
|
void *dScalar, Nd4jLong *dSscalarShapeInfo,
|
||||||
void *extraParams);
|
void *extraParams, bool allowParallelism = true);
|
||||||
|
|
||||||
static void execScalarBool(nd4j::LaunchContext *lc,
|
static void execScalarBool(nd4j::LaunchContext *lc,
|
||||||
int opNum,
|
int opNum,
|
||||||
@ -187,7 +187,7 @@ static void execScalarBool(nd4j::LaunchContext *lc,
|
|||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
void *hScalar, Nd4jLong *hSscalarShapeInfo,
|
void *hScalar, Nd4jLong *hSscalarShapeInfo,
|
||||||
void *dScalar, Nd4jLong *dSscalarShapeInfo,
|
void *dScalar, Nd4jLong *dSscalarShapeInfo,
|
||||||
void *extraParams);
|
void *extraParams, bool allowParallelism = true);
|
||||||
|
|
||||||
static void execScalar(nd4j::LaunchContext *lc,
|
static void execScalar(nd4j::LaunchContext *lc,
|
||||||
int opNum,
|
int opNum,
|
||||||
@ -334,7 +334,7 @@ static void execTransformAny(nd4j::LaunchContext *lc,
|
|||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
void *hZ, Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets);
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, bool allowParallelism = true);
|
||||||
|
|
||||||
static void execTransformStrict(nd4j::LaunchContext *lc,
|
static void execTransformStrict(nd4j::LaunchContext *lc,
|
||||||
int opNum,
|
int opNum,
|
||||||
|
@ -181,16 +181,16 @@ void NDArray::swapUnsafe(NDArray& other) {
|
|||||||
void NDArray::synchronize(const char* msg) const {
|
void NDArray::synchronize(const char* msg) const {
|
||||||
// no-op
|
// no-op
|
||||||
}
|
}
|
||||||
void NDArray::prepareSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables) {
|
void NDArray::prepareSpecialUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList, bool synchronizeWritables) {
|
||||||
// no-op
|
// no-op
|
||||||
}
|
}
|
||||||
void NDArray::registerSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList) {
|
void NDArray::registerSpecialUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList) {
|
||||||
// no-op
|
// no-op
|
||||||
}
|
}
|
||||||
void NDArray::preparePrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables) {
|
void NDArray::preparePrimaryUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList, bool synchronizeWritables) {
|
||||||
// no-op
|
// no-op
|
||||||
}
|
}
|
||||||
void NDArray::registerPrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList) {
|
void NDArray::registerPrimaryUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList) {
|
||||||
// no-op
|
// no-op
|
||||||
}
|
}
|
||||||
void NDArray::syncShape() const {
|
void NDArray::syncShape() const {
|
||||||
|
@ -48,6 +48,13 @@
|
|||||||
#include <exceptions/datatype_exception.h>
|
#include <exceptions/datatype_exception.h>
|
||||||
|
|
||||||
|
|
||||||
|
#ifdef _OPENMP
|
||||||
|
|
||||||
|
#include <omp.h>
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
@ -67,6 +74,10 @@ void NativeOpExecutioner::execIndexReduceScalar(nd4j::LaunchContext *lc, int op
|
|||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
void *hZ, Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo) {
|
void *dZ, Nd4jLong *dZShapeInfo) {
|
||||||
|
|
||||||
|
#ifdef _OPENMP
|
||||||
|
omp_set_nested(1);
|
||||||
|
#endif
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto hz = reinterpret_cast<Nd4jLong*>(hZ);
|
auto hz = reinterpret_cast<Nd4jLong*>(hZ);
|
||||||
|
|
||||||
@ -95,6 +106,9 @@ void NativeOpExecutioner::execIndexReduce(nd4j::LaunchContext *lc,
|
|||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength,
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||||||
|
#ifdef _OPENMP
|
||||||
|
omp_set_nested(1);
|
||||||
|
#endif
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
Nd4jLong* hz = reinterpret_cast<Nd4jLong*>(hZ);
|
Nd4jLong* hz = reinterpret_cast<Nd4jLong*>(hZ);
|
||||||
@ -129,6 +143,10 @@ void NativeOpExecutioner::execBroadcast(nd4j::LaunchContext *lc,
|
|||||||
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
|
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
|
||||||
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
|
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
|
||||||
|
|
||||||
|
#ifdef _OPENMP
|
||||||
|
omp_set_nested(1);
|
||||||
|
#endif
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
@ -155,6 +173,10 @@ void NativeOpExecutioner::execInverseBroadcast(nd4j::LaunchContext *lc,
|
|||||||
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
|
#ifdef _OPENMP
|
||||||
|
omp_set_nested(1);
|
||||||
|
#endif
|
||||||
|
|
||||||
if (!nd4j::Environment::getInstance()->isExperimentalBuild())
|
if (!nd4j::Environment::getInstance()->isExperimentalBuild())
|
||||||
if ((yType != xType && yType != nd4j::DataType::BOOL) || xType != zType)
|
if ((yType != xType && yType != nd4j::DataType::BOOL) || xType != zType)
|
||||||
throw nd4j::datatype_exception::build("NativeOps::execBroadcast both operands must have same data type", xType, yType);
|
throw nd4j::datatype_exception::build("NativeOps::execBroadcast both operands must have same data type", xType, yType);
|
||||||
@ -180,6 +202,9 @@ void NativeOpExecutioner::execBroadcastBool(nd4j::LaunchContext *lc,
|
|||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength,
|
||||||
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
|
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
|
||||||
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
|
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
|
||||||
|
#ifdef _OPENMP
|
||||||
|
omp_set_nested(1);
|
||||||
|
#endif
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
||||||
@ -199,6 +224,11 @@ void NativeOpExecutioner::execInverseBroadcastBool(nd4j::LaunchContext *lc,
|
|||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength,
|
||||||
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
|
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
|
||||||
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
|
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
|
||||||
|
|
||||||
|
#ifdef _OPENMP
|
||||||
|
omp_set_nested(1);
|
||||||
|
#endif
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
@ -232,6 +262,9 @@ void NativeOpExecutioner::execPairwiseTransform(nd4j::LaunchContext *lc,
|
|||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
void *hZ, Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
void *extraParams) {
|
void *extraParams) {
|
||||||
|
#ifdef _OPENMP
|
||||||
|
omp_set_nested(1);
|
||||||
|
#endif
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
||||||
@ -254,6 +287,9 @@ void NativeOpExecutioner::execPairwiseBoolTransform(nd4j::LaunchContext *lc,
|
|||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
void *hZ, Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
void *extraParams) {
|
void *extraParams) {
|
||||||
|
#ifdef _OPENMP
|
||||||
|
omp_set_nested(1);
|
||||||
|
#endif
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
||||||
@ -282,6 +318,10 @@ void NativeOpExecutioner::execReduceFloat(nd4j::LaunchContext *lc,
|
|||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength,
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||||||
|
|
||||||
|
#ifdef _OPENMP
|
||||||
|
omp_set_nested(1);
|
||||||
|
#endif
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
@ -298,6 +338,9 @@ void NativeOpExecutioner::execReduceSame(nd4j::LaunchContext *lc,
|
|||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength,
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||||||
|
#ifdef _OPENMP
|
||||||
|
omp_set_nested(1);
|
||||||
|
#endif
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
@ -315,6 +358,9 @@ void NativeOpExecutioner::execReduceBool(nd4j::LaunchContext *lc,
|
|||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength,
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||||||
|
#ifdef _OPENMP
|
||||||
|
omp_set_nested(1);
|
||||||
|
#endif
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
@ -332,6 +378,9 @@ void NativeOpExecutioner::execReduceLong(nd4j::LaunchContext *lc,
|
|||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength,
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||||||
|
#ifdef _OPENMP
|
||||||
|
omp_set_nested(1);
|
||||||
|
#endif
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
@ -355,6 +404,9 @@ void NativeOpExecutioner::execReduceFloatScalar(nd4j::LaunchContext *lc,
|
|||||||
void *extraParams,
|
void *extraParams,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
void *hZ, Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo) {
|
void *dZ, Nd4jLong *dZShapeInfo) {
|
||||||
|
#ifdef _OPENMP
|
||||||
|
omp_set_nested(1);
|
||||||
|
#endif
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
@ -370,6 +422,9 @@ void NativeOpExecutioner::execReduceSameScalar(nd4j::LaunchContext *lc,
|
|||||||
void *extraParams,
|
void *extraParams,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
void *hZ, Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo) {
|
void *dZ, Nd4jLong *dZShapeInfo) {
|
||||||
|
#ifdef _OPENMP
|
||||||
|
omp_set_nested(1);
|
||||||
|
#endif
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
|
|
||||||
@ -385,6 +440,10 @@ void NativeOpExecutioner::execReduceBoolScalar(nd4j::LaunchContext *lc,
|
|||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
void *hZ, Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo) {
|
void *dZ, Nd4jLong *dZShapeInfo) {
|
||||||
|
|
||||||
|
#ifdef _OPENMP
|
||||||
|
omp_set_nested(1);
|
||||||
|
#endif
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
@ -399,6 +458,9 @@ void NativeOpExecutioner::execReduceLongScalar(nd4j::LaunchContext *lc,
|
|||||||
void *extraParams,
|
void *extraParams,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
void *hZ, Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo) {
|
void *dZ, Nd4jLong *dZShapeInfo) {
|
||||||
|
#ifdef _OPENMP
|
||||||
|
omp_set_nested(1);
|
||||||
|
#endif
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
@ -430,6 +492,9 @@ void NativeOpExecutioner::execReduce3Scalar(nd4j::LaunchContext *lc,
|
|||||||
void *dY, Nd4jLong *dYShapeInfo,
|
void *dY, Nd4jLong *dYShapeInfo,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
void *hZ, Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo) {
|
void *dZ, Nd4jLong *dZShapeInfo) {
|
||||||
|
#ifdef _OPENMP
|
||||||
|
omp_set_nested(1);
|
||||||
|
#endif
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
@ -459,6 +524,9 @@ void NativeOpExecutioner::execReduce3(nd4j::LaunchContext *lc,
|
|||||||
void *dY, Nd4jLong *dYShapeInfo,
|
void *dY, Nd4jLong *dYShapeInfo,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
void *hZ, Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo) {
|
void *dZ, Nd4jLong *dZShapeInfo) {
|
||||||
|
#ifdef _OPENMP
|
||||||
|
omp_set_nested(1);
|
||||||
|
#endif
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
@ -480,6 +548,9 @@ void NativeOpExecutioner::execReduce3(nd4j::LaunchContext *lc,
|
|||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength,
|
||||||
Nd4jLong *xTadOnlyShapeInfo, Nd4jLong *xTadOffsets,
|
Nd4jLong *xTadOnlyShapeInfo, Nd4jLong *xTadOffsets,
|
||||||
Nd4jLong *yTadOnlyShapeInfo, Nd4jLong *yTadOffsets) {
|
Nd4jLong *yTadOnlyShapeInfo, Nd4jLong *yTadOffsets) {
|
||||||
|
#ifdef _OPENMP
|
||||||
|
omp_set_nested(1);
|
||||||
|
#endif
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
@ -501,6 +572,9 @@ void NativeOpExecutioner::execReduce3All(nd4j::LaunchContext *lc,
|
|||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength,
|
||||||
Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets,
|
Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets,
|
||||||
Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets) {
|
Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets) {
|
||||||
|
#ifdef _OPENMP
|
||||||
|
omp_set_nested(1);
|
||||||
|
#endif
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
@ -523,6 +597,10 @@ void NativeOpExecutioner::execReduce3TAD(nd4j::LaunchContext *lc,
|
|||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
||||||
Nd4jLong *yTadShapeInfo, Nd4jLong *yTadOffsets) {
|
Nd4jLong *yTadShapeInfo, Nd4jLong *yTadOffsets) {
|
||||||
|
|
||||||
|
#ifdef _OPENMP
|
||||||
|
omp_set_nested(1);
|
||||||
|
#endif
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
@ -551,7 +629,10 @@ void NativeOpExecutioner::execScalar(nd4j::LaunchContext *lc,
|
|||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
void *hScalar, Nd4jLong *hScalarShapeInfo,
|
void *hScalar, Nd4jLong *hScalarShapeInfo,
|
||||||
void *dScalar, Nd4jLong *dScalarShapeInfo,
|
void *dScalar, Nd4jLong *dScalarShapeInfo,
|
||||||
void *extraParams) {
|
void *extraParams, bool allowParallelism) {
|
||||||
|
#ifdef _OPENMP
|
||||||
|
omp_set_nested(1);
|
||||||
|
#endif
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
|
||||||
@ -563,7 +644,7 @@ void NativeOpExecutioner::execScalar(nd4j::LaunchContext *lc,
|
|||||||
if (xType != yType || xType != zType)
|
if (xType != yType || xType != zType)
|
||||||
throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalar", zType, xType, yType);
|
throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalar", zType, xType, yType);
|
||||||
|
|
||||||
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::scalar::ScalarTransform, ::transform(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, hScalar, extraParams), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::scalar::ScalarTransform, ::transform(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, hScalar, extraParams, allowParallelism), LIBND4J_TYPES);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -580,6 +661,9 @@ void NativeOpExecutioner::execScalar(nd4j::LaunchContext *lc,
|
|||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength,
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
||||||
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
||||||
|
#ifdef _OPENMP
|
||||||
|
omp_set_nested(1);
|
||||||
|
#endif
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
|
||||||
@ -604,7 +688,11 @@ void NativeOpExecutioner::execScalarBool(nd4j::LaunchContext *lc,
|
|||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
void *hScalar, Nd4jLong *hSscalarShapeInfo,
|
void *hScalar, Nd4jLong *hSscalarShapeInfo,
|
||||||
void *dScalar, Nd4jLong *dSscalarShapeInfo,
|
void *dScalar, Nd4jLong *dSscalarShapeInfo,
|
||||||
void *extraParams) {
|
void *extraParams, bool allowParallelism) {
|
||||||
|
|
||||||
|
#ifdef _OPENMP
|
||||||
|
omp_set_nested(1);
|
||||||
|
#endif
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hSscalarShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hSscalarShapeInfo);
|
||||||
@ -632,6 +720,9 @@ void NativeOpExecutioner::execScalarBool(nd4j::LaunchContext *lc,
|
|||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength,
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
||||||
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
||||||
|
#ifdef _OPENMP
|
||||||
|
omp_set_nested(1);
|
||||||
|
#endif
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
|
||||||
@ -664,6 +755,9 @@ void NativeOpExecutioner::execSummaryStats(nd4j::LaunchContext *lc,
|
|||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
void *hZ, Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
bool biasCorrected) {
|
bool biasCorrected) {
|
||||||
|
#ifdef _OPENMP
|
||||||
|
omp_set_nested(1);
|
||||||
|
#endif
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
@ -689,6 +783,9 @@ void NativeOpExecutioner::execSummaryStatsScalar(nd4j::LaunchContext *lc,
|
|||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
void *hZ, Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
bool biasCorrected) {
|
bool biasCorrected) {
|
||||||
|
#ifdef _OPENMP
|
||||||
|
omp_set_nested(1);
|
||||||
|
#endif
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
@ -718,6 +815,9 @@ void NativeOpExecutioner::execSummaryStats(nd4j::LaunchContext *lc,
|
|||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength,
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
||||||
bool biasCorrected) {
|
bool biasCorrected) {
|
||||||
|
#ifdef _OPENMP
|
||||||
|
omp_set_nested(1);
|
||||||
|
#endif
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
@ -745,6 +845,9 @@ void NativeOpExecutioner::execTransformFloat(nd4j::LaunchContext *lc,
|
|||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||||||
|
#ifdef _OPENMP
|
||||||
|
omp_set_nested(1);
|
||||||
|
#endif
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
@ -761,6 +864,9 @@ void NativeOpExecutioner::execTransformBool(nd4j::LaunchContext *lc,
|
|||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||||||
|
#ifdef _OPENMP
|
||||||
|
omp_set_nested(1);
|
||||||
|
#endif
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
@ -776,12 +882,15 @@ void NativeOpExecutioner::execTransformAny(nd4j::LaunchContext *lc,
|
|||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
void *hZ, Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, bool allowParallelism) {
|
||||||
|
#ifdef _OPENMP
|
||||||
|
omp_set_nested(1);
|
||||||
|
#endif
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, tadShapeInfo, tadOffsets), LIBND4J_TYPES, LIBND4J_TYPES);
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, tadShapeInfo, tadOffsets, allowParallelism), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
@ -793,6 +902,9 @@ void NativeOpExecutioner::execTransformSame(nd4j::LaunchContext *lc,
|
|||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||||||
|
#ifdef _OPENMP
|
||||||
|
omp_set_nested(1);
|
||||||
|
#endif
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
@ -809,6 +921,9 @@ void NativeOpExecutioner::execTransformStrict(nd4j::LaunchContext *lc,
|
|||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||||||
|
#ifdef _OPENMP
|
||||||
|
omp_set_nested(1);
|
||||||
|
#endif
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
@ -823,6 +938,9 @@ void NativeOpExecutioner::execRandom(nd4j::LaunchContext *lc,
|
|||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
void *hZ, Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
void *extraArguments) {
|
void *extraArguments) {
|
||||||
|
#ifdef _OPENMP
|
||||||
|
omp_set_nested(1);
|
||||||
|
#endif
|
||||||
|
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
@ -841,6 +959,9 @@ void NativeOpExecutioner::execRandom(nd4j::LaunchContext *lc,
|
|||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
void *hZ, Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
void *extraArguments) {
|
void *extraArguments) {
|
||||||
|
#ifdef _OPENMP
|
||||||
|
omp_set_nested(1);
|
||||||
|
#endif
|
||||||
|
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
@ -861,6 +982,9 @@ void NativeOpExecutioner::execRandom(nd4j::LaunchContext *lc,
|
|||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
void *hZ, Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
void *extraArguments) {
|
void *extraArguments) {
|
||||||
|
#ifdef _OPENMP
|
||||||
|
omp_set_nested(1);
|
||||||
|
#endif
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
|
@ -1895,33 +1895,31 @@ void execRandom2(Nd4jPointer *extraPointers,
|
|||||||
}
|
}
|
||||||
|
|
||||||
Nd4jPointer initRandom(Nd4jPointer *extraPointers, long seed, long bufferSize, Nd4jPointer ptrToBuffer) {
|
Nd4jPointer initRandom(Nd4jPointer *extraPointers, long seed, long bufferSize, Nd4jPointer ptrToBuffer) {
|
||||||
auto ptrBuf = reinterpret_cast<long *>(ptrToBuffer);
|
graph::RandomGenerator* generator = new graph::RandomGenerator(seed, seed);
|
||||||
auto buffer = new nd4j::random::RandomBuffer(seed, bufferSize, reinterpret_cast<uint64_t *>(ptrBuf));
|
// auto ptrBuf = reinterpret_cast<long *>(ptrToBuffer);
|
||||||
|
// auto buffer = new nd4j::random::RandomBuffer(seed, bufferSize, reinterpret_cast<uint64_t *>(ptrBuf));
|
||||||
nd4j::random::Xoroshiro128 generator(buffer);
|
//
|
||||||
generator.refreshBuffer();
|
// nd4j::random::Xoroshiro128 generator(buffer);
|
||||||
|
// generator.refreshBuffer();
|
||||||
return (Nd4jPointer) buffer;
|
//
|
||||||
|
return (Nd4jPointer) generator;
|
||||||
}
|
}
|
||||||
|
|
||||||
void refreshBuffer(Nd4jPointer *extraPointers, long seed, Nd4jPointer ptrRandom) {
|
void refreshBuffer(Nd4jPointer *extraPointers, long seed, Nd4jPointer ptrRandom) {
|
||||||
auto buffer = reinterpret_cast<nd4j::random::RandomBuffer *> (ptrRandom);
|
auto generator = reinterpret_cast<nd4j::graph::RandomGenerator*> (ptrRandom);
|
||||||
|
|
||||||
buffer->setSeed(seed);
|
generator->setStates(seed);
|
||||||
buffer->setOffset(0);
|
|
||||||
nd4j::random::Xoroshiro128 generator(buffer);
|
|
||||||
generator.refreshBuffer();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void reSeedBuffer(Nd4jPointer *extraPointers, long seed, Nd4jPointer ptrRandom) {
|
void reSeedBuffer(Nd4jPointer *extraPointers, long seed, Nd4jPointer ptrRandom) {
|
||||||
auto buffer = reinterpret_cast<nd4j::random::RandomBuffer *> (ptrRandom);
|
auto generator = reinterpret_cast<nd4j::graph::RandomGenerator *> (ptrRandom);
|
||||||
|
|
||||||
buffer->reSeed(seed);
|
generator->setStates(seed);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void destroyRandom(Nd4jPointer ptrBuffer) {
|
void destroyRandom(Nd4jPointer ptrBuffer) {
|
||||||
auto buffer = reinterpret_cast<nd4j::random::RandomBuffer *>(ptrBuffer);
|
auto buffer = reinterpret_cast<nd4j::graph::RandomGenerator*>(ptrBuffer);
|
||||||
delete buffer;
|
delete buffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -231,7 +231,7 @@ void NDArray::synchronize(const char* msg) const {
|
|||||||
throw std::runtime_error(msg + std::string(": synchronization failed !"));
|
throw std::runtime_error(msg + std::string(": synchronization failed !"));
|
||||||
}
|
}
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::prepareSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables) {
|
void NDArray::prepareSpecialUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList, bool synchronizeWritables) {
|
||||||
|
|
||||||
for (const auto& a : readList)
|
for (const auto& a : readList)
|
||||||
if(a != nullptr)
|
if(a != nullptr)
|
||||||
@ -247,7 +247,7 @@ void NDArray::prepareSpecialUse(const std::initializer_list<const NDArray*>& wri
|
|||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::registerSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList) {
|
void NDArray::registerSpecialUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList) {
|
||||||
|
|
||||||
for (const auto& p : readList)
|
for (const auto& p : readList)
|
||||||
if(p != nullptr)
|
if(p != nullptr)
|
||||||
@ -259,7 +259,7 @@ void NDArray::registerSpecialUse(const std::initializer_list<const NDArray*>& wr
|
|||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::preparePrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables) {
|
void NDArray::preparePrimaryUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList, bool synchronizeWritables) {
|
||||||
|
|
||||||
for (const auto& a : readList)
|
for (const auto& a : readList)
|
||||||
if(a != nullptr)
|
if(a != nullptr)
|
||||||
@ -275,7 +275,7 @@ void NDArray::preparePrimaryUse(const std::initializer_list<const NDArray*>& wri
|
|||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::registerPrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList) {
|
void NDArray::registerPrimaryUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList) {
|
||||||
|
|
||||||
for (const auto& p : readList)
|
for (const auto& p : readList)
|
||||||
if(p != nullptr)
|
if(p != nullptr)
|
||||||
|
@ -20,6 +20,7 @@
|
|||||||
#include <helpers/DebugHelper.h>
|
#include <helpers/DebugHelper.h>
|
||||||
#include <DataTypeUtils.h>
|
#include <DataTypeUtils.h>
|
||||||
#include <exceptions/datatype_exception.h>
|
#include <exceptions/datatype_exception.h>
|
||||||
|
#include <exceptions/cuda_exception.h>
|
||||||
#include <helpers/CudaLaunchHelper.h>
|
#include <helpers/CudaLaunchHelper.h>
|
||||||
#include <helpers/ShapeBuilders.h>
|
#include <helpers/ShapeBuilders.h>
|
||||||
#include <PointersManager.h>
|
#include <PointersManager.h>
|
||||||
@ -112,7 +113,10 @@ void NativeOpExecutioner::execPairwiseTransform(nd4j::LaunchContext *lc,
|
|||||||
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::pairwise_transforms::PairWiseTransform, ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraParams), LIBND4J_TYPES)
|
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::pairwise_transforms::PairWiseTransform, ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraParams), LIBND4J_TYPES)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
DEBUG_KERNEL(stream, opNum);
|
// TODO: remove after the release
|
||||||
|
auto res = cudaStreamSynchronize(*stream);
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("execPairwiseTransform failed", res);
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
@ -141,6 +145,11 @@ void NativeOpExecutioner::execPairwiseBoolTransform( nd4j::LaunchContext *lc,
|
|||||||
dim3 launchDims(256, 1024, 16384);
|
dim3 launchDims(256, 1024, 16384);
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::pairwise_transforms::PairWiseBoolTransform, ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraParams), LIBND4J_TYPES, BOOL_TYPES)
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::pairwise_transforms::PairWiseBoolTransform, ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraParams), LIBND4J_TYPES, BOOL_TYPES)
|
||||||
|
|
||||||
|
// TODO: remove after the release
|
||||||
|
auto res = cudaStreamSynchronize(*stream);
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("execPairwiseBoolTransform failed", res);
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
@ -162,6 +171,11 @@ void NativeOpExecutioner::execSummaryStatsScalar(nd4j::LaunchContext *lc,
|
|||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::summarystats::SummaryStatsReduce, ::execSummaryStatsReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, nullptr, nullptr, biasCorrected, reductionPointer), LIBND4J_TYPES, FLOAT_TYPES);
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::summarystats::SummaryStatsReduce, ::execSummaryStatsReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, nullptr, nullptr, biasCorrected, reductionPointer), LIBND4J_TYPES, FLOAT_TYPES);
|
||||||
|
|
||||||
|
// TODO: remove after the release
|
||||||
|
auto res = cudaStreamSynchronize(*stream);
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("execSummaryStatsScalar failed", res);
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
@ -196,7 +210,10 @@ void NativeOpExecutioner::execBroadcastBool(nd4j::LaunchContext *lc,
|
|||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES)
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES)
|
||||||
|
|
||||||
DEBUG_KERNEL(stream, opNum);
|
// TODO: remove after the release
|
||||||
|
auto res = cudaStreamSynchronize(*stream);
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("execBroadcastBool failed", res);
|
||||||
}
|
}
|
||||||
|
|
||||||
void NativeOpExecutioner::execInverseBroadcastBool(nd4j::LaunchContext *lc,
|
void NativeOpExecutioner::execInverseBroadcastBool(nd4j::LaunchContext *lc,
|
||||||
@ -229,7 +246,10 @@ void NativeOpExecutioner::execInverseBroadcastBool(nd4j::LaunchContext *lc,
|
|||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execInverseBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES)
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execInverseBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES)
|
||||||
|
|
||||||
DEBUG_KERNEL(stream, opNum);
|
// TODO: remove after the release
|
||||||
|
auto res = cudaStreamSynchronize(*stream);
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("execInverseBroadcastBool failed", res);
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
@ -274,7 +294,10 @@ void NativeOpExecutioner::execBroadcast(nd4j::LaunchContext *lc,
|
|||||||
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::broadcast::Broadcast, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::broadcast::Broadcast, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
DEBUG_KERNEL(stream, opNum);
|
// TODO: remove after the release
|
||||||
|
auto res = cudaStreamSynchronize(*stream);
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("execBroadcast failed", res);
|
||||||
}
|
}
|
||||||
|
|
||||||
void NativeOpExecutioner::execInverseBroadcast(nd4j::LaunchContext *lc,
|
void NativeOpExecutioner::execInverseBroadcast(nd4j::LaunchContext *lc,
|
||||||
@ -306,7 +329,10 @@ void NativeOpExecutioner::execInverseBroadcast(nd4j::LaunchContext *lc,
|
|||||||
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::broadcast::Broadcast, ::execInverseBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::broadcast::Broadcast, ::execInverseBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
DEBUG_KERNEL(stream, opNum);
|
// TODO: remove after the release
|
||||||
|
auto res = cudaStreamSynchronize(*stream);
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("execInverseBroadcast failed", res);
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
@ -336,9 +362,12 @@ void NativeOpExecutioner::execReduceSame(nd4j::LaunchContext *lc,
|
|||||||
auto numBlocks = shape::length(hZShapeInfo);
|
auto numBlocks = shape::length(hZShapeInfo);
|
||||||
dim3 launchDims(numBlocks, 256, 8192);
|
dim3 launchDims(numBlocks, 256, 8192);
|
||||||
|
|
||||||
BUILD_SINGLE_SELECTOR(xType, functions::reduce::ReduceSameFunction, ::execReduceXD(launchDims, stream, opNum, xRank, dX, dXShapeInfo, extraParams, dZ, dZShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR(xType, functions::reduce::ReduceSameFunction, ::execReduceXD(launchDims, stream, opNum, xRank, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets), LIBND4J_TYPES);
|
||||||
|
|
||||||
nd4j::DebugHelper::checkErrorCode(stream, "execReduceSame(...) failed");
|
// TODO: remove after the release
|
||||||
|
auto res = cudaStreamSynchronize(*stream);
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("execReduceSame failed", res);
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
@ -368,9 +397,12 @@ void NativeOpExecutioner::execReduceLong(nd4j::LaunchContext *lc,
|
|||||||
auto numBlocks = shape::length(hZShapeInfo);
|
auto numBlocks = shape::length(hZShapeInfo);
|
||||||
dim3 launchDims(numBlocks, 256, 32768);
|
dim3 launchDims(numBlocks, 256, 32768);
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceLongFunction, ::execReduceXD(launchDims, stream, opNum, xRank, dX, dXShapeInfo, extraParams, dZ, dZShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets), LIBND4J_TYPES, LONG_TYPES);
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceLongFunction, ::execReduceXD(launchDims, stream, opNum, xRank, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets), LIBND4J_TYPES, LONG_TYPES);
|
||||||
|
|
||||||
nd4j::DebugHelper::checkErrorCode(stream, "execReduceLong(...) failed");
|
// TODO: remove after the release
|
||||||
|
auto res = cudaStreamSynchronize(*stream);
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("execReduceLong failed", res);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -401,9 +433,12 @@ void NativeOpExecutioner::execReduceBool(nd4j::LaunchContext *lc,
|
|||||||
auto numBlocks = shape::length(hZShapeInfo);
|
auto numBlocks = shape::length(hZShapeInfo);
|
||||||
dim3 launchDims(numBlocks, 256, 32768);
|
dim3 launchDims(numBlocks, 256, 32768);
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceBoolFunction, ::execReduceXD(launchDims, stream, opNum, xRank, dX, dXShapeInfo, extraParams, dZ, dZShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets), LIBND4J_TYPES, BOOL_TYPES);
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceBoolFunction, ::execReduceXD(launchDims, stream, opNum, xRank, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets), LIBND4J_TYPES, BOOL_TYPES);
|
||||||
|
|
||||||
nd4j::DebugHelper::checkErrorCode(stream, "execReduceBool(...) failed");
|
// TODO: remove after the release
|
||||||
|
auto res = cudaStreamSynchronize(*stream);
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("execReduceBool failed", res);
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
@ -446,6 +481,11 @@ void NativeOpExecutioner::execIndexReduce(nd4j::LaunchContext *lc,
|
|||||||
auto dz = reinterpret_cast<Nd4jLong*>(dZ);
|
auto dz = reinterpret_cast<Nd4jLong*>(dZ);
|
||||||
|
|
||||||
BUILD_SINGLE_SELECTOR(xType, functions::indexreduce::IndexReduce, ::executeIndexReduce(launchDims, stream, opNum, dX, dXShapeInfo, shape::rank(hXShapeInfo), extraParams, dz, dZShapeInfo, shape::rank(hZShapeInfo), dimension, dimensionLength, 1, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR(xType, functions::indexreduce::IndexReduce, ::executeIndexReduce(launchDims, stream, opNum, dX, dXShapeInfo, shape::rank(hXShapeInfo), extraParams, dz, dZShapeInfo, shape::rank(hZShapeInfo), dimension, dimensionLength, 1, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets), LIBND4J_TYPES);
|
||||||
|
|
||||||
|
// TODO: remove after the release
|
||||||
|
auto res = cudaStreamSynchronize(*stream);
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("execIndexReduce failed", res);
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
@ -481,7 +521,12 @@ void NativeOpExecutioner::execReduceFloat(nd4j::LaunchContext *lc,
|
|||||||
auto numBlocks = shape::length(hZShapeInfo);
|
auto numBlocks = shape::length(hZShapeInfo);
|
||||||
dim3 launchDims(numBlocks, 256, 32768);
|
dim3 launchDims(numBlocks, 256, 32768);
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceFloatFunction, ::execReduceXD(launchDims, stream, opNum, xRank, dX,dXShapeInfo, extraParams, dZ, dZShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets), LIBND4J_TYPES, FLOAT_TYPES);
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceFloatFunction, ::execReduceXD(launchDims, stream, opNum, xRank, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets), LIBND4J_TYPES, FLOAT_TYPES);
|
||||||
|
|
||||||
|
// TODO: remove after the release
|
||||||
|
auto res = cudaStreamSynchronize(*stream);
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("execReduceFloat failed", res);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -536,7 +581,10 @@ void NativeOpExecutioner::execIndexReduceScalar(nd4j::LaunchContext *lc,
|
|||||||
1,
|
1,
|
||||||
allocationPointer, reductionPointer,
|
allocationPointer, reductionPointer,
|
||||||
nullptr, nullptr), LIBND4J_TYPES);
|
nullptr, nullptr), LIBND4J_TYPES);
|
||||||
nd4j::DebugHelper::checkErrorCode(stream, "execIndexReduceScalar(...) failed");
|
// TODO: remove after the release
|
||||||
|
auto res = cudaStreamSynchronize(*stream);
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("execIndexReduceScalar failed", res);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -560,7 +608,12 @@ void NativeOpExecutioner::execReduceFloatScalar(nd4j::LaunchContext *lc,
|
|||||||
auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth);
|
auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth);
|
||||||
dim3 launchDims(numBlocks, blockWidth, 32768);
|
dim3 launchDims(numBlocks, blockWidth, 32768);
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceFloatFunction, ::execReduceScalar(launchDims, stream, opNum, dX,dXShapeInfo, extraParams, dZ,dZShapeInfo, nullptr, 0, reductionPointer, nullptr), LIBND4J_TYPES, FLOAT_TYPES);
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceFloatFunction, ::execReduceScalar(launchDims, stream, opNum, dX,dXShapeInfo, hXShapeInfo, extraParams, dZ,dZShapeInfo, hZShapeInfo, nullptr, 0, reductionPointer, nullptr), LIBND4J_TYPES, FLOAT_TYPES);
|
||||||
|
|
||||||
|
// TODO: remove after the release
|
||||||
|
auto res = cudaStreamSynchronize(*stream);
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("execReduceFloatScalar failed", res);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -587,7 +640,12 @@ void NativeOpExecutioner::execReduceBoolScalar(nd4j::LaunchContext *lc,
|
|||||||
auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth);
|
auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth);
|
||||||
dim3 launchDims(numBlocks, blockWidth, 32768);
|
dim3 launchDims(numBlocks, blockWidth, 32768);
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceBoolFunction, ::execReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, extraParams, dZ, dZShapeInfo, nullptr, 0, reductionPointer, nullptr), LIBND4J_TYPES, BOOL_TYPES);
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceBoolFunction, ::execReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, nullptr, 0, reductionPointer, nullptr), LIBND4J_TYPES, BOOL_TYPES);
|
||||||
|
|
||||||
|
// TODO: remove after the release
|
||||||
|
auto res = cudaStreamSynchronize(*stream);
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("execReduceBoolScalar failed", res);
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
@ -613,7 +671,12 @@ void NativeOpExecutioner::execReduceSameScalar(nd4j::LaunchContext *lc,
|
|||||||
auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth);
|
auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth);
|
||||||
dim3 launchDims(numBlocks, blockWidth, 32768);
|
dim3 launchDims(numBlocks, blockWidth, 32768);
|
||||||
|
|
||||||
BUILD_SINGLE_SELECTOR(xType, functions::reduce::ReduceSameFunction, ::execReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, extraParams, dZ, dZShapeInfo, nullptr, 0, reductionPointer, nullptr), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR(xType, functions::reduce::ReduceSameFunction, ::execReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, nullptr, 0, reductionPointer, nullptr), LIBND4J_TYPES);
|
||||||
|
|
||||||
|
// TODO: remove after the release
|
||||||
|
auto res = cudaStreamSynchronize(*stream);
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("execReduceSameScalar failed", res);
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
@ -639,7 +702,12 @@ void NativeOpExecutioner::execReduceLongScalar(nd4j::LaunchContext *lc,
|
|||||||
auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth);
|
auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth);
|
||||||
dim3 launchDims(numBlocks, blockWidth, 32768);
|
dim3 launchDims(numBlocks, blockWidth, 32768);
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceLongFunction, ::execReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, extraParams, dZ, dZShapeInfo, nullptr, 0, reductionPointer, nullptr), LIBND4J_TYPES, LONG_TYPES);
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceLongFunction, ::execReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, nullptr, 0, reductionPointer, nullptr), LIBND4J_TYPES, LONG_TYPES);
|
||||||
|
|
||||||
|
// TODO: remove after the release
|
||||||
|
auto res = cudaStreamSynchronize(*stream);
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("execReduceLongScalar failed", res);
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
@ -665,7 +733,10 @@ void NativeOpExecutioner::execTransformSame(nd4j::LaunchContext *lc,
|
|||||||
|
|
||||||
BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformSame, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformSame, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES);
|
||||||
|
|
||||||
nd4j::DebugHelper::checkErrorCode(stream, "execTransformSame(...) failed");
|
// TODO: remove after the release
|
||||||
|
auto res = cudaStreamSynchronize(*stream);
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("execTransformSame failed", res);
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
@ -690,6 +761,11 @@ void NativeOpExecutioner::execTransformBool(nd4j::LaunchContext *lc,
|
|||||||
throw std::runtime_error("NativeOpExecutioner::execTransformBool requires Z to have same boolean type");
|
throw std::runtime_error("NativeOpExecutioner::execTransformBool requires Z to have same boolean type");
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformBool, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, BOOL_TYPES);
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformBool, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, BOOL_TYPES);
|
||||||
|
|
||||||
|
// TODO: remove after the release
|
||||||
|
auto res = cudaStreamSynchronize(*stream);
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("execTransformBool failed", res);
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
@ -700,7 +776,7 @@ void NativeOpExecutioner::execTransformAny(nd4j::LaunchContext *lc,
|
|||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
void *hZ, Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, bool allowParallelism) {
|
||||||
|
|
||||||
auto stream = lc->getCudaStream();
|
auto stream = lc->getCudaStream();
|
||||||
|
|
||||||
@ -751,6 +827,11 @@ void NativeOpExecutioner::execTransformAny(nd4j::LaunchContext *lc,
|
|||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, LIBND4J_TYPES);
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: remove after the release
|
||||||
|
auto res = cudaStreamSynchronize(*stream);
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("execTransformAny failed", res);
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
@ -775,6 +856,11 @@ void NativeOpExecutioner::execTransformStrict(nd4j::LaunchContext *lc,
|
|||||||
throw datatype_exception::build("NativeOpExecutioner::execTransformStrict requires X & Z to have same floating point type", xType, zType);
|
throw datatype_exception::build("NativeOpExecutioner::execTransformStrict requires X & Z to have same floating point type", xType, zType);
|
||||||
|
|
||||||
BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformStrict, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformStrict, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), FLOAT_TYPES);
|
||||||
|
|
||||||
|
// TODO: remove after the release
|
||||||
|
auto res = cudaStreamSynchronize(*stream);
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("execTransformStrict failed", res);
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
@ -800,6 +886,11 @@ void NativeOpExecutioner::execTransformFloat(nd4j::LaunchContext *lc,
|
|||||||
|
|
||||||
dim3 launchDims(512, 512, 16384);
|
dim3 launchDims(512, 512, 16384);
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformFloat, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, FLOAT_TYPES);
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformFloat, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, FLOAT_TYPES);
|
||||||
|
|
||||||
|
// TODO: remove after the release
|
||||||
|
auto res = cudaStreamSynchronize(*stream);
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("execTransformFloat failed", res);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -825,6 +916,11 @@ void NativeOpExecutioner::execSummaryStats(nd4j::LaunchContext *lc,
|
|||||||
throw nd4j::datatype_exception::build("NativeOpExecutioner::execSummaryStats requires Z operand to have floating point data type", zType);
|
throw nd4j::datatype_exception::build("NativeOpExecutioner::execSummaryStats requires Z operand to have floating point data type", zType);
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::summarystats::SummaryStatsReduce, ::execSummaryStatsReduce(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, nullptr, nullptr, biasCorrected, reductionPointer), LIBND4J_TYPES, FLOAT_TYPES);
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::summarystats::SummaryStatsReduce, ::execSummaryStatsReduce(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, nullptr, nullptr, biasCorrected, reductionPointer), LIBND4J_TYPES, FLOAT_TYPES);
|
||||||
|
|
||||||
|
// TODO: remove after the release
|
||||||
|
auto res = cudaStreamSynchronize(*stream);
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("execSummaryStats A failed", res);
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
@ -850,6 +946,11 @@ void NativeOpExecutioner::execSummaryStats(nd4j::LaunchContext *lc,
|
|||||||
throw nd4j::datatype_exception::build("NativeOpExecutioner::execSummaryStats requires Z operand to have floating point data type", zType);
|
throw nd4j::datatype_exception::build("NativeOpExecutioner::execSummaryStats requires Z operand to have floating point data type", zType);
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::summarystats::SummaryStatsReduce, ::execSummaryStatsReduce(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, biasCorrected, reductionPointer), LIBND4J_TYPES, FLOAT_TYPES);
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::summarystats::SummaryStatsReduce, ::execSummaryStatsReduce(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets, biasCorrected, reductionPointer), LIBND4J_TYPES, FLOAT_TYPES);
|
||||||
|
|
||||||
|
// TODO: remove after the release
|
||||||
|
auto res = cudaStreamSynchronize(*stream);
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("execSummaryStats B failed", res);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -884,7 +985,10 @@ void NativeOpExecutioner::execReduce3(nd4j::LaunchContext *lc,
|
|||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::execScalar(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, extraParams, dZ, dZShapeInfo, allocationPointer, reductionPointer, nullptr), LIBND4J_TYPES, FLOAT_TYPES);
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::execScalar(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, extraParams, dZ, dZShapeInfo, allocationPointer, reductionPointer, nullptr), LIBND4J_TYPES, FLOAT_TYPES);
|
||||||
|
|
||||||
DEBUG_KERNEL(stream, opNum);
|
// TODO: remove after the release
|
||||||
|
auto res = cudaStreamSynchronize(*stream);
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("execReduce3 failed", res);
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
@ -933,6 +1037,11 @@ void NativeOpExecutioner::execReduce3(nd4j::LaunchContext *lc,
|
|||||||
allocationPointer,
|
allocationPointer,
|
||||||
tadOnlyShapeInfo, tadOffsets,
|
tadOnlyShapeInfo, tadOffsets,
|
||||||
yTadOnlyShapeInfo, yTadOffsets), LIBND4J_TYPES, FLOAT_TYPES);
|
yTadOnlyShapeInfo, yTadOffsets), LIBND4J_TYPES, FLOAT_TYPES);
|
||||||
|
|
||||||
|
// TODO: remove after the release
|
||||||
|
auto res = cudaStreamSynchronize(*stream);
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("execReduce3 B failed", res);
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
@ -967,6 +1076,11 @@ void NativeOpExecutioner::execReduce3Scalar(nd4j::LaunchContext *lc,
|
|||||||
throw nd4j::datatype_exception::build("NativeOpExecutioner::execReduce3Scalar requires Z operand to have floating point data type", zType);
|
throw nd4j::datatype_exception::build("NativeOpExecutioner::execReduce3Scalar requires Z operand to have floating point data type", zType);
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::execScalar(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, extraParams, dZ, dZShapeInfo, allocationPointer, reductionPointer, nullptr), LIBND4J_TYPES, FLOAT_TYPES);
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::execScalar(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, extraParams, dZ, dZShapeInfo, allocationPointer, reductionPointer, nullptr), LIBND4J_TYPES, FLOAT_TYPES);
|
||||||
|
|
||||||
|
// TODO: remove after the release
|
||||||
|
auto res = cudaStreamSynchronize(*stream);
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("execReduce3Scalar failed", res);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -979,7 +1093,7 @@ void NativeOpExecutioner::execScalarBool(nd4j::LaunchContext *lc,
|
|||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
void *hScalar, Nd4jLong *hScalarShapeInfo,
|
void *hScalar, Nd4jLong *hScalarShapeInfo,
|
||||||
void *dScalar, Nd4jLong *dScalarShapeInfo,
|
void *dScalar, Nd4jLong *dScalarShapeInfo,
|
||||||
void *extraParams) {
|
void *extraParams, bool allowParallelism) {
|
||||||
|
|
||||||
auto stream = lc->getCudaStream();
|
auto stream = lc->getCudaStream();
|
||||||
|
|
||||||
@ -997,7 +1111,10 @@ void NativeOpExecutioner::execScalarBool(nd4j::LaunchContext *lc,
|
|||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::scalar::ScalarBoolTransform, ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, dZ, dZShapeInfo, dScalar, extraParams), LIBND4J_TYPES, BOOL_TYPES);
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::scalar::ScalarBoolTransform, ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, dZ, dZShapeInfo, dScalar, extraParams), LIBND4J_TYPES, BOOL_TYPES);
|
||||||
|
|
||||||
DEBUG_KERNEL(stream, opNum);
|
// TODO: remove after the release
|
||||||
|
auto res = cudaStreamSynchronize(*stream);
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("execScalarBool failed", res);
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
@ -1030,7 +1147,10 @@ void NativeOpExecutioner::execScalarBool(nd4j::LaunchContext *lc,
|
|||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::scalar::ScalarBoolTransform, ::executeCudaAlongDimension(launchDims, stream, opNum, dX, dXShapeInfo, dZ, dZShapeInfo, dScalars, extraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES);
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::scalar::ScalarBoolTransform, ::executeCudaAlongDimension(launchDims, stream, opNum, dX, dXShapeInfo, dZ, dZShapeInfo, dScalars, extraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES);
|
||||||
|
|
||||||
DEBUG_KERNEL(stream, opNum);
|
// TODO: remove after the release
|
||||||
|
auto res = cudaStreamSynchronize(*stream);
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("execScalarBool B failed", res);
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
@ -1042,7 +1162,7 @@ void NativeOpExecutioner::execScalar(nd4j::LaunchContext *lc,
|
|||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
void *hScalar, Nd4jLong *hScalarShapeInfo,
|
void *hScalar, Nd4jLong *hScalarShapeInfo,
|
||||||
void *dScalar, Nd4jLong *dScalarShapeInfo,
|
void *dScalar, Nd4jLong *dScalarShapeInfo,
|
||||||
void *extraParams) {
|
void *extraParams, bool allowParallelism) {
|
||||||
|
|
||||||
auto stream = lc->getCudaStream();
|
auto stream = lc->getCudaStream();
|
||||||
|
|
||||||
@ -1059,7 +1179,10 @@ void NativeOpExecutioner::execScalar(nd4j::LaunchContext *lc,
|
|||||||
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::scalar::ScalarTransform, ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, dZ, dZShapeInfo, hZShapeInfo, dScalar, extraParams), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::scalar::ScalarTransform, ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, dZ, dZShapeInfo, hZShapeInfo, dScalar, extraParams), LIBND4J_TYPES);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
DEBUG_KERNEL(stream, opNum);
|
// TODO: remove after the release
|
||||||
|
auto res = cudaStreamSynchronize(*stream);
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("execScalar failed", res);
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
@ -1090,7 +1213,10 @@ void NativeOpExecutioner::execScalar(nd4j::LaunchContext *lc,
|
|||||||
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::scalar::ScalarTransform, ::executeCudaAlongDimension(launchDims, stream, opNum, dX, dXShapeInfo, dZ, dZShapeInfo, dScalars, extraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::scalar::ScalarTransform, ::executeCudaAlongDimension(launchDims, stream, opNum, dX, dXShapeInfo, dZ, dZShapeInfo, dScalars, extraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
DEBUG_KERNEL(stream, opNum);
|
// TODO: remove after the release
|
||||||
|
auto res = cudaStreamSynchronize(*stream);
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("execScalar B failed", res);
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
@ -1117,7 +1243,10 @@ void NativeOpExecutioner::execRandom(nd4j::LaunchContext *lc,
|
|||||||
// functions::random::RandomFunction<float>::executeCudaSingle(launchDims, extraPointers, opNum, stateHost, dZ, dZShapeInfo, extraArguments),
|
// functions::random::RandomFunction<float>::executeCudaSingle(launchDims, extraPointers, opNum, stateHost, dZ, dZShapeInfo, extraArguments),
|
||||||
BUILD_SINGLE_SELECTOR(zType, functions::random::RandomFunction, ::executeCudaSingle(launchDims, stream, opNum, stateDevice, dZ, dZShapeInfo, extraArguments), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(zType, functions::random::RandomFunction, ::executeCudaSingle(launchDims, stream, opNum, stateDevice, dZ, dZShapeInfo, extraArguments), FLOAT_TYPES);
|
||||||
|
|
||||||
checkCudaErrors(cudaStreamSynchronize(*stream));
|
res = cudaStreamSynchronize(*stream);
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("execRandom X failed", res);
|
||||||
|
|
||||||
cudaFree(stateDevice);
|
cudaFree(stateDevice);
|
||||||
|
|
||||||
rng->rewindH(shape::length(hZShapeInfo));
|
rng->rewindH(shape::length(hZShapeInfo));
|
||||||
@ -1149,7 +1278,10 @@ void NativeOpExecutioner::execRandom(nd4j::LaunchContext *lc,
|
|||||||
// functions::random::RandomFunction<float>::executeCudaDouble(launchDims, extraPointers, opNum, stateHost, dX, dXShapeInfo, dZ, dZShapeInfo, extraArguments);
|
// functions::random::RandomFunction<float>::executeCudaDouble(launchDims, extraPointers, opNum, stateHost, dX, dXShapeInfo, dZ, dZShapeInfo, extraArguments);
|
||||||
BUILD_SINGLE_SELECTOR(xType, functions::random::RandomFunction, ::executeCudaDouble(launchDims, stream, opNum, stateDevice, dX, dXShapeInfo, dZ, dZShapeInfo, extraArguments), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(xType, functions::random::RandomFunction, ::executeCudaDouble(launchDims, stream, opNum, stateDevice, dX, dXShapeInfo, dZ, dZShapeInfo, extraArguments), FLOAT_TYPES);
|
||||||
|
|
||||||
checkCudaErrors(cudaStreamSynchronize(*stream));
|
res = cudaStreamSynchronize(*stream);
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("execRandom XY failed", res);
|
||||||
|
|
||||||
cudaFree(stateDevice);
|
cudaFree(stateDevice);
|
||||||
|
|
||||||
rng->rewindH(shape::length(hZShapeInfo));
|
rng->rewindH(shape::length(hZShapeInfo));
|
||||||
@ -1182,7 +1314,10 @@ void NativeOpExecutioner::execRandom(nd4j::LaunchContext *lc,
|
|||||||
// functions::random::RandomFunction<float>::executeCudaTriple(launchDims, extraPointers, opNum, stateHost, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraArguments);
|
// functions::random::RandomFunction<float>::executeCudaTriple(launchDims, extraPointers, opNum, stateHost, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraArguments);
|
||||||
BUILD_SINGLE_SELECTOR(xType, functions::random::RandomFunction, ::executeCudaTriple(launchDims, stream, opNum, stateDevice, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraArguments), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(xType, functions::random::RandomFunction, ::executeCudaTriple(launchDims, stream, opNum, stateDevice, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraArguments), FLOAT_TYPES);
|
||||||
|
|
||||||
checkCudaErrors(cudaStreamSynchronize(*stream));
|
res = cudaStreamSynchronize(*stream);
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("execRandom XYZ failed", res);
|
||||||
|
|
||||||
cudaFree(stateDevice);
|
cudaFree(stateDevice);
|
||||||
|
|
||||||
rng->rewindH(shape::length(hZShapeInfo));
|
rng->rewindH(shape::length(hZShapeInfo));
|
||||||
@ -1223,7 +1358,10 @@ void NativeOpExecutioner::execReduce3All(nd4j::LaunchContext *lc,
|
|||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::execAll(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, extraParamsVals, dZ, dZShapeInfo, dimension, dimensionLength, 1, allocationPointer, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets), LIBND4J_TYPES, FLOAT_TYPES);
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::execAll(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, extraParamsVals, dZ, dZShapeInfo, dimension, dimensionLength, 1, allocationPointer, xTadShapeInfo, xOffsets, yTadShapeInfo, yOffsets), LIBND4J_TYPES, FLOAT_TYPES);
|
||||||
|
|
||||||
DEBUG_KERNEL(stream, opNum);
|
// TODO: remove after the release
|
||||||
|
auto res = cudaStreamSynchronize(*stream);
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("execReduce3All failed", res);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -1263,5 +1401,10 @@ void NativeOpExecutioner::execReduce3TAD(nd4j::LaunchContext *lc,
|
|||||||
dim3 launchDims(numBlocks, 256, 32768);
|
dim3 launchDims(numBlocks, 256, 32768);
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::exec(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, extraParams, dZ, dZShapeInfo, dimension, dimensionLength, 1, allocationPointer, tadShapeInfo, tadOffsets, yTadShapeInfo, yTadOffsets), LIBND4J_TYPES, FLOAT_TYPES);
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce3::Reduce3, ::exec(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, extraParams, dZ, dZShapeInfo, dimension, dimensionLength, 1, allocationPointer, tadShapeInfo, tadOffsets, yTadShapeInfo, yTadOffsets), LIBND4J_TYPES, FLOAT_TYPES);
|
||||||
|
|
||||||
|
// TODO: remove after the release
|
||||||
|
auto res = cudaStreamSynchronize(*stream);
|
||||||
|
if (res != 0)
|
||||||
|
throw cuda_exception::build("execReduce3TAD failed", res);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -473,7 +473,7 @@ void execReduceLong(Nd4jPointer *extraPointers,
|
|||||||
auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth);
|
auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth);
|
||||||
dim3 launchDims(numBlocks, blockWidth, 32768);
|
dim3 launchDims(numBlocks, blockWidth, 32768);
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceLongFunction, ::execReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, extraParams, dZ, dZShapeInfo, nullptr, 0, reductionPointer, dTADShapeInfo), LIBND4J_TYPES, LONG_TYPES);
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceLongFunction, ::execReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hXShapeInfo, nullptr, 0, reductionPointer, dTADShapeInfo), LIBND4J_TYPES, LONG_TYPES);
|
||||||
|
|
||||||
nd4j::DebugHelper::checkErrorCode(stream, "execReduceLong(...) failed");
|
nd4j::DebugHelper::checkErrorCode(stream, "execReduceLong(...) failed");
|
||||||
}
|
}
|
||||||
@ -526,7 +526,7 @@ void execReduceBool(Nd4jPointer *extraPointers,
|
|||||||
auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth);
|
auto numBlocks = CudaLaunchHelper::getReductionBlocks(xLength, blockWidth);
|
||||||
dim3 launchDims(numBlocks, blockWidth, 32768);
|
dim3 launchDims(numBlocks, blockWidth, 32768);
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceBoolFunction, ::execReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, extraParams, dZ, dZShapeInfo, nullptr, 0, reductionPointer, dTADShapeInfo), LIBND4J_TYPES, BOOL_TYPES);
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::reduce::ReduceBoolFunction, ::execReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, extraParams, dZ, dZShapeInfo, hZShapeInfo, nullptr, 0, reductionPointer, dTADShapeInfo), LIBND4J_TYPES, BOOL_TYPES);
|
||||||
|
|
||||||
nd4j::DebugHelper::checkErrorCode(stream, "execReduceBool(...) failed");
|
nd4j::DebugHelper::checkErrorCode(stream, "execReduceBool(...) failed");
|
||||||
}
|
}
|
||||||
@ -649,7 +649,8 @@ void execTransformAny(Nd4jPointer *extraPointers,int opNum,
|
|||||||
void *extraParams) {
|
void *extraParams) {
|
||||||
|
|
||||||
auto stream = reinterpret_cast<cudaStream_t*>(extraPointers[1]);
|
auto stream = reinterpret_cast<cudaStream_t*>(extraPointers[1]);
|
||||||
LaunchContext lc(stream, extraPointers[4], extraPointers[5], extraPointers[3]);
|
auto streamSpecial = reinterpret_cast<cudaStream_t&>(extraPointers[4]);
|
||||||
|
LaunchContext lc(stream, streamSpecial, extraPointers[5], extraPointers[3], reinterpret_cast<int*>(extraPointers[6]));
|
||||||
|
|
||||||
// FIXME: remove this once all operations are enabled
|
// FIXME: remove this once all operations are enabled
|
||||||
if (opNum == nd4j::transform::IsMax && extraParams != nullptr) {
|
if (opNum == nd4j::transform::IsMax && extraParams != nullptr) {
|
||||||
@ -1309,6 +1310,7 @@ void concat(
|
|||||||
Nd4jPointer *tadPointers, Nd4jPointer *offsetPointers) {
|
Nd4jPointer *tadPointers, Nd4jPointer *offsetPointers) {
|
||||||
|
|
||||||
auto stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]);
|
auto stream = reinterpret_cast<cudaStream_t *>(extraPointers[1]);
|
||||||
|
|
||||||
auto hXShapeInfo = hZShapeInfo;
|
auto hXShapeInfo = hZShapeInfo;
|
||||||
auto hShapePointers = reinterpret_cast<Nd4jLong **>(inputShapeInfo);
|
auto hShapePointers = reinterpret_cast<Nd4jLong **>(inputShapeInfo);
|
||||||
auto dShapePointers = reinterpret_cast<Nd4jLong **>(dinputShapeInfo);
|
auto dShapePointers = reinterpret_cast<Nd4jLong **>(dinputShapeInfo);
|
||||||
@ -1323,8 +1325,7 @@ void concat(
|
|||||||
// take into account indices for first array
|
// take into account indices for first array
|
||||||
auto axisSize = shape::sizeAt(reinterpret_cast<Nd4jLong*>(inputShapeInfo[0]), axis);
|
auto axisSize = shape::sizeAt(reinterpret_cast<Nd4jLong*>(inputShapeInfo[0]), axis);
|
||||||
indices[0][2 * axis + 1] = axisSize;
|
indices[0][2 * axis + 1] = axisSize;
|
||||||
//printf("The axe size is %lld\n", axisSize);
|
|
||||||
// loop through the rest of input arrays
|
|
||||||
for(int i = 1; i < numArrays; ++i) {
|
for(int i = 1; i < numArrays; ++i) {
|
||||||
indices[i][2 * axis] = indices[i-1][2 * axis + 1]; // index start from
|
indices[i][2 * axis] = indices[i-1][2 * axis + 1]; // index start from
|
||||||
indices[i][2 * axis + 1] = indices[i-1][2 * axis + 1] + shape::sizeAt(reinterpret_cast<Nd4jLong*>(inputShapeInfo[i]), axis); // index end with (excluding)
|
indices[i][2 * axis + 1] = indices[i-1][2 * axis + 1] + shape::sizeAt(reinterpret_cast<Nd4jLong*>(inputShapeInfo[i]), axis); // index end with (excluding)
|
||||||
@ -1336,25 +1337,12 @@ void concat(
|
|||||||
specialBufferAndShapeWithOffset(dZ, hZShapeInfo, dZShapeInfo, indices[i], outSubArrsBuffs[i], outSubArrsShapes[i]);
|
specialBufferAndShapeWithOffset(dZ, hZShapeInfo, dZShapeInfo, indices[i], outSubArrsBuffs[i], outSubArrsShapes[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// prepare arrays of pointers on buffers and shapes
|
|
||||||
std::vector<void*> hOutBuffers(numArrays), hInBuffers(numArrays);
|
|
||||||
std::vector<Nd4jLong*> hOutShapeInfo(numArrays), hInShapeInfo(numArrays);
|
|
||||||
for(int i = 0; i < numArrays; ++i) {
|
|
||||||
hOutBuffers[i] = outSubArrsBuffs[i];
|
|
||||||
hInBuffers[i] = ddata[i];//->getSpecialBuffer();
|
|
||||||
hOutShapeInfo[i] = outSubArrsShapes[i];
|
|
||||||
hInShapeInfo[i] = (Nd4jLong*)(dShapePointers[i]);//->getSpecialShapeInfo();
|
|
||||||
// nd4j_printf("X_%i shape ptr: %p; data ptr: %p;\n", i, hInShapeInfo[i], hInBuffers[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
// nd4j_printf(" done\n", "");
|
|
||||||
LaunchContext context(stream);
|
LaunchContext context(stream);
|
||||||
// allocate and copy all buffers and shapes arrays to global memory
|
|
||||||
PointersManager manager(&context, "concat");
|
PointersManager manager(&context, "concat");
|
||||||
void* dOutBuffers = manager.replicatePointer(hOutBuffers.data(), hOutBuffers.size() * sizeof(void*));
|
void* dOutBuffers = manager.replicatePointer(outSubArrsBuffs.data(), outSubArrsBuffs.size() * sizeof(void*));
|
||||||
void* dInBuffers = manager.replicatePointer(hInBuffers.data(), hInBuffers.size() * sizeof(void*));
|
void* dInBuffers = manager.replicatePointer(ddata, numArrays * sizeof(void*));
|
||||||
void* dInShapeInfo = manager.replicatePointer(hInShapeInfo.data(), hInShapeInfo.size() * sizeof(Nd4jLong*));
|
void* dInShapeInfo = manager.replicatePointer(dShapePointers, numArrays * sizeof(Nd4jLong*));
|
||||||
void* dOutShapeInfo = manager.replicatePointer(hOutShapeInfo.data(), hOutShapeInfo.size() * sizeof(Nd4jLong*));
|
void* dOutShapeInfo = manager.replicatePointer(outSubArrsShapes.data(), outSubArrsShapes.size() * sizeof(Nd4jLong*));
|
||||||
|
|
||||||
BUILD_SINGLE_SELECTOR(zType, concatCudaLauncher, (numArrays, stream, dInBuffers, dInShapeInfo, dOutBuffers, dOutShapeInfo), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR(zType, concatCudaLauncher, (numArrays, stream, dInBuffers, dInShapeInfo, dOutBuffers, dOutShapeInfo), LIBND4J_TYPES);
|
||||||
manager.synchronize();
|
manager.synchronize();
|
||||||
@ -1791,7 +1779,7 @@ void execReduce3Tad(Nd4jPointer *extraPointers,
|
|||||||
// NativeOpExecutioner::execReduce3TAD(nullptr, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets);
|
// NativeOpExecutioner::execReduce3TAD(nullptr, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets);
|
||||||
// }
|
// }
|
||||||
|
|
||||||
nd4j_printf("Starting...\n","");
|
// nd4j_printf("Starting...\n","");
|
||||||
|
|
||||||
auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, reinterpret_cast<int*>(hDimension), shape::length(hDimensionShape));
|
auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(hXShapeInfo, reinterpret_cast<int*>(hDimension), shape::length(hDimensionShape));
|
||||||
auto tadLength = shape::length(tadPack.primaryShapeInfo());
|
auto tadLength = shape::length(tadPack.primaryShapeInfo());
|
||||||
@ -1801,7 +1789,7 @@ void execReduce3Tad(Nd4jPointer *extraPointers,
|
|||||||
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
|
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
|
||||||
|
|
||||||
if (tadLength == yLength || tadLength == xLength) {
|
if (tadLength == yLength || tadLength == xLength) {
|
||||||
nd4j_printf("== way\n","");
|
// nd4j_printf("== way\n","");
|
||||||
NativeOpExecutioner::execReduce3(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hY, hYShapeInfo, dY,
|
NativeOpExecutioner::execReduce3(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hY, hYShapeInfo, dY,
|
||||||
dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength,
|
dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength,
|
||||||
tadOnlyShapeInfo, tadOffsets, yTadOnlyShapeInfo, yTadOffsets);
|
tadOnlyShapeInfo, tadOffsets, yTadOnlyShapeInfo, yTadOffsets);
|
||||||
@ -2694,7 +2682,7 @@ static FORCEINLINE Nd4jStatus realExec(nd4j::ops::DeclarableOp* op, Nd4jPointer*
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (canNullify)
|
if (canNullify && buffer != nullptr)
|
||||||
memset((uint8_t *) buffer, '\0', shape::length(shape) * DataTypeUtils::sizeOfElement(ArrayOptions::dataType(shape)));
|
memset((uint8_t *) buffer, '\0', shape::length(shape) * DataTypeUtils::sizeOfElement(ArrayOptions::dataType(shape)));
|
||||||
|
|
||||||
auto array = new nd4j::NDArray(buffer, bufferD, shape);
|
auto array = new nd4j::NDArray(buffer, bufferD, shape);
|
||||||
|
@ -94,9 +94,8 @@ static __global__ void usualCudaGemv(const bool transA, const int M, const int N
|
|||||||
|
|
||||||
T3 val = 0;
|
T3 val = 0;
|
||||||
if (row < M)
|
if (row < M)
|
||||||
for (int i = 0; i < N; i++) {
|
for (int i = 0; i < N; i++)
|
||||||
val = val + A[row * strideArow + i * strideAcol] * X[i * incx];
|
val = val + A[row * strideArow + i * strideAcol] * X[i * incx];
|
||||||
}
|
|
||||||
|
|
||||||
Y[row * incy] = alphaZ * val + betaZ * Y[row * incy];
|
Y[row * incy] = alphaZ * val + betaZ * Y[row * incy];
|
||||||
}
|
}
|
||||||
@ -230,6 +229,7 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
|
|||||||
status = cublasSgemm(*handle, transAblas, transBblas, M, N, K, &alphaF, (float*)pA->getSpecialBuffer(), lda, (float*)pB->getSpecialBuffer(), ldb, &betaF, (float*)pC->getSpecialBuffer(), ldc);
|
status = cublasSgemm(*handle, transAblas, transBblas, M, N, K, &alphaF, (float*)pA->getSpecialBuffer(), lda, (float*)pB->getSpecialBuffer(), ldb, &betaF, (float*)pC->getSpecialBuffer(), ldc);
|
||||||
}
|
}
|
||||||
else if(ABC && aType == DataType::HALF) {
|
else if(ABC && aType == DataType::HALF) {
|
||||||
|
printf("!!!!!!!!\n");
|
||||||
float16 alphaH(alpha), betaH(beta);
|
float16 alphaH(alpha), betaH(beta);
|
||||||
status = cublasHgemm(*handle, transAblas, transBblas, M, N, K, &alphaH.data, (__half*)pA->getSpecialBuffer(), lda, (__half*)pB->getSpecialBuffer(), ldb, &betaH.data, (__half*)pC->getSpecialBuffer(), ldc);
|
status = cublasHgemm(*handle, transAblas, transBblas, M, N, K, &alphaH.data, (__half*)pA->getSpecialBuffer(), lda, (__half*)pB->getSpecialBuffer(), ldb, &betaH.data, (__half*)pC->getSpecialBuffer(), ldc);
|
||||||
}
|
}
|
||||||
@ -250,8 +250,8 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
|
|||||||
blocksPerGrid.y = math::nd4j_ceil<double, int>(static_cast<double>(M) / threadsPerBlock.y); // rows
|
blocksPerGrid.y = math::nd4j_ceil<double, int>(static_cast<double>(M) / threadsPerBlock.y); // rows
|
||||||
}
|
}
|
||||||
|
|
||||||
//BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (blocksPerGrid, threadsPerBlock, stream, transA, transB, M, N, K, alpha, pA->getSpecialBuffer(), lda, pB->getSpecialBuffer(), ldb, beta, pC->getSpecialBuffer(), ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
|
BUILD_TRIPLE_SELECTOR(aType, bType, cType, usualGemm, (blocksPerGrid, threadsPerBlock, stream, transA, transB, M, N, K, alpha, pA->getSpecialBuffer(), lda, pB->getSpecialBuffer(), ldb, beta, pC->getSpecialBuffer(), ldc), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
|
||||||
BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, (blocksPerGrid, threadsPerBlock, stream, transA, transB, M, N, K, alpha, pA->getSpecialBuffer(), lda, pB->getSpecialBuffer(), ldb, beta, pC->getSpecialBuffer(), ldc), LIBND4J_TYPES)
|
// BUILD_SINGLE_SELECTOR_THRICE(aType, usualGemm, (blocksPerGrid, threadsPerBlock, stream, transA, transB, M, N, K, alpha, pA->getSpecialBuffer(), lda, pB->getSpecialBuffer(), ldb, beta, pC->getSpecialBuffer(), ldc), NUMERIC_TYPES)
|
||||||
}
|
}
|
||||||
|
|
||||||
if (status != CUBLAS_STATUS_SUCCESS) throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", status);
|
if (status != CUBLAS_STATUS_SUCCESS) throw cuda_exception::build("MmulHelper::mmulMxM cuda failed !", status);
|
||||||
@ -339,8 +339,8 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray*
|
|||||||
threadsPerBlock.x = 512;
|
threadsPerBlock.x = 512;
|
||||||
blocksPerGrid.x = math::nd4j_ceil<double, int>(static_cast<double>(M) / threadsPerBlock.x); // rows
|
blocksPerGrid.x = math::nd4j_ceil<double, int>(static_cast<double>(M) / threadsPerBlock.x); // rows
|
||||||
}
|
}
|
||||||
//BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, transA, M, N, alpha, pA->getSpecialBuffer(), lda, X->getSpecialBuffer(), incx, beta, Y->getSpecialBuffer(), incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
|
BUILD_TRIPLE_SELECTOR(aType, xType, yType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, transA, M, N, alpha, pA->getSpecialBuffer(), lda, X->getSpecialBuffer(), incx, beta, Y->getSpecialBuffer(), incy), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
|
||||||
BUILD_SINGLE_SELECTOR_THRICE(xType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, transA, M, N, alpha, pA->getSpecialBuffer(), lda, X->getSpecialBuffer(), incx, beta, Y->getSpecialBuffer(), incy), LIBND4J_TYPES)
|
// BUILD_SINGLE_SELECTOR_THRICE(xType, usualGemv, (blocksPerGrid, threadsPerBlock, stream, transA, M, N, alpha, pA->getSpecialBuffer(), lda, X->getSpecialBuffer(), incx, beta, Y->getSpecialBuffer(), incy), NUMERIC_TYPES)
|
||||||
}
|
}
|
||||||
|
|
||||||
if (status != CUBLAS_STATUS_SUCCESS) throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", status);
|
if (status != CUBLAS_STATUS_SUCCESS) throw cuda_exception::build("MmulHelper::mmulMxV cuda failed !", status);
|
||||||
@ -397,8 +397,8 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, nd4j::NDArray* Z, c
|
|||||||
|
|
||||||
NDArray::prepareSpecialUse({Z}, {X, Y});
|
NDArray::prepareSpecialUse({Z}, {X, Y});
|
||||||
|
|
||||||
//BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
|
BUILD_TRIPLE_SELECTOR(xType, yType, zType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
|
||||||
BUILD_SINGLE_SELECTOR_THRICE(xType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), LIBND4J_TYPES)
|
// BUILD_SINGLE_SELECTOR_THRICE(xType, usualDot, (blocksPerGrid, threadsPerBlock, stream, length, alpha, X->getSpecialBuffer(), incx, Y->getSpecialBuffer(), incy, beta, Z->getSpecialBuffer()), NUMERIC_TYPES)
|
||||||
|
|
||||||
auto cudaResult = cudaStreamSynchronize(*stream);
|
auto cudaResult = cudaStreamSynchronize(*stream);
|
||||||
if (cudaResult != 0) throw cuda_exception::build("MmulHelper::dot cuda failed !", cudaResult);
|
if (cudaResult != 0) throw cuda_exception::build("MmulHelper::dot cuda failed !", cudaResult);
|
||||||
@ -408,8 +408,8 @@ NDArray* MmulHelper::dot(const NDArray* X, const NDArray* Y, nd4j::NDArray* Z, c
|
|||||||
return Z;
|
return Z;
|
||||||
}
|
}
|
||||||
|
|
||||||
//BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* vA, const int lda, const void* vB, const int ldb, const double beta, void* vC, const int ldc), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
|
BUILD_TRIPLE_TEMPLATE(template void usualGemm, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const bool transB, const int M, const int N, const int K, const double alpha, const void* vA, const int lda, const void* vB, const int ldb, const double beta, void* vC, const int ldc), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
|
||||||
//BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const int M, const int N, const double alpha, const void* vA, const int lda, const void* vB, const int incx, const double beta, void* vC, const int incy), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
|
BUILD_TRIPLE_TEMPLATE(template void usualGemv, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const bool transA, const int M, const int N, const double alpha, const void* vA, const int lda, const void* vB, const int incx, const double beta, void* vC, const int incy), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
|
||||||
//BUILD_TRIPLE_TEMPLATE(template void usualDot, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ), LIBND4J_TYPES, FLOAT_TYPES, FLOAT_TYPES);
|
BUILD_TRIPLE_TEMPLATE(template void usualDot, (const dim3 &blocksPerGrid, const dim3 &threadsPerBlock, cudaStream_t *stream, const Nd4jLong length, const double alpha, const void* vX, const Nd4jLong incx, const void* vY, const Nd4jLong incy, const double beta, void* vZ), NUMERIC_TYPES, NUMERIC_TYPES, FLOAT_TYPES);
|
||||||
|
|
||||||
}
|
}
|
@ -61,7 +61,7 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons
|
|||||||
// fill input gradient arrays in accordance to kind of loss function
|
// fill input gradient arrays in accordance to kind of loss function
|
||||||
fillGradArrays(loss, std::vector<NDArray*>(&inArrsBP[numInArrsFF], &inArrsBP[numInArrsFF + numInGradArrsBP]));
|
fillGradArrays(loss, std::vector<NDArray*>(&inArrsBP[numInArrsFF], &inArrsBP[numInArrsFF + numInGradArrsBP]));
|
||||||
|
|
||||||
// beck prop pass
|
// back prop pass
|
||||||
ResultSet* outArrsBP = opBP.execute(argsHolderBP); // number of output arrays in back prop = numInArrsFF;
|
ResultSet* outArrsBP = opBP.execute(argsHolderBP); // number of output arrays in back prop = numInArrsFF;
|
||||||
|
|
||||||
NDArray tmpScalar(nd4j::DataType::DOUBLE, inArrsFF[0]->getContext()); // scalar = 0
|
NDArray tmpScalar(nd4j::DataType::DOUBLE, inArrsFF[0]->getContext()); // scalar = 0
|
||||||
|
@ -110,9 +110,9 @@ void ScalarTransform<X, Y, Z>::transform(const int opNum,
|
|||||||
void *z, Nd4jLong zStride,
|
void *z, Nd4jLong zStride,
|
||||||
void *scalar,
|
void *scalar,
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
const Nd4jLong n) {
|
const Nd4jLong n, bool allowParallelism) {
|
||||||
|
|
||||||
DISPATCH_BY_OPNUM_TTT(transform, PARAMS(x, xStride, z, zStride, scalar, extraParams, n), SCALAR_OPS);
|
DISPATCH_BY_OPNUM_TTT(transform, PARAMS(x, xStride, z, zStride, scalar, extraParams, n, allowParallelism), SCALAR_OPS);
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
@ -121,9 +121,9 @@ void ScalarTransform<X, Y, Z>::transform(const int opNum,
|
|||||||
void *x, Nd4jLong *xShapeInfo,
|
void *x, Nd4jLong *xShapeInfo,
|
||||||
void *z, Nd4jLong *zShapeInfo,
|
void *z, Nd4jLong *zShapeInfo,
|
||||||
void *scalar,
|
void *scalar,
|
||||||
void *extraParams) {
|
void *extraParams, bool allowParallelism) {
|
||||||
|
|
||||||
DISPATCH_BY_OPNUM_TTT(transform, PARAMS(x, xShapeInfo, z, zShapeInfo, scalar, extraParams), SCALAR_OPS);
|
DISPATCH_BY_OPNUM_TTT(transform, PARAMS(x, xShapeInfo, z, zShapeInfo, scalar, extraParams, allowParallelism), SCALAR_OPS);
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
@ -132,7 +132,7 @@ template<typename OpType>
|
|||||||
void ScalarTransform<X, Y, Z>::transform(void *vx, Nd4jLong *xShapeInfo,
|
void ScalarTransform<X, Y, Z>::transform(void *vx, Nd4jLong *xShapeInfo,
|
||||||
void *vz, Nd4jLong *zShapeInfo,
|
void *vz, Nd4jLong *zShapeInfo,
|
||||||
void *vscalar,
|
void *vscalar,
|
||||||
void *vextraParams) {
|
void *vextraParams, bool allowParallelism) {
|
||||||
|
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto z = reinterpret_cast<Z *>(vz);
|
auto z = reinterpret_cast<Z *>(vz);
|
||||||
@ -146,18 +146,18 @@ void ScalarTransform<X, Y, Z>::transform(void *vx, Nd4jLong *xShapeInfo,
|
|||||||
nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXZ(xShapeInfo, zShapeInfo);
|
nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXZ(xShapeInfo, zShapeInfo);
|
||||||
|
|
||||||
if (kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) {
|
if (kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) {
|
||||||
transform<OpType>(x, xEws, z, zEws, vscalar, extraParams, len);
|
transform<OpType>(x, xEws, z, zEws, vscalar, extraParams, len, allowParallelism);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
uint xShapeInfoCast[MAX_RANK];
|
uint xShapeInfoCast[MAX_RANK];
|
||||||
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo<uint>(xShapeInfo, xShapeInfoCast);
|
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo<uint>(xShapeInfo, xShapeInfoCast);
|
||||||
|
|
||||||
nd4j::OmpLaunchHelper info(len);
|
nd4j::OmpLaunchHelper info(len, allowParallelism ? -1 : 1);
|
||||||
|
|
||||||
if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
|
if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
PRAGMA_OMP_PARALLEL_THREADS_IF(info._numThreads, allowParallelism)
|
||||||
{
|
{
|
||||||
auto threadNum = omp_get_thread_num();
|
auto threadNum = omp_get_thread_num();
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
auto threadOffset = info.getThreadOffset(threadNum);
|
||||||
@ -175,7 +175,7 @@ void ScalarTransform<X, Y, Z>::transform(void *vx, Nd4jLong *xShapeInfo,
|
|||||||
uint zShapeInfoCast[MAX_RANK];
|
uint zShapeInfoCast[MAX_RANK];
|
||||||
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo<uint>(zShapeInfo, zShapeInfoCast);
|
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo<uint>(zShapeInfo, zShapeInfoCast);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
PRAGMA_OMP_PARALLEL_THREADS_IF(info._numThreads, allowParallelism)
|
||||||
{
|
{
|
||||||
auto threadNum = omp_get_thread_num();
|
auto threadNum = omp_get_thread_num();
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
auto threadOffset = info.getThreadOffset(threadNum);
|
||||||
@ -199,18 +199,18 @@ void ScalarTransform<X, Y, Z>::transform(void *vx, Nd4jLong xEws,
|
|||||||
void *vz, Nd4jLong zEws,
|
void *vz, Nd4jLong zEws,
|
||||||
void *vscalar,
|
void *vscalar,
|
||||||
void *vextraParams,
|
void *vextraParams,
|
||||||
const Nd4jLong len) {
|
const Nd4jLong len, bool allowParallelism) {
|
||||||
|
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto z = reinterpret_cast<Z *>(vz);
|
auto z = reinterpret_cast<Z *>(vz);
|
||||||
auto scalar = reinterpret_cast<Y *>(vscalar)[0];
|
auto scalar = reinterpret_cast<Y *>(vscalar)[0];
|
||||||
auto extraParams = reinterpret_cast<Z *>(vextraParams);
|
auto extraParams = reinterpret_cast<Z *>(vextraParams);
|
||||||
|
|
||||||
nd4j::OmpLaunchHelper info(len);
|
nd4j::OmpLaunchHelper info(len, allowParallelism ? -1 : 1);
|
||||||
|
|
||||||
if (xEws == 1 && zEws == 1) {
|
if (xEws == 1 && zEws == 1) {
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
PRAGMA_OMP_PARALLEL_THREADS_IF(info._numThreads, allowParallelism)
|
||||||
{
|
{
|
||||||
auto threadNum = omp_get_thread_num();
|
auto threadNum = omp_get_thread_num();
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
auto threadOffset = info.getThreadOffset(threadNum);
|
||||||
@ -225,7 +225,7 @@ void ScalarTransform<X, Y, Z>::transform(void *vx, Nd4jLong xEws,
|
|||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
PRAGMA_OMP_PARALLEL_THREADS_IF(info._numThreads, allowParallelism)
|
||||||
{
|
{
|
||||||
auto threadNum = omp_get_thread_num();
|
auto threadNum = omp_get_thread_num();
|
||||||
auto threadOffset = info.getThreadOffset(threadNum);
|
auto threadOffset = info.getThreadOffset(threadNum);
|
||||||
|
@ -38,8 +38,8 @@ namespace functions {
|
|||||||
Nd4jLong *zShapeInfo,
|
Nd4jLong *zShapeInfo,
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
Nd4jLong *tadShapeInfo,
|
Nd4jLong *tadShapeInfo,
|
||||||
Nd4jLong *tadOffsets) {
|
Nd4jLong *tadOffsets, bool allowParallelism) {
|
||||||
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, tadShapeInfo, tadOffsets), TRANSFORM_ANY_OPS);
|
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, z, zShapeInfo, extraParams, tadShapeInfo, tadOffsets, allowParallelism), TRANSFORM_ANY_OPS);
|
||||||
}
|
}
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////
|
||||||
@ -48,7 +48,7 @@ template<typename OpType>
|
|||||||
void _CUDA_H TransformAny<X, Z>::exec(void *vx, Nd4jLong *xShapeInfo,
|
void _CUDA_H TransformAny<X, Z>::exec(void *vx, Nd4jLong *xShapeInfo,
|
||||||
void *vz,Nd4jLong *zShapeInfo,
|
void *vz,Nd4jLong *zShapeInfo,
|
||||||
void *vextraParams,
|
void *vextraParams,
|
||||||
Nd4jLong *tadShapeInfo,Nd4jLong *tadOffsets) {
|
Nd4jLong *tadShapeInfo,Nd4jLong *tadOffsets, bool allowParallelism) {
|
||||||
|
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto z = reinterpret_cast<Z *>(vz);
|
auto z = reinterpret_cast<Z *>(vz);
|
||||||
@ -59,7 +59,10 @@ void _CUDA_H TransformAny<X, Z>::exec(void *vx, Nd4jLong *xShapeInfo,
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (allowParallelism)
|
||||||
nd4j::TransformLoops<X,Z,X>::template loopTransform<OpType, true>(x, xShapeInfo, z, zShapeInfo, extraParams);
|
nd4j::TransformLoops<X,Z,X>::template loopTransform<OpType, true>(x, xShapeInfo, z, zShapeInfo, extraParams);
|
||||||
|
else
|
||||||
|
nd4j::TransformLoops<X,Z,X>::template loopTransform<OpType, false>(x, xShapeInfo, z, zShapeInfo, extraParams);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -127,7 +127,7 @@ namespace functions {
|
|||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
tadLength = shape::length(tadOnlyShapeInfo);//shape::tadLength(xShapeInfo, dimension, dimensionLength);
|
tadLength = shape::length(tadOnlyShapeInfo);//shape::tadLength(xShapeInfo, dimension, dimensionLength);
|
||||||
tadEWS = shape::elementWiseStride(tadOnlyShapeInfo);
|
tadEWS = shape::elementWiseStride(tadOnlyShapeInfo);
|
||||||
numTads = shape::length(xShapeInfo) / tadLength;
|
numTads = shape::length(yShapeInfo) / tadLength;
|
||||||
xEWS = shape::elementWiseStride(xShapeInfo);
|
xEWS = shape::elementWiseStride(xShapeInfo);
|
||||||
zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ);
|
zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ);
|
||||||
}
|
}
|
||||||
|
@ -98,7 +98,6 @@ void __host__ PairWiseTransform<X,Y,Z>::intermediateShaped(dim3& launchDims, cud
|
|||||||
void *vextraParams){
|
void *vextraParams){
|
||||||
|
|
||||||
pairwiseSimpleShaped<X, Y, Z, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams);
|
pairwiseSimpleShaped<X, Y, Z, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams);
|
||||||
nd4j::DebugHelper::checkErrorCode(stream, "PWT (...) failed");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -24,6 +24,10 @@
|
|||||||
#include <loops/legacy_ops.h>
|
#include <loops/legacy_ops.h>
|
||||||
#include <helpers/DebugHelper.h>
|
#include <helpers/DebugHelper.h>
|
||||||
#include <types/types.h>
|
#include <types/types.h>
|
||||||
|
#include <execution/LaunchContext.h>
|
||||||
|
#include <exceptions/cuda_exception.h>
|
||||||
|
#include <loops/scalar.h>
|
||||||
|
|
||||||
|
|
||||||
using namespace simdOps;
|
using namespace simdOps;
|
||||||
|
|
||||||
@ -104,16 +108,17 @@ __device__ void ReduceBoolFunction<X,Z>::transformCudaXD( void *vx, Nd4jLong *xS
|
|||||||
|
|
||||||
//shared memory space for storing intermediate results
|
//shared memory space for storing intermediate results
|
||||||
__shared__ Z* sPartials;
|
__shared__ Z* sPartials;
|
||||||
__shared__ int tadLength;
|
__shared__ int tadLength, numTads;
|
||||||
__shared__ int numTads;
|
|
||||||
__shared__ bool isPlainOutput;
|
__shared__ bool isPlainOutput;
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
extern __shared__ unsigned char shmem[];
|
extern __shared__ unsigned char shmem[];
|
||||||
sPartials = reinterpret_cast<Z*>(shmem);
|
sPartials = reinterpret_cast<Z*>(shmem);
|
||||||
|
|
||||||
|
isPlainOutput = shape::order(zShapeInfo) == 'c' && shape::elementWiseStride(zShapeInfo) == 1;
|
||||||
|
|
||||||
tadLength = shape::length(tadOnlyShapeInfo); //tadLength(xShapeInfo, dimension, dimensionLength);
|
tadLength = shape::length(tadOnlyShapeInfo); //tadLength(xShapeInfo, dimension, dimensionLength);
|
||||||
numTads = shape::length(xShapeInfo) / tadLength;
|
numTads = shape::length(xShapeInfo) / tadLength;
|
||||||
isPlainOutput = shape::order(zShapeInfo) == 'c' && shape::elementWiseStride(zShapeInfo) == 1;
|
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
@ -230,34 +235,65 @@ __device__ void ReduceBoolFunction<X,Z>::execScalarCuda(void *vx, Nd4jLong *xSha
|
|||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
template <typename X, typename Z>
|
template <typename X, typename Z>
|
||||||
template<typename OpType>
|
template<typename OpType>
|
||||||
__host__ void ReduceBoolFunction<X,Z>::intermediateXD(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShape, void *extraParams, void *z, Nd4jLong *zShape, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
__host__ void ReduceBoolFunction<X,Z>::intermediateXD(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, Nd4jLong *hXShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||||||
|
|
||||||
simpleReduce<X, Z, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShape, extraParams, z, zShape, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets);
|
if(shape::isEmpty(hXShapeInfo)) {
|
||||||
|
|
||||||
|
if(shape::isEmpty(hZShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
|
const auto startingVal = static_cast<Z>(OpType::startingValue(reinterpret_cast<X*>(x)));
|
||||||
|
|
||||||
|
auto res = cudaMemcpyAsync(nd4j::LaunchContext::defaultContext()->getScalarPointer(), &startingVal, sizeof(Z), cudaMemcpyHostToDevice, *stream);
|
||||||
|
if (res != 0)
|
||||||
|
throw nd4j::cuda_exception::build("ReduceBoolFunction<X,Z>::intermediateXD: failed to copy temporary scalar", res);
|
||||||
|
|
||||||
|
auto ptr = nd4j::LaunchContext::defaultContext()->getScalarPointer();
|
||||||
|
|
||||||
|
// scalar assign
|
||||||
|
functions::scalar::ScalarTransform<Z, Z, Z>::executeCudaShaped(launchDims, stream, 14, z, zShapeInfo, hXShapeInfo, z, zShapeInfo, hZShapeInfo, ptr, nullptr);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
simpleReduce<X, Z, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets);
|
||||||
nd4j::DebugHelper::checkErrorCode(stream, "reduceBoolDim(...) failed");
|
nd4j::DebugHelper::checkErrorCode(stream, "reduceBoolDim(...) failed");
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
template <typename X, typename Z>
|
template <typename X, typename Z>
|
||||||
template<typename OpType>
|
template<typename OpType>
|
||||||
__host__ void ReduceBoolFunction<X,Z>::intermediateScalar(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo) {
|
__host__ void ReduceBoolFunction<X,Z>::intermediateScalar(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, Nd4jLong *hXShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo) {
|
||||||
|
|
||||||
|
if (shape::isEmpty(hXShapeInfo)) {
|
||||||
|
|
||||||
|
if (shape::isEmpty(hZShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
|
const auto startingVal = static_cast<Z>(OpType::startingValue(reinterpret_cast<X*>(x)));
|
||||||
|
|
||||||
|
auto res = cudaMemcpyAsync(z, &startingVal, sizeof(Z), cudaMemcpyHostToDevice, *stream);
|
||||||
|
if (res != 0)
|
||||||
|
throw nd4j::cuda_exception::build("ReduceBoolFunction<X,Z>::intermediateScalar: failed to copy resulting scalar", res);
|
||||||
|
}
|
||||||
|
else {
|
||||||
simpleScalar<X, Z, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo);
|
simpleScalar<X, Z, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo);
|
||||||
nd4j::DebugHelper::checkErrorCode(stream, "reduceBoolScalar(...) failed");
|
nd4j::DebugHelper::checkErrorCode(stream, "reduceBoolScalar(...) failed");
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
template <typename X, typename Y>
|
template <typename X, typename Y>
|
||||||
_CUDA_H void ReduceBoolFunction<X,Y>::execReduceScalar(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo) {
|
_CUDA_H void ReduceBoolFunction<X,Y>::execReduceScalar(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, Nd4jLong *hXShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo) {
|
||||||
|
|
||||||
DISPATCH_BY_OPNUM_TT(intermediateScalar, PARAMS(launchDims, stream, x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo), OPS_A(REDUCE_BOOL_OPS));
|
DISPATCH_BY_OPNUM_TT(intermediateScalar, PARAMS(launchDims, stream, x, xShapeInfo, hXShapeInfo, extraParams, z, zShapeInfo, hZShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo), OPS_A(REDUCE_BOOL_OPS));
|
||||||
nd4j::DebugHelper::checkErrorCode(stream, "execReduceScalarFloat(...) failed");
|
nd4j::DebugHelper::checkErrorCode(stream, "execReduceScalarFloat(...) failed");
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
template <typename X, typename Y>
|
template <typename X, typename Y>
|
||||||
_CUDA_H void ReduceBoolFunction<X,Y>::execReduceXD(dim3 launchDims, cudaStream_t *stream, int opNum, int rank, void *x, Nd4jLong *xShape, void *extraParams, void *z, Nd4jLong *zShape, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
_CUDA_H void ReduceBoolFunction<X,Y>::execReduceXD(dim3 launchDims, cudaStream_t *stream, int opNum, int rank, void *x, Nd4jLong *xShapeInfo, Nd4jLong *hXShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||||||
|
|
||||||
DISPATCH_BY_OPNUM_TT(intermediateXD, PARAMS(launchDims, stream, x, xShape, extraParams, z, zShape, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets), OPS_A(REDUCE_BOOL_OPS));
|
DISPATCH_BY_OPNUM_TT(intermediateXD, PARAMS(launchDims, stream, x, xShapeInfo, hXShapeInfo, extraParams, z, zShapeInfo, hZShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets), OPS_A(REDUCE_BOOL_OPS));
|
||||||
DEBUG_KERNEL(stream, opNum);
|
DEBUG_KERNEL(stream, opNum);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -18,13 +18,19 @@
|
|||||||
// @author raver119@gmail.com
|
// @author raver119@gmail.com
|
||||||
//
|
//
|
||||||
|
|
||||||
|
#include <execution/LaunchContext.h>
|
||||||
|
#include <exceptions/cuda_exception.h>
|
||||||
#include <op_boilerplate.h>
|
#include <op_boilerplate.h>
|
||||||
#include <loops/reduce_float.h>
|
#include <loops/reduce_float.h>
|
||||||
|
#include <loops/scalar.h>
|
||||||
#include <loops/legacy_ops.h>
|
#include <loops/legacy_ops.h>
|
||||||
#include <helpers/DebugHelper.h>
|
#include <helpers/DebugHelper.h>
|
||||||
#include <types/types.h>
|
#include <types/types.h>
|
||||||
#include <specials_cuda.h>
|
#include <specials_cuda.h>
|
||||||
|
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
using namespace simdOps;
|
using namespace simdOps;
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
@ -103,26 +109,26 @@ __device__ void ReduceFloatFunction<X,Z>::transformCudaXD( void *vx, Nd4jLong *x
|
|||||||
|
|
||||||
//shared memory space for storing intermediate results
|
//shared memory space for storing intermediate results
|
||||||
__shared__ Z* sPartials;
|
__shared__ Z* sPartials;
|
||||||
__shared__ int tadLength;
|
__shared__ int tadLength, numTads;
|
||||||
__shared__ int numTads;
|
|
||||||
__shared__ bool isPlainOutput;
|
__shared__ bool isPlainOutput;
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
extern __shared__ unsigned char shmem[];
|
extern __shared__ unsigned char shmem[];
|
||||||
sPartials = reinterpret_cast<Z*>(shmem);
|
sPartials = reinterpret_cast<Z*>(shmem);
|
||||||
tadLength = shape::length(tadOnlyShapeInfo);//shape::tadLength(xShapeInfo, dimension, dimensionLength);
|
|
||||||
numTads = shape::length(xShapeInfo) / tadLength;
|
|
||||||
isPlainOutput = shape::order(zShapeInfo) == 'c' && shape::elementWiseStride(zShapeInfo) == 1;
|
isPlainOutput = shape::order(zShapeInfo) == 'c' && shape::elementWiseStride(zShapeInfo) == 1;
|
||||||
|
|
||||||
|
tadLength = shape::length(tadOnlyShapeInfo);
|
||||||
|
numTads = shape::length(xShapeInfo) / tadLength;
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
for (int r = blockIdx.x; r < numTads; r += gridDim.x) {
|
for (int r = blockIdx.x; r < numTads; r += gridDim.x) {
|
||||||
|
|
||||||
Nd4jLong tadOffsetForBlock = tadOffsets[r];
|
auto tadOffsetForBlock = tadOffsets[r];
|
||||||
sPartials[threadIdx.x] = OpType::startingValue(x + tadOffsetForBlock);
|
sPartials[threadIdx.x] = OpType::startingValue(x + tadOffsetForBlock);
|
||||||
|
|
||||||
for (int i = threadIdx.x; i < tadLength; i += blockDim.x) {
|
for (int i = threadIdx.x; i < tadLength; i += blockDim.x) {
|
||||||
|
|
||||||
auto xOffset = tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo, tadLength);
|
auto xOffset = tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo, tadLength);
|
||||||
sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[xOffset], extraParams), extraParams);
|
sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[xOffset], extraParams), extraParams);
|
||||||
}
|
}
|
||||||
@ -130,7 +136,6 @@ __device__ void ReduceFloatFunction<X,Z>::transformCudaXD( void *vx, Nd4jLong *x
|
|||||||
|
|
||||||
// aggregate. do NOT reduce for elements > tadLength
|
// aggregate. do NOT reduce for elements > tadLength
|
||||||
aggregatePartials<OpType>(sPartials, threadIdx.x, nd4j::math::nd4j_min<int>(blockDim.x, tadLength), extraParams);
|
aggregatePartials<OpType>(sPartials, threadIdx.x, nd4j::math::nd4j_min<int>(blockDim.x, tadLength), extraParams);
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
if (threadIdx.x == 0)
|
if (threadIdx.x == 0)
|
||||||
@ -229,32 +234,62 @@ __device__ void ReduceFloatFunction<X,Z>::execScalarCuda(void *vx, Nd4jLong *xSh
|
|||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
template <typename X, typename Z>
|
template <typename X, typename Z>
|
||||||
template<typename OpType>
|
template<typename OpType>
|
||||||
__host__ void ReduceFloatFunction<X,Z>::intermediateXD(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShape, void *extraParams, void *z, Nd4jLong *zShape, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
__host__ void ReduceFloatFunction<X,Z>::intermediateXD(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShape, Nd4jLong *hXShapeInfo, void *extraParams, void *z, Nd4jLong *zShape, Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||||||
|
|
||||||
|
if(shape::isEmpty(hXShapeInfo)) {
|
||||||
|
|
||||||
|
if(shape::isEmpty(hZShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
|
const auto startingVal = std::is_same<OpType, simdOps::Mean<X,Z>>::value ? nd4j::DataTypeUtils::nanOrZero<Z>() : static_cast<Z>(OpType::startingValue(reinterpret_cast<X*>(x)));
|
||||||
|
auto res = cudaMemcpyAsync(nd4j::LaunchContext::defaultContext()->getScalarPointer(), &startingVal, sizeof(Z), cudaMemcpyHostToDevice, *stream);
|
||||||
|
if (res != 0)
|
||||||
|
throw nd4j::cuda_exception::build("ReduceFloatFunction<X,Z>::intermediateXD: failed to copy temporary scalar", res);
|
||||||
|
|
||||||
|
auto ptr = nd4j::LaunchContext::defaultContext()->getScalarPointer();
|
||||||
|
|
||||||
|
// scalar assign
|
||||||
|
functions::scalar::ScalarTransform<Z, Z, Z>::executeCudaShaped(launchDims, stream, 14, z, zShape, hXShapeInfo, z, zShape, hZShapeInfo, ptr, nullptr);
|
||||||
|
}
|
||||||
|
else {
|
||||||
simpleReduce<X, Z, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShape, extraParams, z, zShape, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets);
|
simpleReduce<X, Z, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShape, extraParams, z, zShape, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
template <typename X, typename Z>
|
template <typename X, typename Z>
|
||||||
template<typename OpType>
|
template<typename OpType>
|
||||||
__host__ void ReduceFloatFunction<X,Z>::intermediateScalar(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo) {
|
__host__ void ReduceFloatFunction<X,Z>::intermediateScalar(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, Nd4jLong *hXShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo) {
|
||||||
|
|
||||||
|
if (shape::isEmpty(hXShapeInfo)) {
|
||||||
|
|
||||||
|
if (shape::isEmpty(hZShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
|
const auto startingVal = std::is_same<OpType, simdOps::Mean<X,Z>>::value ? nd4j::DataTypeUtils::nanOrZero<Z>() : static_cast<Z>(OpType::startingValue(reinterpret_cast<X*>(x)));
|
||||||
|
|
||||||
|
auto res = cudaMemcpyAsync(z, &startingVal, sizeof(Z), cudaMemcpyHostToDevice, *stream);
|
||||||
|
if (res != 0)
|
||||||
|
throw nd4j::cuda_exception::build("ReduceFloatFunction<X,Z>::intermediateScalar: failed to copy resulting scalar", res);
|
||||||
|
}
|
||||||
|
else {
|
||||||
simpleScalar<X, Z, OpType> << < launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo);
|
simpleScalar<X, Z, OpType> << < launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
template <typename X, typename Y>
|
template <typename X, typename Y>
|
||||||
_CUDA_H void ReduceFloatFunction<X,Y>::execReduceScalar(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo) {
|
_CUDA_H void ReduceFloatFunction<X,Y>::execReduceScalar(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, Nd4jLong *hXShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo) {
|
||||||
|
|
||||||
DISPATCH_BY_OPNUM_TT(intermediateScalar, PARAMS(launchDims, stream, x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo), OPS_A(REDUCE_FLOAT_OPS));
|
DISPATCH_BY_OPNUM_TT(intermediateScalar, PARAMS(launchDims, stream, x, xShapeInfo, hXShapeInfo, extraParams, z, zShapeInfo, hZShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo), OPS_A(REDUCE_FLOAT_OPS));
|
||||||
nd4j::DebugHelper::checkErrorCode(stream, "execReduceScalarFloat(...) failed");
|
nd4j::DebugHelper::checkErrorCode(stream, "execReduceScalarFloat(...) failed");
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
template <typename X, typename Y>
|
template <typename X, typename Y>
|
||||||
_CUDA_H void ReduceFloatFunction<X,Y>::execReduceXD(dim3 launchDims, cudaStream_t *stream, int opNum, int rank, void *x, Nd4jLong *xShape, void *extraParams, void *z, Nd4jLong *zShape, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
_CUDA_H void ReduceFloatFunction<X,Y>::execReduceXD(dim3 launchDims, cudaStream_t *stream, int opNum, int rank, void *x, Nd4jLong *xShape, Nd4jLong *hXShapeInfo, void *extraParams, void *z, Nd4jLong *zShape, Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||||||
|
|
||||||
DISPATCH_BY_OPNUM_TT(intermediateXD, PARAMS(launchDims, stream, x, xShape, extraParams, z, zShape, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets), OPS_A(REDUCE_FLOAT_OPS));
|
DISPATCH_BY_OPNUM_TT(intermediateXD, PARAMS(launchDims, stream, x, xShape, hXShapeInfo, extraParams, z, zShape, hZShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets), OPS_A(REDUCE_FLOAT_OPS));
|
||||||
DEBUG_KERNEL(stream, opNum);
|
DEBUG_KERNEL(stream, opNum);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -24,6 +24,10 @@
|
|||||||
#include <loops/legacy_ops.h>
|
#include <loops/legacy_ops.h>
|
||||||
#include <helpers/DebugHelper.h>
|
#include <helpers/DebugHelper.h>
|
||||||
#include <types/types.h>
|
#include <types/types.h>
|
||||||
|
#include <execution/LaunchContext.h>
|
||||||
|
#include <exceptions/cuda_exception.h>
|
||||||
|
#include <loops/scalar.h>
|
||||||
|
|
||||||
|
|
||||||
using namespace simdOps;
|
using namespace simdOps;
|
||||||
|
|
||||||
@ -126,16 +130,17 @@ __device__ void ReduceLongFunction<X,Z>::transformCudaXD( void *vx, Nd4jLong *xS
|
|||||||
|
|
||||||
//shared memory space for storing intermediate results
|
//shared memory space for storing intermediate results
|
||||||
__shared__ Z* sPartials;
|
__shared__ Z* sPartials;
|
||||||
__shared__ int tadLength;
|
__shared__ int tadLength, numTads;
|
||||||
__shared__ int numTads;
|
|
||||||
__shared__ bool isPlainOutput;
|
__shared__ bool isPlainOutput;
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
extern __shared__ unsigned char shmem[];
|
extern __shared__ unsigned char shmem[];
|
||||||
sPartials = reinterpret_cast<Z*>(shmem);
|
sPartials = reinterpret_cast<Z*>(shmem);
|
||||||
|
|
||||||
|
isPlainOutput = shape::order(zShapeInfo) == 'c' && shape::elementWiseStride(zShapeInfo) == 1;
|
||||||
|
|
||||||
tadLength = shape::length(tadOnlyShapeInfo);
|
tadLength = shape::length(tadOnlyShapeInfo);
|
||||||
numTads = shape::length(xShapeInfo) / tadLength;
|
numTads = shape::length(xShapeInfo) / tadLength;
|
||||||
isPlainOutput = shape::order(zShapeInfo) == 'c' && shape::elementWiseStride(zShapeInfo) == 1;
|
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
@ -145,7 +150,6 @@ __device__ void ReduceLongFunction<X,Z>::transformCudaXD( void *vx, Nd4jLong *xS
|
|||||||
sPartials[threadIdx.x] = OpType::startingValue(x + tadOffsetForBlock);
|
sPartials[threadIdx.x] = OpType::startingValue(x + tadOffsetForBlock);
|
||||||
|
|
||||||
for (int i = threadIdx.x; i < tadLength; i += blockDim.x) {
|
for (int i = threadIdx.x; i < tadLength; i += blockDim.x) {
|
||||||
|
|
||||||
auto xOffset = tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo, tadLength);
|
auto xOffset = tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo, tadLength);
|
||||||
sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[xOffset], extraParams), extraParams);
|
sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[xOffset], extraParams), extraParams);
|
||||||
}
|
}
|
||||||
@ -153,7 +157,6 @@ __device__ void ReduceLongFunction<X,Z>::transformCudaXD( void *vx, Nd4jLong *xS
|
|||||||
|
|
||||||
// aggregate. do NOT reduce for elements > tadLength
|
// aggregate. do NOT reduce for elements > tadLength
|
||||||
aggregatePartials<OpType>(sPartials, threadIdx.x, nd4j::math::nd4j_min<int>(blockDim.x, tadLength), extraParams);
|
aggregatePartials<OpType>(sPartials, threadIdx.x, nd4j::math::nd4j_min<int>(blockDim.x, tadLength), extraParams);
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
if (threadIdx.x == 0)
|
if (threadIdx.x == 0)
|
||||||
@ -251,32 +254,63 @@ __device__ void ReduceLongFunction<X,Z>::execScalarCuda(void *vx, Nd4jLong *xSha
|
|||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
template <typename X, typename Z>
|
template <typename X, typename Z>
|
||||||
template<typename OpType>
|
template<typename OpType>
|
||||||
__host__ void ReduceLongFunction<X,Z>::intermediateXD(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShape, void *extraParams, void *z, Nd4jLong *zShape, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
__host__ void ReduceLongFunction<X,Z>::intermediateXD(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, Nd4jLong *hXShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||||||
|
|
||||||
simpleReduce<X, Z, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShape, extraParams, z, zShape, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets);
|
if(shape::isEmpty(hXShapeInfo)) {
|
||||||
|
|
||||||
|
if(shape::isEmpty(hZShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
|
const auto startingVal = static_cast<Z>(OpType::startingValue(reinterpret_cast<X*>(x)));
|
||||||
|
|
||||||
|
auto res = cudaMemcpyAsync(nd4j::LaunchContext::defaultContext()->getScalarPointer(), &startingVal, sizeof(Z), cudaMemcpyHostToDevice, *stream);
|
||||||
|
if (res != 0)
|
||||||
|
throw nd4j::cuda_exception::build("ReduceLongFunction<X,Z>::intermediateXD: failed to copy temporary scalar", res);
|
||||||
|
|
||||||
|
auto ptr = nd4j::LaunchContext::defaultContext()->getScalarPointer();
|
||||||
|
|
||||||
|
// scalar assign
|
||||||
|
functions::scalar::ScalarTransform<Z, Z, Z>::executeCudaShaped(launchDims, stream, 14, z, zShapeInfo, hXShapeInfo, z, zShapeInfo, hZShapeInfo, ptr, nullptr);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
simpleReduce<X, Z, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
template <typename X, typename Z>
|
template <typename X, typename Z>
|
||||||
template<typename OpType>
|
template<typename OpType>
|
||||||
__host__ void ReduceLongFunction<X,Z>::intermediateScalar(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo) {
|
__host__ void ReduceLongFunction<X,Z>::intermediateScalar(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, Nd4jLong *hXShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo) {
|
||||||
|
|
||||||
|
if (shape::isEmpty(hXShapeInfo)) {
|
||||||
|
|
||||||
|
if (shape::isEmpty(hZShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
|
const auto startingVal = static_cast<Z>(OpType::startingValue(reinterpret_cast<X*>(x)));
|
||||||
|
|
||||||
|
auto res = cudaMemcpyAsync(z, &startingVal, sizeof(Z), cudaMemcpyHostToDevice, *stream);
|
||||||
|
if (res != 0)
|
||||||
|
throw nd4j::cuda_exception::build("ReduceLongFunction<X,Z>::intermediateScalar: failed to copy resulting scalar", res);
|
||||||
|
}
|
||||||
|
else {
|
||||||
simpleScalar<X, Z, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo);
|
simpleScalar<X, Z, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
template <typename X, typename Y>
|
template <typename X, typename Y>
|
||||||
_CUDA_H void ReduceLongFunction<X,Y>::execReduceScalar(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo) {
|
_CUDA_H void ReduceLongFunction<X,Y>::execReduceScalar(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, Nd4jLong* hXShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, Nd4jLong* hZShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo) {
|
||||||
|
|
||||||
DISPATCH_BY_OPNUM_TT(intermediateScalar, PARAMS(launchDims, stream, x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo), OPS_A(REDUCE_LONG_OPS));
|
DISPATCH_BY_OPNUM_TT(intermediateScalar, PARAMS(launchDims, stream, x, xShapeInfo, hXShapeInfo, extraParams, z, zShapeInfo, hZShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo), OPS_A(REDUCE_LONG_OPS));
|
||||||
nd4j::DebugHelper::checkErrorCode(stream, "execReduceScalarFloat(...) failed");
|
nd4j::DebugHelper::checkErrorCode(stream, "execReduceScalarFloat(...) failed");
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
template <typename X, typename Y>
|
template <typename X, typename Y>
|
||||||
_CUDA_H void ReduceLongFunction<X,Y>::execReduceXD(dim3 launchDims, cudaStream_t *stream, int opNum, int rank, void *x, Nd4jLong *xShape, void *extraParams, void *z, Nd4jLong *zShape, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
_CUDA_H void ReduceLongFunction<X,Y>::execReduceXD(dim3 launchDims, cudaStream_t *stream, int opNum, int rank, void *x, Nd4jLong *xShapeInfo, Nd4jLong* hXShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, Nd4jLong* hZShapeInfo, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||||||
|
|
||||||
DISPATCH_BY_OPNUM_TT(intermediateXD, PARAMS(launchDims, stream, x, xShape, extraParams, z, zShape, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets), OPS_A(REDUCE_LONG_OPS));
|
DISPATCH_BY_OPNUM_TT(intermediateXD, PARAMS(launchDims, stream, x, xShapeInfo, hXShapeInfo, extraParams, z, zShapeInfo, hZShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets), OPS_A(REDUCE_LONG_OPS));
|
||||||
DEBUG_KERNEL(stream, opNum);
|
DEBUG_KERNEL(stream, opNum);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -24,6 +24,10 @@
|
|||||||
#include <loops/legacy_ops.h>
|
#include <loops/legacy_ops.h>
|
||||||
#include <helpers/DebugHelper.h>
|
#include <helpers/DebugHelper.h>
|
||||||
#include <types/types.h>
|
#include <types/types.h>
|
||||||
|
#include <execution/LaunchContext.h>
|
||||||
|
#include <exceptions/cuda_exception.h>
|
||||||
|
#include <loops/scalar.h>
|
||||||
|
|
||||||
|
|
||||||
using namespace simdOps;
|
using namespace simdOps;
|
||||||
|
|
||||||
@ -111,24 +115,21 @@ __device__ void ReduceSameFunction<X>::transformCudaXD( void *vx, Nd4jLong *xSha
|
|||||||
//shared memory space for storing intermediate results
|
//shared memory space for storing intermediate results
|
||||||
__shared__ X* sPartials;
|
__shared__ X* sPartials;
|
||||||
|
|
||||||
// __shared__ shape::TAD *tad;
|
__shared__ int tadLength, tadRank, numTads;
|
||||||
__shared__ int tadLength;
|
__shared__ Nd4jLong *tadShape, *tadStride;
|
||||||
__shared__ int tadRank;
|
|
||||||
__shared__ int numTads;
|
|
||||||
__shared__ Nd4jLong *tadShape;
|
|
||||||
__shared__ Nd4jLong *tadStride;
|
|
||||||
__shared__ bool isPlainOutput;
|
__shared__ bool isPlainOutput;
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
extern __shared__ unsigned char shmem[];
|
extern __shared__ unsigned char shmem[];
|
||||||
sPartials = reinterpret_cast<X*>(shmem);
|
sPartials = reinterpret_cast<X*>(shmem);
|
||||||
|
|
||||||
|
isPlainOutput = shape::order(zShapeInfo) == 'c' && shape::elementWiseStride(zShapeInfo) == 1;
|
||||||
|
|
||||||
tadLength = shape::length(tadOnlyShapeInfo);
|
tadLength = shape::length(tadOnlyShapeInfo);
|
||||||
tadRank = shape::rank(tadOnlyShapeInfo);
|
tadRank = shape::rank(tadOnlyShapeInfo);
|
||||||
numTads = shape::length(xShapeInfo) / tadLength;
|
numTads = shape::length(xShapeInfo) / tadLength;
|
||||||
tadShape = shape::shapeOf(tadOnlyShapeInfo);
|
tadShape = shape::shapeOf(tadOnlyShapeInfo);
|
||||||
tadStride = shape::stride(tadOnlyShapeInfo);
|
tadStride = shape::stride(tadOnlyShapeInfo);
|
||||||
|
|
||||||
isPlainOutput = shape::order(zShapeInfo) == 'c' && shape::elementWiseStride(zShapeInfo) == 1;
|
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
@ -138,7 +139,6 @@ __device__ void ReduceSameFunction<X>::transformCudaXD( void *vx, Nd4jLong *xSha
|
|||||||
sPartials[threadIdx.x] = OpType::startingValue(x + tadOffsetForBlock);
|
sPartials[threadIdx.x] = OpType::startingValue(x + tadOffsetForBlock);
|
||||||
|
|
||||||
for (int i = threadIdx.x; i < tadLength; i += blockDim.x) {
|
for (int i = threadIdx.x; i < tadLength; i += blockDim.x) {
|
||||||
|
|
||||||
auto xOffset = tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo, tadLength);
|
auto xOffset = tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo, tadLength);
|
||||||
sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[xOffset], extraParams), extraParams);
|
sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[xOffset], extraParams), extraParams);
|
||||||
}
|
}
|
||||||
@ -146,7 +146,6 @@ __device__ void ReduceSameFunction<X>::transformCudaXD( void *vx, Nd4jLong *xSha
|
|||||||
|
|
||||||
// aggregate. do NOT reduce for elements > tadLength
|
// aggregate. do NOT reduce for elements > tadLength
|
||||||
aggregatePartials<OpType>(sPartials, threadIdx.x, nd4j::math::nd4j_min<int>(blockDim.x, tadLength), extraParams);
|
aggregatePartials<OpType>(sPartials, threadIdx.x, nd4j::math::nd4j_min<int>(blockDim.x, tadLength), extraParams);
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
if (threadIdx.x == 0)
|
if (threadIdx.x == 0)
|
||||||
@ -172,7 +171,6 @@ __device__ void ReduceSameFunction<X>::execScalarCuda(void *vx, Nd4jLong *xShape
|
|||||||
void *vz, Nd4jLong *zShapeInfo,
|
void *vz, Nd4jLong *zShapeInfo,
|
||||||
void *vreductionBuffer,
|
void *vreductionBuffer,
|
||||||
Nd4jLong *tadOnlyShapeInfo) {
|
Nd4jLong *tadOnlyShapeInfo) {
|
||||||
|
|
||||||
auto x = reinterpret_cast<X*>(vx);
|
auto x = reinterpret_cast<X*>(vx);
|
||||||
auto z = reinterpret_cast<X*>(vz);
|
auto z = reinterpret_cast<X*>(vz);
|
||||||
auto extraParams = reinterpret_cast<X*>(vextraParams);
|
auto extraParams = reinterpret_cast<X*>(vextraParams);
|
||||||
@ -253,31 +251,63 @@ __device__ void ReduceSameFunction<X>::execScalarCuda(void *vx, Nd4jLong *xShape
|
|||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
template <typename X>
|
template <typename X>
|
||||||
template<typename OpType>
|
template<typename OpType>
|
||||||
__host__ void ReduceSameFunction<X>::intermediateXD(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShape, void *extraParams, void *z, Nd4jLong *zShape, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
__host__ void ReduceSameFunction<X>::intermediateXD(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, Nd4jLong *hXShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||||||
|
|
||||||
simpleReduce<X, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShape, extraParams, z, zShape, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets);
|
if(shape::isEmpty(hXShapeInfo)) {
|
||||||
|
|
||||||
|
if(shape::isEmpty(hZShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
|
const auto startingVal = static_cast<X>(OpType::startingValue(reinterpret_cast<X*>(x)));
|
||||||
|
|
||||||
|
auto res = cudaMemcpyAsync(nd4j::LaunchContext::defaultContext()->getScalarPointer(), &startingVal, sizeof(X), cudaMemcpyHostToDevice, *stream);
|
||||||
|
if (res != 0)
|
||||||
|
throw nd4j::cuda_exception::build("ReduceSameFunction<X,Z>::intermediateXD: failed to copy temporary scalar", res);
|
||||||
|
|
||||||
|
auto ptr = nd4j::LaunchContext::defaultContext()->getScalarPointer();
|
||||||
|
|
||||||
|
// scalar assign
|
||||||
|
functions::scalar::ScalarTransform<X, X, X>::executeCudaShaped(launchDims, stream, 14, z, zShapeInfo, hXShapeInfo, z, zShapeInfo, hZShapeInfo, ptr, nullptr);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
simpleReduce<X, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
template <typename X>
|
template <typename X>
|
||||||
template<typename OpType>
|
template<typename OpType>
|
||||||
__host__ void ReduceSameFunction<X>::intermediateScalar(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo) {
|
__host__ void ReduceSameFunction<X>::intermediateScalar(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, Nd4jLong *hXShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo) {
|
||||||
|
|
||||||
|
if (shape::isEmpty(hXShapeInfo)) {
|
||||||
|
|
||||||
|
if (shape::isEmpty(hZShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
|
const auto startingVal = static_cast<X>(OpType::startingValue(reinterpret_cast<X*>(x)));
|
||||||
|
|
||||||
|
auto res = cudaMemcpyAsync(z, &startingVal, sizeof(X), cudaMemcpyHostToDevice, *stream);
|
||||||
|
if (res != 0)
|
||||||
|
throw nd4j::cuda_exception::build("ReduceSameFunction<X>::intermediateScalar: failed to copy resulting scalar", res);
|
||||||
|
}
|
||||||
|
else {
|
||||||
simpleScalar<X, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo);
|
simpleScalar<X, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
template <typename X>
|
template <typename X>
|
||||||
_CUDA_H void ReduceSameFunction<X>::execReduceScalar(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo) {
|
_CUDA_H void ReduceSameFunction<X>::execReduceScalar(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, Nd4jLong* hXShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, Nd4jLong* hZShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo) {
|
||||||
|
|
||||||
DISPATCH_BY_OPNUM_T(intermediateScalar, PARAMS(launchDims, stream, x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo), REDUCE_SAME_OPS);
|
DISPATCH_BY_OPNUM_T(intermediateScalar, PARAMS(launchDims, stream, x, xShapeInfo, hXShapeInfo, extraParams, z, zShapeInfo, hZShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo), REDUCE_SAME_OPS);
|
||||||
nd4j::DebugHelper::checkErrorCode(stream, "execReduceScalarSame(...) failed");
|
nd4j::DebugHelper::checkErrorCode(stream, "execReduceScalarSame(...) failed");
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
template <typename X>
|
template <typename X>
|
||||||
_CUDA_H void ReduceSameFunction<X>::execReduceXD(dim3 launchDims, cudaStream_t *stream, int opNum, int rank, void *x, Nd4jLong *xShape, void *extraParams, void *z, Nd4jLong *zShape, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
_CUDA_H void ReduceSameFunction<X>::execReduceXD(dim3 launchDims, cudaStream_t *stream, int opNum, int rank, void *x, Nd4jLong *xShapeInfo, Nd4jLong* hXShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, Nd4jLong* hZShapeInfo, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||||||
|
|
||||||
DISPATCH_BY_OPNUM_T(intermediateXD, PARAMS(launchDims, stream, x, xShape, extraParams, z, zShape, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets), REDUCE_SAME_OPS);
|
DISPATCH_BY_OPNUM_T(intermediateXD, PARAMS(launchDims, stream, x, xShapeInfo, hXShapeInfo, extraParams, z, zShapeInfo, hZShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets), REDUCE_SAME_OPS);
|
||||||
DEBUG_KERNEL(stream, opNum);
|
DEBUG_KERNEL(stream, opNum);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -69,14 +69,14 @@ namespace functions {
|
|||||||
static __device__ void transformCudaXD( void *vx, Nd4jLong *xShapeInfo, void *extraParams, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets);
|
static __device__ void transformCudaXD( void *vx, Nd4jLong *xShapeInfo, void *extraParams, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets);
|
||||||
|
|
||||||
template<typename OpType>
|
template<typename OpType>
|
||||||
static __host__ void intermediateScalar(dim3 launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *extraParams, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo);
|
static __host__ void intermediateScalar(dim3 launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, Nd4jLong *hXShapeInfo, void *extraParams, void *vz, Nd4jLong *zShapeInfo, Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo);
|
||||||
|
|
||||||
template<typename OpType>
|
template<typename OpType>
|
||||||
static __host__ void intermediateXD(dim3 launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *extraParams, void *vz, Nd4jLong *zShape, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets);
|
static __host__ void intermediateXD(dim3 launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, Nd4jLong *hXShapeInfo, void *extraParams, void *vz, Nd4jLong *zShapeInfo, Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets);
|
||||||
|
|
||||||
static __host__ void execReduceScalar(dim3 launchDims, cudaStream_t *stream, int opNum, void *vx, Nd4jLong *xShapeInfo, void *extraParams, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo);
|
static __host__ void execReduceScalar(dim3 launchDims, cudaStream_t *stream, int opNum, void *vx, Nd4jLong *xShapeInfo, Nd4jLong* hXShapeInfo, void *extraParams, void *vz, Nd4jLong *zShapeInfo, Nd4jLong* hZShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo);
|
||||||
|
|
||||||
static __host__ void execReduceXD(dim3 launchDims, cudaStream_t *stream, int opNum, int rank, void *vx, Nd4jLong *xShapeInfo, void *extraParams, void *vz, Nd4jLong *zShape, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets);
|
static __host__ void execReduceXD(dim3 launchDims, cudaStream_t *stream, int opNum, int rank, void *vx, Nd4jLong *xShapeInfo, Nd4jLong* hXShapeInfo, void *extraParams, void *vz, Nd4jLong *zShapeInfo, Nd4jLong* hZShapeInfo, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -71,14 +71,14 @@ namespace functions {
|
|||||||
static __device__ void transformCudaXD( void *vx, Nd4jLong *xShapeInfo, void *extraParams, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets);
|
static __device__ void transformCudaXD( void *vx, Nd4jLong *xShapeInfo, void *extraParams, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets);
|
||||||
|
|
||||||
template<typename OpType>
|
template<typename OpType>
|
||||||
static __host__ void intermediateScalar(dim3 launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *extraParams, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo);
|
static __host__ void intermediateScalar(dim3 launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, Nd4jLong *hXShapeInfo, void *extraParams, void *vz, Nd4jLong *zShapeInfo, Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo);
|
||||||
|
|
||||||
template<typename OpType>
|
template<typename OpType>
|
||||||
static __host__ void intermediateXD(dim3 launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *extraParams, void *vz, Nd4jLong *zShape, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets);
|
static __host__ void intermediateXD(dim3 launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, Nd4jLong *hXShapeInfo, void *extraParams, void *vz, Nd4jLong *zShape, Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets);
|
||||||
|
|
||||||
static __host__ void execReduceScalar(dim3 launchDims, cudaStream_t *stream, int opNum, void *vx, Nd4jLong *xShapeInfo, void *extraParams, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo);
|
static __host__ void execReduceScalar(dim3 launchDims, cudaStream_t *stream, int opNum, void *vx, Nd4jLong *xShapeInfo, Nd4jLong *hXShapeInfo, void *extraParams, void *vz, Nd4jLong *zShapeInfo, Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo);
|
||||||
|
|
||||||
static __host__ void execReduceXD(dim3 launchDims, cudaStream_t *stream, int opNum, int rank, void *vx, Nd4jLong *xShapeInfo, void *extraParams, void *vz, Nd4jLong *zShape, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets);
|
static __host__ void execReduceXD(dim3 launchDims, cudaStream_t *stream, int opNum, int rank, void *vx, Nd4jLong *xShapeInfo, Nd4jLong *hXShapeInfo, void *extraParams, void *vz, Nd4jLong *zShape, Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -69,14 +69,14 @@ namespace functions {
|
|||||||
static __device__ void transformCudaXD( void *vx, Nd4jLong *xShapeInfo, void *extraParams, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets);
|
static __device__ void transformCudaXD( void *vx, Nd4jLong *xShapeInfo, void *extraParams, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets);
|
||||||
|
|
||||||
template<typename OpType>
|
template<typename OpType>
|
||||||
static __host__ void intermediateScalar(dim3 launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *extraParams, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo);
|
static __host__ void intermediateScalar(dim3 launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, Nd4jLong *hXShapeInfo, void *extraParams, void *vz, Nd4jLong *zShapeInfo, Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo);
|
||||||
|
|
||||||
template<typename OpType>
|
template<typename OpType>
|
||||||
static __host__ void intermediateXD(dim3 launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *extraParams, void *vz, Nd4jLong *zShape, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets);
|
static __host__ void intermediateXD(dim3 launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, Nd4jLong *hXShapeInfo, void *extraParams, void *vz, Nd4jLong *zShapeInfo, Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets);
|
||||||
|
|
||||||
static __host__ void execReduceScalar(dim3 launchDims, cudaStream_t *stream, int opNum, void *vx, Nd4jLong *xShapeInfo, void *extraParams, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo);
|
static __host__ void execReduceScalar(dim3 launchDims, cudaStream_t *stream, int opNum, void *vx, Nd4jLong *xShapeInfo, Nd4jLong* hXShapeInfo, void *extraParams, void *vz, Nd4jLong *zShapeInfo, Nd4jLong* hZShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo);
|
||||||
|
|
||||||
static __host__ void execReduceXD(dim3 launchDims, cudaStream_t *stream, int opNum, int rank, void *vx, Nd4jLong *xShapeInfo, void *extraParams, void *vz, Nd4jLong *zShape, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets);
|
static __host__ void execReduceXD(dim3 launchDims, cudaStream_t *stream, int opNum, int rank, void *vx, Nd4jLong *xShapeInfo, Nd4jLong* hXShapeInfo, void *extraParams, void *vz, Nd4jLong *zShapeInfo, Nd4jLong* hZShapeInfo, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets);
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -72,14 +72,14 @@ namespace functions {
|
|||||||
static __device__ void transformCudaXD( void *vx, Nd4jLong *xShapeInfo, void *extraParams, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets);
|
static __device__ void transformCudaXD( void *vx, Nd4jLong *xShapeInfo, void *extraParams, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets);
|
||||||
|
|
||||||
template<typename OpType>
|
template<typename OpType>
|
||||||
static __host__ void intermediateScalar(dim3 launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *extraParams, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo);
|
static __host__ void intermediateScalar(dim3 launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, Nd4jLong *hXShapeInfo, void *extraParams, void *vz, Nd4jLong *zShapeInfo, Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo);
|
||||||
|
|
||||||
template<typename OpType>
|
template<typename OpType>
|
||||||
static __host__ void intermediateXD(dim3 launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *extraParams, void *vz, Nd4jLong *zShape, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets);
|
static __host__ void intermediateXD(dim3 launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, Nd4jLong *hXShapeInfo, void *extraParams, void *vz, Nd4jLong *zShapeInfo, Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets);
|
||||||
|
|
||||||
static __host__ void execReduceScalar(dim3 launchDims, cudaStream_t *stream, int opNum, void *vx, Nd4jLong *xShapeInfo, void *extraParams, void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo);
|
static __host__ void execReduceScalar(dim3 launchDims, cudaStream_t *stream, int opNum, void *vx, Nd4jLong *xShapeInfo, Nd4jLong* hXShapeInfo, void *extraParams, void *vz, Nd4jLong *zShapeInfo, Nd4jLong* hZShapeInfo, int *dimension, int dimensionLength, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo);
|
||||||
|
|
||||||
static __host__ void execReduceXD(dim3 launchDims, cudaStream_t *stream, int opNum, int rank, void *vx, Nd4jLong *xShapeInfo, void *extraParams, void *vz, Nd4jLong *zShape, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets);
|
static __host__ void execReduceXD(dim3 launchDims, cudaStream_t *stream, int opNum, int rank, void *vx, Nd4jLong *xShapeInfo, Nd4jLong* hXShapeInfo, void *extraParams, void *vz, Nd4jLong *zShapeInfo, Nd4jLong* hZShapeInfo, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -76,9 +76,9 @@ namespace functions {
|
|||||||
|
|
||||||
static void transform(int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, void *scalars, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ);
|
static void transform(int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, void *scalars, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ);
|
||||||
|
|
||||||
static void transform(const int opNum, void *x, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *scalar, void *extraParams);
|
static void transform(const int opNum, void *x, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *scalar, void *extraParams, bool allowParallelism);
|
||||||
|
|
||||||
static void transform(const int opNum, void *x, Nd4jLong xStride, void *result, Nd4jLong resultStride, void *scalar, void *extraParams, const Nd4jLong len);
|
static void transform(const int opNum, void *x, Nd4jLong xStride, void *result, Nd4jLong resultStride, void *scalar, void *extraParams, const Nd4jLong len, bool allowParallelism);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -101,7 +101,7 @@ namespace functions {
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
template<typename OpType>
|
template<typename OpType>
|
||||||
static void transform(void *x, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *scalar, void *extraParams);
|
static void transform(void *x, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *scalar, void *extraParams, bool allowParallelism);
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -117,7 +117,7 @@ namespace functions {
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
template<typename OpType>
|
template<typename OpType>
|
||||||
static void transform(void *x, Nd4jLong xStride, void *result, Nd4jLong resultStride, void *scalar, void *extraParams, const Nd4jLong len);
|
static void transform(void *x, Nd4jLong xStride, void *result, Nd4jLong resultStride, void *scalar, void *extraParams, const Nd4jLong len, bool allowParallelism);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -71,10 +71,10 @@ class TransformAny {
|
|||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
static void exec(int opNum, void *dx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets);
|
static void exec(int opNum, void *dx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, bool allowParallelism);
|
||||||
|
|
||||||
template<typename OpType>
|
template<typename OpType>
|
||||||
static ND4J_EXPORT void exec(void *dx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets);
|
static ND4J_EXPORT void exec(void *dx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, void *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, bool allowParallelism);
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -62,6 +62,8 @@ namespace nd4j {
|
|||||||
|
|
||||||
this->_initialSize = initialSize;
|
this->_initialSize = initialSize;
|
||||||
this->_currentSize = initialSize;
|
this->_currentSize = initialSize;
|
||||||
|
this->_currentSizeSecondary = 0;
|
||||||
|
this->_spillsSizeSecondary = 0;
|
||||||
this->_offset = 0;
|
this->_offset = 0;
|
||||||
this->_offsetSecondary = 0;
|
this->_offsetSecondary = 0;
|
||||||
this->_cycleAllocations = 0;
|
this->_cycleAllocations = 0;
|
||||||
|
@ -84,6 +84,7 @@ namespace nd4j {
|
|||||||
this->_offsetSecondary = 0;
|
this->_offsetSecondary = 0;
|
||||||
this->_cycleAllocations = 0;
|
this->_cycleAllocations = 0;
|
||||||
this->_spillsSize = 0;
|
this->_spillsSize = 0;
|
||||||
|
this->_spillsSizeSecondary = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Workspace::init(Nd4jLong primaryBytes, Nd4jLong secondaryBytes) {
|
void Workspace::init(Nd4jLong primaryBytes, Nd4jLong secondaryBytes) {
|
||||||
|
@ -1459,12 +1459,12 @@
|
|||||||
#ifdef _RELEASE
|
#ifdef _RELEASE
|
||||||
|
|
||||||
#define ALLOCATE_SPECIAL(VARIABLE, WORKSPACE, LENGTH, TT) if (WORKSPACE == nullptr) {auto erc_##VARIABLE = cudaMalloc(reinterpret_cast<void**>(&VARIABLE), LENGTH * sizeof(TT)); if (erc_##VARIABLE != 0) {throw cuda_exception::build("[DEVICE] allocation failed", erc_##VARIABLE);} else { }; } else {VARIABLE = reinterpret_cast<TT *>(WORKSPACE->allocateBytes(nd4j::memory::MemoryType::DEVICE, LENGTH * sizeof(TT))); }
|
#define ALLOCATE_SPECIAL(VARIABLE, WORKSPACE, LENGTH, TT) if (WORKSPACE == nullptr) {auto erc_##VARIABLE = cudaMalloc(reinterpret_cast<void**>(&VARIABLE), LENGTH * sizeof(TT)); if (erc_##VARIABLE != 0) {throw cuda_exception::build("[DEVICE] allocation failed", erc_##VARIABLE);} else { }; } else {VARIABLE = reinterpret_cast<TT *>(WORKSPACE->allocateBytes(nd4j::memory::MemoryType::DEVICE, LENGTH * sizeof(TT))); }
|
||||||
#define RELEASE_SPECIAL(VARIABLE, WORKSPACE) if (WORKSPACE == nullptr) { auto erc_##VARIABLE = cudaFree(reinterpret_cast<void *>(VARIABLE)); if (erc_##VARIABLE != 0) {throw cuda_exception::build("[DEVICE] deallocation failed", erc_##VARIABLE);}; };
|
#define RELEASE_SPECIAL(VARIABLE, WORKSPACE) if (VARIABLE != nullptr) {if (WORKSPACE == nullptr) { auto erc_##VARIABLE = cudaFree(reinterpret_cast<void *>(VARIABLE)); if (erc_##VARIABLE != 0) {throw cuda_exception::build("[DEVICE] deallocation failed", erc_##VARIABLE);}; }; };
|
||||||
|
|
||||||
#else
|
#else
|
||||||
|
|
||||||
#define ALLOCATE_SPECIAL(VARIABLE, WORKSPACE, LENGTH, TT) if (WORKSPACE == nullptr) {auto erc_##VARIABLE = cudaMalloc(reinterpret_cast<void**>(&VARIABLE), LENGTH * sizeof(TT)); if (erc_##VARIABLE != 0) {throw cuda_exception::build("[DEVICE] allocation failed", erc_##VARIABLE);} else { nd4j::memory::MemoryTracker::getInstance()->countIn(nd4j::memory::MemoryType::DEVICE, VARIABLE, LENGTH * sizeof(TT)); }; } else {VARIABLE = reinterpret_cast<TT *>(WORKSPACE->allocateBytes(nd4j::memory::MemoryType::DEVICE, LENGTH * sizeof(TT))); }
|
#define ALLOCATE_SPECIAL(VARIABLE, WORKSPACE, LENGTH, TT) if (WORKSPACE == nullptr) {auto erc_##VARIABLE = cudaMalloc(reinterpret_cast<void**>(&VARIABLE), LENGTH * sizeof(TT)); if (erc_##VARIABLE != 0) {throw cuda_exception::build("[DEVICE] allocation failed", erc_##VARIABLE);} else { nd4j::memory::MemoryTracker::getInstance()->countIn(nd4j::memory::MemoryType::DEVICE, VARIABLE, LENGTH * sizeof(TT)); }; } else {VARIABLE = reinterpret_cast<TT *>(WORKSPACE->allocateBytes(nd4j::memory::MemoryType::DEVICE, LENGTH * sizeof(TT))); }
|
||||||
#define RELEASE_SPECIAL(VARIABLE, WORKSPACE) if (WORKSPACE == nullptr) { nd4j::memory::MemoryTracker::getInstance()->countOut(VARIABLE); auto erc_##VARIABLE = cudaFree(reinterpret_cast<void *>(VARIABLE)); if (erc_##VARIABLE != 0) {throw cuda_exception::build("[DEVICE] deallocation failed", erc_##VARIABLE);}; };
|
#define RELEASE_SPECIAL(VARIABLE, WORKSPACE) if (VARIABLE != nullptr) {if (WORKSPACE == nullptr) { nd4j::memory::MemoryTracker::getInstance()->countOut(VARIABLE); auto erc_##VARIABLE = cudaFree(reinterpret_cast<void *>(VARIABLE)); if (erc_##VARIABLE != 0) {throw cuda_exception::build("[DEVICE] deallocation failed", erc_##VARIABLE);}; }; };
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -40,6 +40,7 @@
|
|||||||
#define PRAGMA_OMP_PARALLEL_REDUCTION(args)
|
#define PRAGMA_OMP_PARALLEL_REDUCTION(args)
|
||||||
#define PRAGMA_OMP_PARALLEL_ARGS(args)
|
#define PRAGMA_OMP_PARALLEL_ARGS(args)
|
||||||
#define PRAGMA_OMP_PARALLEL_THREADS(args)
|
#define PRAGMA_OMP_PARALLEL_THREADS(args)
|
||||||
|
#define PRAGMA_OMP_PARALLEL_THREADS_IF(threads, condition)
|
||||||
#define PRAGMA_OMP_PARALLEL_FOR
|
#define PRAGMA_OMP_PARALLEL_FOR
|
||||||
#define PRAGMA_OMP_PARALLEL_FOR_ARGS(args)
|
#define PRAGMA_OMP_PARALLEL_FOR_ARGS(args)
|
||||||
#define PRAGMA_OMP_PARALLEL_FOR_IF(args)
|
#define PRAGMA_OMP_PARALLEL_FOR_IF(args)
|
||||||
@ -77,6 +78,7 @@
|
|||||||
#define PRAGMA_OMP_PARALLEL_REDUCTION(args) _Pragma(OMP_STRINGIFY(omp parallel reduction(args) default(shared)))
|
#define PRAGMA_OMP_PARALLEL_REDUCTION(args) _Pragma(OMP_STRINGIFY(omp parallel reduction(args) default(shared)))
|
||||||
#define PRAGMA_OMP_PARALLEL_ARGS(args) _Pragma(OMP_STRINGIFY(omp parallel args default(shared)))
|
#define PRAGMA_OMP_PARALLEL_ARGS(args) _Pragma(OMP_STRINGIFY(omp parallel args default(shared)))
|
||||||
#define PRAGMA_OMP_PARALLEL_THREADS(args) _Pragma(OMP_STRINGIFY(omp parallel num_threads(args) if(args > 1) default(shared)))
|
#define PRAGMA_OMP_PARALLEL_THREADS(args) _Pragma(OMP_STRINGIFY(omp parallel num_threads(args) if(args > 1) default(shared)))
|
||||||
|
#define PRAGMA_OMP_PARALLEL_THREADS_IF(threads, condition) _Pragma(OMP_STRINGIFY(omp parallel num_threads(threads) if(condition) default(shared)))
|
||||||
#define PRAGMA_OMP_PARALLEL_FOR _Pragma(OMP_STRINGIFY(omp parallel for default(shared)))
|
#define PRAGMA_OMP_PARALLEL_FOR _Pragma(OMP_STRINGIFY(omp parallel for default(shared)))
|
||||||
#define PRAGMA_OMP_PARALLEL_FOR_REDUCTION(args) _Pragma(OMP_STRINGIFY(omp parallel for reduction(args) default(shared)))
|
#define PRAGMA_OMP_PARALLEL_FOR_REDUCTION(args) _Pragma(OMP_STRINGIFY(omp parallel for reduction(args) default(shared)))
|
||||||
#define PRAGMA_OMP_PARALLEL_FOR_ARGS(args) _Pragma(OMP_STRINGIFY(omp parallel for args default(shared)))
|
#define PRAGMA_OMP_PARALLEL_FOR_ARGS(args) _Pragma(OMP_STRINGIFY(omp parallel for args default(shared)))
|
||||||
|
@ -50,7 +50,9 @@ namespace nd4j {
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
// FIXME: set proper extras here
|
// FIXME: set proper extras here
|
||||||
y->applyPairwiseTransform(pairwise::Axpy, x, z, nullptr);
|
ExtraArguments arguments({a});
|
||||||
|
|
||||||
|
y->applyPairwiseTransform(pairwise::Axpy, x, z, &arguments);
|
||||||
|
|
||||||
return ND4J_STATUS_OK;
|
return ND4J_STATUS_OK;
|
||||||
}
|
}
|
||||||
|
@ -102,7 +102,7 @@ namespace nd4j {
|
|||||||
COPY_SHAPE(x, shapeE);
|
COPY_SHAPE(x, shapeE);
|
||||||
COPY_SHAPE(y, shapeG);
|
COPY_SHAPE(y, shapeG);
|
||||||
|
|
||||||
auto shapeList = SHAPELIST(shapeE, shapeG);
|
auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG));
|
||||||
|
|
||||||
return shapeList;
|
return shapeList;
|
||||||
}
|
}
|
||||||
|
@ -67,9 +67,14 @@ namespace nd4j {
|
|||||||
auto gradX = OUTPUT_VARIABLE(0);
|
auto gradX = OUTPUT_VARIABLE(0);
|
||||||
auto gradY = OUTPUT_VARIABLE(1);
|
auto gradY = OUTPUT_VARIABLE(1);
|
||||||
|
|
||||||
gradY->assign(0.0f);
|
gradY->assign(x);
|
||||||
gradX->assign(0.0f);
|
std::unique_ptr<NDArray> ySq(y->dup());
|
||||||
|
ySq->applyTransform(transform::Square, nullptr);
|
||||||
|
gradY->applyPairwiseTransform(pairwise::FloorDiv, ySq.get(), gradY, nullptr);
|
||||||
|
gradY->applyPairwiseTransform(pairwise::Multiply, epsNext, gradY, nullptr);
|
||||||
|
gradY->applyTransform(transform::Neg, nullptr);
|
||||||
|
gradX->assign(epsNext);
|
||||||
|
//gradX->applyPairwiseTransform(pairwise::FloorDiv, y, gradX, nullptr);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -87,9 +92,7 @@ namespace nd4j {
|
|||||||
COPY_SHAPE(x, shapeE);
|
COPY_SHAPE(x, shapeE);
|
||||||
COPY_SHAPE(y, shapeG);
|
COPY_SHAPE(y, shapeG);
|
||||||
|
|
||||||
auto shapeList = SHAPELIST(shapeE, shapeG);
|
return SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG));
|
||||||
|
|
||||||
return shapeList;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -22,6 +22,7 @@
|
|||||||
#if NOT_EXCLUDED(OP_apply_sgd)
|
#if NOT_EXCLUDED(OP_apply_sgd)
|
||||||
|
|
||||||
#include <ops/declarable/CustomOperations.h>
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <ops/declarable/helpers/gradient.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
@ -44,14 +45,8 @@ namespace nd4j {
|
|||||||
|
|
||||||
auto Z = OUTPUT_VARIABLE(0);
|
auto Z = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
// FIXME: lambda
|
helpers::applyGradientDescent(block.launchContext(), parameters, gradients, lr, Z);
|
||||||
/*
|
|
||||||
auto lambda = LAMBDA_TT(_x, _y, lr) {
|
|
||||||
return _x - (_y * lr);
|
|
||||||
};
|
|
||||||
|
|
||||||
parameters->applyPairwiseLambda(gradients, lambda, Z);
|
|
||||||
*/
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
DECLARE_SYN(ApplyGradientDescent, apply_sgd);
|
DECLARE_SYN(ApplyGradientDescent, apply_sgd);
|
||||||
|
@ -28,8 +28,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
//
|
//
|
||||||
// Created by raver119 on 19.01.18.
|
// @author raver119@gmail.com, created on 19.01.18.
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <op_boilerplate.h>
|
#include <op_boilerplate.h>
|
||||||
@ -40,261 +42,87 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
const int kMaxSpaceToBatchBlockDims = 4;
|
|
||||||
|
|
||||||
DECLARE_TYPES(batch_to_space) {
|
|
||||||
getOpDescriptor()
|
|
||||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
|
||||||
->setSameMode(true);
|
|
||||||
}
|
|
||||||
|
|
||||||
CUSTOM_OP_IMPL(batch_to_space, 1, 1, false, 0, -2) {
|
CUSTOM_OP_IMPL(batch_to_space, 2, 1, false, 0, 1) {
|
||||||
|
|
||||||
|
// [bS*blockSize*blockSize, H/blockSize, W/blockSize, iC] is rearranged/permuted to [bS, oH, oW, iC]
|
||||||
|
// oH = H - cropTop - cropBottom
|
||||||
|
// oW = W - cropLeft - cropRight
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto crop = INPUT_VARIABLE(1);
|
||||||
bool order_changed = false;
|
|
||||||
if (input->ordering() != 'c' || input->ews() != 1 || input->isView()) {
|
|
||||||
order_changed = true;
|
|
||||||
input = input->dup('c');
|
|
||||||
}
|
|
||||||
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
const int input_dims = input->rankOf();
|
const uint blockSize = INT_ARG(0);
|
||||||
int block_dims = 0;
|
REQUIRE_TRUE(blockSize >= 2, 0, "BatchToSpace: integer parameter block_size must be >= 2, but got %i instead", blockSize);
|
||||||
std::vector<Nd4jLong> block_shape; // = blocks->template asVectorT<int>();
|
|
||||||
std::vector<Nd4jLong> crops_shape; // = crops->template asVectorT<int>();
|
|
||||||
|
|
||||||
if (block.width() >= 3) {
|
const int rank = input->rankOf();
|
||||||
auto blocks = INPUT_VARIABLE(1);
|
const int dim0 = input->sizeAt(0);
|
||||||
auto crops = INPUT_VARIABLE(2);
|
REQUIRE_TRUE(rank == 4, 0, "BatchToSpace: rank of input array must be equal 4, but got %i instead", rank);
|
||||||
|
REQUIRE_TRUE(dim0 % (blockSize * blockSize) == 0, 0, "BatchToSpace: first dimension of input array must be divisible by blockSize * blockSize (that is by %i), but got first dimension equal to %i", blockSize * blockSize, dim0);
|
||||||
|
|
||||||
block_dims = (int) blocks->sizeAt(0);
|
const std::string expectedCropShape = "[2, 2]";
|
||||||
|
const std::string actualCropShape = ShapeUtils::shapeAsString(crop);
|
||||||
|
REQUIRE_TRUE(actualCropShape == expectedCropShape, 0, "BatchToSpace: operation expects crop shape to be {2, 2}, but got %s instead", actualCropShape.c_str());
|
||||||
|
|
||||||
REQUIRE_TRUE(blocks->isVector() || blocks->lengthOf() == 1, 0, "BatchToSpace: blocks supposed to be vector or scalar, but got %iD instead", blocks->rankOf());
|
const uint cropBottom = crop->e<uint>(0,0);
|
||||||
REQUIRE_TRUE(input->rankOf() >= 1 + blocks->lengthOf(), 0, "BatchToSpace: blocks length + 1 should match input rank at least");
|
const uint cropTop = crop->e<uint>(0,1);
|
||||||
REQUIRE_TRUE(crops->rankOf() == 2, 0, "BatchToSpace: padding should have rank of 2, but got %i instead", crops->rankOf());
|
const uint cropLeft = crop->e<uint>(1,0);
|
||||||
REQUIRE_TRUE(crops->columns() == 2 && blocks->lengthOf() == crops->rows(), 0, "BatchToSpace: padding should have M rows and 2 columns");
|
const uint cropRight = crop->e<uint>(1,1);
|
||||||
|
|
||||||
block_shape = blocks->template asVectorT<Nd4jLong>();
|
const int oH = input->sizeAt(1) * blockSize - cropBottom - cropTop; // top and bottom
|
||||||
crops_shape = crops->template asVectorT<Nd4jLong>();
|
const int oW = input->sizeAt(2) * blockSize - cropLeft - cropRight; // left and right
|
||||||
|
REQUIRE_TRUE(oH >= 0, 0, "BatchToSpace: crop top/bottom values are too big and cause negative output height dimension !");
|
||||||
|
REQUIRE_TRUE(oW >= 0, 0, "BatchToSpace: crop left/right values are too big and cause negative output width dimension !");
|
||||||
|
|
||||||
} else if (block.numI() > 0) {
|
helpers::batchToSpace(block.launchContext(), *input, *output, cropBottom, cropTop, cropLeft, cropRight, blockSize);
|
||||||
int totalArgs = block.numI();
|
|
||||||
|
|
||||||
int M = totalArgs / 3;
|
|
||||||
REQUIRE_TRUE(totalArgs % 3 == 0, 0, "BatchToSpace: number of IntArguments should be dividable by 3 without reminder");
|
|
||||||
|
|
||||||
block_dims = M;
|
|
||||||
block_shape.resize(block_dims);
|
|
||||||
crops_shape.resize(M*2);
|
|
||||||
|
|
||||||
REQUIRE_TRUE(input->rankOf() >= 1 + M + 1, 0, "BatchToSpace: blocks length + 2 should match input rank at least");
|
|
||||||
|
|
||||||
int e = 0;
|
|
||||||
for (; e < block_dims; e++)
|
|
||||||
block_shape[e] = INT_ARG(e);
|
|
||||||
|
|
||||||
for (; e < block.numI(); e++)
|
|
||||||
crops_shape[e - M] = INT_ARG(e);
|
|
||||||
} else {
|
|
||||||
REQUIRE_TRUE(false, 0, "BatchToSpace: there should be some params :(");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
int removed_prefix_block_dims = 0;
|
|
||||||
for (; removed_prefix_block_dims < block_dims; ++removed_prefix_block_dims) {
|
|
||||||
const int dim = removed_prefix_block_dims;
|
|
||||||
if (crops_shape[2 * dim] != 0 || crops_shape[2 * dim + 1] != 0 || block_shape[dim] != 1)
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
int removed_suffix_block_dims = 0;
|
|
||||||
for (; removed_suffix_block_dims < block_dims - removed_prefix_block_dims; ++removed_suffix_block_dims) {
|
|
||||||
const int dim = block_dims - 1 - removed_suffix_block_dims;
|
|
||||||
if (crops_shape[2 * dim] != 0 || crops_shape[2 * dim + 1] != 0 || block_shape[dim] != 1)
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
int block_shape_product = 1;
|
|
||||||
for (int block_dim = 0; block_dim < block_dims; ++block_dim)
|
|
||||||
block_shape_product *= block_shape[block_dim];
|
|
||||||
|
|
||||||
REQUIRE_TRUE(block_shape_product > 0, 0, "BatchToSpace: block should contain values >= 1 ONLY");
|
|
||||||
|
|
||||||
|
|
||||||
const Nd4jLong orig_input_batch_size = input->sizeAt(0);
|
|
||||||
const int internal_block_dims = block_dims - removed_prefix_block_dims - removed_suffix_block_dims;
|
|
||||||
|
|
||||||
REQUIRE_TRUE(internal_block_dims <= kMaxSpaceToBatchBlockDims, 0, "BatchToSpace: Maximum number of non-combined block dimensions should be less or equal then %i but got %i instead", kMaxSpaceToBatchBlockDims, internal_block_dims);
|
|
||||||
|
|
||||||
if (internal_block_dims == 0) {
|
|
||||||
output->assign(input);
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<Nd4jLong> internal_input_shape;
|
|
||||||
std::vector<Nd4jLong> internal_output_shape;
|
|
||||||
std::vector<Nd4jLong> external_output_shape;
|
|
||||||
|
|
||||||
external_output_shape.emplace_back(orig_input_batch_size / block_shape_product);
|
|
||||||
|
|
||||||
int input_batch_size = orig_input_batch_size;
|
|
||||||
for (int block_dim = 0; block_dim < removed_prefix_block_dims; ++block_dim) {
|
|
||||||
const int size = input->sizeAt(block_dim + 1);
|
|
||||||
input_batch_size *= size;
|
|
||||||
external_output_shape.emplace_back(size);
|
|
||||||
}
|
|
||||||
internal_input_shape.emplace_back(input_batch_size);
|
|
||||||
internal_output_shape.emplace_back(input_batch_size / block_shape_product);
|
|
||||||
|
|
||||||
for (int block_dim = removed_prefix_block_dims;
|
|
||||||
block_dim < block_dims - removed_suffix_block_dims; ++block_dim) {
|
|
||||||
const int crop_start = crops_shape[2 * block_dim];
|
|
||||||
const int crop_end = crops_shape[2 * block_dim + 1];
|
|
||||||
|
|
||||||
const int input_size = input->sizeAt(block_dim + 1);
|
|
||||||
const int block_shape_value = block_shape[block_dim];
|
|
||||||
const int cropped_size = input_size * block_shape_value - crop_start - crop_end;
|
|
||||||
|
|
||||||
REQUIRE_TRUE(cropped_size >= 0, 0, "BatchToSpace: cropped_size should have non-negative value");
|
|
||||||
|
|
||||||
internal_input_shape.emplace_back(input_size);
|
|
||||||
internal_output_shape.emplace_back(cropped_size);
|
|
||||||
external_output_shape.emplace_back(cropped_size);
|
|
||||||
}
|
|
||||||
|
|
||||||
int depth = 1;
|
|
||||||
for (int dim = block_dims - removed_suffix_block_dims + 1; dim < input_dims; ++dim) {
|
|
||||||
const int size = input->sizeAt(dim);
|
|
||||||
external_output_shape.emplace_back(size);
|
|
||||||
depth *= size;
|
|
||||||
}
|
|
||||||
|
|
||||||
internal_input_shape.emplace_back(depth);
|
|
||||||
internal_output_shape.emplace_back(depth);
|
|
||||||
|
|
||||||
auto internal_crops = &crops_shape.data()[2 * removed_prefix_block_dims];
|
|
||||||
auto internal_block_shape = &block_shape.data()[removed_prefix_block_dims];
|
|
||||||
|
|
||||||
helpers::_batchToSpace(block.launchContext(), internal_block_dims, output, input, internal_output_shape, internal_input_shape, internal_block_shape, internal_crops);
|
|
||||||
|
|
||||||
if (order_changed)
|
|
||||||
delete input;
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
DECLARE_TYPES(batch_to_space) {
|
||||||
|
|
||||||
|
getOpDescriptor()->setAllowedInputTypes(0, nd4j::DataType::ANY)
|
||||||
|
->setAllowedInputTypes(1, {ALL_INTS})
|
||||||
|
->setSameMode(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
DECLARE_SHAPE_FN(batch_to_space) {
|
DECLARE_SHAPE_FN(batch_to_space) {
|
||||||
auto in = inputShape->at(0);
|
|
||||||
|
|
||||||
const int input_dims = shape::rank(in);
|
auto inputShapeInfo = inputShape->at(0);
|
||||||
int block_dims = 0;
|
auto cropShapeInfo = inputShape->at(1);
|
||||||
|
|
||||||
std::vector<int> block_shape;
|
const uint blockSize = INT_ARG(0);
|
||||||
std::vector<int> crops_shape;
|
REQUIRE_TRUE(blockSize >= 2, 0, "BatchToSpace: integer parameter block_size must be >= 2, but got %i instead", blockSize);
|
||||||
|
|
||||||
if (block.width() >= 3) {
|
const int rank = inputShapeInfo[0];
|
||||||
auto blocks = INPUT_VARIABLE(1);
|
const int dim0 = inputShapeInfo[1];
|
||||||
auto crops = INPUT_VARIABLE(2);
|
REQUIRE_TRUE(rank == 4, 0, "BatchToSpace: rank of input array must be equal 4, but got %i instead", rank);
|
||||||
|
REQUIRE_TRUE(dim0 % (blockSize * blockSize) == 0, 0, "BatchToSpace: first dimension of input array must be divisible by blockSize * blockSize (that is by %i), but got first dimension equal to %i", blockSize * blockSize, dim0);
|
||||||
|
|
||||||
block_dims = (int) blocks->sizeAt(0);
|
const std::string expectedCropShape = "[2, 2]";
|
||||||
|
const std::string actualCropShape = ShapeUtils::shapeAsString(cropShapeInfo);
|
||||||
|
REQUIRE_TRUE(actualCropShape == expectedCropShape, 0, "BatchToSpace: operation expects crop shape to be {2, 2}, but got %s instead", actualCropShape.c_str());
|
||||||
|
|
||||||
block_shape = blocks->template asVectorT<int>();
|
const uint cropBottom = INPUT_VARIABLE(1)->e<Nd4jLong>(0,0);
|
||||||
crops_shape = crops->template asVectorT<int>();
|
const uint cropTop = INPUT_VARIABLE(1)->e<Nd4jLong>(0,1);
|
||||||
|
const uint cropLeft = INPUT_VARIABLE(1)->e<Nd4jLong>(1,0);
|
||||||
|
const uint cropRight = INPUT_VARIABLE(1)->e<Nd4jLong>(1,1);
|
||||||
|
|
||||||
//shape::printShapeInfoLinear("STB input shape: ",in);
|
const int oH = inputShapeInfo[2] * blockSize - cropTop - cropBottom; // top and bottom
|
||||||
//blocks->printBuffer("STB blocks");
|
const int oW = inputShapeInfo[3] * blockSize - cropLeft - cropRight; // left and right
|
||||||
//crops->printBuffer("STB crops");
|
REQUIRE_TRUE(oH >= 0, 0, "BatchToSpace: crop top/bottom values are too big and cause negative output height dimension !");
|
||||||
|
REQUIRE_TRUE(oW >= 0, 0, "BatchToSpace: crop left/right values are too big and cause negative output width dimension !");
|
||||||
} else if (block.numI() > 0) {
|
|
||||||
int totalArgs = block.numI();
|
|
||||||
|
|
||||||
int M = totalArgs / 3;
|
|
||||||
|
|
||||||
block_dims = M;
|
|
||||||
block_shape.resize(block_dims);
|
|
||||||
crops_shape.resize(M*2);
|
|
||||||
|
|
||||||
int e = 0;
|
|
||||||
for (; e < block_dims; e++)
|
|
||||||
block_shape[e] = INT_ARG(e);
|
|
||||||
|
|
||||||
for (; e < block.numI(); e++)
|
|
||||||
crops_shape[e - M] = INT_ARG(e);
|
|
||||||
} else {
|
|
||||||
// throw something here
|
|
||||||
}
|
|
||||||
|
|
||||||
int removed_prefix_block_dims = 0;
|
|
||||||
for (; removed_prefix_block_dims < block_dims; ++removed_prefix_block_dims) {
|
|
||||||
const int dim = removed_prefix_block_dims;
|
|
||||||
if (crops_shape[2 * dim] != 0 || crops_shape[2 * dim + 1] != 0 || block_shape[dim] != 1)
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
int removed_suffix_block_dims = 0;
|
|
||||||
for (; removed_suffix_block_dims < block_dims - removed_prefix_block_dims; ++removed_suffix_block_dims) {
|
|
||||||
const int dim = block_dims - 1 - removed_suffix_block_dims;
|
|
||||||
if (crops_shape[2 * dim] != 0 || crops_shape[2 * dim + 1] != 0 || block_shape[dim] != 1)
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
int block_shape_product = 1;
|
|
||||||
for (int block_dim = 0; block_dim < block_dims; ++block_dim)
|
|
||||||
block_shape_product *= block_shape[block_dim];
|
|
||||||
|
|
||||||
|
|
||||||
const int orig_input_batch_size = shape::sizeAt(in, 0);
|
|
||||||
const int internal_block_dims = block_dims - removed_prefix_block_dims - removed_suffix_block_dims;
|
|
||||||
|
|
||||||
if (internal_block_dims == 0) {
|
|
||||||
// just return input shape here
|
|
||||||
Nd4jLong *newShape;
|
|
||||||
COPY_SHAPE(in, newShape);
|
|
||||||
return SHAPELIST(newShape);
|
|
||||||
}
|
|
||||||
|
|
||||||
// go full route otherwise
|
|
||||||
std::vector<Nd4jLong> internal_input_shape;
|
|
||||||
std::vector<Nd4jLong> internal_output_shape;
|
|
||||||
std::vector<Nd4jLong> external_output_shape;
|
|
||||||
|
|
||||||
external_output_shape.emplace_back(orig_input_batch_size / block_shape_product);
|
|
||||||
|
|
||||||
auto input_batch_size = orig_input_batch_size;
|
|
||||||
for (int block_dim = 0; block_dim < removed_prefix_block_dims; ++block_dim) {
|
|
||||||
const Nd4jLong size = shape::sizeAt(in, block_dim + 1);
|
|
||||||
input_batch_size *= size;
|
|
||||||
external_output_shape.emplace_back(size);
|
|
||||||
}
|
|
||||||
internal_input_shape.emplace_back(input_batch_size);
|
|
||||||
internal_output_shape.emplace_back(input_batch_size / block_shape_product);
|
|
||||||
|
|
||||||
for (int block_dim = removed_prefix_block_dims;
|
|
||||||
block_dim < block_dims - removed_suffix_block_dims; ++block_dim) {
|
|
||||||
const int crop_start = crops_shape[2 * block_dim];
|
|
||||||
const int crop_end = crops_shape[2 * block_dim + 1];
|
|
||||||
|
|
||||||
const int input_size = shape::sizeAt(in, block_dim + 1);
|
|
||||||
const int block_shape_value = block_shape[block_dim];
|
|
||||||
const int cropped_size = input_size * block_shape_value - crop_start - crop_end;
|
|
||||||
|
|
||||||
internal_input_shape.emplace_back(input_size);
|
|
||||||
internal_output_shape.emplace_back(cropped_size);
|
|
||||||
external_output_shape.emplace_back(cropped_size);
|
|
||||||
}
|
|
||||||
|
|
||||||
int depth = 1;
|
|
||||||
for (int dim = block_dims - removed_suffix_block_dims + 1; dim < input_dims; ++dim) {
|
|
||||||
const int size = shape::sizeAt(in, dim);
|
|
||||||
external_output_shape.emplace_back(size);
|
|
||||||
depth *= size;
|
|
||||||
}
|
|
||||||
|
|
||||||
// we always give out C order here
|
// we always give out C order here
|
||||||
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(in), 'c', external_output_shape));
|
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShapeInfo), 'c', {dim0 / (blockSize * blockSize), oH, oW, inputShapeInfo[4]}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -73,8 +73,9 @@ namespace nd4j {
|
|||||||
outputShape[2] = height;
|
outputShape[2] = height;
|
||||||
outputShape[3] = in[4];
|
outputShape[3] = in[4];
|
||||||
|
|
||||||
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(in), shape::order(in), outputShape, 4)));
|
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(DataType::FLOAT32, shape::order(in), outputShape, 4)));
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_TYPES(crop_and_resize) {
|
DECLARE_TYPES(crop_and_resize) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS})
|
->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS})
|
||||||
|
@ -36,7 +36,7 @@ CUSTOM_OP_IMPL(parallel_stack, -1, 1, false, 0, 0) {
|
|||||||
for (int i = 0; i < (int) block.width() - 1; ++i)
|
for (int i = 0; i < (int) block.width() - 1; ++i)
|
||||||
REQUIRE_TRUE(shape::equalsSoft((INPUT_VARIABLE(i))->getShapeInfo(), (INPUT_VARIABLE(i+1))->getShapeInfo()), 0, "PARALLEL_STACK op: the shapes of all input arrays must be the same !");
|
REQUIRE_TRUE(shape::equalsSoft((INPUT_VARIABLE(i))->getShapeInfo(), (INPUT_VARIABLE(i+1))->getShapeInfo()), 0, "PARALLEL_STACK op: the shapes of all input arrays must be the same !");
|
||||||
|
|
||||||
std::vector<NDArray*> inArrs(block.width());
|
std::vector<const NDArray*> inArrs(block.width());
|
||||||
for(int i = 0; i < block.width(); ++i)
|
for(int i = 0; i < block.width(); ++i)
|
||||||
inArrs[i] = INPUT_VARIABLE(i);
|
inArrs[i] = INPUT_VARIABLE(i);
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// @author Created by raver119 on 24.11.17.
|
// @author raver119@gmail.com, created on 24.11.17.
|
||||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
//
|
//
|
||||||
|
|
||||||
|
@ -13,7 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
//
|
//
|
||||||
// Created by raver119 on 19.01.18.
|
// @author raver119@gmail.com, created on 19.01.18.
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <op_boilerplate.h>
|
#include <op_boilerplate.h>
|
||||||
@ -24,272 +25,73 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
const int kMaxSpaceToBatchBlockDims = 4;
|
|
||||||
|
|
||||||
DECLARE_TYPES(space_to_batch) {
|
|
||||||
getOpDescriptor()
|
|
||||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
|
||||||
->setSameMode(true);
|
|
||||||
}
|
|
||||||
|
|
||||||
CUSTOM_OP_IMPL(space_to_batch, 1, 1, false, 0, -2) {
|
CUSTOM_OP_IMPL(space_to_batch, 2, 1, false, 0, 1) {
|
||||||
|
|
||||||
|
// [bS, iH, iW, iC] is rearranged/permuted to [bS*blockSize*blockSize, (iH + padBottom + padTop)/blockSize, (iW + padLeft + padRight)/blockSize, iC]
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto padding = INPUT_VARIABLE(1);
|
||||||
std::vector<Nd4jLong> block_shape;
|
|
||||||
std::vector<Nd4jLong> padding_shape;
|
|
||||||
|
|
||||||
bool order_changed = false;
|
|
||||||
if (input->ordering() != 'c') {
|
|
||||||
order_changed = true;
|
|
||||||
input = input->dup('c');
|
|
||||||
}
|
|
||||||
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
const int xRank = input->rankOf();
|
const uint blockSize = INT_ARG(0);
|
||||||
int block_dims = 0;
|
REQUIRE_TRUE(blockSize >= 2, 0, "SpaceToBatch: integer parameter block_size must be >= 2, but got %i instead", blockSize);
|
||||||
|
|
||||||
|
const int rank = input->rankOf();
|
||||||
|
REQUIRE_TRUE(rank == 4, 0, "SpaceToBatch: rank of input array must be equal 4, but got %i instead", rank);
|
||||||
|
|
||||||
|
const std::string expectedpaddingShape = "[2, 2]";
|
||||||
|
const std::string actualpaddingShape = ShapeUtils::shapeAsString(padding);
|
||||||
|
REQUIRE_TRUE(actualpaddingShape == expectedpaddingShape, 0, "SpaceToBatch: operation expects padding shape to be {2, 2}, but got %s instead", actualpaddingShape.c_str());
|
||||||
|
|
||||||
if (block.width() >= 3) {
|
const uint padBottom = padding->e<uint>(0,0);
|
||||||
auto blocks = INPUT_VARIABLE(1);
|
const uint padTop = padding->e<uint>(0,1);
|
||||||
auto padding = INPUT_VARIABLE(2);
|
const uint padLeft = padding->e<uint>(1,0);
|
||||||
|
const uint padRight = padding->e<uint>(1,1);
|
||||||
|
|
||||||
block_dims = (int) blocks->lengthOf();
|
REQUIRE_TRUE((input->sizeAt(1) + padBottom + padTop) % blockSize == 0 && (input->sizeAt(2) + padLeft + padRight) % blockSize == 0, 0, "SpaceToBatch: after padding, second and third dimensions of input array must be divisible by blockSize !");
|
||||||
|
|
||||||
REQUIRE_TRUE(blocks->isVector() || blocks->lengthOf() == 1, 0, "SpaceToBatch: blocks supposed to be vector or scalar, but got %iD instead", blocks->rankOf());
|
helpers::spaceToBatch(block.launchContext(), *input, *output, padBottom, padTop, padLeft, padRight, blockSize);
|
||||||
REQUIRE_TRUE(input->rankOf() >= 1 + blocks->lengthOf(), 0, "SpaceToBatch: blocks length + 1 should match input rank at least");
|
|
||||||
REQUIRE_TRUE(padding->rankOf() == 2, 0, "SpaceToBatch: padding should have rank of 2, but got %i instead", padding->rankOf());
|
|
||||||
REQUIRE_TRUE(padding->columns() == 2 && blocks->lengthOf() == padding->rows(), 0, "SpaceToBatch: padding should have M rows and 2 columns");
|
|
||||||
|
|
||||||
block_shape = blocks->template asVectorT<Nd4jLong>();
|
|
||||||
padding_shape = padding->template asVectorT<Nd4jLong>();
|
|
||||||
|
|
||||||
} else if (block.numI() > 0) {
|
|
||||||
int totalArgs = block.numI();
|
|
||||||
|
|
||||||
int M = totalArgs / 3;
|
|
||||||
REQUIRE_TRUE(totalArgs % 3 == 0, 0, "SpaceToBatch: number of IntArguments should be dividable by 3 without reminder");
|
|
||||||
|
|
||||||
block_dims = M;
|
|
||||||
block_shape.resize(block_dims);
|
|
||||||
padding_shape.resize(M*2);
|
|
||||||
|
|
||||||
REQUIRE_TRUE(input->rankOf() >= 1 + M, 0, "SpaceToBatch: blocks length + 1 should match input rank at least");
|
|
||||||
|
|
||||||
int e = 0;
|
|
||||||
for (; e < block_dims; e++)
|
|
||||||
block_shape[e] = INT_ARG(e);
|
|
||||||
|
|
||||||
for (; e < block.numI(); e++)
|
|
||||||
padding_shape[e - M] = INT_ARG(e);
|
|
||||||
|
|
||||||
} else {
|
|
||||||
REQUIRE_TRUE(false, 0, "SpaceToBatch: there should be some params :(");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
// Determine the length of the prefix of block dims that can be combined
|
|
||||||
// into the batch dimension due to having no padding and block_shape=1.
|
|
||||||
int removed_prefix_block_dims = 0;
|
|
||||||
for (; removed_prefix_block_dims < block_dims; ++removed_prefix_block_dims) {
|
|
||||||
const int dim = removed_prefix_block_dims;
|
|
||||||
if (padding_shape[2 * dim] != 0 || padding_shape[2 * dim + 1] != 0 || block_shape[dim] != 1)
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Determine the length of the suffix of block dims that can be combined
|
|
||||||
// into the depth dimension due to having no padding and block_shape=1.
|
|
||||||
int removed_suffix_block_dims = 0;
|
|
||||||
for (; removed_suffix_block_dims < block_dims - removed_prefix_block_dims; ++removed_suffix_block_dims) {
|
|
||||||
const int dim = block_dims - 1 - removed_suffix_block_dims;
|
|
||||||
if (padding_shape[dim * 2] != 0 || padding_shape[dim * 2 + 1] != 0 || block_shape[dim] != 1)
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
int block_shape_product = 1;
|
|
||||||
for (int block_dim = 0; block_dim < block_dims; ++block_dim)
|
|
||||||
block_shape_product *= block_shape[block_dim];
|
|
||||||
|
|
||||||
REQUIRE_TRUE(block_shape_product > 0, 0, "SpaceToBatch: block should contain values >= 1 ONLY");
|
|
||||||
|
|
||||||
const int internal_block_dims = block_dims - removed_prefix_block_dims - removed_suffix_block_dims;
|
|
||||||
|
|
||||||
REQUIRE_TRUE(internal_block_dims <= kMaxSpaceToBatchBlockDims, 0, "SpaceToBatch: Maximum number of non-combined block dimensions should be less or equal then %i but got %i instead", kMaxSpaceToBatchBlockDims, internal_block_dims);
|
|
||||||
|
|
||||||
if (internal_block_dims == 0) {
|
|
||||||
// we return array if there's nothing to move here
|
|
||||||
output->assign(input);
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<Nd4jLong> internal_input_shape;
|
|
||||||
std::vector<Nd4jLong> internal_output_shape;
|
|
||||||
std::vector<Nd4jLong> external_output_shape;
|
|
||||||
|
|
||||||
external_output_shape.emplace_back(input->sizeAt(0) * block_shape_product);
|
|
||||||
int input_batch_size = input->sizeAt(0);
|
|
||||||
for (int block_dim = 0; block_dim < removed_prefix_block_dims; block_dim++) {
|
|
||||||
const int size = input->sizeAt(block_dim + 1);
|
|
||||||
input_batch_size *= size;
|
|
||||||
external_output_shape.emplace_back(size);
|
|
||||||
}
|
|
||||||
internal_input_shape.emplace_back(input_batch_size);
|
|
||||||
internal_output_shape.emplace_back(input_batch_size * block_shape_product);
|
|
||||||
|
|
||||||
for (int block_dim = removed_prefix_block_dims; block_dim < block_dims - removed_suffix_block_dims; block_dim++) {
|
|
||||||
const int pad_start = padding_shape[2 * block_dim];
|
|
||||||
const int pad_end = padding_shape[2 * block_dim + 1];
|
|
||||||
|
|
||||||
const int input_size = input->sizeAt(block_dim + 1);
|
|
||||||
const int block_shape_value = block_shape[block_dim];
|
|
||||||
const int padded_size = input_size + pad_start + pad_end;
|
|
||||||
const int output_size = padded_size / block_shape_value;
|
|
||||||
|
|
||||||
// FIXME: validation required here
|
|
||||||
|
|
||||||
internal_input_shape.emplace_back(input_size);
|
|
||||||
internal_output_shape.emplace_back(output_size);
|
|
||||||
external_output_shape.emplace_back(output_size);
|
|
||||||
}
|
|
||||||
|
|
||||||
int depth = 1;
|
|
||||||
for (int dim = block_dims - removed_suffix_block_dims + 1; dim < xRank; dim++) {
|
|
||||||
const int size = input->sizeAt(dim);
|
|
||||||
external_output_shape.emplace_back(size);
|
|
||||||
depth *= size;
|
|
||||||
}
|
|
||||||
|
|
||||||
internal_input_shape.emplace_back(depth);
|
|
||||||
internal_output_shape.emplace_back(depth);
|
|
||||||
|
|
||||||
Nd4jLong* internal_paddings = &padding_shape.data()[2 * removed_prefix_block_dims];
|
|
||||||
Nd4jLong* internal_block_shape = &block_shape.data()[removed_prefix_block_dims];
|
|
||||||
|
|
||||||
helpers::_spaceToBatch(block.launchContext(), internal_block_dims, input, output, internal_input_shape, internal_output_shape, internal_block_shape, internal_paddings);
|
|
||||||
|
|
||||||
if (order_changed)
|
|
||||||
delete input;
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
DECLARE_TYPES(space_to_batch) {
|
||||||
|
|
||||||
|
getOpDescriptor()->setAllowedInputTypes(0, nd4j::DataType::ANY)
|
||||||
|
->setAllowedInputTypes(1, {ALL_INTS})
|
||||||
|
->setSameMode(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
DECLARE_SHAPE_FN(space_to_batch) {
|
DECLARE_SHAPE_FN(space_to_batch) {
|
||||||
auto in = inputShape->at(0);
|
|
||||||
|
|
||||||
const int xRank = shape::rank(in);
|
auto inputShapeInfo = inputShape->at(0);
|
||||||
int block_dims = 0;
|
auto paddingShapeInfo = inputShape->at(1);
|
||||||
|
|
||||||
std::vector<int> block_shape;
|
const uint blockSize = INT_ARG(0);
|
||||||
std::vector<int> padding_shape;
|
REQUIRE_TRUE(blockSize >= 2, 0, "SpaceToBatch: integer parameter block_size must be >= 2, but got %i instead", blockSize);
|
||||||
|
|
||||||
if (block.width() >= 3) {
|
const int rank = inputShapeInfo[0];
|
||||||
auto blocks = INPUT_VARIABLE(1);
|
REQUIRE_TRUE(rank == 4, 0, "SpaceToBatch: rank of input array must be equal 4, but got %i instead", rank);
|
||||||
auto padding = INPUT_VARIABLE(2);
|
|
||||||
|
|
||||||
block_dims = (int) blocks->lengthOf();
|
const std::string expectedpaddingShape = "[2, 2]";
|
||||||
|
const std::string actualpaddingShape = ShapeUtils::shapeAsString(paddingShapeInfo);
|
||||||
|
REQUIRE_TRUE(actualpaddingShape == expectedpaddingShape, 0, "SpaceToBatch: operation expects padding shape to be {2, 2}, but got %s instead", actualpaddingShape.c_str());
|
||||||
|
|
||||||
block_shape.resize(block_dims);
|
const uint padBottom = INPUT_VARIABLE(1)->e<Nd4jLong>(0,0);
|
||||||
padding_shape.resize(padding->lengthOf());
|
const uint padTop = INPUT_VARIABLE(1)->e<Nd4jLong>(0,1);
|
||||||
|
const uint padLeft = INPUT_VARIABLE(1)->e<Nd4jLong>(1,0);
|
||||||
|
const uint padRight = INPUT_VARIABLE(1)->e<Nd4jLong>(1,1);
|
||||||
|
|
||||||
for (int e = 0; e < block_dims; e++)
|
REQUIRE_TRUE((inputShapeInfo[2] + padBottom + padTop) % blockSize == 0 && (inputShapeInfo[3] + padLeft + padRight) % blockSize == 0, 0, "SpaceToBatch: after padding, second and third dimensions of input array must be divisible by blockSize !");
|
||||||
block_shape[e] = blocks->e<int>(e);
|
|
||||||
|
|
||||||
for (int e = 0; e < padding->lengthOf(); e++)
|
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShapeInfo), 'c', {inputShapeInfo[1] * blockSize * blockSize, (inputShapeInfo[2] + padBottom + padTop) / blockSize, (inputShapeInfo[3] + padLeft + padRight) / blockSize, inputShapeInfo[4]}));
|
||||||
padding_shape[e] = padding->e<int>(e);
|
|
||||||
} else if (block.numI() > 0) {
|
|
||||||
int totalArgs = block.numI();
|
|
||||||
|
|
||||||
int M = totalArgs / 3;
|
|
||||||
|
|
||||||
block_dims = M;
|
|
||||||
block_shape.resize(block_dims);
|
|
||||||
padding_shape.resize(M*2);
|
|
||||||
|
|
||||||
int e = 0;
|
|
||||||
for (; e < block_dims; e++)
|
|
||||||
block_shape[e] = INT_ARG(e);
|
|
||||||
|
|
||||||
for (; e < block.numI(); e++)
|
|
||||||
padding_shape[e - M] = INT_ARG(e);
|
|
||||||
|
|
||||||
} else {
|
|
||||||
// throw something here
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
int removed_prefix_block_dims = 0;
|
|
||||||
for (; removed_prefix_block_dims < block_dims; ++removed_prefix_block_dims) {
|
|
||||||
const int dim = removed_prefix_block_dims;
|
|
||||||
if (padding_shape[2 * dim] != 0 || padding_shape[2 * dim + 1] != 0 || block_shape[dim] != 1)
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
int removed_suffix_block_dims = 0;
|
|
||||||
for (; removed_suffix_block_dims < block_dims - removed_prefix_block_dims; ++removed_suffix_block_dims) {
|
|
||||||
const int dim = block_dims - 1 - removed_suffix_block_dims;
|
|
||||||
if (padding_shape[dim * 2] != 0 || padding_shape[dim * 2 + 1] != 0 || block_shape[dim] != 1)
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
int block_shape_product = 1;
|
|
||||||
for (int block_dim = 0; block_dim < block_dims; ++block_dim)
|
|
||||||
block_shape_product *= block_shape[block_dim];
|
|
||||||
|
|
||||||
const int internal_block_dims = block_dims - removed_prefix_block_dims - removed_suffix_block_dims;
|
|
||||||
|
|
||||||
if (internal_block_dims == 0) {
|
|
||||||
// just return input shape here
|
|
||||||
Nd4jLong *newShape;
|
|
||||||
COPY_SHAPE(in, newShape);
|
|
||||||
return SHAPELIST(CONSTANT(newShape));
|
|
||||||
}
|
|
||||||
|
|
||||||
// go full route otherwise
|
|
||||||
std::vector<Nd4jLong> internal_input_shape;
|
|
||||||
std::vector<Nd4jLong> internal_output_shape;
|
|
||||||
std::vector<Nd4jLong> external_output_shape;
|
|
||||||
|
|
||||||
external_output_shape.emplace_back(shape::sizeAt(in, 0) * block_shape_product);
|
|
||||||
Nd4jLong input_batch_size = shape::sizeAt(in, 0);
|
|
||||||
for (int block_dim = 0; block_dim < removed_prefix_block_dims; block_dim++) {
|
|
||||||
const int size = shape::sizeAt(in, block_dim + 1);
|
|
||||||
input_batch_size *= size;
|
|
||||||
external_output_shape.emplace_back(size);
|
|
||||||
}
|
|
||||||
internal_input_shape.emplace_back(input_batch_size);
|
|
||||||
internal_output_shape.emplace_back(input_batch_size * block_shape_product);
|
|
||||||
|
|
||||||
for (int block_dim = removed_prefix_block_dims; block_dim < block_dims - removed_suffix_block_dims; block_dim++) {
|
|
||||||
const Nd4jLong pad_start = padding_shape[2 * block_dim];
|
|
||||||
const Nd4jLong pad_end = padding_shape[2 * block_dim + 1];
|
|
||||||
|
|
||||||
const Nd4jLong input_size = shape::sizeAt(in, block_dim + 1);
|
|
||||||
const Nd4jLong block_shape_value = block_shape[block_dim];
|
|
||||||
const Nd4jLong padded_size = input_size + pad_start + pad_end;
|
|
||||||
const Nd4jLong output_size = padded_size / block_shape_value;
|
|
||||||
|
|
||||||
// FIXME: validation required here
|
|
||||||
|
|
||||||
internal_input_shape.emplace_back(input_size);
|
|
||||||
internal_output_shape.emplace_back(output_size);
|
|
||||||
external_output_shape.emplace_back(output_size);
|
|
||||||
}
|
|
||||||
|
|
||||||
int depth = 1;
|
|
||||||
for (int dim = block_dims - removed_suffix_block_dims + 1; dim < xRank; dim++) {
|
|
||||||
const Nd4jLong size = shape::sizeAt(in, dim);
|
|
||||||
external_output_shape.emplace_back(size);
|
|
||||||
depth *= size;
|
|
||||||
}
|
|
||||||
|
|
||||||
internal_input_shape.emplace_back(depth);
|
|
||||||
internal_output_shape.emplace_back(depth);
|
|
||||||
|
|
||||||
// we always give out C order here
|
|
||||||
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(in), 'c', external_output_shape);
|
|
||||||
return SHAPELIST(newShape);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -46,7 +46,7 @@ CUSTOM_OP_IMPL(stack, -1, 1, false, 0, 0) {
|
|||||||
REQUIRE_TRUE(dim <= input->rankOf(), 0, "STACK op: the input dimension parameter must be <= rank of input arrays shapes (rank=%i), but got %i instead !", input->shapeOf(), dim);
|
REQUIRE_TRUE(dim <= input->rankOf(), 0, "STACK op: the input dimension parameter must be <= rank of input arrays shapes (rank=%i), but got %i instead !", input->shapeOf(), dim);
|
||||||
|
|
||||||
|
|
||||||
std::vector<NDArray*> inArrs(block.width());
|
std::vector<const NDArray*> inArrs(block.width());
|
||||||
for(int i = 0; i < block.width(); ++i)
|
for(int i = 0; i < block.width(); ++i)
|
||||||
inArrs[i] = INPUT_VARIABLE(i);
|
inArrs[i] = INPUT_VARIABLE(i);
|
||||||
|
|
||||||
|
@ -31,26 +31,26 @@ namespace ops {
|
|||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
CUSTOM_OP_IMPL(lstmBlock, 9, 7, false, 2, 2) {
|
CUSTOM_OP_IMPL(lstmBlock, 9, 7, false, 2, 2) {
|
||||||
auto maxTSLength = INPUT_VARIABLE(0);
|
auto maxTSLength = INPUT_VARIABLE(0);
|
||||||
auto x = INPUT_VARIABLE(1); // input [seqLen, bS, inSize] at time t
|
auto x = INPUT_VARIABLE(1); // input [seqLen, bS, nIn] at time t
|
||||||
auto cLast = INPUT_VARIABLE(2); // previous cell state [bS, numUnits], time t-1
|
auto cLast = INPUT_VARIABLE(2); // previous cell state [bS, nOut], time t-1
|
||||||
auto yLast = INPUT_VARIABLE(3); // previous output [bS, numUnits], time t-1
|
auto yLast = INPUT_VARIABLE(3); // previous output [bS, nOut], time t-1
|
||||||
|
|
||||||
auto W = INPUT_VARIABLE(4); // Weights - concatenated (input-to-hidden, hidden-to-hidden weights) weights, [(inSize+numUnits), 4*numUnits]
|
auto W = INPUT_VARIABLE(4); // Weights - concatenated (input-to-hidden, hidden-to-hidden weights) weights, [(nIn+nOut), 4*nOut]
|
||||||
auto Wci = INPUT_VARIABLE(5); // weights - cell peephole (t-1) connections to input modulation gate, [numUnits]
|
auto Wci = INPUT_VARIABLE(5); // weights - cell peephole (t-1) connections to input modulation gate, [nOut]
|
||||||
auto Wcf = INPUT_VARIABLE(6); // weights - cell peephole (t-1) connections to forget gate, [numUnits]
|
auto Wcf = INPUT_VARIABLE(6); // weights - cell peephole (t-1) connections to forget gate, [nOut]
|
||||||
auto Wco = INPUT_VARIABLE(7); // weights - cell peephole (t) connections to output gate, [numUnits]
|
auto Wco = INPUT_VARIABLE(7); // weights - cell peephole (t) connections to output gate, [nOut]
|
||||||
auto b = INPUT_VARIABLE(8); // biases, [4*numUnits]
|
auto b = INPUT_VARIABLE(8); // biases, [4*nOut]
|
||||||
|
|
||||||
auto i = OUTPUT_VARIABLE(0); // Output - input modulation gate activations [seqLen, bS, numUnits]
|
auto i = OUTPUT_VARIABLE(0); // Output - input modulation gate activations [seqLen, bS, nOut]
|
||||||
auto c = OUTPUT_VARIABLE(1); // Activations, cell state (pre tanh) [seqLen, bs, numUnits]
|
auto c = OUTPUT_VARIABLE(1); // Activations, cell state (pre tanh) [seqLen, bs, nOut]
|
||||||
auto f = OUTPUT_VARIABLE(2); // Output - forget gate activations [seqLen, bs, numUnits]
|
auto f = OUTPUT_VARIABLE(2); // Output - forget gate activations [seqLen, bs, nOut]
|
||||||
auto o = OUTPUT_VARIABLE(3); // Output - output gate activations [seqLen, bs, numUnits]
|
auto o = OUTPUT_VARIABLE(3); // Output - output gate activations [seqLen, bs, nOut]
|
||||||
auto z = OUTPUT_VARIABLE(4); // Output - input gate activations [seqLen, bs, numUnits]
|
auto z = OUTPUT_VARIABLE(4); // Output - input gate activations [seqLen, bs, nOut]
|
||||||
auto h = OUTPUT_VARIABLE(5); // Cell state, post tanh [seqLen, bs, numUnits]
|
auto h = OUTPUT_VARIABLE(5); // Cell state, post tanh [seqLen, bs, nOut]
|
||||||
auto y = OUTPUT_VARIABLE(6); // current cell output [seqLen, bS, numProj], time t
|
auto y = OUTPUT_VARIABLE(6); // current cell output [seqLen, bS, numProj], time t
|
||||||
|
|
||||||
const int peephole = INT_ARG(0); // if 1, provide peephole connections
|
const int peephole = INT_ARG(0); // if 1, provide peephole connections
|
||||||
const int dataFormat = INT_ARG(1); // 0=TNS=[seqLen,mb,size]; 1=NST=[mb,size,seqLen]; 2=NTS=[mb,seqLen,size]
|
const int dataFormat = INT_ARG(1); // 0=TNS=[seqLen,bS,nIn]; 1=NST=[bS,nIn,seqLen]; 2=NTS=[bS,seqLen,nIn]
|
||||||
const double forgetBias = T_ARG(0);
|
const double forgetBias = T_ARG(0);
|
||||||
const double clippingCellValue = T_ARG(1); // clipping value for ct, if it is not equal to zero, then cell state is clipped
|
const double clippingCellValue = T_ARG(1); // clipping value for ct, if it is not equal to zero, then cell state is clipped
|
||||||
|
|
||||||
@ -86,7 +86,7 @@ DECLARE_SHAPE_FN(lstmBlock) {
|
|||||||
REQUIRE_TRUE(shape::rank(b)==1, 0, "lstmBlock: Biases must be rank 1");
|
REQUIRE_TRUE(shape::rank(b)==1, 0, "lstmBlock: Biases must be rank 1");
|
||||||
|
|
||||||
|
|
||||||
const int dataFormat = INT_ARG(1); // 0=TNS=[seqLen,mb,size]; 1=NST=[mb,size,seqLen]; 2=NTS=[mb,seqLen,size]
|
const int dataFormat = INT_ARG(1); // 0=TNS=[seqLen,bS,size]; 1=NST=[bS,size,seqLen]; 2=NTS=[bS,seqLen,size]
|
||||||
int bs;
|
int bs;
|
||||||
int t;
|
int t;
|
||||||
int nOut = cLast[2]; //rank, bs, nOut, ...]
|
int nOut = cLast[2]; //rank, bs, nOut, ...]
|
||||||
|
@ -65,7 +65,7 @@ namespace nd4j {
|
|||||||
|
|
||||||
std::unique_ptr<ResultSet> rows(x->allTensorsAlongDimension({1}));
|
std::unique_ptr<ResultSet> rows(x->allTensorsAlongDimension({1}));
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR
|
//PRAGMA_OMP_PARALLEL_FOR
|
||||||
for (int r = 0; r < batchSize; r++) {
|
for (int r = 0; r < batchSize; r++) {
|
||||||
auto row = rows->at(r);
|
auto row = rows->at(r);
|
||||||
|
|
||||||
@ -77,14 +77,14 @@ namespace nd4j {
|
|||||||
int denseIdx = sparse2dense.at(idx);
|
int denseIdx = sparse2dense.at(idx);
|
||||||
|
|
||||||
|
|
||||||
float value = row->e<float>(e + 1);
|
float value = row->e<float>(e);
|
||||||
float current = z->e<float>(r, denseIdx);
|
float current = z->e<float>(r, denseIdx);
|
||||||
z->p(r, denseIdx, value + current);
|
z->p(r, denseIdx, value + current);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
STORE_RESULT(*z);
|
//STORE_RESULT(*z);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -93,7 +93,7 @@ namespace nd4j {
|
|||||||
auto inP = inputShape->at(0);
|
auto inP = inputShape->at(0);
|
||||||
|
|
||||||
std::vector<Nd4jLong> shape({shape::shapeOf(inP)[0], (Nd4jLong) block.getIArguments()->size()});
|
std::vector<Nd4jLong> shape({shape::shapeOf(inP)[0], (Nd4jLong) block.getIArguments()->size()});
|
||||||
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(block.dataType(), 'c', shape);
|
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inP), 'c', shape);
|
||||||
return SHAPELIST(newShape);
|
return SHAPELIST(newShape);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -69,8 +69,9 @@ namespace ops {
|
|||||||
|
|
||||||
DECLARE_TYPES(clipbynorm_bp) {
|
DECLARE_TYPES(clipbynorm_bp) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
->setAllowedInputTypes(0, DataType::ANY)
|
||||||
->setAllowedOutputTypes({ALL_FLOATS});
|
->setAllowedInputTypes(1, {ALL_FLOATS})
|
||||||
|
->setAllowedOutputTypes(0, {ALL_FLOATS});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -71,6 +71,7 @@ namespace ops {
|
|||||||
NDArray* rowCounts = NDArrayFactory::create_<int>('c', {N}); //rowP->dup();
|
NDArray* rowCounts = NDArrayFactory::create_<int>('c', {N}); //rowP->dup();
|
||||||
//srowCounts->assign(0);
|
//srowCounts->assign(0);
|
||||||
Nd4jLong len = helpers::barnes_row_count(rowP, colP, N, *rowCounts);
|
Nd4jLong len = helpers::barnes_row_count(rowP, colP, N, *rowCounts);
|
||||||
|
rowCounts->syncToHost();
|
||||||
// rowCounts->printBuffer("Row Counts");
|
// rowCounts->printBuffer("Row Counts");
|
||||||
if (len <= 0) throw std::runtime_error("barnes_symmetrized: Cannot allocate shape due non-positive len.");
|
if (len <= 0) throw std::runtime_error("barnes_symmetrized: Cannot allocate shape due non-positive len.");
|
||||||
rowCountsPtr = rowCounts;
|
rowCountsPtr = rowCounts;
|
||||||
|
@ -91,7 +91,7 @@ namespace nd4j {
|
|||||||
#if NOT_EXCLUDED(OP_batchnorm)
|
#if NOT_EXCLUDED(OP_batchnorm)
|
||||||
DECLARE_CUSTOM_OP(batchnorm, 3, 1, false, 1, 2);
|
DECLARE_CUSTOM_OP(batchnorm, 3, 1, false, 1, 2);
|
||||||
#endif
|
#endif
|
||||||
#if NOT_EXCLUDED(OP_batchnorm)
|
#if NOT_EXCLUDED(OP_batchnorm_new)
|
||||||
DECLARE_CUSTOM_OP(batchnorm_new, 3, 1, false, 1, 2);
|
DECLARE_CUSTOM_OP(batchnorm_new, 3, 1, false, 1, 2);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -614,7 +614,7 @@ namespace nd4j {
|
|||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_space_to_batch)
|
#if NOT_EXCLUDED(OP_space_to_batch)
|
||||||
DECLARE_CUSTOM_OP(space_to_batch, 1, 1, false, 0, -2);
|
DECLARE_CUSTOM_OP(space_to_batch, 2, 1, false, 0, 1);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -622,7 +622,7 @@ namespace nd4j {
|
|||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_batch_to_space)
|
#if NOT_EXCLUDED(OP_batch_to_space)
|
||||||
DECLARE_CUSTOM_OP(batch_to_space, 1, 1, false, 0, -2);
|
DECLARE_CUSTOM_OP(batch_to_space, 2, 1, false, 0, 1);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -1540,7 +1540,7 @@ namespace nd4j {
|
|||||||
* CAUTION: either size tensor or a pair of int params should be provided.
|
* CAUTION: either size tensor or a pair of int params should be provided.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#if NOT_EXCLUDED(OP_resize_bilinear)
|
#if NOT_EXCLUDED(OP_resize_nearest_neighbor)
|
||||||
DECLARE_CUSTOM_OP(resize_nearest_neighbor, 1, 1, false, 0, -2);
|
DECLARE_CUSTOM_OP(resize_nearest_neighbor, 1, 1, false, 0, -2);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -73,9 +73,11 @@ namespace helpers {
|
|||||||
symRowP[0] = 0;
|
symRowP[0] = 0;
|
||||||
for (int n = 0; n < N; n++)
|
for (int n = 0; n < N; n++)
|
||||||
symRowP[n + 1] = symRowP[n] + rowCounts->e<int>(n);
|
symRowP[n + 1] = symRowP[n] + rowCounts->e<int>(n);
|
||||||
|
// outputRows->printBuffer("output rows");
|
||||||
|
|
||||||
int* symColP = reinterpret_cast<int*>(outputCols->buffer());
|
int* symColP = reinterpret_cast<int*>(outputCols->buffer());
|
||||||
// symRowP.p(n + 1, symRowP.e(n) + rowCounts.e(n))
|
// symRowP.p(n + 1, symRowP.e(n) + rowCounts.e(n))
|
||||||
outputRows->printBuffer("SymRows are");
|
// outputRows->printBuffer("SymRows are");
|
||||||
int const* pCols = reinterpret_cast<int const*>(colP->getBuffer());
|
int const* pCols = reinterpret_cast<int const*>(colP->getBuffer());
|
||||||
T const* pVals = reinterpret_cast<T const*>(valP->getBuffer());
|
T const* pVals = reinterpret_cast<T const*>(valP->getBuffer());
|
||||||
T* pOutput = reinterpret_cast<T*>(outputVals->buffer());
|
T* pOutput = reinterpret_cast<T*>(outputVals->buffer());
|
||||||
@ -145,27 +147,28 @@ namespace helpers {
|
|||||||
T* outputP = reinterpret_cast<T*>(output->buffer());
|
T* outputP = reinterpret_cast<T*>(output->buffer());
|
||||||
int colCount = data->columns();
|
int colCount = data->columns();
|
||||||
|
|
||||||
std::vector<T> buffer(colCount);
|
|
||||||
auto shift = 0;
|
// auto shift = 0;
|
||||||
auto rowSize = sizeof(T) * colCount;
|
auto rowSize = sizeof(T) * colCount;
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR
|
||||||
for (int n = 0; n < N; n++) {
|
for (int n = 0; n < N; n++) {
|
||||||
int start = rowP->e<int>(n);
|
int start = rowP->e<int>(n);
|
||||||
int end = rowP->e<int>(n+1);
|
int end = rowP->e<int>(n+1);
|
||||||
|
int shift = n * colCount;
|
||||||
for (int i = start; i < end; i++) {
|
for (int i = start; i < end; i++) {
|
||||||
T const* thisSlice = dataP + colP->e<int>(i) * colCount;
|
T const* thisSlice = dataP + colP->e<int>(i) * colCount;
|
||||||
T res = 1;
|
T res = 1;
|
||||||
|
|
||||||
for (int k = 0; k < colCount; k++) {
|
for (int k = 0; k < colCount; k++) {
|
||||||
buffer[k] = dataP[shift + k] - thisSlice[k];//thisSlice[k];
|
auto tempVal = dataP[shift + k] - thisSlice[k];//thisSlice[k];
|
||||||
res += buffer[k] * buffer[k];
|
res += tempVal * tempVal;
|
||||||
}
|
}
|
||||||
|
|
||||||
res = vals[i] / res;
|
res = vals[i] / res;
|
||||||
for (int k = 0; k < colCount; k++)
|
for (int k = 0; k < colCount; k++)
|
||||||
outputP[shift + k] += (buffer[k] * res);
|
outputP[shift + k] += ((dataP[shift + k] - thisSlice[k]) * res);
|
||||||
}
|
}
|
||||||
shift += colCount;
|
//shift += colCount;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -61,13 +61,6 @@ namespace helpers {
|
|||||||
val = nd4j::math::nd4j_max(x_shape->e<Nd4jLong>(xLen - 1), y_shape->e<Nd4jLong>(e));
|
val = nd4j::math::nd4j_max(x_shape->e<Nd4jLong>(xLen - 1), y_shape->e<Nd4jLong>(e));
|
||||||
}
|
}
|
||||||
|
|
||||||
// if (e)
|
|
||||||
// if (val != output->e<Nd4jLong>(e - 1)) {
|
|
||||||
// nd4j_printf(
|
|
||||||
// "broadcast_dynamic_shape: Input shapes should be compatible, but %lld and %lld were given.\n",
|
|
||||||
// val, output->e<Nd4jLong>(e - 1));
|
|
||||||
// return Status::CODE(ND4J_STATUS_VALIDATION, "broadcast_dynamic_shape: BDS validation failed!");
|
|
||||||
// }
|
|
||||||
output->p(e, val);
|
output->p(e, val);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
42
libnd4j/include/ops/declarable/helpers/cpu/gradient.cpp
Normal file
42
libnd4j/include/ops/declarable/helpers/cpu/gradient.cpp
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author sgazeos@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/axis.h>
|
||||||
|
#include <op_boilerplate.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
template <typename T>
|
||||||
|
static void applyGradientDescent_(NDArray* input, NDArray* step, double weight, NDArray* output) {
|
||||||
|
auto lambda = LAMBDA_TT(_x, _y, weight) {
|
||||||
|
return _x - (_y * weight);
|
||||||
|
};
|
||||||
|
|
||||||
|
input->applyPairwiseLambda<T>(step, lambda, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
void applyGradientDescent(nd4j::LaunchContext* context, NDArray* input, NDArray* step, double weight, NDArray* output) {
|
||||||
|
BUILD_SINGLE_SELECTOR(input->dataType(), applyGradientDescent_, (input, step, weight, output), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void applyGradientDescent_, (NDArray* input, NDArray* step, double weight, NDArray* output), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -70,7 +70,7 @@ namespace helpers {
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
static void leakyReluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) {
|
static void leakyReluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) {
|
||||||
auto functor = LAMBDA_TT(x, y){
|
auto functor = LAMBDA_TT(x, y){
|
||||||
return x >= (T)0.f? T(1.f) : T(0.f);
|
return x >= (T)0.f? y : T(0.f);
|
||||||
};
|
};
|
||||||
|
|
||||||
input->applyPairwiseLambda<T>(epsilon, functor, output);
|
input->applyPairwiseLambda<T>(epsilon, functor, output);
|
||||||
|
@ -44,40 +44,40 @@ namespace helpers {
|
|||||||
void lstmCell(nd4j::LaunchContext * context, const NDArray* xt, const NDArray* ht_1, const NDArray* ct_1, const NDArray* Wx, const NDArray* Wh, const NDArray* Wc, const NDArray* Wp, const NDArray* b,
|
void lstmCell(nd4j::LaunchContext * context, const NDArray* xt, const NDArray* ht_1, const NDArray* ct_1, const NDArray* Wx, const NDArray* Wh, const NDArray* Wc, const NDArray* Wp, const NDArray* b,
|
||||||
NDArray* ht, NDArray* ct, const std::vector<double>& params) {
|
NDArray* ht, NDArray* ct, const std::vector<double>& params) {
|
||||||
|
|
||||||
// xt input [bS x inSize]
|
// xt input [bS x nIn]
|
||||||
// ht_1 previous cell output [bS x numProj], that is at previous time step t-1, in case of projection=false -> numProj=numUnits!!!
|
// ht_1 previous cell output [bS x numProj], that is at previous time step t-1, in case of projection=false -> numProj=nOut!!!
|
||||||
// ct_1 previous cell state [bS x numUnits], that is at previous time step t-1
|
// ct_1 previous cell state [bS x nOut], that is at previous time step t-1
|
||||||
|
|
||||||
// Wx input-to-hidden weights, [inSize x 4*numUnits]
|
// Wx input-to-hidden weights, [nIn x 4*nOut]
|
||||||
// Wh hidden-to-hidden weights, [numProj x 4*numUnits]
|
// Wh hidden-to-hidden weights, [numProj x 4*nOut]
|
||||||
// Wc diagonal weights for peephole connections [3*numUnits]
|
// Wc diagonal weights for peephole connections [3*nOut]
|
||||||
// Wp projection weights [numUnits x numProj]
|
// Wp projection weights [nOut x numProj]
|
||||||
// b biases, [4*numUnits]
|
// b biases, [4*nOut]
|
||||||
|
|
||||||
// ht current cell output [bS x numProj], that is at current time step t
|
// ht current cell output [bS x numProj], that is at current time step t
|
||||||
// ct current cell state [bS x numUnits], that is at current time step t
|
// ct current cell state [bS x nOut], that is at current time step t
|
||||||
|
|
||||||
const bool peephole = (bool)params[0]; // if true, provide peephole connections
|
const bool peephole = (bool)params[0]; // if true, provide peephole connections
|
||||||
const bool projection = (bool)params[1]; // if true, then projection is performed, if false then numProj==numUnits is mandatory!!!!
|
const bool projection = (bool)params[1]; // if true, then projection is performed, if false then numProj==nOut is mandatory!!!!
|
||||||
double clippingCellValue = params[2]; // clipping value for ct, if it is not equal to zero, then cell state is clipped
|
double clippingCellValue = params[2]; // clipping value for ct, if it is not equal to zero, then cell state is clipped
|
||||||
double clippingProjValue = params[3]; // clipping value for projected ht, if it is not equal to zero, then projected cell output is clipped
|
double clippingProjValue = params[3]; // clipping value for projected ht, if it is not equal to zero, then projected cell output is clipped
|
||||||
const double forgetBias = params[4];
|
const double forgetBias = params[4];
|
||||||
|
|
||||||
const int bS = xt->sizeAt(0);
|
const int bS = xt->sizeAt(0);
|
||||||
const int inSize = xt->sizeAt(1);
|
const int nIn = xt->sizeAt(1);
|
||||||
const int numProj = ht_1->sizeAt(1);
|
const int numProj = ht_1->sizeAt(1);
|
||||||
const int numUnits = ct_1->sizeAt(1);
|
const int nOut = ct_1->sizeAt(1);
|
||||||
|
|
||||||
auto z = mmul(*xt, *Wx) + mmul(*ht_1, *Wh) + *b; // [bS x 4*numUnits] + [bS x 4*numUnits] + [1 x 4*numUnits] = [bS x 4*numUnits]
|
auto z = mmul(*xt, *Wx) + mmul(*ht_1, *Wh) + *b; // [bS x 4*nOut] + [bS x 4*nOut] + [1 x 4*nOut] = [bS x 4*nOut]
|
||||||
|
|
||||||
auto zit = z({0,0, 0, numUnits}); // z for input gate, = mmul(Wxi,xt) + mmul(Whi,ht_1) + bi = [bS x numUnits]
|
auto zit = z({0,0, 0,nOut}); // z for input gate, = mmul(Wxi,xt) + mmul(Whi,ht_1) + bi = [bS x nOut]
|
||||||
auto zft = z({0,0, numUnits, 2*numUnits}); // z for forget gate, = mmul(Wxf,xt) + mmul(Whf,ht_1) + bf = [bS x numUnits]
|
auto zft = z({0,0, nOut,2*nOut}); // z for forget gate, = mmul(Wxf,xt) + mmul(Whf,ht_1) + bf = [bS x nOut]
|
||||||
auto zct = z({0,0, 2*numUnits, 3*numUnits}); // z for cell state, = mmul(Wxc,xt) + mmul(Whc,ht_1) + bc = [bS x numUnits]
|
auto zct = z({0,0, 2*nOut,3*nOut}); // z for cell state, = mmul(Wxc,xt) + mmul(Whc,ht_1) + bc = [bS x nOut]
|
||||||
auto zot = z({0,0, 3*numUnits, 4*numUnits}); // z for output gate, = mmul(Wxo,xt) + mmul(Who,ht_1) + bo = [bS x numUnits]
|
auto zot = z({0,0, 3*nOut,4*nOut}); // z for output gate, = mmul(Wxo,xt) + mmul(Who,ht_1) + bo = [bS x nOut]
|
||||||
|
|
||||||
if(peephole) { // add peephole connections: z + ct_1*Wc
|
if(peephole) { // add peephole connections: z + ct_1*Wc
|
||||||
zit += (*ct_1) * (*Wc)({0, numUnits}); // add peephole connections to input gate
|
zit += (*ct_1) * (*Wc)({0, nOut}); // add peephole connections to input gate
|
||||||
zft += (*ct_1) * (*Wc)({numUnits, 2*numUnits}); // add peephole connections to forget gate
|
zft += (*ct_1) * (*Wc)({nOut, 2*nOut}); // add peephole connections to forget gate
|
||||||
}
|
}
|
||||||
|
|
||||||
// current sell state = ft*ct_1 + it*tanh(mmul(Wxc,xt) + mmul(Whc,ht_1) + bc
|
// current sell state = ft*ct_1 + it*tanh(mmul(Wxc,xt) + mmul(Whc,ht_1) + bc
|
||||||
@ -85,20 +85,20 @@ void lstmCell(nd4j::LaunchContext * context, const NDArray* xt, const NDArray* h
|
|||||||
|
|
||||||
// if clipping value is provided then cell state is clipped by this value prior to the cell output activation
|
// if clipping value is provided then cell state is clipped by this value prior to the cell output activation
|
||||||
if(clippingCellValue > 0.0)
|
if(clippingCellValue > 0.0)
|
||||||
clipping(ct, clippingCellValue);
|
ct->applyScalar(scalar::LstmClip, clippingCellValue);
|
||||||
|
|
||||||
if(peephole)
|
if(peephole)
|
||||||
zot += (*ct) * (*Wc)({{2*numUnits, 3*numUnits}}); // add peephole connections to output gate zot + ct*Wc
|
zot += (*ct) * (*Wc)({{2*nOut, 3*nOut}}); // add peephole connections to output gate zot + ct*Wc
|
||||||
|
|
||||||
// current cell output = ot*tanh(ct)
|
// current cell output = ot*tanh(ct)
|
||||||
auto htNoPeepHole = sigmoid(zot) * tanh(*ct); // = [bS x numUnits]
|
auto htNoPeepHole = sigmoid(zot) * tanh(*ct); // = [bS x nOut]
|
||||||
|
|
||||||
// apply projection
|
// apply projection
|
||||||
if(projection) {
|
if(projection) {
|
||||||
ht->assign( mmul(htNoPeepHole, *Wp) ); // [bS x numUnits] * [ numUnits x numProj] = [bS x numProj]
|
ht->assign( mmul(htNoPeepHole, *Wp) ); // [bS x nOut] * [ nOut x numProj] = [bS x numProj]
|
||||||
// if clipping projection is provided then projected cell output state is clipped by this value
|
// if clipping projection is provided then projected cell output state is clipped by this value
|
||||||
if(clippingProjValue != 0.)
|
if(clippingProjValue != 0.)
|
||||||
clipping(ht, clippingProjValue);
|
ht->applyScalar(scalar::LstmClip, clippingProjValue);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
ht->assign(&htNoPeepHole);
|
ht->assign(&htNoPeepHole);
|
||||||
@ -136,14 +136,14 @@ void lstmBlockCell(const NDArray* xt, const NDArray* cLast, const NDArray* yLast
|
|||||||
NDArray* i, NDArray* c, NDArray* f, NDArray* o, NDArray* z, NDArray* h, NDArray* y, const std::vector<double>& params) {
|
NDArray* i, NDArray* c, NDArray* f, NDArray* o, NDArray* z, NDArray* h, NDArray* y, const std::vector<double>& params) {
|
||||||
|
|
||||||
/* Input arrays:
|
/* Input arrays:
|
||||||
* 0: xt - input [bS, inSize] at time t
|
* 0: xt - input [bS, nIn] at time t
|
||||||
* 1: cLast (cs_prev) - previous cell state [bS, numUnits], time t-1
|
* 1: cLast (cs_prev) - previous cell state [bS, nOut], time t-1
|
||||||
* 2: yLast (h_prev) - previous output [bS, numUnits], time t-1
|
* 2: yLast (h_prev) - previous output [bS, nOut], time t-1
|
||||||
* 3: W - Weights - concatenated (input-to-hidden, hidden-to-hidden weights) weights, [(inSize+numUnits), 4*numUnits]
|
* 3: W - Weights - concatenated (input-to-hidden, hidden-to-hidden weights) weights, [(nIn+nOut), 4*nOut]
|
||||||
* 4: Wci - weights - cell peephole (t-1) connections to input modulation gate, [numUnits]
|
* 4: Wci - weights - cell peephole (t-1) connections to input modulation gate, [nOut]
|
||||||
* 5: Wcf - weights - cell peephole (t-1) connections to forget gate, [numUnits]
|
* 5: Wcf - weights - cell peephole (t-1) connections to forget gate, [nOut]
|
||||||
* 6: Wco - weights - cell peephole (t) connections to output gate, [numUnits]
|
* 6: Wco - weights - cell peephole (t) connections to output gate, [nOut]
|
||||||
* 7: b - biases, [4*numUnits]
|
* 7: b - biases, [4*nOut]
|
||||||
*
|
*
|
||||||
* Input integer arguments:
|
* Input integer arguments:
|
||||||
* 0: if not zero, provide peephole connections
|
* 0: if not zero, provide peephole connections
|
||||||
@ -153,38 +153,34 @@ void lstmBlockCell(const NDArray* xt, const NDArray* cLast, const NDArray* yLast
|
|||||||
* 1: clipping value for cell state, if it is not equal to zero, then cell state is clipped
|
* 1: clipping value for cell state, if it is not equal to zero, then cell state is clipped
|
||||||
*
|
*
|
||||||
* Output arrays:
|
* Output arrays:
|
||||||
* 0: i - Input modulation gate activations [bS, numUnits]
|
* 0: i - Input modulation gate activations [bS, nOut]
|
||||||
* 1: c (cs) - Cell state (pre tanh) [bs, numUnits] (cs)
|
* 1: c (cs) - Cell state (pre tanh) [bs, nOut] (cs)
|
||||||
* 2: f - Output - forget gate activations [bs, numUnits]
|
* 2: f - Output - forget gate activations [bs, nOut]
|
||||||
* 3: o - Output - output gate activations [bs, numUnits]
|
* 3: o - Output - output gate activations [bs, nOut]
|
||||||
* 4: z (ci) - Output - block input [bs, numUnits]
|
* 4: z (ci) - Output - block input [bs, nOut]
|
||||||
* 5: h (co) - Cell state, post tanh [bs, numUnits]
|
* 5: h (co) - Cell state, post tanh [bs, nOut]
|
||||||
* 6: y (h) - Current cell output [bS, numUnits], time t
|
* 6: y (h) - Current cell output [bS, nOut], time t
|
||||||
*/
|
*/
|
||||||
const bool peephole = (bool)params[0]; // if true, provide peephole connections
|
const bool peephole = (bool)params[0]; // if true, provide peephole connections
|
||||||
const double forgetBias = params[1];
|
const double forgetBias = params[1];
|
||||||
const double clippingCellValue = params[2]; // clipping value for ct, if it is not equal to zero, then cell state is clipped
|
const double clippingCellValue = params[2]; // clipping value for ct, if it is not equal to zero, then cell state is clipped
|
||||||
|
|
||||||
|
|
||||||
const int bS = xt->sizeAt(0);
|
const int bS = xt->sizeAt(0);
|
||||||
const int inSize = xt->sizeAt(1);
|
const int nIn = xt->sizeAt(1);
|
||||||
const int numUnits = cLast->sizeAt(1);
|
const int nOut = cLast->sizeAt(1);
|
||||||
|
|
||||||
//Concat inputs: [xt, yt-1]: concat([bs,nIn],[bs,nOut]) -> [bs, (nIn+nOut)]
|
//Concat inputs: [xt, yt-1]: concat([bs,nIn],[bs,nOut]) -> [bs, (nIn+nOut)]
|
||||||
auto concatOut = NDArrayFactory::create(xt->ordering(), {xt->sizeAt(0), xt->sizeAt(1) + yLast->sizeAt(1)}, xt->dataType(), xt->getContext());
|
NDArray concatOut(xt->ordering(), {xt->sizeAt(0), xt->sizeAt(1) + yLast->sizeAt(1)}, xt->dataType(), xt->getContext());
|
||||||
helpers::concat(xt->getContext(), {const_cast<NDArray*>(xt), const_cast<NDArray*>(yLast)}, concatOut, {1});
|
helpers::concat(xt->getContext(), {const_cast<NDArray*>(xt), const_cast<NDArray*>(yLast)}, concatOut, {1});
|
||||||
|
|
||||||
//NDArray* NDArrayFactory::create_( const char order, const std::vector<Nd4jLong> &shape, nd4j::DataType dataType, nd4j::memory::Workspace* workspace) {
|
auto m = mmul(concatOut, *W); // mmul: [bs, (nIn+nOut)] * [(nIn+nOut), 4*nOut] = [bs, 4*nOut]
|
||||||
std::vector<Nd4jLong> shape = {bS, 4*numUnits};
|
|
||||||
auto m = NDArrayFactory::create('c', shape, xt->dataType());
|
|
||||||
MmulHelper::mmul(&concatOut, W, &m, 1.0f, 0.0f, 'c'); //mmul: [bs, (nIn+numUnits)]* [(inSize+numUnits), 4*numUnits] = [bs, 4*numUnits] - C result array
|
|
||||||
m += (*b); // addiRowVector
|
m += (*b); // addiRowVector
|
||||||
|
|
||||||
//Note: weights are ordered [inputGate, blockInput, forgetGate, outputGate] to match TF (TF code comments state [i,f,z/ci,o] but behaviour is [i,z,f,o])
|
//Note: weights are ordered [inputGate, blockInput, forgetGate, outputGate] to match TF (TF code comments state [i,f,z/ci,o] but behaviour is [i,z,f,o])
|
||||||
auto zi = (m)({0,0, 0, numUnits}); // z for input modulation gate, [bS, numUnits]
|
auto zi = m({0,0, 0, nOut}); // z for input modulation gate, [bS, nOut]
|
||||||
auto zz = (m)({0,0, numUnits, 2*numUnits}); // z for block input, [bS, numUnits]
|
auto zz = m({0,0, nOut, 2*nOut}); // z for block input, [bS, nOut]
|
||||||
auto zf = (m)({0,0, 2*numUnits, 3*numUnits}); // z for forget gate, [bS, numUnits]
|
auto zf = m({0,0, 2*nOut, 3*nOut}); // z for forget gate, [bS, nOut]
|
||||||
auto zo = (m)({0,0, 3*numUnits, 4*numUnits}); // z for output gate, [bS, numUnits]
|
auto zo = m({0,0, 3*nOut, 4*nOut}); // z for output gate, [bS, nOut]
|
||||||
|
|
||||||
if(peephole) { // add peephole connections: z + ct_1*Wc
|
if(peephole) { // add peephole connections: z + ct_1*Wc
|
||||||
zi += (*cLast) * (*Wci); // add peephole connections to input gate
|
zi += (*cLast) * (*Wci); // add peephole connections to input gate
|
||||||
@ -192,9 +188,8 @@ void lstmBlockCell(const NDArray* xt, const NDArray* cLast, const NDArray* yLast
|
|||||||
}
|
}
|
||||||
|
|
||||||
// current sell state = ft*cLast + it*tanh(mmul(Wxc,xt) + mmul(Whc,ht_1) + bc
|
// current sell state = ft*cLast + it*tanh(mmul(Wxc,xt) + mmul(Whc,ht_1) + bc
|
||||||
if(forgetBias != 0.0){
|
if(forgetBias != 0.0)
|
||||||
zf += forgetBias;
|
zf += forgetBias;
|
||||||
}
|
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL
|
PRAGMA_OMP_PARALLEL
|
||||||
PRAGMA_OMP_SINGLE
|
PRAGMA_OMP_SINGLE
|
||||||
@ -209,7 +204,6 @@ void lstmBlockCell(const NDArray* xt, const NDArray* cLast, const NDArray* yLast
|
|||||||
zf.applyTransform(transform::Sigmoid, f); //f = sigmoid(zf);
|
zf.applyTransform(transform::Sigmoid, f); //f = sigmoid(zf);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if (z->ews() == 1 && i->ews() == 1 && c->ews() == 1 && cLast->ews() == 1 && f->ews() == 1 && h->ews() == 1 &&
|
if (z->ews() == 1 && i->ews() == 1 && c->ews() == 1 && cLast->ews() == 1 && f->ews() == 1 && h->ews() == 1 &&
|
||||||
z->ordering() == i->ordering() && z->ordering() == c->ordering() && z->ordering() == cLast->ordering() && z->ordering() == f->ordering() && z->ordering() == h->ordering()) {
|
z->ordering() == i->ordering() && z->ordering() == c->ordering() && z->ordering() == cLast->ordering() && z->ordering() == f->ordering() && z->ordering() == h->ordering()) {
|
||||||
//cell state = blockInput .* inputGate + prevCellState .* forgetGate
|
//cell state = blockInput .* inputGate + prevCellState .* forgetGate
|
||||||
@ -223,15 +217,15 @@ void lstmBlockCell(const NDArray* xt, const NDArray* cLast, const NDArray* yLast
|
|||||||
}
|
}
|
||||||
|
|
||||||
// if clipping value is provided then cell state is clipped by this value prior to the cell output activation
|
// if clipping value is provided then cell state is clipped by this value prior to the cell output activation
|
||||||
if(clippingCellValue > 0.0) {
|
if(clippingCellValue > 0.0)
|
||||||
clipping(c, clippingCellValue);
|
c->applyScalar(scalar::LstmClip, clippingCellValue);
|
||||||
}
|
|
||||||
|
|
||||||
if(peephole) {
|
|
||||||
// add peephole connections to output gate zot + ct*Wc
|
// add peephole connections to output gate zot + ct*Wc
|
||||||
|
if(peephole) {
|
||||||
auto prod = *c * (*Wco);
|
auto prod = *c * (*Wco);
|
||||||
zo += prod;
|
zo += prod;
|
||||||
}
|
}
|
||||||
|
|
||||||
zo.applyTransform(transform::Sigmoid, o); // o = sigmoid(zo)
|
zo.applyTransform(transform::Sigmoid, o); // o = sigmoid(zo)
|
||||||
|
|
||||||
// current cell output = ot*tanh(ct)
|
// current cell output = ot*tanh(ct)
|
||||||
|
@ -15,7 +15,8 @@
|
|||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// Created by raver119 on 19.01.18.
|
// @author raver119@gmail.com, created on 19.01.18.
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <ops/declarable/helpers/s_t_b.h>
|
#include <ops/declarable/helpers/s_t_b.h>
|
||||||
@ -25,6 +26,146 @@ namespace ops {
|
|||||||
namespace helpers {
|
namespace helpers {
|
||||||
|
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename T>
|
||||||
|
static void batchToSpace_(const NDArray& input, NDArray& output, const uint cropBottom, const uint cropTop, const uint cropLeft, const uint cropRight) {
|
||||||
|
|
||||||
|
// input [bS, H * blockSize, W * blockSize, iC]
|
||||||
|
// output [bS, H * blockSize - cropBottom - cropTop, W * blockSize - cropLeft - cropRight, iC]
|
||||||
|
|
||||||
|
// if (cropTop = cropBottom = cropRight = cropLeft = 0) shapes are the same
|
||||||
|
// else:
|
||||||
|
// oH -> [cropBottom, iH - cropTop]
|
||||||
|
// oW -> [cropLeft, iH - cropRight]
|
||||||
|
// xLen > zLen
|
||||||
|
|
||||||
|
const T* x = input.bufferAsT<T>();
|
||||||
|
T* z = output.bufferAsT<T>();
|
||||||
|
|
||||||
|
const int rank = 4;
|
||||||
|
|
||||||
|
const Nd4jLong* xShapeInfo = input.getShapeInfo();
|
||||||
|
const Nd4jLong* zShapeInfo = output.getShapeInfo();
|
||||||
|
|
||||||
|
const uint bS = xShapeInfo[1];
|
||||||
|
const uint iH = xShapeInfo[2];
|
||||||
|
const uint iW = xShapeInfo[3];
|
||||||
|
const uint iC = xShapeInfo[4];
|
||||||
|
|
||||||
|
// loop through output array
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(4))
|
||||||
|
for (uint b = 0; b < bS; ++b) {
|
||||||
|
for (uint h = cropBottom; h < iH - cropTop; ++h) {
|
||||||
|
for (uint w = cropLeft; w < iW - cropRight; ++w) {
|
||||||
|
for (uint c = 0; c < iC; ++c) {
|
||||||
|
|
||||||
|
const Nd4jLong xOffset = b * xShapeInfo[5] + h * xShapeInfo[6] + w * xShapeInfo[7] + c * xShapeInfo[8];
|
||||||
|
|
||||||
|
const Nd4jLong zOffset = b * zShapeInfo[5] + (h - cropBottom) * zShapeInfo[6] + (w - cropLeft) * zShapeInfo[7] + c * zShapeInfo[8];
|
||||||
|
|
||||||
|
z[zOffset] = x[xOffset];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void batchToSpace_, (const NDArray& input, NDArray& output, const uint cropBottom, const uint cropTop, const uint cropLeft, const uint cropRight), LIBND4J_TYPES);
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
void batchToSpace(nd4j::LaunchContext* context, const NDArray& input, NDArray& output, const uint cropBottom, const uint cropTop, const uint cropLeft, const uint cropRight, const uint blockSize) {
|
||||||
|
|
||||||
|
// [bS*blockSize*blockSize, H/blockSize, W/blockSize, iC] is rearranged/permuted to [bS, oH, oW, iC]
|
||||||
|
// oH = H - cropTop - cropBottom
|
||||||
|
// oW = W - cropLeft - cropRight
|
||||||
|
|
||||||
|
NDArray inputRearranged0 = input.reshape(input.ordering(), {blockSize, blockSize, output.sizeAt(0), input.sizeAt(1), input.sizeAt(2), input.sizeAt(3)});
|
||||||
|
inputRearranged0.permutei({2, 3,0, 4,1, 5});
|
||||||
|
|
||||||
|
if(input.lengthOf() == output.lengthOf())
|
||||||
|
output.assign(inputRearranged0);
|
||||||
|
else {
|
||||||
|
NDArray inputRearranged1 = inputRearranged0.reshape(input.ordering(), {output.sizeAt(0), input.sizeAt(1) * blockSize, input.sizeAt(2) * blockSize, input.sizeAt(3)});
|
||||||
|
BUILD_SINGLE_SELECTOR(input.dataType(), batchToSpace_, (inputRearranged1, output, cropBottom, cropTop, cropLeft, cropRight), LIBND4J_TYPES);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename T>
|
||||||
|
static void spaceToBatch_(const NDArray& input, NDArray& output, const uint padBottom, const uint padTop, const uint padLeft, const uint padRight) {
|
||||||
|
|
||||||
|
// input [bS, H * blockSize - padBottom - padTop, W * blockSize - padLeft - padRight, iC]
|
||||||
|
// output [bs, H * blockSize, W * blockSize, iC]
|
||||||
|
|
||||||
|
// if (padTop = padBottom = padRight = padLeft = 0) shapes are the same
|
||||||
|
// else:
|
||||||
|
// iH -> [padBottom, oH - padTop]
|
||||||
|
// iW -> [padLeft, oW - padRight]
|
||||||
|
// zLen > xLen
|
||||||
|
|
||||||
|
const T* x = input.bufferAsT<T>();
|
||||||
|
T* z = output.bufferAsT<T>();
|
||||||
|
|
||||||
|
const int rank = 4;
|
||||||
|
|
||||||
|
const Nd4jLong* xShapeInfo = input.getShapeInfo();
|
||||||
|
const Nd4jLong* zShapeInfo = output.getShapeInfo();
|
||||||
|
|
||||||
|
const uint bS = zShapeInfo[1];
|
||||||
|
const uint oH = zShapeInfo[2];
|
||||||
|
const uint oW = zShapeInfo[3];
|
||||||
|
const uint iC = zShapeInfo[4];
|
||||||
|
|
||||||
|
// loop through output array
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(4))
|
||||||
|
for (uint b = 0; b < bS; ++b) {
|
||||||
|
for (uint h = 0; h < oH; ++h) {
|
||||||
|
for (uint w = 0; w < oW; ++w) {
|
||||||
|
for (uint c = 0; c < iC; ++c) {
|
||||||
|
|
||||||
|
const Nd4jLong zOffset = b * zShapeInfo[5] + h * zShapeInfo[6] + w * zShapeInfo[7] + c * zShapeInfo[8];
|
||||||
|
|
||||||
|
if(h >= padBottom && h < oH - padTop && w >= padLeft && w < oW - padRight) {
|
||||||
|
const Nd4jLong xOffset = b * xShapeInfo[5] + (h - padBottom) * xShapeInfo[6] + (w - padLeft) * xShapeInfo[7] + c * xShapeInfo[8];
|
||||||
|
z[zOffset] = x[xOffset];
|
||||||
|
}
|
||||||
|
else
|
||||||
|
z[zOffset] = 0.f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void spaceToBatch_, (const NDArray& input, NDArray& output, const uint padBottom, const uint padTop, const uint padLeft, const uint padRight), LIBND4J_TYPES);
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
void spaceToBatch(nd4j::LaunchContext* context, const NDArray& input, NDArray& output, const uint padBottom, const uint padTop, const uint padLeft, const uint padRight, const uint blockSize) {
|
||||||
|
|
||||||
|
// [bS, iH, iW, iC] is rearranged/permuted to [bS*blockSize*blockSize, (iH + padBottom + padTop)/blockSize, (iW + padLeft + padRight)/blockSize, iC]
|
||||||
|
|
||||||
|
NDArray outputRearranged0 = output.reshape(output.ordering(), {blockSize, blockSize, input.sizeAt(0), output.sizeAt(1), output.sizeAt(2), input.sizeAt(3)});
|
||||||
|
outputRearranged0.permutei({2, 3,0, 4,1, 5});
|
||||||
|
|
||||||
|
if(input.lengthOf() == output.lengthOf()) {
|
||||||
|
outputRearranged0.assign(input);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), {input.sizeAt(0), output.sizeAt(1) * blockSize, output.sizeAt(2) * blockSize, input.sizeAt(3)});
|
||||||
|
BUILD_SINGLE_SELECTOR(input.dataType(), spaceToBatch_, (input, outputRearranged1, padBottom, padTop, padLeft, padRight), LIBND4J_TYPES);
|
||||||
|
|
||||||
|
if(output.getBuffer() != outputRearranged1.getBuffer())
|
||||||
|
outputRearranged0.assign(outputRearranged1);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
/*
|
||||||
template <int N, bool B2S>
|
template <int N, bool B2S>
|
||||||
struct SpaceToBatchHelper {
|
struct SpaceToBatchHelper {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -124,6 +265,8 @@ namespace helpers {
|
|||||||
|
|
||||||
#undef STB_BOOL
|
#undef STB_BOOL
|
||||||
#undef STB_DIM
|
#undef STB_DIM
|
||||||
|
*/
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -30,14 +30,14 @@ namespace helpers {
|
|||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static void stack_(const std::vector<NDArray*>& inArrs, NDArray* outArr, const int dim) {
|
static void stack_(const std::vector<const NDArray*>& inArrs, NDArray* outArr, const int dim) {
|
||||||
|
|
||||||
if(inArrs[0]->rankOf() == 0) {
|
if(inArrs[0]->rankOf() == 0) {
|
||||||
int inSize = inArrs.size();
|
int inSize = inArrs.size();
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_IF(inSize > Environment::getInstance()->tadThreshold())
|
PRAGMA_OMP_PARALLEL_FOR_IF(inSize > Environment::getInstance()->tadThreshold())
|
||||||
for(int i=0; i < inSize; ++i)
|
for(int i=0; i < inSize; ++i)
|
||||||
outArr->p(i, inArrs[i]->e<T>(0));
|
outArr->p<T>(i, inArrs[i]->t<T>(0));
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
@ -53,11 +53,11 @@ static void stack_(const std::vector<NDArray*>& inArrs, NDArray* outArr, const i
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void stack(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray* outArr, const int dim) {
|
void stack(nd4j::LaunchContext * context, const std::vector<const NDArray*>& inArrs, NDArray* outArr, const int dim) {
|
||||||
BUILD_SINGLE_SELECTOR(outArr->dataType(), stack_, (inArrs, outArr, dim), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR(outArr->dataType(), stack_, (inArrs, outArr, dim), LIBND4J_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void stack_ , (const std::vector<NDArray*>& inArrs, NDArray* outArr, const int dim), LIBND4J_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void stack_ , (const std::vector<const NDArray*>& inArrs, NDArray* outArr, const int dim), LIBND4J_TYPES);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -559,68 +559,73 @@ void invertPermutation(nd4j::LaunchContext * context, const NDArray& input, NDAr
|
|||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
template<typename T>
|
template<typename X, typename Y>
|
||||||
static void gatherND_(NDArray& input, NDArray& indices, NDArray& output) {
|
static void gatherND_(NDArray& input, NDArray& indices, NDArray& output) {
|
||||||
|
|
||||||
if (input.ordering() != 'c')
|
const X* x = reinterpret_cast<X*>(input.getBuffer());
|
||||||
input.streamline('c');
|
const Y* y = reinterpret_cast<Y*>(indices.getBuffer());
|
||||||
|
X* z = reinterpret_cast<X*>(output.getBuffer());
|
||||||
|
|
||||||
if (indices.ordering() != 'c')
|
const int xRank = input.rankOf();
|
||||||
indices.streamline('c');
|
const int yRank = indices.rankOf();
|
||||||
|
const int zRank = output.rankOf();
|
||||||
|
const int maxRank = nd4j::math::nd4j_max<int>(yRank, nd4j::math::nd4j_max<int>(xRank, zRank));
|
||||||
|
|
||||||
const int rankIn = input.rankOf();
|
const Nd4jLong zLen = output.lengthOf();
|
||||||
const int rankInd = indices.rankOf();
|
|
||||||
const int lastIndDim = indices.sizeAt(-1);
|
|
||||||
|
|
||||||
std::vector<int> tadDims(rankIn - lastIndDim);
|
const int yLastDim = indices.sizeAt(-1);
|
||||||
std::iota(tadDims.begin(), tadDims.end(), rankInd-1);
|
|
||||||
auto innerMostOut = output.allTensorsAlongDimension(tadDims);
|
|
||||||
|
|
||||||
auto innerMostInd = indices.allTensorsAlongDimension({rankInd-1});
|
std::vector<Nd4jLong> coords(maxRank);
|
||||||
|
|
||||||
std::iota(tadDims.begin(), tadDims.end(), lastIndDim);
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(if(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(coords))
|
||||||
auto innerMostIn = input.allTensorsAlongDimension(tadDims);
|
for (Nd4jLong i = 0; i < zLen; ++i) {
|
||||||
|
|
||||||
Nd4jLong* outerShapeInfo = nullptr;
|
Nd4jLong *zCoordStart, *xCoordStart;
|
||||||
ALLOCATE(outerShapeInfo, input.getContext()->getWorkspace(), shape::shapeInfoLength(lastIndDim), Nd4jLong);
|
|
||||||
outerShapeInfo[0] = lastIndDim;
|
|
||||||
for(int i = 1; i <= lastIndDim; ++i)
|
|
||||||
outerShapeInfo[i] = input.sizeAt(i-1);
|
|
||||||
shape::updateStrides(outerShapeInfo, input.ordering());
|
|
||||||
|
|
||||||
Nd4jLong idx[MAX_RANK];
|
if(yLastDim == xRank) {
|
||||||
|
zCoordStart = coords.data();
|
||||||
for(int i = 0; i < innerMostInd->size(); ++i) {
|
xCoordStart = coords.data();
|
||||||
|
}
|
||||||
auto idxSubArr = innerMostInd->at(i);
|
else if(zRank >= xRank) {
|
||||||
|
zCoordStart = coords.data();
|
||||||
for(int j = 0; j < lastIndDim; ++j) {
|
xCoordStart = coords.data() + zRank - xRank;
|
||||||
if(idxSubArr->e<Nd4jLong>(j) >= input.sizeAt(j))
|
}
|
||||||
throw std::runtime_error("helpers::gatherND function: indices array contains wrong elements, each element must be smaller than corresponding dimension of input array !");
|
else {
|
||||||
idx[j] = idxSubArr->e<Nd4jLong>(j);
|
zCoordStart = coords.data() + xRank - zRank;
|
||||||
|
xCoordStart = coords.data();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto currentInd0 = shape::getOffset(0, shape::shapeOf(outerShapeInfo), shape::stride(outerShapeInfo), idx, lastIndDim);
|
shape::index2coords(zRank, output.shapeOf(), i, zLen, zCoordStart);
|
||||||
|
|
||||||
if(rankIn != lastIndDim) {
|
const auto zOffset = shape::getOffset(0, output.shapeOf(), output.stridesOf(), zCoordStart, zRank);
|
||||||
auto outSubArr = innerMostOut->at(i);
|
|
||||||
outSubArr->assign(innerMostIn->at(currentInd0));
|
// last y coordinate
|
||||||
|
uint coordToRestore;
|
||||||
|
if(yLastDim != xRank)
|
||||||
|
coordToRestore = static_cast<uint>(zCoordStart[yRank - 1]);
|
||||||
|
|
||||||
|
zCoordStart[yRank - 1] = 0;
|
||||||
|
const auto yOffset = shape::getOffset(0, indices.shapeOf(), indices.stridesOf(), zCoordStart, yRank);
|
||||||
|
|
||||||
|
//restore z coordinate
|
||||||
|
if(yLastDim != xRank)
|
||||||
|
zCoordStart[yRank - 1] = coordToRestore;
|
||||||
|
|
||||||
|
// construct coordinates for x
|
||||||
|
for(uint j = 0; j < yLastDim; ++j)
|
||||||
|
xCoordStart[j] = y[yOffset + j * indices.stridesOf()[yRank - 1]]; // last stride
|
||||||
|
|
||||||
|
const auto xOffset = shape::getOffset(0, input.shapeOf(), input.stridesOf(), xCoordStart, xRank);
|
||||||
|
|
||||||
|
z[zOffset] = x[xOffset];
|
||||||
}
|
}
|
||||||
else
|
|
||||||
output.p(i, input.e<T>(currentInd0));
|
|
||||||
}
|
|
||||||
|
|
||||||
delete innerMostInd;
|
|
||||||
delete innerMostIn;
|
|
||||||
delete innerMostOut;
|
|
||||||
RELEASE(outerShapeInfo, input.getContext()->getWorkspace());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
void gatherND(nd4j::LaunchContext * context, NDArray& input, NDArray& indices, NDArray& output) {
|
void gatherND(nd4j::LaunchContext * context, NDArray& input, NDArray& indices, NDArray& output) {
|
||||||
BUILD_SINGLE_SELECTOR(input.dataType(), gatherND_, (input, indices, output), LIBND4J_TYPES);
|
BUILD_DOUBLE_SELECTOR(input.dataType(), indices.dataType(), gatherND_, (input, indices, output), LIBND4J_TYPES, INTEGER_TYPES);
|
||||||
}
|
}
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template void gatherND_, (NDArray& input, NDArray& indices, NDArray& output), LIBND4J_TYPES, INTEGER_TYPES);
|
||||||
BUILD_SINGLE_TEMPLATE(template void gatherND_, (NDArray& input, NDArray& indices, NDArray& output), LIBND4J_TYPES);
|
|
||||||
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
@ -900,66 +905,85 @@ template<typename T>
|
|||||||
static void clipByNorm_(NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace) {
|
static void clipByNorm_(NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace) {
|
||||||
|
|
||||||
const int rank = input.rankOf();
|
const int rank = input.rankOf();
|
||||||
auto norm2 = input.reduceAlongDims(reduce::Norm2, dimensions);
|
const auto norm2 = input.reduceAlongDims(reduce::Norm2, dimensions);
|
||||||
|
|
||||||
|
const T normActual = norm2.e<T>(0);
|
||||||
|
const T normClip = clipNorm.e<T>(0);
|
||||||
|
|
||||||
if (isInplace) {
|
if (isInplace) {
|
||||||
|
|
||||||
if(norm2.lengthOf() == 1) {
|
if(norm2.lengthOf() == 1) {
|
||||||
|
|
||||||
if(norm2.e<T>(0) > clipNorm.e<T>(0))
|
if(normActual > normClip)
|
||||||
input *= (clipNorm.e<T>(0) / norm2.e<T>(0));
|
input *= (normClip / normActual);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
std::vector<int> dimsToExclude = ShapeUtils::evalDimsToExclude(rank, dimensions);
|
auto listOfInSubArrs = input.allTensorsAlongDimension(dimensions);
|
||||||
const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(input.getShapeInfo(), dimsToExclude);
|
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR
|
PRAGMA_OMP_PARALLEL_FOR
|
||||||
for(Nd4jLong i = 0; i < numOfSubArrs; ++i) {
|
for(Nd4jLong i = 0; i < listOfInSubArrs->size(); ++i) {
|
||||||
if (norm2.e<T>(i) > clipNorm.e<T>(0)) {
|
|
||||||
|
|
||||||
auto inputSubArr = input(i, dimsToExclude);
|
const T iNormActual = norm2.e<T>(i);
|
||||||
inputSubArr *= (clipNorm.e<T>(0) / norm2.e<T>(i));
|
|
||||||
}
|
if (iNormActual > normClip)
|
||||||
|
*listOfInSubArrs->at(i) *= normClip / iNormActual;
|
||||||
}
|
}
|
||||||
|
delete listOfInSubArrs;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
if(norm2.lengthOf() == 1) {
|
if(norm2.lengthOf() == 1) {
|
||||||
|
|
||||||
if(norm2.e<T>(0) > clipNorm.e<T>(0))
|
if(normActual > normClip)
|
||||||
output.assign( input * (clipNorm / norm2.e<T>(0)));
|
output.assign(input * (normClip / normActual));
|
||||||
else
|
else
|
||||||
output.assign(input);
|
output.assign(input);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
std::vector<int> dimsToExclude = ShapeUtils::evalDimsToExclude(rank, dimensions);
|
auto listOfInSubArrs = input.allTensorsAlongDimension(dimensions);
|
||||||
const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(input.getShapeInfo(), dimsToExclude);
|
auto listOfOutSubArrs = output.allTensorsAlongDimension(dimensions);
|
||||||
std::vector<Nd4jLong> idxRanges(rank * 2);
|
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(firstprivate(idxRanges))
|
PRAGMA_OMP_PARALLEL_FOR
|
||||||
for(Nd4jLong i = 0; i < numOfSubArrs; ++i) {
|
for(Nd4jLong i = 0; i < listOfInSubArrs->size(); ++i) {
|
||||||
|
|
||||||
ShapeUtils::evalIdxRangesForSubArr(i, input.getShapeInfo(), dimsToExclude, idxRanges.data());
|
auto inputSubArr = listOfInSubArrs->at(i);
|
||||||
|
auto outputSubArr = listOfOutSubArrs->at(i);
|
||||||
|
outputSubArr->assign(inputSubArr);
|
||||||
|
|
||||||
auto outputSubArr = output(idxRanges);
|
const T iNormActual = norm2.e<T>(i);
|
||||||
auto inputSubArr = input(idxRanges);
|
|
||||||
outputSubArr.assign(inputSubArr);
|
|
||||||
|
|
||||||
if (norm2.e<T>(i) > clipNorm.e<T>(0))
|
if (iNormActual > clipNorm.e<T>(0))
|
||||||
outputSubArr *= clipNorm / norm2.e<T>(i);
|
*outputSubArr *= clipNorm / iNormActual;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
delete listOfInSubArrs;
|
||||||
|
delete listOfOutSubArrs;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void clipByNorm(nd4j::LaunchContext * context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace) {
|
void clipByNorm(nd4j::LaunchContext * context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace) {
|
||||||
BUILD_SINGLE_SELECTOR(output.dataType(), clipByNorm_, (input, output, dimensions, clipNorm, isInplace), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(output.dataType(), clipByNorm_, (input, output, dimensions, clipNorm, isInplace), FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void clipByNorm_, (NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace), FLOAT_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void clipByNorm_, (NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace), FLOAT_TYPES);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static void clipByGlobalNorm_(std::vector<NDArray*> const& inputs, double clipNorm, nd4j::memory::Workspace* workspace, std::vector<NDArray*>& outputs, bool isInplace) {
|
static void clipByGlobalNorm_(std::vector<NDArray*> const& inputs, double clipNorm, nd4j::memory::Workspace* workspace, std::vector<NDArray*>& outputs, bool isInplace) {
|
||||||
NDArray globalNorm = NDArrayFactory::create<T>(0, inputs[0]->getContext()); //sqrt(sum([l2norm(t)**2 for t in t_list]))
|
NDArray globalNorm = NDArrayFactory::create<T>(0, inputs[0]->getContext()); //sqrt(sum([l2norm(t)**2 for t in t_list]))
|
||||||
@ -1026,37 +1050,41 @@ static void clipByNormBP_(const NDArray& input, const NDArray& gradO, NDArray& g
|
|||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
std::vector<int> dimsToExclude = ShapeUtils::evalDimsToExclude(rank, dimensions);
|
const auto gradISubArrs = gradI.allTensorsAlongDimension({dimensions});
|
||||||
const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(input.getShapeInfo(), dimsToExclude);
|
const auto gradOSubArrs = gradO.allTensorsAlongDimension({dimensions});
|
||||||
std::vector<Nd4jLong> idxRanges(rank * 2);
|
const auto inputSubArrs = input.allTensorsAlongDimension({dimensions});
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(firstprivate(idxRanges))
|
|
||||||
for(Nd4jLong i = 0; i < numOfSubArrs; ++i) {
|
|
||||||
|
|
||||||
ShapeUtils::evalIdxRangesForSubArr(i, input.getShapeInfo(), dimsToExclude, idxRanges.data());
|
|
||||||
T N = norm2.e<T>(i);
|
|
||||||
|
|
||||||
auto gradOSubArr = gradO(idxRanges);
|
|
||||||
auto gradISubArr = gradI(idxRanges);
|
|
||||||
|
|
||||||
auto cn = clipNorm.e<T>(0);
|
auto cn = clipNorm.e<T>(0);
|
||||||
|
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR
|
||||||
|
for(Nd4jLong i = 0; i < gradISubArrs->size(); ++i) {
|
||||||
|
|
||||||
|
T N = norm2.e<T>(i);
|
||||||
|
|
||||||
|
auto gradOSubArr = gradOSubArrs->at(i);
|
||||||
|
auto gradISubArr = gradISubArrs->at(i);
|
||||||
|
|
||||||
if (N > cn) {
|
if (N > cn) {
|
||||||
|
|
||||||
auto inputSubArr = input(idxRanges);
|
auto inputSubArr = inputSubArrs->at(i);
|
||||||
|
|
||||||
const T sumOfProd = (inputSubArr * gradOSubArr).reduceNumber(reduce::Sum).e<T>(0); // reduce to scalar
|
const T sumOfProd = (*inputSubArr * *gradOSubArr).reduceNumber(reduce::Sum).e<T>(0); // reduce to scalar
|
||||||
const T factor1 = static_cast<T>(1.f) / N;
|
const T factor1 = static_cast<T>(1.f) / N;
|
||||||
const T factor3 = factor1 / (N * N) ; // 1 / (N*N*N)
|
const T factor3 = factor1 / (N * N) ; // 1 / (N*N*N)
|
||||||
|
|
||||||
auto lambda = LAMBDA_TT(elem1, elem2, cn, sumOfProd, factor1, factor3) {
|
auto lambda = LAMBDA_TT(elem1, elem2, cn, sumOfProd, factor1, factor3) {
|
||||||
return cn * (factor1 * elem2 - factor3 * elem1 * sumOfProd);
|
return cn * (factor1 * elem2 - factor3 * elem1 * sumOfProd);
|
||||||
};
|
};
|
||||||
inputSubArr.applyPairwiseLambda<T>(&gradOSubArr, lambda, &gradISubArr);
|
|
||||||
|
inputSubArr->applyPairwiseLambda<T>(gradOSubArr, lambda, gradISubArr);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
gradISubArr.assign(gradOSubArr);
|
gradISubArr->assign(gradOSubArr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
delete gradISubArrs;
|
||||||
|
delete gradOSubArrs;
|
||||||
|
delete inputSubArrs;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -24,12 +24,10 @@ namespace nd4j {
|
|||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
|
|
||||||
Nd4jLong barnes_row_count(const NDArray* rowP, const NDArray* colP, Nd4jLong N, NDArray& rowCounts) {
|
static __global__ void countRowsKernel(int* pRowCounts, int const* pRows, int const* pCols, Nd4jLong N) {
|
||||||
|
auto start = blockIdx.x * blockDim.x;
|
||||||
int* pRowCounts = reinterpret_cast<int*>(rowCounts.buffer());
|
auto step = blockDim.x * gridDim.x;
|
||||||
int const* pRows = reinterpret_cast<int const*>(rowP->getBuffer());
|
for (int n = threadIdx.x + start; n < N; n += step) {
|
||||||
int const* pCols = reinterpret_cast<int const*>(colP->getBuffer());
|
|
||||||
for (int n = 0; n < N; n++) {
|
|
||||||
int begin = pRows[n];//->e<int>(n);
|
int begin = pRows[n];//->e<int>(n);
|
||||||
int end = pRows[n + 1];//rowP->e<int>(n + 1);
|
int end = pRows[n + 1];//rowP->e<int>(n + 1);
|
||||||
for (int i = begin; i < end; i++) {
|
for (int i = begin; i < end; i++) {
|
||||||
@ -40,51 +38,189 @@ namespace helpers {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
++pRowCounts[n];
|
atomicAdd(&pRowCounts[n], 1);
|
||||||
|
|
||||||
if (!present)
|
if (!present)
|
||||||
++pRowCounts[pCols[i]];
|
atomicAdd(&pRowCounts[pCols[i]], 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
Nd4jLong barnes_row_count(const NDArray* rowP, const NDArray* colP, Nd4jLong N, NDArray& rowCounts) {
|
||||||
|
|
||||||
|
int* pRowCounts = reinterpret_cast<int*>(rowCounts.specialBuffer());
|
||||||
|
int const* pRows = reinterpret_cast<int const*>(rowP->getSpecialBuffer());
|
||||||
|
int const* pCols = reinterpret_cast<int const*>(colP->getSpecialBuffer());
|
||||||
|
auto stream = rowCounts.getContext()->getCudaStream();
|
||||||
|
countRowsKernel<<<1, 1, 128, *stream>>>(pRowCounts, pRows, pCols, N);
|
||||||
NDArray numElementsArr = rowCounts.sumNumber(); //reduceAlongDimension(reduce::Sum, {});
|
NDArray numElementsArr = rowCounts.sumNumber(); //reduceAlongDimension(reduce::Sum, {});
|
||||||
//rowCounts.printBuffer("Row counts");
|
//rowCounts.printBuffer("Row counts");
|
||||||
auto numElements = numElementsArr.e<Nd4jLong>(0);
|
auto numElements = numElementsArr.e<Nd4jLong>(0);
|
||||||
return numElements;
|
return numElements;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __global__ void fillUpsymRow(int const* pRowCounts, int* symRowP, int N) {
|
||||||
|
|
||||||
|
auto start = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
auto step = blockDim.x * gridDim.x;
|
||||||
|
|
||||||
|
for (int n = start; n < N + 1; n += step) {
|
||||||
|
symRowP[n] = 0;
|
||||||
|
for (int i = 0; i < n; i++)
|
||||||
|
atomicAdd(&symRowP[n], pRowCounts[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static __global__ void symmetrizeKernel(int const* pRows, int const* pCols, T const* pVals, int* symRowP, int* symColP, int* offset, T* pOutput, int N) {
|
||||||
|
auto start = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
auto step = blockDim.x * gridDim.x;
|
||||||
|
|
||||||
|
for (int n = start; n < N; n += step) {
|
||||||
|
int begin = pRows[n];
|
||||||
|
int bound = pRows[n + 1];
|
||||||
|
|
||||||
|
for (int i = begin; i < bound; i++) {
|
||||||
|
bool present = false;
|
||||||
|
int colPI = pCols[i];
|
||||||
|
int start = pRows[colPI];
|
||||||
|
int end = pRows[colPI + 1];
|
||||||
|
|
||||||
|
//PRAGMA_OMP_PARALLEL_FOR_ARGS(schedule(guided) firstprivate(offset))
|
||||||
|
for (int m = start; m < end; m++) {
|
||||||
|
if (pCols[m] == n) {
|
||||||
|
present = true;
|
||||||
|
if (n <= colPI) {
|
||||||
|
symColP[symRowP[n] + offset[n]] = colPI;
|
||||||
|
symColP[symRowP[colPI] + offset[colPI]] = n;
|
||||||
|
pOutput[symRowP[n] + offset[n]] = pVals[i] + pVals[m];
|
||||||
|
pOutput[symRowP[colPI] + offset[colPI]] = pVals[i] + pVals[m];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If (colP[i], n) is not present, there is no addition involved
|
||||||
|
if (!present) {
|
||||||
|
//int colPI = pCols[i];
|
||||||
|
//if (n <= colPI) {
|
||||||
|
symColP[symRowP[n] + offset[n]] = colPI;
|
||||||
|
symColP[symRowP[pCols[i]] + offset[colPI]] = n;
|
||||||
|
pOutput[symRowP[n] + offset[n]] = pVals[i];
|
||||||
|
pOutput[symRowP[colPI] + offset[colPI]] = pVals[i];
|
||||||
|
//}
|
||||||
|
|
||||||
|
}
|
||||||
|
// Update offsets
|
||||||
|
if (!present || (present && n <= colPI)) {
|
||||||
|
atomicAdd(&offset[n], 1);
|
||||||
|
|
||||||
|
if (colPI != n)
|
||||||
|
atomicAdd(&offset[colPI], 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static void barnes_symmetrize_(const NDArray* rowP, const NDArray* colP, const NDArray* valP, Nd4jLong N, NDArray* outputRows, NDArray* outputCols, NDArray* outputVals, NDArray* rowCounts) {
|
static void barnes_symmetrize_(const NDArray* rowP, const NDArray* colP, const NDArray* valP, Nd4jLong N, NDArray* outputRows, NDArray* outputCols, NDArray* outputVals, NDArray* rowCounts) {
|
||||||
|
int const* pRows = reinterpret_cast<int const*>(rowP->getSpecialBuffer());
|
||||||
|
int* symRowP = reinterpret_cast<int*>(outputRows->specialBuffer());
|
||||||
|
int* pRowCounts = reinterpret_cast<int*>(rowCounts->specialBuffer());
|
||||||
|
auto stream = outputCols->getContext()->getCudaStream();
|
||||||
|
|
||||||
|
fillUpsymRow<<<1, N, 128, *stream>>>(pRowCounts, symRowP, N);
|
||||||
|
outputRows->syncToHost();
|
||||||
|
// outputRows->printBuffer("output rows");
|
||||||
|
int* symColP = reinterpret_cast<int*>(outputCols->specialBuffer());
|
||||||
|
// outputRows->printBuffer("SymRows are");
|
||||||
|
int const* pCols = reinterpret_cast<int const*>(colP->getSpecialBuffer());
|
||||||
|
T const* pVals = reinterpret_cast<T const*>(valP->getSpecialBuffer());
|
||||||
|
T* pOutput = reinterpret_cast<T*>(outputVals->specialBuffer());
|
||||||
|
//std::vector<int> rowCountsV = rowCounts->getBufferAsVector<int>();
|
||||||
|
auto offsetArr = NDArrayFactory::create<int>('c', {N});
|
||||||
|
int* offset = reinterpret_cast<int*>(offsetArr.specialBuffer());
|
||||||
|
symmetrizeKernel<T><<<1, 1, 1024, *stream>>>(pRows, pCols, pVals, symRowP, symColP, offset, pOutput, N);
|
||||||
|
//PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(schedule(guided) shared(offset))
|
||||||
}
|
}
|
||||||
void barnes_symmetrize(const NDArray* rowP, const NDArray* colP, const NDArray* valP, Nd4jLong N, NDArray* outputRows, NDArray* outputCols, NDArray* outputVals, NDArray* rowCounts) {
|
void barnes_symmetrize(const NDArray* rowP, const NDArray* colP, const NDArray* valP, Nd4jLong N, NDArray* outputRows, NDArray* outputCols, NDArray* outputVals, NDArray* rowCounts) {
|
||||||
//
|
BUILD_SINGLE_SELECTOR(valP->dataType(), barnes_symmetrize_, (rowP, colP, valP, N, outputRows, outputCols, outputVals, rowCounts), NUMERIC_TYPES);
|
||||||
|
|
||||||
|
*outputVals /= 2.0;
|
||||||
}
|
}
|
||||||
BUILD_SINGLE_TEMPLATE(template void barnes_symmetrize_, (const NDArray* rowP, const NDArray* colP, const NDArray* valP, Nd4jLong N, NDArray* outputRows, NDArray* outputCols, NDArray* outputVals, NDArray* rowCounts), NUMERIC_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void barnes_symmetrize_, (const NDArray* rowP, const NDArray* colP, const NDArray* valP, Nd4jLong N, NDArray* outputRows, NDArray* outputCols, NDArray* outputVals, NDArray* rowCounts), NUMERIC_TYPES);
|
||||||
|
template <typename T>
|
||||||
|
static __global__ void edgeForcesKernel(int const* pRows, int const* pCols, T const* dataP, T const* vals, T* outputP, int N, int colCount, int rowSize) {
|
||||||
|
// std::vector<T> buffer(colCount);
|
||||||
|
|
||||||
|
auto start = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
auto step = blockDim.x * gridDim.x;
|
||||||
|
|
||||||
|
for (int n = start; n < N; n += step) {
|
||||||
|
int start = pRows[n];
|
||||||
|
int end = pRows[n + 1];
|
||||||
|
int shift = n * colCount;
|
||||||
|
for (int i = start; i < end; i++) {
|
||||||
|
T const* thisSlice = dataP + pCols[i] * colCount;
|
||||||
|
T res = 1;
|
||||||
|
|
||||||
|
for (int k = 0; k < colCount; k++) {
|
||||||
|
auto valTemp = dataP[shift + k] - thisSlice[k];//thisSlice[k];
|
||||||
|
res += valTemp * valTemp; // (dataP[shift + k] * dataP[shift + k] - 2 * dataP[shift + k] * thisSlice[k] + thisSlice[k] * thisSlice[k])
|
||||||
|
}
|
||||||
|
res = vals[i] / res;
|
||||||
|
for (int k = 0; k < colCount; k++)
|
||||||
|
math::atomics::nd4j_atomicAdd(&outputP[shift + k], T((dataP[shift + k] - thisSlice[k]) * res));
|
||||||
|
}
|
||||||
|
//atomicAdd(&shift, colCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static void barnes_edge_forces_(const NDArray* rowP, NDArray const* colP, NDArray const* valP, int N, NDArray const* data, NDArray* output) {
|
static void barnes_edge_forces_(const NDArray* rowP, NDArray const* colP, NDArray const* valP, int N, NDArray const* data, NDArray* output) {
|
||||||
|
NDArray::prepareSpecialUse({output}, {data, rowP, colP, valP, valP});
|
||||||
|
T const* dataP = reinterpret_cast<T const*>(data->getSpecialBuffer());
|
||||||
|
T const* vals = reinterpret_cast<T const*>(valP->getSpecialBuffer());
|
||||||
|
T* outputP = reinterpret_cast<T*>(output->specialBuffer());
|
||||||
|
int const* pRows = reinterpret_cast<int const*>(rowP->getSpecialBuffer());
|
||||||
|
int const* pCols = reinterpret_cast<int const*>(colP->getSpecialBuffer());
|
||||||
|
int colCount = data->columns();
|
||||||
|
//auto shift = 0;
|
||||||
|
auto rowSize = sizeof(T) * colCount;
|
||||||
|
auto stream = output->getContext()->getCudaStream();
|
||||||
|
edgeForcesKernel<T><<<1, 128, 1024, *stream>>>(pRows, pCols, dataP, vals, outputP, N, colCount, rowSize);
|
||||||
|
NDArray::registerSpecialUse({output}, {rowP, colP, valP, data});
|
||||||
}
|
}
|
||||||
|
|
||||||
void barnes_edge_forces(const NDArray* rowP, NDArray const* colP, NDArray const* valP, int N, NDArray* output, NDArray const& data) {
|
void barnes_edge_forces(const NDArray* rowP, NDArray const* colP, NDArray const* valP, int N, NDArray* output, NDArray const& data) {
|
||||||
|
// Loop over all edges in the graph
|
||||||
|
BUILD_SINGLE_SELECTOR(output->dataType(), barnes_edge_forces_, (rowP, colP, valP, N, &data, output), FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
BUILD_SINGLE_TEMPLATE(template void barnes_edge_forces_, (const NDArray* rowP, NDArray const* colP, NDArray const* valP, int N, NDArray const* data, NDArray* output), FLOAT_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void barnes_edge_forces_, (const NDArray* rowP, NDArray const* colP, NDArray const* valP, int N, NDArray const* data, NDArray* output), FLOAT_TYPES);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void barnes_gains_(NDArray* input, NDArray* gradX, NDArray* epsilon, NDArray* output) {
|
void barnes_gains_(NDArray* input, NDArray* gradX, NDArray* epsilon, NDArray* output) {
|
||||||
|
auto gainsInternal = LAMBDA_TTT(x, grad, eps) {
|
||||||
|
// return T((x + 2.) * nd4j::math::nd4j_sign<T,T>(grad) != nd4j::math::nd4j_sign<T,T>(eps)) + T(x * 0.8 * nd4j::math::nd4j_sign<T,T>(grad) != nd4j::math::nd4j_sign<T,T>(eps));
|
||||||
|
//return T((x + 2.) * nd4j::math::nd4j_sign<T,T>(grad) == nd4j::math::nd4j_sign<T,T>(eps)) + T(x * 0.8 * nd4j::math::nd4j_sign<T,T>(grad) == nd4j::math::nd4j_sign<T,T>(eps));
|
||||||
|
T res = nd4j::math::nd4j_sign<T,T>(grad) != nd4j::math::nd4j_sign<T,T>(eps) ? x + T(.2) : x * T(.8);
|
||||||
|
if(res < .01) res = .01;
|
||||||
|
return res;
|
||||||
|
};
|
||||||
|
|
||||||
|
input->applyTriplewiseLambda(gradX, epsilon, gainsInternal, output);
|
||||||
}
|
}
|
||||||
|
|
||||||
void barnes_gains(NDArray* input, NDArray* gradX, NDArray* epsilon, NDArray* output) {
|
void barnes_gains(NDArray* input, NDArray* gradX, NDArray* epsilon, NDArray* output) {
|
||||||
|
BUILD_SINGLE_SELECTOR(input->dataType(), barnes_gains_, (input, gradX, epsilon, output), NUMERIC_TYPES);
|
||||||
}
|
}
|
||||||
BUILD_SINGLE_TEMPLATE(template void barnes_gains_, (NDArray* input, NDArray* gradX, NDArray* epsilon, NDArray* output), NUMERIC_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void barnes_gains_, (NDArray* input, NDArray* gradX, NDArray* epsilon, NDArray* output), NUMERIC_TYPES);
|
||||||
|
|
||||||
bool cell_contains(NDArray* corner, NDArray* width, NDArray* point, Nd4jLong dimension) {
|
bool cell_contains(NDArray* corner, NDArray* width, NDArray* point, Nd4jLong dimension) {
|
||||||
auto cornerMinusWidth = *corner - *width;
|
auto cornerMinusWidth = *corner - *width;
|
||||||
auto cornerPlusWidth = *corner + *width;
|
auto cornerPlusWidth = *corner + *width;
|
||||||
|
cornerMinusWidth.syncToHost();
|
||||||
|
cornerPlusWidth.syncToHost();
|
||||||
for (Nd4jLong i = 0; i < dimension; i++) {
|
for (Nd4jLong i = 0; i < dimension; i++) {
|
||||||
if (cornerMinusWidth.e<double>(i) > point->e<double>(i))
|
if (cornerMinusWidth.e<double>(i) > point->e<double>(i))
|
||||||
return false;
|
return false;
|
||||||
|
@ -363,7 +363,7 @@ __global__ void logSoftMaxForVectorCuda(const void *vx, const Nd4jLong *xzShape
|
|||||||
temp = 0;
|
temp = 0;
|
||||||
|
|
||||||
// ************ evaluate value of exp(x[offset] - max) per each element, store it to shared memory shmem ************ //
|
// ************ evaluate value of exp(x[offset] - max) per each element, store it to shared memory shmem ************ //
|
||||||
// at the same evaluate sum of exponents, sum will be stored in shmem[0]
|
// at the same time evaluate sum of exponents, sum will be stored in shmem[0]
|
||||||
for (int i = 0; i < numOfIters; ++i) {
|
for (int i = 0; i < numOfIters; ++i) {
|
||||||
|
|
||||||
const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x;
|
const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x;
|
||||||
|
@ -42,9 +42,10 @@ namespace helpers {
|
|||||||
xLen = shape::length(inputXshape);
|
xLen = shape::length(inputXshape);
|
||||||
yLen = shape::length(inputYshape);
|
yLen = shape::length(inputYshape);
|
||||||
outputLen = shape::length(outputShape);
|
outputLen = shape::length(outputShape);
|
||||||
speedWay = speedWay && shape::elementWiseStride(inputXshape) == 1;
|
speedWay = true;
|
||||||
speedWay = speedWay && shape::elementWiseStride(inputYshape) == 1;
|
speedWay = speedWay && (shape::elementWiseStride(inputXshape) == 1);
|
||||||
speedWay = speedWay && shape::elementWiseStride(outputShape) == 1;
|
speedWay = speedWay && (shape::elementWiseStride(inputYshape) == 1);
|
||||||
|
speedWay = speedWay && (shape::elementWiseStride(outputShape) == 1);
|
||||||
|
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
@ -71,27 +72,38 @@ namespace helpers {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static void bdsLoopH(cudaStream_t* stream, void const* inputX, Nd4jLong const* inputXshape, void const* inputY, Nd4jLong const* inputYshape, void* output, Nd4jLong* outputShape) {
|
static void bdsLoopH(cudaStream_t* stream, void const* inputX, Nd4jLong const* inputXshape, void const* inputY, Nd4jLong const* inputYshape, void* output, Nd4jLong* outputShape) {
|
||||||
bdsLoopKernel<T><<<128, 256, 512, *stream>>>(inputX, inputXshape, inputY, inputYshape, output, outputShape);
|
bdsLoopKernel<T><<<1, 256, 512, *stream>>>(inputX, inputXshape, inputY, inputYshape, output, outputShape);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Nd4jStatus bdsFunctor(nd4j::LaunchContext * context, NDArray* x_shape, NDArray* y_shape, NDArray* output) {
|
Nd4jStatus bdsFunctor(nd4j::LaunchContext * context, NDArray* x_shape, NDArray* y_shape, NDArray* output) {
|
||||||
//int e = 0, x = 0, y = 0;
|
//int e = 0, x = 0, y = 0;
|
||||||
|
NDArray::prepareSpecialUse({output}, {x_shape, y_shape});
|
||||||
if (x_shape->lengthOf() == 1 || y_shape->lengthOf() == 1) {// except case
|
if (x_shape->lengthOf() == 1 || y_shape->lengthOf() == 1) {// except case
|
||||||
|
x_shape->syncToHost(); y_shape->syncToHost();
|
||||||
|
if (x_shape->lengthOf() == y_shape->lengthOf()) {
|
||||||
|
auto greater = (x_shape->e<Nd4jLong>(0) < y_shape->e<Nd4jLong>(0) ? y_shape : x_shape);
|
||||||
|
output->assign(greater);
|
||||||
|
}
|
||||||
|
else {
|
||||||
auto lesser = (x_shape->lengthOf() == 1 ? x_shape : y_shape);
|
auto lesser = (x_shape->lengthOf() == 1 ? x_shape : y_shape);
|
||||||
auto greater = (x_shape->lengthOf() == 1 ? y_shape : x_shape);
|
auto greater = (x_shape->lengthOf() == 1 ? y_shape : x_shape);
|
||||||
output->assign(greater);
|
output->assign(greater);
|
||||||
output->syncToHost();
|
auto lastG = greater->lengthOf() - 1;
|
||||||
output->p(output->lengthOf() - 1, *lesser);
|
auto lastL = lesser->lengthOf() - 1;
|
||||||
|
if (greater->e<Nd4jLong>(lastG) < lesser->e<Nd4jLong>(lastL))
|
||||||
|
output->p(lastG, lesser->e(lastL));
|
||||||
output->syncToDevice();
|
output->syncToDevice();
|
||||||
}
|
}
|
||||||
|
}
|
||||||
else {
|
else {
|
||||||
//bdsLoopH(context->getCudaStream(), x->getSpecialBuffer(), x->getSpecialShapeInfo(), y->getSpecialBuffer(), y->getSpecialShape(), output->specialBuffer(), output->specialShapeInfo())
|
//bdsLoopH(context->getCudaStream(), x->getSpecialBuffer(), x->getSpecialShapeInfo(), y->getSpecialBuffer(), y->getSpecialShape(), output->specialBuffer(), output->specialShapeInfo())
|
||||||
BUILD_SINGLE_SELECTOR(output->dataType(), bdsLoopH, (context->getCudaStream(), x_shape->getSpecialBuffer(), x_shape->getSpecialShapeInfo(), y_shape->getSpecialBuffer(), y_shape->getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo()), NUMERIC_TYPES);
|
BUILD_SINGLE_SELECTOR(output->dataType(), bdsLoopH, (context->getCudaStream(), x_shape->getSpecialBuffer(), x_shape->getSpecialShapeInfo(), y_shape->getSpecialBuffer(), y_shape->getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo()), NUMERIC_TYPES);
|
||||||
}
|
}
|
||||||
|
NDArray::registerSpecialUse({output}, {x_shape, y_shape});
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -189,7 +189,7 @@ static void col2imCudaLauncher(const int blocksPerGrid, const int threadsPerBloc
|
|||||||
// col2imCuda2<T><<<512, 512, 1024, *stream>>>(columns, image, colShapeInfo, imShapeInfo, sH, sW, pH, pW, dH, dW);
|
// col2imCuda2<T><<<512, 512, 1024, *stream>>>(columns, image, colShapeInfo, imShapeInfo, sH, sW, pH, pW, dH, dW);
|
||||||
col2imCuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(columns, colShapeInfo, image, imShapeInfo, sH, sW, pH, pW, dH, dW);
|
col2imCuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(columns, colShapeInfo, image, imShapeInfo, sH, sW, pH, pW, dH, dW);
|
||||||
}
|
}
|
||||||
BUILD_SINGLE_TEMPLATE(template void col2imCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t* stream, const void *col, const Nd4jLong *colShapeInfo, void *im, const Nd4jLong *imShapeInfo, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW), FLOAT_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void col2imCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t* stream, const void *col, const Nd4jLong *colShapeInfo, void *im, const Nd4jLong *imShapeInfo, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW), LIBND4J_TYPES);
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void col2im(nd4j::LaunchContext& context, const NDArray& col, NDArray& im, const int sH, const int sW, const int pH, const int pW, const int iH, const int iW, const int dH, const int dW) {
|
void col2im(nd4j::LaunchContext& context, const NDArray& col, NDArray& im, const int sH, const int sW, const int pH, const int pW, const int iH, const int iW, const int dH, const int dW) {
|
||||||
@ -201,7 +201,7 @@ void col2im(nd4j::LaunchContext& context, const NDArray& col, NDArray& im, const
|
|||||||
const int sharedMem = col.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128;
|
const int sharedMem = col.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128;
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({&im}, {&col});
|
NDArray::prepareSpecialUse({&im}, {&col});
|
||||||
BUILD_SINGLE_SELECTOR(im.dataType(), col2imCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context.getCudaStream(), col.getSpecialBuffer(), col.getSpecialShapeInfo(), im.specialBuffer(), im.specialShapeInfo(), sH, sW, pH, pW, dH, dW), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(im.dataType(), col2imCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context.getCudaStream(), col.getSpecialBuffer(), col.getSpecialShapeInfo(), im.specialBuffer(), im.specialShapeInfo(), sH, sW, pH, pW, dH, dW), LIBND4J_TYPES);
|
||||||
NDArray::registerSpecialUse({&im}, {&col});
|
NDArray::registerSpecialUse({&im}, {&col});
|
||||||
|
|
||||||
manager.synchronize();
|
manager.synchronize();
|
||||||
|
42
libnd4j/include/ops/declarable/helpers/cuda/gradient.cu
Normal file
42
libnd4j/include/ops/declarable/helpers/cuda/gradient.cu
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author sgazeos@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/axis.h>
|
||||||
|
#include <op_boilerplate.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
template <typename T>
|
||||||
|
static void applyGradientDescent_(LaunchContext* context, NDArray* input, NDArray* step, double weight, NDArray* output) {
|
||||||
|
auto lambda = LAMBDA_TT(_x, _y, weight) {
|
||||||
|
return _x - (_y * weight);
|
||||||
|
};
|
||||||
|
|
||||||
|
input->applyPairwiseLambda(step, lambda, output);
|
||||||
|
}
|
||||||
|
|
||||||
|
void applyGradientDescent(nd4j::LaunchContext* context, NDArray* input, NDArray* step, double weight, NDArray* output) {
|
||||||
|
BUILD_SINGLE_SELECTOR(input->dataType(), applyGradientDescent_, (context, input, step, weight, output), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void applyGradientDescent_, (LaunchContext* context, NDArray* input, NDArray* step, double weight, NDArray* output), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -27,13 +27,13 @@ namespace nd4j {
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
static __global__ void splitBufferToChuncks(T* buffer, Nd4jLong* tempBuffer, Nd4jLong numBlocks, Nd4jLong blockSize, Nd4jLong length) {
|
static __global__ void splitBufferToChuncks(T* buffer, Nd4jLong* tempBuffer, Nd4jLong numBlocks, Nd4jLong blockSize, Nd4jLong length) {
|
||||||
|
|
||||||
for (int b = blockIdx.x; b < numBlocks; b += gridDim.x) {
|
for (int b = blockIdx.x * blockDim.x + threadIdx.x; b < numBlocks; b += gridDim.x*blockDim.x) {
|
||||||
auto blockBuffer = buffer + b * numBlocks;
|
auto blockBuffer = buffer + b * numBlocks;
|
||||||
|
|
||||||
Nd4jLong r = 1;
|
Nd4jLong r = 1LL;
|
||||||
for (int e = threadIdx.x; e < blockSize && e + (b * numBlocks) < length; e += blockDim.x) {
|
for (int e = 0; e < blockSize && e + (b * numBlocks) < length; e++) {
|
||||||
auto v = longBytes<T>(blockBuffer[e]);
|
auto v = longBytes<T>(blockBuffer[e]);
|
||||||
r = 31 * r + v;
|
r = 31LL * r + v;
|
||||||
}
|
}
|
||||||
|
|
||||||
tempBuffer[b] = r;
|
tempBuffer[b] = r;
|
||||||
@ -43,16 +43,17 @@ namespace nd4j {
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
static __global__ void internalHash(Nd4jLong* tempBuffer, Nd4jLong* tempResult, Nd4jLong numBlocks, Nd4jLong blockSize, Nd4jLong lastLength) {
|
static __global__ void internalHash(Nd4jLong* tempBuffer, Nd4jLong* tempResult, Nd4jLong numBlocks, Nd4jLong blockSize, Nd4jLong lastLength) {
|
||||||
|
|
||||||
for (int b = blockIdx.x; b < numBlocks; b += gridDim.x) {
|
for (int b = blockIdx.x * blockDim.x + threadIdx.x; b < numBlocks; b += gridDim.x * blockDim.x) {
|
||||||
auto blockBuffer = tempBuffer + b * numBlocks;
|
auto blockBuffer = tempBuffer + b * numBlocks;
|
||||||
|
Nd4jLong r = 1LL;
|
||||||
|
|
||||||
Nd4jLong r = 1;
|
for (Nd4jLong e = 0; e < blockSize && e + (b * numBlocks) < lastLength; e++) {
|
||||||
for (int e = threadIdx.x; e < blockSize && e + (b * numBlocks) < lastLength; e += blockDim.x) {
|
|
||||||
auto v = longBytes<T>(blockBuffer[e]);
|
auto v = longBytes<T>(blockBuffer[e]);
|
||||||
r = 31 * r + v;
|
r = 31LL * r + v;
|
||||||
}
|
}
|
||||||
|
|
||||||
tempResult[b] = r;
|
tempResult[b] = r;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@ -89,7 +90,7 @@ namespace nd4j {
|
|||||||
auto tempResult = tempBufferB;
|
auto tempResult = tempBufferB;
|
||||||
|
|
||||||
// we divide array into 32 element chunks, and store intermediate results once
|
// we divide array into 32 element chunks, and store intermediate results once
|
||||||
splitBufferToChuncks<T><<<numBlocks, length, 1024, *stream>>>(buffer, tempBuffer, numBlocks, blockSize, length);
|
splitBufferToChuncks<T><<<numBlocks, 1, 1024, *stream>>>(buffer, tempBuffer, numBlocks, blockSize, length);
|
||||||
|
|
||||||
// we replace pointer with intermediate one, and repeat only one chunk left
|
// we replace pointer with intermediate one, and repeat only one chunk left
|
||||||
int iterationCount = 0;
|
int iterationCount = 0;
|
||||||
@ -98,7 +99,7 @@ namespace nd4j {
|
|||||||
numBlocks = lastLength / blockSize + ((lastLength % blockSize == 0) ? 0 : 1);
|
numBlocks = lastLength / blockSize + ((lastLength % blockSize == 0) ? 0 : 1);
|
||||||
|
|
||||||
|
|
||||||
internalHash<Nd4jLong><<<numBlocks, lastLength, 1024, *stream>>>(tempBuffer, tempResult, numBlocks, blockSize, lastLength);
|
internalHash<Nd4jLong><<<numBlocks, 1, 1024, *stream>>>(tempBuffer, tempResult, numBlocks, blockSize, lastLength);
|
||||||
|
|
||||||
|
|
||||||
iterationCount++;
|
iterationCount++;
|
||||||
@ -112,10 +113,10 @@ namespace nd4j {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//lastStep<Nd4jLong><<<1,1,128, *stream>>>(result.specialBuffer(), tempBufferA, tempResult, length, blockSize);
|
lastStep<<<1,1,128, *stream>>>(reinterpret_cast<Nd4jLong*>(result.specialBuffer()), tempBufferA, tempResult, length, blockSize);
|
||||||
tempA.syncToHost();
|
// tempA.syncToHost();
|
||||||
tempB.syncToHost();
|
// tempB.syncToHost();
|
||||||
result.assign((length <= blockSize?tempA.e(0) : tempB.e(0)));
|
// result.assign((length <= blockSize?tempA.e(0) : tempB.e(0)));
|
||||||
|
|
||||||
NDArray::registerSpecialUse({&result}, {&array});
|
NDArray::registerSpecialUse({&result}, {&array});
|
||||||
}
|
}
|
||||||
|
@ -85,7 +85,7 @@ template <typename T>
|
|||||||
static void im2colCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, nd4j::LaunchContext & context, const void *image, void *columns, const Nd4jLong *imShapeInfo, const Nd4jLong *colShapeInfo, int sH, int sW, int pH, int pW, int dH, int dW, double zeroPadVal) {
|
static void im2colCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, nd4j::LaunchContext & context, const void *image, void *columns, const Nd4jLong *imShapeInfo, const Nd4jLong *colShapeInfo, int sH, int sW, int pH, int pW, int dH, int dW, double zeroPadVal) {
|
||||||
im2colCuda<T><<<blocksPerGrid, threadsPerBlock, threadsPerBlock * sizeof(Nd4jLong) * 6 /* rank of columns = 6 */, *context.getCudaStream()>>>(image, columns, imShapeInfo, colShapeInfo, sH, sW, pH, pW, dH, dW, zeroPadVal);
|
im2colCuda<T><<<blocksPerGrid, threadsPerBlock, threadsPerBlock * sizeof(Nd4jLong) * 6 /* rank of columns = 6 */, *context.getCudaStream()>>>(image, columns, imShapeInfo, colShapeInfo, sH, sW, pH, pW, dH, dW, zeroPadVal);
|
||||||
}
|
}
|
||||||
BUILD_SINGLE_TEMPLATE(template void im2colCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, nd4j::LaunchContext& context, const void *image, void *columns, const Nd4jLong *imShapeInfo, const Nd4jLong *colShapeInfo, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const double zeroPadVal), FLOAT_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void im2colCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, nd4j::LaunchContext& context, const void *image, void *columns, const Nd4jLong *imShapeInfo, const Nd4jLong *colShapeInfo, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const double zeroPadVal), LIBND4J_TYPES);
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void im2col(nd4j::LaunchContext& context, const NDArray& image, NDArray& columns, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const NDArray& arrZeroPadVal) {
|
void im2col(nd4j::LaunchContext& context, const NDArray& image, NDArray& columns, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const NDArray& arrZeroPadVal) {
|
||||||
@ -96,7 +96,7 @@ void im2col(nd4j::LaunchContext& context, const NDArray& image, NDArray& columns
|
|||||||
const int blocksPerGrid = (columns.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
const int blocksPerGrid = (columns.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({&columns}, {&image});
|
NDArray::prepareSpecialUse({&columns}, {&image});
|
||||||
BUILD_SINGLE_SELECTOR(columns.dataType(), im2colCudaLauncher, (blocksPerGrid, threadsPerBlock, context, image.getSpecialBuffer(), columns.getSpecialBuffer(), image.getSpecialShapeInfo(), columns.getSpecialShapeInfo(), sH, sW, pH, pW, dH, dW, arrZeroPadVal.e<double>(0)), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(columns.dataType(), im2colCudaLauncher, (blocksPerGrid, threadsPerBlock, context, image.getSpecialBuffer(), columns.getSpecialBuffer(), image.getSpecialShapeInfo(), columns.getSpecialShapeInfo(), sH, sW, pH, pW, dH, dW, arrZeroPadVal.e<double>(0)), LIBND4J_TYPES);
|
||||||
NDArray::registerSpecialUse({&columns}, {&image});
|
NDArray::registerSpecialUse({&columns}, {&image});
|
||||||
|
|
||||||
manager.synchronize();
|
manager.synchronize();
|
||||||
|
@ -71,7 +71,7 @@ namespace helpers {
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
linkage void leakyReluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) {
|
linkage void leakyReluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) {
|
||||||
auto functor = LAMBDA_TT(x, y){
|
auto functor = LAMBDA_TT(x, y){
|
||||||
return x >= (T)0.f? T(1.f) : T(0.f);
|
return x >= (T)0.f? y : T(0.f);
|
||||||
};
|
};
|
||||||
|
|
||||||
input->applyPairwiseLambda(epsilon, functor, output);
|
input->applyPairwiseLambda(epsilon, functor, output);
|
||||||
|
@ -30,6 +30,7 @@
|
|||||||
#include<ops/declarable/helpers/lstmBlock.h>
|
#include<ops/declarable/helpers/lstmBlock.h>
|
||||||
#include <ops/declarable/CustomOperations.h>
|
#include <ops/declarable/CustomOperations.h>
|
||||||
#include<ops/declarable/helpers/transforms.h>
|
#include<ops/declarable/helpers/transforms.h>
|
||||||
|
#include <helpers/PointersManager.h>
|
||||||
#include <array/NDArrayList.h>
|
#include <array/NDArrayList.h>
|
||||||
#include <iterator>
|
#include <iterator>
|
||||||
|
|
||||||
@ -43,40 +44,40 @@ namespace helpers {
|
|||||||
void lstmCell(nd4j::LaunchContext * context, const NDArray* xt, const NDArray* ht_1, const NDArray* ct_1, const NDArray* Wx, const NDArray* Wh, const NDArray* Wc, const NDArray* Wp, const NDArray* b,
|
void lstmCell(nd4j::LaunchContext * context, const NDArray* xt, const NDArray* ht_1, const NDArray* ct_1, const NDArray* Wx, const NDArray* Wh, const NDArray* Wc, const NDArray* Wp, const NDArray* b,
|
||||||
NDArray* ht, NDArray* ct, const std::vector<double>& params) {
|
NDArray* ht, NDArray* ct, const std::vector<double>& params) {
|
||||||
|
|
||||||
// xt input [bS x inSize]
|
// xt input [bS x nIn]
|
||||||
// ht_1 previous cell output [bS x numProj], that is at previous time step t-1, in case of projection=false -> numProj=numUnits!!!
|
// ht_1 previous cell output [bS x numProj], that is at previous time step t-1, in case of projection=false -> numProj=nOut!!!
|
||||||
// ct_1 previous cell state [bS x numUnits], that is at previous time step t-1
|
// ct_1 previous cell state [bS x nOut], that is at previous time step t-1
|
||||||
|
|
||||||
// Wx input-to-hidden weights, [inSize x 4*numUnits]
|
// Wx input-to-hidden weights, [nIn x 4*nOut]
|
||||||
// Wh hidden-to-hidden weights, [numProj x 4*numUnits]
|
// Wh hidden-to-hidden weights, [numProj x 4*nOut]
|
||||||
// Wc diagonal weights for peephole connections [3*numUnits]
|
// Wc diagonal weights for peephole connections [3*nOut]
|
||||||
// Wp projection weights [numUnits x numProj]
|
// Wp projection weights [nOut x numProj]
|
||||||
// b biases, [4*numUnits]
|
// b biases, [4*nOut]
|
||||||
|
|
||||||
// ht current cell output [bS x numProj], that is at current time step t
|
// ht current cell output [bS x numProj], that is at current time step t
|
||||||
// ct current cell state [bS x numUnits], that is at current time step t
|
// ct current cell state [bS x nOut], that is at current time step t
|
||||||
|
|
||||||
const bool peephole = (bool)params[0]; // if true, provide peephole connections
|
const bool peephole = (bool)params[0]; // if true, provide peephole connections
|
||||||
const bool projection = (bool)params[1]; // if true, then projection is performed, if false then numProj==numUnits is mandatory!!!!
|
const bool projection = (bool)params[1]; // if true, then projection is performed, if false then numProj==nOut is mandatory!!!!
|
||||||
double clippingCellValue = params[2]; // clipping value for ct, if it is not equal to zero, then cell state is clipped
|
double clippingCellValue = params[2]; // clipping value for ct, if it is not equal to zero, then cell state is clipped
|
||||||
double clippingProjValue = params[3]; // clipping value for projected ht, if it is not equal to zero, then projected cell output is clipped
|
double clippingProjValue = params[3]; // clipping value for projected ht, if it is not equal to zero, then projected cell output is clipped
|
||||||
const double forgetBias = params[4];
|
const double forgetBias = params[4];
|
||||||
|
|
||||||
const int bS = xt->sizeAt(0);
|
const int bS = xt->sizeAt(0);
|
||||||
const int inSize = xt->sizeAt(1);
|
const int nIn = xt->sizeAt(1);
|
||||||
const int numProj = ht_1->sizeAt(1);
|
const int numProj = ht_1->sizeAt(1);
|
||||||
const int numUnits = ct_1->sizeAt(1);
|
const int nOut = ct_1->sizeAt(1);
|
||||||
|
|
||||||
auto z = mmul(*xt, *Wx) + mmul(*ht_1, *Wh) + *b; // [bS x 4*numUnits] + [bS x 4*numUnits] + [1 x 4*numUnits] = [bS x 4*numUnits]
|
auto z = mmul(*xt, *Wx) + mmul(*ht_1, *Wh) + *b; // [bS x 4*nOut] + [bS x 4*nOut] + [1 x 4*nOut] = [bS x 4*nOut]
|
||||||
|
|
||||||
auto zit = z({0,0, 0, numUnits}); // z for input gate, = mmul(Wxi,xt) + mmul(Whi,ht_1) + bi = [bS x numUnits]
|
auto zit = z({0,0, 0,nOut}); // z for input gate, = mmul(Wxi,xt) + mmul(Whi,ht_1) + bi = [bS x nOut]
|
||||||
auto zft = z({0,0, numUnits, 2*numUnits}); // z for forget gate, = mmul(Wxf,xt) + mmul(Whf,ht_1) + bf = [bS x numUnits]
|
auto zft = z({0,0, nOut,2*nOut}); // z for forget gate, = mmul(Wxf,xt) + mmul(Whf,ht_1) + bf = [bS x nOut]
|
||||||
auto zct = z({0,0, 2*numUnits, 3*numUnits}); // z for cell state, = mmul(Wxc,xt) + mmul(Whc,ht_1) + bc = [bS x numUnits]
|
auto zct = z({0,0, 2*nOut,3*nOut}); // z for cell state, = mmul(Wxc,xt) + mmul(Whc,ht_1) + bc = [bS x nOut]
|
||||||
auto zot = z({0,0, 3*numUnits, 4*numUnits}); // z for output gate, = mmul(Wxo,xt) + mmul(Who,ht_1) + bo = [bS x numUnits]
|
auto zot = z({0,0, 3*nOut,4*nOut}); // z for output gate, = mmul(Wxo,xt) + mmul(Who,ht_1) + bo = [bS x nOut]
|
||||||
|
|
||||||
if(peephole) { // add peephole connections: z + ct_1*Wc
|
if(peephole) { // add peephole connections: z + ct_1*Wc
|
||||||
zit += (*ct_1) * (*Wc)({0, numUnits}); // add peephole connections to input gate
|
zit += (*ct_1) * (*Wc)({0, nOut}); // add peephole connections to input gate
|
||||||
zft += (*ct_1) * (*Wc)({numUnits, 2*numUnits}); // add peephole connections to forget gate
|
zft += (*ct_1) * (*Wc)({nOut, 2*nOut}); // add peephole connections to forget gate
|
||||||
}
|
}
|
||||||
|
|
||||||
// current sell state = ft*ct_1 + it*tanh(mmul(Wxc,xt) + mmul(Whc,ht_1) + bc
|
// current sell state = ft*ct_1 + it*tanh(mmul(Wxc,xt) + mmul(Whc,ht_1) + bc
|
||||||
@ -84,20 +85,20 @@ void lstmCell(nd4j::LaunchContext * context, const NDArray* xt, const NDArray* h
|
|||||||
|
|
||||||
// if clipping value is provided then cell state is clipped by this value prior to the cell output activation
|
// if clipping value is provided then cell state is clipped by this value prior to the cell output activation
|
||||||
if(clippingCellValue > 0.0)
|
if(clippingCellValue > 0.0)
|
||||||
clipping(ct, clippingCellValue);
|
ct->applyScalar(scalar::LstmClip, clippingCellValue);
|
||||||
|
|
||||||
if(peephole)
|
if(peephole)
|
||||||
zot += (*ct) * (*Wc)({{2*numUnits, 3*numUnits}}); // add peephole connections to output gate zot + ct*Wc
|
zot += (*ct) * (*Wc)({{2*nOut, 3*nOut}}); // add peephole connections to output gate zot + ct*Wc
|
||||||
|
|
||||||
// current cell output = ot*tanh(ct)
|
// current cell output = ot*tanh(ct)
|
||||||
auto htNoPeepHole = sigmoid(zot) * tanh(*ct); // = [bS x numUnits]
|
auto htNoPeepHole = sigmoid(zot) * tanh(*ct); // = [bS x nOut]
|
||||||
|
|
||||||
// apply projection
|
// apply projection
|
||||||
if(projection) {
|
if(projection) {
|
||||||
ht->assign( mmul(htNoPeepHole, *Wp) ); // [bS x numUnits] * [ numUnits x numProj] = [bS x numProj]
|
ht->assign( mmul(htNoPeepHole, *Wp) ); // [bS x nOut] * [ nOut x numProj] = [bS x numProj]
|
||||||
// if clipping projection is provided then projected cell output state is clipped by this value
|
// if clipping projection is provided then projected cell output state is clipped by this value
|
||||||
if(clippingProjValue != 0.)
|
if(clippingProjValue != 0.)
|
||||||
clipping(ht, clippingProjValue);
|
ht->applyScalar(scalar::LstmClip, clippingProjValue);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
ht->assign(&htNoPeepHole);
|
ht->assign(&htNoPeepHole);
|
||||||
@ -109,14 +110,14 @@ void lstmCell(nd4j::LaunchContext * context, const NDArray* xt, const NDArray* h
|
|||||||
const NDArray* W, const NDArray* Wci, const NDArray* Wcf, const NDArray* Wco, const NDArray* b,
|
const NDArray* W, const NDArray* Wci, const NDArray* Wcf, const NDArray* Wco, const NDArray* b,
|
||||||
NDArray* i, NDArray* c, NDArray* f, NDArray* o, NDArray* z, NDArray* h, NDArray* y, const std::vector<double>& params) {
|
NDArray* i, NDArray* c, NDArray* f, NDArray* o, NDArray* z, NDArray* h, NDArray* y, const std::vector<double>& params) {
|
||||||
/* Input arrays:
|
/* Input arrays:
|
||||||
* 0: xt - input [bS, inSize] at time t
|
* 0: xt - input [bS, nIn] at time t
|
||||||
* 1: cLast (cs_prev) - previous cell state [bS, numUnits], time t-1
|
* 1: cLast (cs_prev) - previous cell state [bS, nOut], time t-1
|
||||||
* 2: yLast (h_prev) - previous output [bS, numUnits], time t-1
|
* 2: yLast (h_prev) - previous output [bS, nOut], time t-1
|
||||||
* 3: W - Weights - concatenated (input-to-hidden, hidden-to-hidden weights) weights, [(inSize+numUnits), 4*numUnits]
|
* 3: W - Weights - concatenated (input-to-hidden, hidden-to-hidden weights) weights, [(nIn+nOut), 4*nOut]
|
||||||
* 4: Wci - weights - cell peephole (t-1) connections to input modulation gate, [numUnits]
|
* 4: Wci - weights - cell peephole (t-1) connections to input modulation gate, [nOut]
|
||||||
* 5: Wcf - weights - cell peephole (t-1) connections to forget gate, [numUnits]
|
* 5: Wcf - weights - cell peephole (t-1) connections to forget gate, [nOut]
|
||||||
* 6: Wco - weights - cell peephole (t) connections to output gate, [numUnits]
|
* 6: Wco - weights - cell peephole (t) connections to output gate, [nOut]
|
||||||
* 7: b - biases, [4*numUnits]
|
* 7: b - biases, [4*nOut]
|
||||||
*
|
*
|
||||||
* Input integer arguments:
|
* Input integer arguments:
|
||||||
* 0: if not zero, provide peephole connections
|
* 0: if not zero, provide peephole connections
|
||||||
@ -126,42 +127,34 @@ void lstmCell(nd4j::LaunchContext * context, const NDArray* xt, const NDArray* h
|
|||||||
* 1: clipping value for cell state, if it is not equal to zero, then cell state is clipped
|
* 1: clipping value for cell state, if it is not equal to zero, then cell state is clipped
|
||||||
*
|
*
|
||||||
* Output arrays:
|
* Output arrays:
|
||||||
* 0: i - Input modulation gate activations [bS, numUnits]
|
* 0: i - Input modulation gate activations [bS, nOut]
|
||||||
* 1: c (cs) - Cell state (pre tanh) [bs, numUnits] (cs)
|
* 1: c (cs) - Cell state (pre tanh) [bs, nOut] (cs)
|
||||||
* 2: f - Output - forget gate activations [bs, numUnits]
|
* 2: f - Output - forget gate activations [bs, nOut]
|
||||||
* 3: o - Output - output gate activations [bs, numUnits]
|
* 3: o - Output - output gate activations [bs, nOut]
|
||||||
* 4: z (ci) - Output - block input [bs, numUnits]
|
* 4: z (ci) - Output - block input [bs, nOut]
|
||||||
* 5: h (co) - Cell state, post tanh [bs, numUnits]
|
* 5: h (co) - Cell state, post tanh [bs, nOut]
|
||||||
* 6: y (h) - Current cell output [bS, numUnits], time t
|
* 6: y (h) - Current cell output [bS, nOut], time t
|
||||||
*/
|
*/
|
||||||
const bool peephole = (bool)params[0]; // if true, provide peephole connections
|
const bool peephole = (bool)params[0]; // if true, provide peephole connections
|
||||||
const double forgetBias = params[1];
|
const double forgetBias = params[1];
|
||||||
const double clippingCellValue = params[2]; // clipping value for ct, if it is not equal to zero, then cell state is clipped
|
const double clippingCellValue = params[2]; // clipping value for ct, if it is not equal to zero, then cell state is clipped
|
||||||
|
|
||||||
|
|
||||||
const int bS = xt->sizeAt(0);
|
const int bS = xt->sizeAt(0);
|
||||||
const int inSize = xt->sizeAt(1);
|
const int nIn = xt->sizeAt(1);
|
||||||
const int numUnits = cLast->sizeAt(1);
|
const int nOut = cLast->sizeAt(1);
|
||||||
|
|
||||||
//Concat inputs: [xt, yt-1]: concat([bs,nIn],[bs,nOut]) -> [bs, (nIn+nOut)]
|
//Concat inputs: [xt, yt-1]: concat([bs,nIn],[bs,nOut]) -> [bs, (nIn+nOut)]
|
||||||
nd4j::ops::concat concat;
|
NDArray concatOut(xt->ordering(), {xt->sizeAt(0), xt->sizeAt(1) + yLast->sizeAt(1)}, xt->dataType(), xt->getContext());
|
||||||
Context cContext(119);
|
helpers::concat(xt->getContext(), {const_cast<NDArray*>(xt), const_cast<NDArray*>(yLast)}, concatOut, {1});
|
||||||
auto concatOut = NDArrayFactory::create(xt->ordering(), {xt->sizeAt(0), xt->sizeAt(1) + yLast->sizeAt(1)}, xt->dataType(), xt->getContext());
|
|
||||||
cContext.setInputArray(0, const_cast<NDArray*>(xt), false);
|
|
||||||
cContext.setInputArray(1, const_cast<NDArray*>(yLast), false);
|
|
||||||
cContext.setOutputArray(0, &concatOut, false);
|
|
||||||
cContext.getIArguments()->emplace_back(1);
|
|
||||||
|
|
||||||
concat.execute(&cContext);
|
auto m = mmul(concatOut, *W); // mmul: [bs, (nIn+nOut)] * [(nIn+nOut), 4*nOut] = [bs, 4*nOut]
|
||||||
|
m += (*b); // addiRowVector
|
||||||
auto m = mmul(concatOut, *W); //mmul: [bs, (nIn+numUnits)]* [(inSize+numUnits), 4*numUnits] = [bs, 4*numUnits]
|
|
||||||
m += (*b);
|
|
||||||
|
|
||||||
//Note: weights are ordered [inputGate, blockInput, forgetGate, outputGate] to match TF (TF code comments state [i,f,z/ci,o] but behaviour is [i,z,f,o])
|
//Note: weights are ordered [inputGate, blockInput, forgetGate, outputGate] to match TF (TF code comments state [i,f,z/ci,o] but behaviour is [i,z,f,o])
|
||||||
auto zi = m({0,0, 0, numUnits}); // z for input modulation gate, [bS, numUnits]
|
auto zi = m({0,0, 0, nOut}); // z for input modulation gate, [bS, nOut]
|
||||||
auto zz = m({0,0, numUnits, 2*numUnits}); // z for block input, [bS, numUnits]
|
auto zz = m({0,0, nOut, 2*nOut}); // z for block input, [bS, nOut]
|
||||||
auto zf = m({0,0, 2*numUnits, 3*numUnits}); // z for forget gate, [bS, numUnits]
|
auto zf = m({0,0, 2*nOut, 3*nOut}); // z for forget gate, [bS, nOut]
|
||||||
auto zo = m({0,0, 3*numUnits, 4*numUnits}); // z for output gate, [bS, numUnits]
|
auto zo = m({0,0, 3*nOut, 4*nOut}); // z for output gate, [bS, nOut]
|
||||||
|
|
||||||
if(peephole) { // add peephole connections: z + ct_1*Wc
|
if(peephole) { // add peephole connections: z + ct_1*Wc
|
||||||
zi += (*cLast) * (*Wci); // add peephole connections to input gate
|
zi += (*cLast) * (*Wci); // add peephole connections to input gate
|
||||||
@ -169,26 +162,22 @@ void lstmCell(nd4j::LaunchContext * context, const NDArray* xt, const NDArray* h
|
|||||||
}
|
}
|
||||||
|
|
||||||
// current sell state = ft*cLast + it*tanh(mmul(Wxc,xt) + mmul(Whc,ht_1) + bc
|
// current sell state = ft*cLast + it*tanh(mmul(Wxc,xt) + mmul(Whc,ht_1) + bc
|
||||||
if(forgetBias != 0.0){
|
if(forgetBias != 0.0)
|
||||||
zf += forgetBias;
|
zf += forgetBias;
|
||||||
}
|
|
||||||
|
|
||||||
zz.applyTransform(transform::Tanh, z); //z = tanh(zz)
|
zz.applyTransform(transform::Tanh, z); //z = tanh(zz)
|
||||||
zi.applyTransform(transform::Sigmoid, i); //i = sigmoid(zi)
|
zi.applyTransform(transform::Sigmoid, i); //i = sigmoid(zi)
|
||||||
zf.applyTransform(transform::Sigmoid, f); //f = sigmoid(zf);
|
zf.applyTransform(transform::Sigmoid, f); //f = sigmoid(zf);
|
||||||
|
|
||||||
|
|
||||||
//cell state = blockInput .* inputGate + prevCellState .* forgetGate
|
//cell state = blockInput .* inputGate + prevCellState .* forgetGate
|
||||||
z->applyPairwiseTransform(pairwise::Multiply, i, c, nullptr); //c = z * i
|
z->applyPairwiseTransform(pairwise::Multiply, i, c, nullptr); //c = z * i
|
||||||
auto temp = (*f) * (*cLast);
|
auto temp = (*f) * (*cLast);
|
||||||
*c += temp; //c = (i * z) + (zf * (*cLast))
|
*c += temp; //c = (i * z) + (zf * (*cLast))
|
||||||
c->applyTransform(transform::Tanh, h); //h = tanh(c)
|
c->applyTransform(transform::Tanh, h); //h = tanh(c)
|
||||||
|
|
||||||
|
|
||||||
// if clipping value is provided then cell state is clipped by this value prior to the cell output activation
|
// if clipping value is provided then cell state is clipped by this value prior to the cell output activation
|
||||||
if(clippingCellValue > 0.0) {
|
if(clippingCellValue > 0.0)
|
||||||
clipping(c, clippingCellValue);
|
c->applyScalar(scalar::LstmClip, clippingCellValue);
|
||||||
}
|
|
||||||
|
|
||||||
if(peephole) {
|
if(peephole) {
|
||||||
// add peephole connections to output gate zot + ct*Wc
|
// add peephole connections to output gate zot + ct*Wc
|
||||||
|
@ -805,6 +805,8 @@ namespace helpers {
|
|||||||
|
|
||||||
if(!inplace)
|
if(!inplace)
|
||||||
output->assign(tempOutput.get());
|
output->assign(tempOutput.get());
|
||||||
|
else
|
||||||
|
input->assign(tempOutput.get());
|
||||||
|
|
||||||
NDArray::registerSpecialUse({output}, {input});
|
NDArray::registerSpecialUse({output}, {input});
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -812,6 +814,7 @@ namespace helpers {
|
|||||||
|
|
||||||
// template <typename T>
|
// template <typename T>
|
||||||
int cholesky_(LaunchContext* context, NDArray* input, NDArray* output, bool inplace) {
|
int cholesky_(LaunchContext* context, NDArray* input, NDArray* output, bool inplace) {
|
||||||
|
NDArray::prepareSpecialUse({output}, {input});
|
||||||
if (input->dataType() == DataType::DOUBLE)
|
if (input->dataType() == DataType::DOUBLE)
|
||||||
cholesky__<double>(context, input, output, inplace);
|
cholesky__<double>(context, input, output, inplace);
|
||||||
else if (input->dataType() == DataType::FLOAT32)
|
else if (input->dataType() == DataType::FLOAT32)
|
||||||
@ -822,6 +825,7 @@ namespace helpers {
|
|||||||
cholesky__<float>(context, tempOutput.get(), tempOutput.get(), true);
|
cholesky__<float>(context, tempOutput.get(), tempOutput.get(), true);
|
||||||
output->assign(tempOutput.get());
|
output->assign(tempOutput.get());
|
||||||
}
|
}
|
||||||
|
NDArray::registerSpecialUse({output}, {input});
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -832,23 +836,23 @@ namespace helpers {
|
|||||||
// BUILD_SINGLE_TEMPLATE(template int cholesky_, (LaunchContext* context, NDArray* input, NDArray* output, bool inplace), FLOAT_TYPES);
|
// BUILD_SINGLE_TEMPLATE(template int cholesky_, (LaunchContext* context, NDArray* input, NDArray* output, bool inplace), FLOAT_TYPES);
|
||||||
BUILD_SINGLE_TEMPLATE(template int inverse_, (nd4j::LaunchContext* context, NDArray* input, NDArray* output), FLOAT_NATIVE);
|
BUILD_SINGLE_TEMPLATE(template int inverse_, (nd4j::LaunchContext* context, NDArray* input, NDArray* output), FLOAT_NATIVE);
|
||||||
|
|
||||||
__global__ void logDetKernel(void* inputBuf, Nd4jLong* inputShape, Nd4jLong batchNum, Nd4jLong* tadShape, Nd4jLong* tadOffsets, void* outputBuf, Nd4jLong* outputShape) {
|
__global__ void logDetKernel(double* inputBuf, Nd4jLong* inputShape, Nd4jLong batchNum, Nd4jLong* tadShape, Nd4jLong* tadOffsets, double* outputBuf, Nd4jLong* outputShape) {
|
||||||
__shared__ double* output;
|
|
||||||
__shared__ double* input;
|
__shared__ int n;
|
||||||
__shared__ int n2;
|
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
output = reinterpret_cast<double*>(outputBuf);
|
n = shape::sizeAt(inputShape, -1); // * shape::sizeAt(inputShape, -1);
|
||||||
input = reinterpret_cast<double*>(inputBuf);
|
|
||||||
n2 = shape::sizeAt(inputShape, -1) * shape::sizeAt(inputShape, -1);
|
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
for (Nd4jLong i = blockIdx.x; i < batchNum; i += gridDim.x) {
|
double* output = outputBuf;
|
||||||
|
double* input = inputBuf;
|
||||||
|
|
||||||
|
for (auto i = blockIdx.x; i < batchNum; i += gridDim.x) {
|
||||||
double* current = input + tadOffsets[i];
|
double* current = input + tadOffsets[i];
|
||||||
Nd4jLong* shapeOf = shape::shapeOf(tadShape);
|
Nd4jLong* shapeOf = shape::shapeOf(tadShape);
|
||||||
Nd4jLong* strideOf = shape::stride(tadShape);
|
Nd4jLong* strideOf = shape::stride(tadShape);
|
||||||
auto zIndex = shape::getIndexOffset(i, outputShape, batchNum);
|
auto zIndex = shape::getIndexOffset(i, outputShape, batchNum);
|
||||||
for (Nd4jLong e = threadIdx.x; e < n2; e += blockDim.x) {
|
for (auto e = threadIdx.x; e < n; e += blockDim.x) {
|
||||||
Nd4jLong diag[] = {e, e};
|
Nd4jLong diag[] = {e, e};
|
||||||
auto xIndex = shape::getOffset(0, shapeOf, strideOf, diag, 2);
|
auto xIndex = shape::getOffset(0, shapeOf, strideOf, diag, 2);
|
||||||
math::atomics::nd4j_atomicAdd(&output[zIndex], math::nd4j_log<double,double>(current[xIndex] * current[xIndex]));
|
math::atomics::nd4j_atomicAdd(&output[zIndex], math::nd4j_log<double,double>(current[xIndex] * current[xIndex]));
|
||||||
@ -858,17 +862,27 @@ namespace helpers {
|
|||||||
|
|
||||||
int logdetFunctor(nd4j::LaunchContext* context, NDArray* input, NDArray* output) {
|
int logdetFunctor(nd4j::LaunchContext* context, NDArray* input, NDArray* output) {
|
||||||
NDArray::prepareSpecialUse({output}, {input});
|
NDArray::prepareSpecialUse({output}, {input});
|
||||||
auto tempOutput = input->dup('c');
|
|
||||||
auto n2 = input->sizeAt(-1) * input->sizeAt(-2);
|
auto n2 = input->sizeAt(-1) * input->sizeAt(-2);
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
cholesky(context, tempOutput, tempOutput, true);
|
std::unique_ptr<NDArray> tempOutput(input->dup());
|
||||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempOutput->getShapeInfo(), {tempOutput->rankOf() - 2, tempOutput->rankOf() - 1});
|
// auto inputs = tempOutput->allTensorsAlongDimension({input->rankOf() - 2, input->rankOf() - 1});
|
||||||
//for (Nd4jLong e = 0; e < output->lengthOf(); e++) {
|
// for (Nd4jLong e = 0; e < packX.numberOfTads(); e++) {
|
||||||
auto outputBuf = reinterpret_cast<double*>(output->specialBuffer()); // + e * n2;
|
// auto subArray = inputs->at(e);
|
||||||
logDetKernel<<<packX.numberOfTads(), n2, 128, *stream>>>(tempOutput->specialBuffer(), tempOutput->specialShapeInfo(), packX.numberOfTads(), packX.specialShapeInfo(), packX.specialOffsets(), outputBuf, output->specialShapeInfo());
|
// cholesky(context, subArray, subArray, true);
|
||||||
|
// }
|
||||||
|
// delete inputs;
|
||||||
|
cholesky(context, input, tempOutput.get(), false);
|
||||||
|
tempOutput->syncToHost();
|
||||||
|
tempOutput->printIndexedBuffer("Cholesky res!!!");
|
||||||
|
auto outputBuf = reinterpret_cast<double*>(output->specialBuffer()); // + e * n2; // + e * n2;
|
||||||
|
auto inputBuf = reinterpret_cast<double*>(tempOutput->specialBuffer());
|
||||||
|
output->assign(0);
|
||||||
|
output->syncToDevice();
|
||||||
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(tempOutput->getShapeInfo(), {input->rankOf() - 2, input->rankOf() - 1});
|
||||||
|
logDetKernel<<<packX.numberOfTads(), n2, 128, *stream>>>(inputBuf, tempOutput->specialShapeInfo(), packX.numberOfTads(), packX.specialShapeInfo(), packX.specialOffsets(), outputBuf, output->specialShapeInfo());
|
||||||
// }
|
// }
|
||||||
NDArray::registerSpecialUse({output}, {input});
|
NDArray::registerSpecialUse({output}, {input});
|
||||||
delete tempOutput;
|
//delete tempOutput;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -16,15 +16,206 @@
|
|||||||
|
|
||||||
//
|
//
|
||||||
// Created by raver119 on 19.01.18.
|
// Created by raver119 on 19.01.18.
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <ops/declarable/helpers/s_t_b.h>
|
#include <ops/declarable/helpers/s_t_b.h>
|
||||||
|
#include <PointersManager.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
|
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
__global__ static void batchToSpaceCuda(const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const uint cropBottom, const uint cropLeft) {
|
||||||
|
|
||||||
|
// input [bS, H * blockSize, W * blockSize, iC]
|
||||||
|
// output [bS, H * blockSize - cropBottom - cropTop, W * blockSize - cropLeft - cropRight, iC]
|
||||||
|
|
||||||
|
// if (cropTop = cropBottom = cropRight = cropLeft = 0) shapes are the same
|
||||||
|
// else:
|
||||||
|
// oH -> [cropBottom, iH - cropTop]
|
||||||
|
// oW -> [cropLeft, iH - cropRight]
|
||||||
|
// xLen > zLen
|
||||||
|
|
||||||
|
const auto x = reinterpret_cast<const T*>(vx);
|
||||||
|
auto z = reinterpret_cast<T*>(vz);
|
||||||
|
|
||||||
|
__shared__ int rank;
|
||||||
|
__shared__ Nd4jLong zLen, totalThreads, *sharedMem;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
|
||||||
|
extern __shared__ unsigned char shmem[];
|
||||||
|
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
|
||||||
|
|
||||||
|
rank = shape::rank(zShapeInfo);
|
||||||
|
zLen = shape::length(zShapeInfo);
|
||||||
|
totalThreads = gridDim.x * blockDim.x;
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
auto coords = sharedMem + threadIdx.x * rank;
|
||||||
|
|
||||||
|
const auto i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
|
if(i >= zLen)
|
||||||
|
return;
|
||||||
|
|
||||||
|
shape::index2coords(rank, zShapeInfo + 1, i, zLen, coords);
|
||||||
|
|
||||||
|
const auto zOffset = shape::getOffset(0, zShapeInfo + 1, zShapeInfo + rank + 1, coords, rank);
|
||||||
|
|
||||||
|
coords[1] += cropBottom;
|
||||||
|
coords[2] += cropLeft;
|
||||||
|
|
||||||
|
const auto xOffset = shape::getOffset(0, xShapeInfo + 1, xShapeInfo + rank + 1, coords, rank);
|
||||||
|
|
||||||
|
z[zOffset] = x[xOffset];
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
static void batchToSpaceCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const uint cropBottom, const uint cropLeft) {
|
||||||
|
|
||||||
|
batchToSpaceCuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vx, xShapeInfo, vz, zShapeInfo, cropBottom, cropLeft);
|
||||||
|
}
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void batchToSpaceCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const uint cropBottom, const uint cropLeft), LIBND4J_TYPES);
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
void batchToSpace(nd4j::LaunchContext* context, const NDArray& input, NDArray& output, const uint cropBottom, const uint cropTop, const uint cropLeft, const uint cropRight, const uint blockSize) {
|
||||||
|
|
||||||
|
// [bS*blockSize*blockSize, H/blockSize, W/blockSize, iC] is rearranged/permuted to [bS, oH, oW, iC]
|
||||||
|
// oH = H - cropTop - cropBottom
|
||||||
|
// oW = W - cropLeft - cropRight
|
||||||
|
|
||||||
|
NDArray inputRearranged0 = input.reshape(input.ordering(), {blockSize, blockSize, output.sizeAt(0), input.sizeAt(1), input.sizeAt(2), input.sizeAt(3)});
|
||||||
|
inputRearranged0.permutei({2, 3,0, 4,1, 5});
|
||||||
|
|
||||||
|
if(input.lengthOf() == output.lengthOf()) {
|
||||||
|
|
||||||
|
output.assign(inputRearranged0);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
|
||||||
|
NDArray inputRearranged1 = inputRearranged0.reshape(input.ordering(), {output.sizeAt(0), input.sizeAt(1) * blockSize, input.sizeAt(2) * blockSize, input.sizeAt(3)});
|
||||||
|
|
||||||
|
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||||
|
const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
const int sharedMem = threadsPerBlock * sizeof(Nd4jLong) * output.rankOf() + 128;
|
||||||
|
|
||||||
|
PointersManager manager(context, "batchToSpace");
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&output}, {&inputRearranged1});
|
||||||
|
BUILD_SINGLE_SELECTOR(input.dataType(), batchToSpaceCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), inputRearranged1.getSpecialBuffer(), inputRearranged1.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), cropBottom, cropLeft), LIBND4J_TYPES);
|
||||||
|
NDArray::registerSpecialUse({&output}, {&inputRearranged1});
|
||||||
|
|
||||||
|
manager.synchronize();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
__global__ static void spaceToBatchCuda(const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const uint padBottom, const uint padTop, const uint padLeft, const uint padRight) {
|
||||||
|
|
||||||
|
// input [bS, H * blockSize - padBottom - padTop, W * blockSize - padLeft - padRight, iC]
|
||||||
|
// output [bs, H * blockSize, W * blockSize, iC]
|
||||||
|
|
||||||
|
// if (padTop = padBottom = padRight = padLeft = 0) shapes are the same
|
||||||
|
// else:
|
||||||
|
// iH -> [padBottom, oH - padTop]
|
||||||
|
// iW -> [padLeft, oW - padRight]
|
||||||
|
// zLen > xLen
|
||||||
|
|
||||||
|
const auto x = reinterpret_cast<const T*>(vx);
|
||||||
|
auto z = reinterpret_cast<T*>(vz);
|
||||||
|
|
||||||
|
__shared__ int rank;
|
||||||
|
__shared__ Nd4jLong zLen, totalThreads, *sharedMem;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
|
||||||
|
extern __shared__ unsigned char shmem[];
|
||||||
|
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
|
||||||
|
|
||||||
|
rank = shape::rank(zShapeInfo);
|
||||||
|
zLen = shape::length(zShapeInfo);
|
||||||
|
totalThreads = gridDim.x * blockDim.x;
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
auto coords = sharedMem + threadIdx.x * rank;
|
||||||
|
|
||||||
|
const auto i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
|
if(i >= zLen)
|
||||||
|
return;
|
||||||
|
|
||||||
|
shape::index2coords(rank, zShapeInfo + 1, i, zLen, coords);
|
||||||
|
|
||||||
|
const auto zOffset = shape::getOffset(0, zShapeInfo + 1, zShapeInfo + rank + 1, coords, rank);
|
||||||
|
|
||||||
|
if(coords[1] >= padBottom && coords[1] < zShapeInfo[2] - padTop && coords[2] >= padLeft && coords[2] < zShapeInfo[3] - padRight) {
|
||||||
|
|
||||||
|
coords[1] -= padBottom;
|
||||||
|
coords[2] -= padLeft;
|
||||||
|
|
||||||
|
const auto xOffset = shape::getOffset(0, xShapeInfo + 1, xShapeInfo + rank + 1, coords, rank);
|
||||||
|
|
||||||
|
z[zOffset] = x[xOffset];
|
||||||
|
}
|
||||||
|
else
|
||||||
|
z[zOffset] = 0.f;
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
static void spaceToBatchCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const uint padBottom, const uint padTop, const uint padLeft, const uint padRight) {
|
||||||
|
|
||||||
|
spaceToBatchCuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vx, xShapeInfo, vz, zShapeInfo, padBottom, padTop, padLeft, padRight);
|
||||||
|
}
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void spaceToBatchCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const uint padBottom, const uint padTop, const uint padLeft, const uint padRight), LIBND4J_TYPES);
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
void spaceToBatch(nd4j::LaunchContext* context, const NDArray& input, NDArray& output, const uint padBottom, const uint padTop, const uint padLeft, const uint padRight, const uint blockSize) {
|
||||||
|
|
||||||
|
// [bS, iH, iW, iC] is rearranged/permuted to [bS*blockSize*blockSize, (iH + padBottom + padTop)/blockSize, (iW + padLeft + padRight)/blockSize, iC]
|
||||||
|
|
||||||
|
NDArray outputRearranged0 = output.reshape(output.ordering(), {blockSize, blockSize, input.sizeAt(0), output.sizeAt(1), output.sizeAt(2), input.sizeAt(3)});
|
||||||
|
outputRearranged0.permutei({2, 3,0, 4,1, 5});
|
||||||
|
|
||||||
|
if(input.lengthOf() == output.lengthOf()) {
|
||||||
|
|
||||||
|
outputRearranged0.assign(input);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
|
||||||
|
NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), {input.sizeAt(0), output.sizeAt(1) * blockSize, output.sizeAt(2) * blockSize, input.sizeAt(3)});
|
||||||
|
|
||||||
|
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||||
|
const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
const int sharedMem = threadsPerBlock * sizeof(Nd4jLong) * output.rankOf() + 128;
|
||||||
|
|
||||||
|
PointersManager manager(context, "spaceToBatch");
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&outputRearranged1}, {&input});
|
||||||
|
BUILD_SINGLE_SELECTOR(input.dataType(), spaceToBatchCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), outputRearranged1.specialBuffer(), outputRearranged1.specialShapeInfo(), padBottom, padTop, padLeft, padRight), LIBND4J_TYPES);
|
||||||
|
NDArray::registerSpecialUse({&outputRearranged1}, {&input});
|
||||||
|
|
||||||
|
manager.synchronize();
|
||||||
|
|
||||||
|
if(output.getSpecialBuffer() != outputRearranged1.getSpecialBuffer())
|
||||||
|
outputRearranged0.assign(outputRearranged1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/*
|
||||||
template <int N, bool B2S>
|
template <int N, bool B2S>
|
||||||
struct SpaceToBatchHelper {
|
struct SpaceToBatchHelper {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -64,11 +255,6 @@ namespace helpers {
|
|||||||
SpaceToBatchHelper<NUM_BLOCK_DIMS, B2S>::run(ptrSpace, space_shape, space_strides, block_shape, pad_start, block_offsets, ptrBatch, batch_shape, batch_strides);
|
SpaceToBatchHelper<NUM_BLOCK_DIMS, B2S>::run(ptrSpace, space_shape, space_strides, block_shape, pad_start, block_offsets, ptrBatch, batch_shape, batch_strides);
|
||||||
};
|
};
|
||||||
|
|
||||||
Nd4jStatus _spaceToBatch(nd4j::LaunchContext * context, int internal_block_dims, NDArray *input, NDArray *output, std::vector<Nd4jLong> &internal_input_shape, std::vector<Nd4jLong> &internal_output_shape, Nd4jLong *block_shape, Nd4jLong *paddings) {
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
Nd4jStatus _batchToSpace(nd4j::LaunchContext * context, int internal_block_dims, NDArray *input, NDArray *output, std::vector<Nd4jLong> &internal_input_shape, std::vector<Nd4jLong> &internal_output_shape, Nd4jLong *block_shape, Nd4jLong *crops) {
|
Nd4jStatus _batchToSpace(nd4j::LaunchContext * context, int internal_block_dims, NDArray *input, NDArray *output, std::vector<Nd4jLong> &internal_input_shape, std::vector<Nd4jLong> &internal_output_shape, Nd4jLong *block_shape, Nd4jLong *crops) {
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -86,6 +272,8 @@ namespace helpers {
|
|||||||
|
|
||||||
#undef STB_BOOL
|
#undef STB_BOOL
|
||||||
#undef STB_DIM
|
#undef STB_DIM
|
||||||
|
*/
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -19,12 +19,109 @@
|
|||||||
//
|
//
|
||||||
|
|
||||||
#include <ops/declarable/helpers/sg_cb.h>
|
#include <ops/declarable/helpers/sg_cb.h>
|
||||||
|
#include <cuda_exception.h>
|
||||||
|
#include <NDArrayFactory.h>
|
||||||
|
|
||||||
#define HS_MAX_EXP 6.0f
|
#define HS_MAX_EXP 6.0f
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
|
template <typename T>
|
||||||
|
__global__ void hSoftmaxKernel(void *vsyn0, void *vsyn1, void *vexpTable, void *vneu1e, double alpha, int vectorLength, int code, int expLength, bool isInference) {
|
||||||
|
|
||||||
|
auto syn0 = reinterpret_cast<T*>(vsyn0);
|
||||||
|
auto syn1 = reinterpret_cast<T*>(vsyn1);
|
||||||
|
auto expTable = reinterpret_cast<T*>(vexpTable);
|
||||||
|
auto neu1e = reinterpret_cast<T*>(vneu1e);
|
||||||
|
|
||||||
|
T dot(0.0f);
|
||||||
|
T g(0.0f);
|
||||||
|
T f(0.0f);
|
||||||
|
|
||||||
|
// dot
|
||||||
|
for (int e = 0; e < vectorLength; e++) {
|
||||||
|
dot += syn0[e] * syn1[e];
|
||||||
|
}
|
||||||
|
|
||||||
|
// gradient
|
||||||
|
if (dot < (T) - HS_MAX_EXP || dot >= (T) HS_MAX_EXP)
|
||||||
|
return;
|
||||||
|
|
||||||
|
|
||||||
|
int idx = static_cast<int>((dot + HS_MAX_EXP) * ((float) expLength / HS_MAX_EXP / 2.0f));
|
||||||
|
|
||||||
|
if (idx >= expLength || idx < 0)
|
||||||
|
return;
|
||||||
|
|
||||||
|
f = expTable[idx];
|
||||||
|
g = (static_cast<T>(1.0f) - static_cast<T>(code) - f) * (T) alpha;
|
||||||
|
|
||||||
|
// axpy1
|
||||||
|
|
||||||
|
for (int e = 0; e < vectorLength; e++) {
|
||||||
|
neu1e[e] = g * syn1[e] + neu1e[e];
|
||||||
|
}
|
||||||
|
|
||||||
|
// axpy2
|
||||||
|
if (!isInference) {
|
||||||
|
for (int e = 0; e < vectorLength; e++) {
|
||||||
|
syn1[e] = g * syn0[e] + syn1[e];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void hSoftmax_(void *vsyn0, void *vsyn1, void *vexpTable, void *vneu1e, double alpha, int vectorLength, int code, int expLength, bool isInference, cudaStream_t* stream) {
|
||||||
|
hSoftmaxKernel<T><<<1,1,128, *stream>>>(vsyn0, vsyn1, vexpTable, vneu1e, alpha, vectorLength, code, expLength, isInference);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void nSamplingKernel(void *vsyn0, void *vsyn1Neg, void *vexpTable, void *vneu1e, double alpha, int vectorLength, int code, int expLength, bool isInference) {
|
||||||
|
auto syn0 = reinterpret_cast<T*>(vsyn0);
|
||||||
|
auto syn1Neg = reinterpret_cast<T*>(vsyn1Neg);
|
||||||
|
auto expTable = reinterpret_cast<T*>(vexpTable);
|
||||||
|
auto neu1e = reinterpret_cast<T*>(vneu1e);
|
||||||
|
|
||||||
|
T dot = (T) 0.0f;
|
||||||
|
T g = (T) 0.0f;
|
||||||
|
|
||||||
|
for (int e = 0; e < vectorLength; e++) {
|
||||||
|
dot += syn0[e] * syn1Neg[e];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dot > HS_MAX_EXP)
|
||||||
|
g = (code - 1) * alpha;
|
||||||
|
else if (dot < (T) - HS_MAX_EXP)
|
||||||
|
g = (code - 0) * alpha;
|
||||||
|
else {
|
||||||
|
int idx = (int) ((dot + (T) HS_MAX_EXP) * ((T) expLength / HS_MAX_EXP / 2.0));
|
||||||
|
if (idx >= expLength)
|
||||||
|
return;
|
||||||
|
|
||||||
|
if (idx < 0)
|
||||||
|
return;
|
||||||
|
|
||||||
|
g = ((T) code - expTable[idx]) * alpha;
|
||||||
|
}
|
||||||
|
|
||||||
|
// axpy1
|
||||||
|
for (int e = 0; e < vectorLength; e++) {
|
||||||
|
neu1e[e] = g * syn1Neg[e] + neu1e[e];
|
||||||
|
}
|
||||||
|
|
||||||
|
// axpy2
|
||||||
|
if (!isInference) {
|
||||||
|
for (int e = 0; e < vectorLength; e++) {
|
||||||
|
syn1Neg[e] = g * syn0[e] + syn1Neg[e];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void nSampling_(void *vsyn0, void *vsyn1Neg, void *vexpTable, void *vneu1e, double alpha, int vectorLength, int code, int expLength, bool isInference, cudaStream_t* stream) {
|
||||||
|
nSamplingKernel<T><<<1,1,128, *stream>>>(vsyn0, vsyn1Neg, vexpTable, vneu1e, alpha, vectorLength, code, expLength, isInference);
|
||||||
|
}
|
||||||
|
|
||||||
int binarySearch(const int *haystack, const int needle, const int totalElements) {
|
int binarySearch(const int *haystack, const int needle, const int totalElements) {
|
||||||
return 0;
|
return 0;
|
||||||
@ -34,11 +131,392 @@ namespace nd4j {
|
|||||||
auto xType = syn0.dataType();
|
auto xType = syn0.dataType();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
template <typename T>
|
||||||
|
static __global__ void checkContextKernel(int* context, T* syn0, T* neu1, int contextWidth, int vectorLength, int vocabSize) {
|
||||||
|
__shared__ bool hasError;
|
||||||
|
if (0 == threadIdx.x) {
|
||||||
|
hasError = false;
|
||||||
|
}
|
||||||
|
auto start = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
auto step = blockDim.x * gridDim.x;
|
||||||
|
|
||||||
|
for (int c = start; c < contextWidth; c += step) {
|
||||||
|
if (context[c] >= vocabSize)
|
||||||
|
hasError = true; //throw std::runtime_error("Bad context 4");
|
||||||
|
if (!hasError) {
|
||||||
|
T *syn0word = syn0 + (context[c] * vectorLength);
|
||||||
|
|
||||||
|
for (int i = 0; i < vectorLength; i++) {
|
||||||
|
neu1[i] += syn0word[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
if (hasError)
|
||||||
|
neu1[0] = DataTypeUtils::infOrMax<T>();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void addInfVectorKernel(T* neu1, T* infVector, int vectorLength) {
|
||||||
|
auto start = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
auto step = blockDim.x * gridDim.x;
|
||||||
|
|
||||||
|
for (auto i = start; i < vectorLength; i += step) {
|
||||||
|
neu1[i] += infVector[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void shiftKernel(T* neu1, T* infVector, int contextWidth, int vectorLength) {
|
||||||
|
auto start = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
auto step = blockDim.x * gridDim.x;
|
||||||
|
|
||||||
|
for (int i = start; i < vectorLength; i += step) {
|
||||||
|
neu1[i] /= contextWidth + int(infVector != nullptr); // ? 1 : 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void fillUpSynonymsKernel(int starter, int contextWidth, int vectorLength, int* lockedWords, int* context, T* neu1e, T* syn0) {
|
||||||
|
auto start = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
auto step = blockDim.x * gridDim.x;
|
||||||
|
|
||||||
|
for (int c = starter + start; c < contextWidth; c += step) {
|
||||||
|
if (lockedWords[c] == 1)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
T *syn0word = syn0 + (context[c] * vectorLength);
|
||||||
|
|
||||||
|
for (int i = 0; i < vectorLength; i++) {
|
||||||
|
syn0word[i] += neu1e[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void cbow_(LaunchContext* lc, void *vsyn0, void *vsyn1, void *vsyn1Neg, void *vexpTable, void *vnegTable, void *vinfVector, int target, int ngStarter, int *context, int *lockedWords, int *indices, int8_t *codes, double alpha, Nd4jLong randomValue, const int contextWidth, const int hsRounds, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength, const int numLabels, const bool trainWords) {
|
||||||
|
auto syn0 = reinterpret_cast<T *>(vsyn0);
|
||||||
|
auto syn1 = reinterpret_cast<T *>(vsyn1);
|
||||||
|
auto syn1Neg = reinterpret_cast<T *>(vsyn1Neg);
|
||||||
|
auto expTable = reinterpret_cast<T *>(vexpTable);
|
||||||
|
auto negTable = reinterpret_cast<T *>(vnegTable);
|
||||||
|
auto infVector = reinterpret_cast<T *>(vinfVector);
|
||||||
|
auto stream = lc->getCudaStream();
|
||||||
|
|
||||||
|
T* neu1; // = new T[vectorLength];
|
||||||
|
T* neu1e; // = new T[vectorLength];
|
||||||
|
size_t buffSize = sizeof(T) * vectorLength;
|
||||||
|
auto err = cudaMalloc(&neu1, buffSize);
|
||||||
|
err = cudaMalloc(&neu1e, buffSize);
|
||||||
|
err = cudaMemset(neu1, 0, buffSize);
|
||||||
|
err = cudaMemset(neu1e, 0, buffSize);
|
||||||
|
|
||||||
|
// building neu1 for current window
|
||||||
|
checkContextKernel<T><<<1,1,128,*stream>>>(context, syn0, neu1, contextWidth, vectorLength, vocabSize);
|
||||||
|
|
||||||
|
T checkVal;
|
||||||
|
err = cudaMemcpy(&checkVal, neu1, sizeof(T), cudaMemcpyDeviceToHost);
|
||||||
|
if (DataTypeUtils::infOrMax<T>() == checkVal)
|
||||||
|
throw std::runtime_error("Bad context 4");
|
||||||
|
// for inference we add additional inference vector
|
||||||
|
if (infVector != nullptr) {
|
||||||
|
addInfVectorKernel<T><<<128, 256, 128, *stream>>>(neu1, infVector, vectorLength);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// average neu1
|
||||||
|
if (contextWidth > 0) {
|
||||||
|
shiftKernel<T><<<128, 256, 128, *stream>>>(neu1, infVector, contextWidth, vectorLength);
|
||||||
|
}
|
||||||
|
|
||||||
|
// softmax round
|
||||||
|
if (hsRounds > 0) {
|
||||||
|
for (int i = 0; i < hsRounds; i++) {
|
||||||
|
if (indices[i] < 0 || indices[i] >= vocabSize)
|
||||||
|
throw std::runtime_error("Bad context 5");
|
||||||
|
T* syn1Shifted = syn1 + (indices[i] * vectorLength);
|
||||||
|
hSoftmax_<T>(neu1, syn1Shifted, expTable, neu1e, alpha, vectorLength, codes[i], expLength, infVector != nullptr, stream);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto nsStarter = ngStarter;
|
||||||
|
auto irow = nsStarter;
|
||||||
|
if (nsRounds > 0) {
|
||||||
|
for (int r = 0; r < nsRounds + 1; r++) {
|
||||||
|
if (r == 0) {
|
||||||
|
// target is known in advance
|
||||||
|
} else {
|
||||||
|
randomValue = randomValue * (unsigned long long) 25214903917 + 11;
|
||||||
|
auto idx = nd4j::math::nd4j_abs<Nd4jLong >((randomValue >> 16) % negLength);
|
||||||
|
irow = idx >= negLength ? -1 : static_cast<int>(negTable[idx]);
|
||||||
|
|
||||||
|
if (irow < 0 || irow >= vocabSize) irow = randomValue % (vocabSize - 1) + 1;
|
||||||
|
if (irow == nsStarter)
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
nSampling_<T>(neu1, syn1Neg + (irow * vectorLength), expTable, neu1e, alpha, vectorLength, r == 0 ? 1 : 0, expLength, infVector != nullptr, stream);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// if we don't train words - we skip start of idxSyn0
|
||||||
|
int starter = trainWords == 1 ? 0 : contextWidth - numLabels;
|
||||||
|
|
||||||
|
// propagate neu1e -> syn0
|
||||||
|
if (infVector == nullptr) {
|
||||||
|
fillUpSynonymsKernel<T><<<1,1,128, *stream>>>(starter, contextWidth, vectorLength, lockedWords, context, neu1e, syn0);
|
||||||
|
} else {
|
||||||
|
|
||||||
|
for (int i = 0; i < vectorLength; i++) {
|
||||||
|
infVector[i] += neu1e[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = cudaFree(neu1);
|
||||||
|
err = cudaFree(neu1e);
|
||||||
|
}
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void cbow_, (LaunchContext* lc, void *syn0, void *syn1, void *syn1Neg, void *expTable, void *vnegTable, void *vinfVector, int target, int ngStarter, int *context, int *lockedWords, int *indices, int8_t *codes, double alpha, Nd4jLong randomValue, const int contextWidth, const int hsRounds, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength, const int numLabels, const bool trainWords), FLOAT_TYPES);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static __global__ void buildCurrentWindowKernel(int vocabSize, int contextWidth, int vectorLength, int* bContext, T* syn0, T* neu1, int* actualContext, int e) {
|
||||||
|
// building neu1 for current window
|
||||||
|
auto start = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
auto step = blockDim.x * gridDim.x;
|
||||||
|
|
||||||
|
for (int c = start; c < contextWidth; c += step) {
|
||||||
|
// getting next context word
|
||||||
|
auto cContext = bContext[c + (e * contextWidth)];
|
||||||
|
|
||||||
|
// skipping padded values
|
||||||
|
if (cContext < 0)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
// if (cContext >= vocabSize)
|
||||||
|
// throw std::runtime_error("ContextID can't be >= vocab size");
|
||||||
|
|
||||||
|
T *syn0word = syn0 + (cContext * vectorLength);
|
||||||
|
|
||||||
|
for (int i = 0; i < vectorLength; i++)
|
||||||
|
neu1[i] += syn0word[i];
|
||||||
|
|
||||||
|
atomicAdd(actualContext, 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void arrangeNeuKernel(int vectorLength, T* neu1, T* infVector, int* actualContext) {
|
||||||
|
auto start = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
auto step = blockDim.x * gridDim.x;
|
||||||
|
|
||||||
|
for (int i = start; i < vectorLength && *actualContext > 0; i += step)
|
||||||
|
neu1[i] /= (*actualContext + int(infVector != nullptr));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void applyShiftKernel(int* bContext, int* bLocker, T* syn0, T* neu1e, int contextWidth, int vectorLength, int e, int starter) {
|
||||||
|
auto step = blockDim.x * gridDim.x;
|
||||||
|
auto start = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
|
for (int c = starter + start; c < contextWidth; c += step) {
|
||||||
|
// getting context
|
||||||
|
auto cContext = bContext[c + (e * contextWidth)];
|
||||||
|
auto cLock = bLocker[c + (e * contextWidth)];
|
||||||
|
|
||||||
|
// skipping padded values
|
||||||
|
if (cContext < 0 || cLock == 1)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
// if (cContext >= vocabSize)
|
||||||
|
// throw std::runtime_error("ContextID can't be > vocab size");
|
||||||
|
|
||||||
|
// one word from context
|
||||||
|
T *syn0word = syn0 + (cContext * vectorLength);
|
||||||
|
|
||||||
|
for (int i = 0; i < vectorLength; i++)
|
||||||
|
syn0word[i] += neu1e[i];
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void cbowBatchExec_(LaunchContext* lc, NDArray &s0, NDArray &s1, NDArray &s1n, void *vexpTable, void *vnegTable, void *vinfVector, NDArray &context, NDArray &lockedWords, NDArray &targets, NDArray &negStarters, NDArray &indices, NDArray &codes, NDArray &lr, NDArray &nextRandom, NDArray &nLabels, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength, const bool trainWords, const int numThreads) {
|
||||||
|
const auto syn0 = reinterpret_cast<T*>(s0.specialBuffer()); //bufferAsT<T>();
|
||||||
|
const auto syn1 = reinterpret_cast<T*>(s1.specialBuffer()); //bufferAsT<T>();
|
||||||
|
const auto syn1Neg = reinterpret_cast<T*>(s1n.specialBuffer()); //bufferAsT<T>();
|
||||||
|
|
||||||
|
const auto expTable = reinterpret_cast<T*>(vexpTable);
|
||||||
|
const auto negTable = reinterpret_cast<T*>(vnegTable);
|
||||||
|
const auto infVector = reinterpret_cast<T*>(vinfVector);
|
||||||
|
|
||||||
|
auto stream = lc->getCudaStream();
|
||||||
|
|
||||||
|
indices.syncToHost();
|
||||||
|
codes.syncToHost();
|
||||||
|
negStarters.syncToHost();
|
||||||
|
context.syncToHost();
|
||||||
|
|
||||||
|
//const auto numThreads = omp_get_max_threads();
|
||||||
|
const auto idxShift = indices.isEmpty() ? 0 : indices.sizeAt(1);
|
||||||
|
const auto hsRounds = codes.isEmpty() ? 0 : codes.sizeAt(1);
|
||||||
|
const auto numTargets = context.sizeAt(0);
|
||||||
|
const int contextWidth = context.sizeAt(1);
|
||||||
|
const auto bContext = reinterpret_cast<int*>(context.buffer()); //bufferAsT<int>();
|
||||||
|
const auto dContext = reinterpret_cast<int*>(context.specialBuffer()); //bufferAsT<int>();
|
||||||
|
const auto bLocker = reinterpret_cast<int*>(lockedWords.buffer()); //lockedWords.bufferAsT<int>();
|
||||||
|
const auto dLocker = reinterpret_cast<int*>(lockedWords.specialBuffer()); //lockedWords.bufferAsT<int>();
|
||||||
|
const auto bIndices = reinterpret_cast<int*>(indices.buffer());//AsT<int>();
|
||||||
|
const auto bCodes = reinterpret_cast<int8_t*>(codes.buffer()); //bufferAsT<int8_t>();
|
||||||
|
const auto bStarters = reinterpret_cast<int*>(negStarters.buffer()); //AsT<int>();
|
||||||
|
const auto numIndices = indices.isEmpty() ? 0 : indices.sizeAt(1);
|
||||||
|
lr.syncToHost();
|
||||||
|
nLabels.syncToHost();
|
||||||
|
//PRAGMA_OMP_PARALLEL_FOR_ARGS(num_threads(numThreads) private(sneu1, sneu1e))
|
||||||
|
//NDArray neuVector('c', {vectorLength}, DataTypeUtils::fromT<T>());
|
||||||
|
// auto neuEVector = neuVector; //NDArrayFactory::create<T>('c', {vectorLength});
|
||||||
|
T* neu1; // = reinterpret_cast<T*>(neuVector.specialBuffer());// = vectorLength <= 600 ? sneu1 : new T[vectorLength];
|
||||||
|
T* neu1e; // = reinterpret_cast<T*>(neuVector.specialBuffer()); // = vectorLength <= 600 ? sneu1e : new T[vectorLength];
|
||||||
|
auto cerr = cudaMalloc(&neu1, sizeof(T) * vectorLength);
|
||||||
|
if (cerr) {
|
||||||
|
throw cuda_exception::build("Cannot allocate temp vector buffer", cerr);
|
||||||
|
}
|
||||||
|
cerr = cudaMalloc(&neu1e, sizeof(T) * vectorLength);
|
||||||
|
if (cerr) {
|
||||||
|
throw cuda_exception::build("Cannot allocate temp vector buffer", cerr);
|
||||||
|
}
|
||||||
|
int* actualContext;
|
||||||
|
cerr = cudaMalloc(&actualContext, sizeof(int));
|
||||||
|
if (cerr) {
|
||||||
|
throw cuda_exception::build("Cannot allocate counter buffer", cerr);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int e = 0; e < numTargets; e++) {
|
||||||
|
|
||||||
|
// auto err = cudaMalloc(&neu1, sizeof(T)* vectorLength);
|
||||||
|
// q err = cudaMalloc(&neu1e, sizeof(T)*vectorLength);
|
||||||
|
//
|
||||||
|
// // optionally we nullify temp arrays after successful (and on first) cycle
|
||||||
|
// memset(neu1, 0, sizeof(T) * vectorLength);
|
||||||
|
// memset(neu1e, 0, sizeof(T) * vectorLength);
|
||||||
|
|
||||||
|
auto alpha = lr.e<double>(e);
|
||||||
|
auto numLabels = nLabels.isEmpty() ? 0 : nLabels.e<int>(e);
|
||||||
|
|
||||||
|
// auto err = cudaMemset(actualContext, 0, sizeof(int));
|
||||||
|
// if (err) {
|
||||||
|
// printf("Cuda error %d\n", err); break;
|
||||||
|
// }
|
||||||
|
|
||||||
|
buildCurrentWindowKernel<T><<<1,1,128, *stream>>>(vocabSize, contextWidth, vectorLength, dContext, syn0, neu1, actualContext, e);
|
||||||
|
arrangeNeuKernel<T><<<1,1,128, *stream>>>(vectorLength, neu1, infVector, actualContext);
|
||||||
|
|
||||||
|
// hierarchic softmax step
|
||||||
|
if (!indices.isEmpty()) {
|
||||||
|
for (int i = 0; i < numIndices; i++) {
|
||||||
|
const int cIndex = bIndices[(e * numIndices) + i];
|
||||||
|
const int cCode = bCodes[(e * numIndices) + i];
|
||||||
|
|
||||||
|
// we're skipping padded values
|
||||||
|
if (cIndex < 0)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
if (cIndex >= vocabSize)
|
||||||
|
throw std::runtime_error("Index can't be > vocab size");
|
||||||
|
|
||||||
|
hSoftmax_<T>(neu1, syn1 + (cIndex * vectorLength), expTable, neu1e, alpha, vectorLength, cCode, expLength, false, stream);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// negative sampling step
|
||||||
|
if (!negStarters.isEmpty() && nsRounds > 0) {
|
||||||
|
int irow = bStarters[e];
|
||||||
|
const int nsStarter = irow;
|
||||||
|
unsigned long long randomValue = nextRandom.e<Nd4jLong>(e);
|
||||||
|
|
||||||
|
for (int r = 0; r < nsRounds + 1; r++) {
|
||||||
|
// we're skipping rng on 0 step
|
||||||
|
if (r != 0) {
|
||||||
|
randomValue = randomValue * (unsigned long long) 25214903917 + 11;
|
||||||
|
auto idx = nd4j::math::nd4j_abs<Nd4jLong>((randomValue >> 16) % negLength);
|
||||||
|
irow = idx >= negLength ? -1 : static_cast<int>(negTable[idx]);
|
||||||
|
|
||||||
|
if (irow < 0 || irow >= vocabSize) irow = randomValue % (vocabSize - 1) + 1;
|
||||||
|
if (irow == nsStarter)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
nSampling_<T>(neu1, s1n.bufferWithOffset(irow * vectorLength), expTable, neu1e, alpha, vectorLength, r == 0 ? 1 : 0, expLength, infVector != nullptr, stream);
|
||||||
|
} else {
|
||||||
|
nSampling_<T>(neu1, s1n.bufferWithOffset(irow * vectorLength), expTable, neu1e, alpha, vectorLength, r == 0 ? 1 : 0, expLength, infVector != nullptr, stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
//nd4j_printf("Thread <%i>: syn0: [%i]; s1n: [%i];\n", omp_get_thread_num(), 0, irow);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// if we're skipping labels
|
||||||
|
int starter = trainWords == 1 ? 0 : contextWidth - numLabels;
|
||||||
|
|
||||||
|
// applying previously averaged results
|
||||||
|
applyShiftKernel<T><<<1,1,128, *stream>>>(dContext, dLocker, syn0, neu1e, contextWidth, vectorLength, e, starter);
|
||||||
|
|
||||||
|
// optionally release temp arrays
|
||||||
|
// if (vectorLength > 600) {
|
||||||
|
// }
|
||||||
|
|
||||||
|
}
|
||||||
|
cerr = cudaFree(neu1);
|
||||||
|
if (cerr) {
|
||||||
|
throw cuda_exception::build("Cannot deallocate temp buffer1", cerr);
|
||||||
|
}
|
||||||
|
cerr = cudaFree(neu1e);
|
||||||
|
if (cerr) {
|
||||||
|
throw cuda_exception::build("Cannot deallocate temp buffer1 E", cerr);
|
||||||
|
}
|
||||||
|
cerr = cudaFree(actualContext);
|
||||||
|
if (cerr) {
|
||||||
|
throw cuda_exception::build("Cannot deallocate temp buffer1", cerr);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void cbowBatchExec_, (LaunchContext* lc, NDArray &s0, NDArray &s1, NDArray &s1n, void *vexpTable, void *vnegTable, void *vinfVector, NDArray &context, NDArray &lockedWords, NDArray &targets, NDArray &negStarters, NDArray &indices, NDArray &codes, NDArray &lr, NDArray &nextRandom, NDArray &nLabels, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength, const bool trainWords, const int numThreads), FLOAT_TYPES);
|
||||||
|
|
||||||
void cbow(NDArray &syn0, NDArray &syn1, NDArray &syn1Neg, NDArray &expTable, NDArray &negTable, NDArray &target, NDArray &ngStarter, int nsRounds, NDArray &context, NDArray &lockedWords, NDArray &indices, NDArray &codes, NDArray &alpha, NDArray &randomValue, NDArray &numLabels, NDArray &inferenceVector, const bool trainWords, int numWorkers) {
|
void cbow(NDArray &syn0, NDArray &syn1, NDArray &syn1Neg, NDArray &expTable, NDArray &negTable, NDArray &target, NDArray &ngStarter, int nsRounds, NDArray &context, NDArray &lockedWords, NDArray &indices, NDArray &codes, NDArray &alpha, NDArray &randomValue, NDArray &numLabels, NDArray &inferenceVector, const bool trainWords, int numWorkers) {
|
||||||
auto xType = syn0.dataType();
|
auto xType = syn0.dataType();
|
||||||
|
auto lc = context.getContext();
|
||||||
|
indices.syncToHost();
|
||||||
|
NDArray::prepareSpecialUse({&syn0, &syn1, &syn1Neg, &expTable, &negTable, &target, &ngStarter}, {&context, &lockedWords, &indices, &codes, &alpha, &randomValue, &numLabels, &inferenceVector});
|
||||||
|
//auto stream = lc->getCudaStream();
|
||||||
|
if ((context.rankOf() == 0 || context.rankOf() == 1) && (indices.rankOf() == 1 || indices.rankOf() == 0)) {
|
||||||
|
// single round case
|
||||||
|
/*nd4j_printf("Row exec; ContextWidth: %i; LockedWords: %i; numLabels: %i; Train words: %i\n", (int) context.lengthOf(), (int) lockedWords.lengthOf(), numLabels.isEmpty() ? 0 : numLabels.e<int>(0), (int) trainWords);
|
||||||
|
if (context.lengthOf() == 2) {
|
||||||
|
context.printBuffer("context");
|
||||||
|
lockedWords.printBuffer("locked");
|
||||||
|
codes.printBuffer("codes");
|
||||||
|
indices.printBuffer("indices");
|
||||||
|
}*/
|
||||||
|
|
||||||
|
auto hsRounds = codes.lengthOf();
|
||||||
|
target.syncToHost();
|
||||||
|
numLabels.syncToHost();
|
||||||
|
target.syncToHost();
|
||||||
|
alpha.syncToHost();
|
||||||
|
numLabels.syncToHost();
|
||||||
|
codes.syncToHost();
|
||||||
|
negTable.syncToHost();
|
||||||
|
BUILD_SINGLE_SELECTOR(xType, cbow_, (lc, syn0.specialBuffer(), syn1.specialBuffer(), syn1Neg.specialBuffer(), expTable.specialBuffer(), negTable.buffer(), inferenceVector.specialBuffer(), target.isEmpty() ? -1 : target.e<int>(0), ngStarter.isEmpty() ? -1 : ngStarter.e<int>(0), reinterpret_cast<int *>(context.specialBuffer()), reinterpret_cast<int *>(lockedWords.specialBuffer()),reinterpret_cast<int *>(indices.buffer()), reinterpret_cast<int8_t *>(codes.buffer()), alpha.e<double>( 0), randomValue.e<Nd4jLong>(0), (int) context.lengthOf(), hsRounds, nsRounds, (int) syn0.sizeAt(0), (int) syn0.sizeAt(1), (int) expTable.lengthOf(), (int) negTable.lengthOf(), numLabels.isEmpty() ? 0 : numLabels.e<int>(0), trainWords), FLOAT_TYPES);
|
||||||
|
} else if (context.rankOf() == 2 && indices.rankOf() == 2) {
|
||||||
|
// batch mode
|
||||||
|
//nd4j_printf("Batch exec\n","");
|
||||||
|
|
||||||
|
BUILD_SINGLE_SELECTOR(xType, cbowBatchExec_, (lc, syn0, syn1, syn1Neg, expTable.specialBuffer(), negTable.specialBuffer(), nullptr, context, lockedWords, target, ngStarter, indices, codes, alpha, randomValue, numLabels, nsRounds, syn0.sizeAt(0), syn0.sizeAt(1), expTable.lengthOf(), negTable.isEmpty() ? 0 : negTable.lengthOf(), trainWords, numWorkers), FLOAT_TYPES);
|
||||||
|
} else
|
||||||
|
throw std::runtime_error("CBOW: context must have rank 0/1 or 2");
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({&syn0, &syn1, &syn1Neg, &expTable, &negTable, &target, &ngStarter}, {&context, &lockedWords, &indices, &codes, &alpha, &randomValue, &numLabels, &inferenceVector});
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
@ -32,75 +32,71 @@ namespace helpers {
|
|||||||
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static __global__ void stackKernel(void** inputList, void** inputShapeList, int inputListLength, Nd4jLong arrLen, void* outputBuffer, Nd4jLong* tadShape, Nd4jLong *tadOffsets) { //, Nd4jLong* tadShape, Nd4jLong* tadOffsets) {
|
static __global__ void stackKernel(void** inputList, void** inputShapeList, int inputListLength, Nd4jLong arrLen, void* vz, const Nd4jLong* zShapeInfo, Nd4jLong* tadShape, Nd4jLong *tadOffsets) {
|
||||||
|
|
||||||
__shared__ int arrIdx, blocksPerArr;
|
T* z = reinterpret_cast<T*>(vz);
|
||||||
__shared__ T *z;
|
|
||||||
__shared__ Nd4jLong *zShapeInfo, *xShapeInfo, arrLenPerBlock, start, end, offsetZ, zLength;
|
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
if(tadShape == nullptr) { // scalar case
|
||||||
z = reinterpret_cast<T*>(outputBuffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
__syncthreads();
|
for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < inputListLength; i += gridDim.x * blockDim.x)
|
||||||
|
z[shape::getIndexOffset(i, zShapeInfo, inputListLength)] = reinterpret_cast<T*>(inputList[i])[0];
|
||||||
for (int t = blockIdx.x; t < inputListLength; t += gridDim.x) {
|
|
||||||
auto tZ = z + tadOffsets[t];
|
|
||||||
auto tX = reinterpret_cast<T*>(inputList[t]);
|
|
||||||
auto xShape = reinterpret_cast<Nd4jLong*>(inputShapeList[t]);
|
|
||||||
|
|
||||||
for (int e = threadIdx.x; e < arrLen; e += blockDim.x) {
|
|
||||||
tZ[shape::getIndexOffset(e, tadShape, arrLen)] = tX[shape::getIndexOffset(e, xShape, arrLen)];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
///////////////////////////////////////////////////////////////////
|
|
||||||
template <typename T>
|
|
||||||
static void stack_(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray* outArr, const int dim) {
|
|
||||||
if(inArrs[0]->isScalar()) {
|
|
||||||
|
|
||||||
//#pragma omp parallel for
|
|
||||||
for (size_t i = 0; i < inArrs.size(); ++i) {
|
|
||||||
inArrs[i]->syncToHost();
|
|
||||||
|
|
||||||
outArr->p(i, inArrs[i]->e<T>(0));
|
|
||||||
}
|
|
||||||
outArr->syncToDevice();
|
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
//Nd4jLong **dInShapeInfo;
|
|
||||||
//void **dInBuffers;
|
for (int t = blockIdx.x; t < inputListLength; t += gridDim.x) {
|
||||||
|
|
||||||
|
auto tZ = z + tadOffsets[t];
|
||||||
|
auto tX = reinterpret_cast<T*>(inputList[t]);
|
||||||
|
auto xShapeInfo = reinterpret_cast<Nd4jLong*>(inputShapeList[t]);
|
||||||
|
|
||||||
|
for (int e = threadIdx.x; e < arrLen; e += blockDim.x)
|
||||||
|
tZ[shape::getIndexOffset(e, tadShape, arrLen)] = tX[shape::getIndexOffset(e, xShapeInfo, arrLen)];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template <typename T>
|
||||||
|
static void stack_(nd4j::LaunchContext * context, const std::vector<const NDArray*>& inArrs, NDArray* outArr, const int dim) {
|
||||||
|
|
||||||
|
const bool scalarCase = inArrs[0]->isScalar();
|
||||||
|
|
||||||
|
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||||
|
const int blocksPerGrid = scalarCase ? (outArr->lengthOf() + threadsPerBlock - 1) / threadsPerBlock : inArrs.size();
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({outArr}, inArrs);
|
||||||
|
|
||||||
std::vector<void const*> inputList(inArrs.size());
|
std::vector<void const*> inputList(inArrs.size());
|
||||||
std::vector<Nd4jLong const*> inputShapeList(inArrs.size());
|
std::vector<Nd4jLong const*> inputShapeList(inArrs.size());
|
||||||
auto stream = context->getCudaStream();
|
|
||||||
|
|
||||||
for (size_t i = 0; i < inputList.size(); ++i) {
|
for (size_t i = 0; i < inputList.size(); ++i) {
|
||||||
inputList[i] = inArrs[i]->getSpecialBuffer();
|
inputList[i] = inArrs[i]->getSpecialBuffer();
|
||||||
inputShapeList[i] = inArrs[i]->getSpecialShapeInfo();
|
inputShapeList[i] = inArrs[i]->getSpecialShapeInfo();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int> axis = ShapeUtils::evalDimsToExclude(outArr->rankOf(), {dim});
|
|
||||||
|
|
||||||
|
|
||||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(outArr->getShapeInfo(), axis);
|
|
||||||
|
|
||||||
|
|
||||||
PointersManager manager(context, "helpers::stack");
|
PointersManager manager(context, "helpers::stack");
|
||||||
auto dInBuffers = (void **) manager.replicatePointer(inputList.data(), inputList.size() * sizeof(Nd4jLong*));
|
auto dInBuffers = (void **) manager.replicatePointer(inputList.data(), inputList.size() * sizeof(Nd4jLong*));
|
||||||
auto dInShapeInfo = (void **) manager.replicatePointer(inputShapeList.data(), inputShapeList.size() * sizeof(Nd4jLong*));
|
auto dInShapeInfo = (void **) manager.replicatePointer(inputShapeList.data(), inputShapeList.size() * sizeof(Nd4jLong*));
|
||||||
|
|
||||||
dim3 launchDims(inArrs.size(), inArrs[0]->lengthOf(), 1024);
|
if(scalarCase) {
|
||||||
|
stackKernel<T><<<blocksPerGrid, threadsPerBlock, 1024, *context->getCudaStream()>>>((void**)dInBuffers, (void**)dInShapeInfo, inputList.size(), inArrs[0]->lengthOf(), outArr->specialBuffer(), outArr->getSpecialShapeInfo(), nullptr, nullptr);
|
||||||
stackKernel<T><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>((void**)dInBuffers, (void**)dInShapeInfo, inputList.size(), inArrs[0]->lengthOf(), outArr->specialBuffer(), packX.specialShapeInfo(), packX.specialOffsets()); //, dTadShape, dTadOffsets);
|
}
|
||||||
|
else {
|
||||||
|
std::vector<int> axis = ShapeUtils::evalDimsToExclude(outArr->rankOf(), {dim});
|
||||||
|
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(outArr->getShapeInfo(), axis);
|
||||||
|
stackKernel<T><<<blocksPerGrid, threadsPerBlock, 1024, *context->getCudaStream()>>>((void**)dInBuffers, (void**)dInShapeInfo, inputList.size(), inArrs[0]->lengthOf(), outArr->specialBuffer(), nullptr, packZ.specialShapeInfo(), packZ.specialOffsets());
|
||||||
|
}
|
||||||
manager.synchronize();
|
manager.synchronize();
|
||||||
}
|
|
||||||
|
NDArray::registerSpecialUse({outArr}, inArrs);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void stack(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray* outArr, const int dim) {
|
void stack(nd4j::LaunchContext * context, const std::vector<const NDArray*>& inArrs, NDArray* outArr, const int dim) {
|
||||||
BUILD_SINGLE_SELECTOR(outArr->dataType(), stack_, (context, inArrs, outArr, dim), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR(outArr->dataType(), stack_, (context, inArrs, outArr, dim), LIBND4J_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void stack_ , (nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray* outArr, const int dim), LIBND4J_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void stack_ , (nd4j::LaunchContext * context, const std::vector<const NDArray*>& inArrs, NDArray* outArr, const int dim), LIBND4J_TYPES);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -70,6 +70,7 @@ __host__ static void concatCudaLauncher(const int numOfArrs, const cudaStream_t
|
|||||||
|
|
||||||
concatCuda<T><<<512, 256, 1024, *stream>>>(numOfArrs, pVx, pxShapeInfo, pVz, pzShapeInfo);
|
concatCuda<T><<<512, 256, 1024, *stream>>>(numOfArrs, pVx, pxShapeInfo, pVz, pzShapeInfo);
|
||||||
}
|
}
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void concatCudaLauncher, (const int numOfArrs, const cudaStream_t *stream, void* pVx, void* pxShapeInfo, void* pVz, void* pzShapeInfo), LIBND4J_TYPES);
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
// x - input, y - paddings, z - output
|
// x - input, y - paddings, z - output
|
||||||
@ -167,6 +168,7 @@ static void padCudaLauncher(const int blocksPerGrid, const int threadsPerBlock,
|
|||||||
|
|
||||||
padCuda<X,Y><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(mode, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, padVal);
|
padCuda<X,Y><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(mode, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, padVal);
|
||||||
}
|
}
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template void padCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const int mode, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, const void* vPadVal), LIBND4J_TYPES, INTEGER_TYPES);
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
void pad(nd4j::LaunchContext * context, const int mode, const NDArray& input, const NDArray& paddings, NDArray& output, const NDArray& padValue) {
|
void pad(nd4j::LaunchContext * context, const int mode, const NDArray& input, const NDArray& paddings, NDArray& output, const NDArray& padValue) {
|
||||||
@ -553,6 +555,396 @@ void scatterUpdate(nd4j::LaunchContext* context, NDArray& input, NDArray& update
|
|||||||
manager.synchronize();
|
manager.synchronize();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
// x - input, y - indices, z - output
|
||||||
|
template<typename X, typename Y>
|
||||||
|
__global__ static void gatherNDCuda(const void *vx, const Nd4jLong *xShapeInfo,
|
||||||
|
const void *vy, const Nd4jLong *yShapeInfo,
|
||||||
|
void *vz, const Nd4jLong *zShapeInfo) {
|
||||||
|
|
||||||
|
const auto x = reinterpret_cast<const X*>(vx);
|
||||||
|
const auto y = reinterpret_cast<const Y*>(vy);
|
||||||
|
auto z = reinterpret_cast<X*>(vz);
|
||||||
|
|
||||||
|
__shared__ int xRank, yRank, zRank, maxRank, yLastDim;
|
||||||
|
__shared__ Nd4jLong zLen, totalThreads, *sharedMem;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
|
||||||
|
extern __shared__ unsigned char shmem[];
|
||||||
|
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
|
||||||
|
|
||||||
|
xRank = shape::rank(xShapeInfo);
|
||||||
|
yRank = shape::rank(yShapeInfo);
|
||||||
|
zRank = shape::rank(zShapeInfo);
|
||||||
|
maxRank = nd4j::math::nd4j_max<int>(yRank, nd4j::math::nd4j_max<int>(xRank, zRank));
|
||||||
|
|
||||||
|
zLen = shape::length(zShapeInfo);
|
||||||
|
yLastDim = yShapeInfo[yRank];
|
||||||
|
|
||||||
|
totalThreads = gridDim.x * blockDim.x;
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
auto coord = sharedMem + threadIdx.x * maxRank;
|
||||||
|
|
||||||
|
Nd4jLong *zCoordStart, *xCoordStart;
|
||||||
|
|
||||||
|
if(yLastDim == xRank) {
|
||||||
|
zCoordStart = coord;
|
||||||
|
xCoordStart = coord;
|
||||||
|
}
|
||||||
|
if(zRank >= xRank) {
|
||||||
|
zCoordStart = coord;
|
||||||
|
xCoordStart = coord + zRank - xRank;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
zCoordStart = coord + xRank - zRank;
|
||||||
|
xCoordStart = coord;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
|
for (Nd4jLong i = tid; i < zLen; i += totalThreads) {
|
||||||
|
|
||||||
|
shape::index2coords(zRank, zShapeInfo + 1, i, zLen, zCoordStart);
|
||||||
|
|
||||||
|
const auto zOffset = shape::getOffset(0, zShapeInfo + 1, zShapeInfo + zRank + 1, zCoordStart, zRank);
|
||||||
|
|
||||||
|
// last y coordinate
|
||||||
|
int coordToRestore;
|
||||||
|
if(yLastDim != xRank)
|
||||||
|
coordToRestore = static_cast<int>(zCoordStart[yRank - 1]);
|
||||||
|
|
||||||
|
zCoordStart[yRank - 1] = 0; // last y coordinate
|
||||||
|
const auto yOffset = shape::getOffset(0, yShapeInfo + 1, yShapeInfo + yRank + 1, zCoordStart, yRank);
|
||||||
|
|
||||||
|
//restore z coordinate
|
||||||
|
if(yLastDim != xRank)
|
||||||
|
zCoordStart[yRank - 1] = coordToRestore;
|
||||||
|
|
||||||
|
// construct coordinates for x
|
||||||
|
for(uint j = 0; j < yLastDim; ++j)
|
||||||
|
xCoordStart[j] = y[yOffset + j * yShapeInfo[2 * yRank]]; // last stride
|
||||||
|
|
||||||
|
const auto xOffset = shape::getOffset(0, xShapeInfo + 1, xShapeInfo + xRank + 1, xCoordStart, xRank);
|
||||||
|
|
||||||
|
z[zOffset] = x[xOffset];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template<typename X, typename Y>
|
||||||
|
static void gatherNDCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream,
|
||||||
|
const void *vx, const Nd4jLong *xShapeInfo,
|
||||||
|
const void *vy, const Nd4jLong *yShapeInfo,
|
||||||
|
void *vz, const Nd4jLong *zShapeInfo) {
|
||||||
|
|
||||||
|
gatherNDCuda<X,Y><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo);
|
||||||
|
}
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template void gatherNDCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo), LIBND4J_TYPES, INTEGER_TYPES);
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
void gatherND(nd4j::LaunchContext * context, NDArray& input, NDArray& indices, NDArray& output) {
|
||||||
|
|
||||||
|
const int maxRank = nd4j::math::nd4j_max<int>(indices.rankOf(), nd4j::math::nd4j_max<int>(input.rankOf(), output.rankOf()));
|
||||||
|
|
||||||
|
const int threadsPerBlock = MAX_NUM_THREADS;
|
||||||
|
const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
const int sharedMem = 8 * threadsPerBlock * maxRank + 128;
|
||||||
|
|
||||||
|
const auto xType = input.dataType();
|
||||||
|
const auto yType = indices.dataType();
|
||||||
|
|
||||||
|
PointersManager manager(context, "gatherND");
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&output}, {&input, &indices});
|
||||||
|
BUILD_DOUBLE_SELECTOR(xType, yType, gatherNDCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), indices.getSpecialBuffer(), indices.getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo()), LIBND4J_TYPES, INTEGER_TYPES);
|
||||||
|
NDArray::registerSpecialUse({&output}, {&input, &indices});
|
||||||
|
|
||||||
|
manager.synchronize();
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
// x - input, y - gradO, z - gradI
|
||||||
|
template<typename X, typename Z>
|
||||||
|
__global__ static void clipByNormBPWholeArrCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, void* vreducBuff, const Z clipNormVal) {
|
||||||
|
|
||||||
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
|
if(tid >= shape::length(zShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
|
const auto x = reinterpret_cast<const X*>(vx);
|
||||||
|
const auto y = reinterpret_cast<const Z*>(vy);
|
||||||
|
auto z = reinterpret_cast<Z*>(vz);
|
||||||
|
|
||||||
|
auto reducBuff = reinterpret_cast<Z*>(vreducBuff);
|
||||||
|
uint* count = reinterpret_cast<uint*>(vreducBuff) + 16384;
|
||||||
|
|
||||||
|
__shared__ Z* shMem;
|
||||||
|
__shared__ Nd4jLong len;
|
||||||
|
__shared__ bool amIinLastBlock;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
extern __shared__ unsigned char shmem[];
|
||||||
|
shMem = reinterpret_cast<Z*>(shmem);
|
||||||
|
|
||||||
|
len = shape::length(zShapeInfo); // xLen = yLen = zLen
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// fill shared memory with array elements
|
||||||
|
const auto xVal = x[shape::getIndexOffset(tid, xShapeInfo, len)];
|
||||||
|
const auto yVal = y[shape::getIndexOffset(tid, yShapeInfo, len)];
|
||||||
|
|
||||||
|
shMem[2*threadIdx.x] = static_cast<Z>(xVal * xVal); // for norm
|
||||||
|
shMem[2*threadIdx.x + 1] = static_cast<Z>(xVal * yVal); // for input * gradO
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// accumulate sum per block
|
||||||
|
for (int activeThreads = blockDim.x / 2; activeThreads > 0; activeThreads /= 2) {
|
||||||
|
|
||||||
|
if (threadIdx.x < activeThreads && tid + activeThreads < len) {
|
||||||
|
|
||||||
|
shMem[2*threadIdx.x] += shMem[2*(threadIdx.x + activeThreads)];
|
||||||
|
shMem[2*threadIdx.x + 1] += shMem[2*(threadIdx.x + activeThreads) + 1];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
// store accumulated sums in reduction buffer (reducBuff)
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
|
||||||
|
reducBuff[2*blockIdx.x] = shMem[0];
|
||||||
|
reducBuff[2*blockIdx.x + 1] = shMem[1];
|
||||||
|
|
||||||
|
__threadfence();
|
||||||
|
|
||||||
|
amIinLastBlock = gridDim.x == 1 || (atomicInc(count, gridDim.x) == gridDim.x - 1);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// shared memory of last block is used for final summation of values stored in reduction buffer
|
||||||
|
if (amIinLastBlock) {
|
||||||
|
|
||||||
|
for (int i = threadIdx.x; i < gridDim.x; i += blockDim.x) {
|
||||||
|
|
||||||
|
shMem[2*threadIdx.x] = (i == threadIdx.x ) ? reducBuff[2*i] : reducBuff[2*i] + shMem[2*threadIdx.x];
|
||||||
|
shMem[2*threadIdx.x + 1] = (i == threadIdx.x ) ? reducBuff[2*i + 1] : reducBuff[2*i + 1] + shMem[2*threadIdx.x + 1];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// accumulate sum
|
||||||
|
for (int activeThreads = blockDim.x / 2; activeThreads > 0; activeThreads /= 2) {
|
||||||
|
|
||||||
|
if (threadIdx.x < activeThreads && threadIdx.x + activeThreads < gridDim.x) {
|
||||||
|
shMem[2*threadIdx.x] += shMem[2*(threadIdx.x + activeThreads)];
|
||||||
|
shMem[2*threadIdx.x + 1] += shMem[2*(threadIdx.x + activeThreads) + 1];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
|
||||||
|
reducBuff[0] = math::nd4j_sqrt<Z,Z>(shMem[0]);
|
||||||
|
reducBuff[1] = shMem[1];
|
||||||
|
count = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
// x - input, y - gradO, z - gradI
|
||||||
|
template<typename X, typename Z>
|
||||||
|
__global__ static void clipByNormBPCalcGradCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, void* vreducBuff, const Z clipNormVal) {
|
||||||
|
|
||||||
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
|
const Nd4jLong len = shape::length(zShapeInfo); // xLen = yLen = zLen
|
||||||
|
|
||||||
|
if(tid >= len)
|
||||||
|
return;
|
||||||
|
|
||||||
|
const auto x = reinterpret_cast<const X*>(vx);
|
||||||
|
const auto y = reinterpret_cast<const Z*>(vy);
|
||||||
|
auto z = reinterpret_cast<Z*>(vz);
|
||||||
|
|
||||||
|
__shared__ Z norm, sumOfProd;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
|
||||||
|
norm = reinterpret_cast<Z*>(vreducBuff)[0];
|
||||||
|
sumOfProd = reinterpret_cast<Z*>(vreducBuff)[1];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
const auto yOffset = shape::getIndexOffset(tid, yShapeInfo, len);
|
||||||
|
const auto zOffset = shape::getIndexOffset(tid, zShapeInfo, len);
|
||||||
|
|
||||||
|
if(norm > clipNormVal) {
|
||||||
|
|
||||||
|
const auto xOffset = shape::getIndexOffset(tid, xShapeInfo, len);
|
||||||
|
|
||||||
|
const Z factor1 = static_cast<Z>(1) / norm; // 1 / norm
|
||||||
|
const Z factor2 = factor1 / (norm * norm); // 1 / (norm * norm * norm)
|
||||||
|
|
||||||
|
z[zOffset] = clipNormVal * (factor1 * y[yOffset] - factor2 * sumOfProd * x[xOffset]);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
z[zOffset] = y[yOffset];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
// x - input, y - gradO, z - gradI
|
||||||
|
template<typename X, typename Z>
|
||||||
|
__global__ static void clipByNormBPTadsCuda(const void* vx, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const void* vy, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, void* vz, const Nd4jLong* zTadShapeInfo, const Nd4jLong* zTadOffsets, const Z clipNormVal) {
|
||||||
|
|
||||||
|
const auto x = reinterpret_cast<const X*>(vx);
|
||||||
|
const auto y = reinterpret_cast<const Z*>(vy);
|
||||||
|
auto z = reinterpret_cast<Z*>(vz);
|
||||||
|
|
||||||
|
__shared__ Z* shMem;
|
||||||
|
__shared__ Nd4jLong tadLen;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
|
||||||
|
extern __shared__ unsigned char shmem[];
|
||||||
|
shMem = reinterpret_cast<Z*>(shmem);
|
||||||
|
tadLen = shape::length(zTadShapeInfo); // xTadLen = yTadLen = zTadLen
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
const auto* xTad = x + xTadOffsets[blockIdx.x];
|
||||||
|
const auto* yTad = y + yTadOffsets[blockIdx.x];
|
||||||
|
auto* zTad = z + zTadOffsets[blockIdx.x];
|
||||||
|
|
||||||
|
// *** FIRST STAGE - ACCUMULATE REQUIRED SUMS *** //
|
||||||
|
|
||||||
|
Z norm = 0;
|
||||||
|
Z sumOfProd = 0;
|
||||||
|
|
||||||
|
for (uint i = threadIdx.x; i < tadLen; i += blockDim.x) {
|
||||||
|
|
||||||
|
const auto xOffset = shape::getIndexOffset(i, xTadShapeInfo, tadLen);
|
||||||
|
const auto yOffset = shape::getIndexOffset(i, yTadShapeInfo, tadLen);
|
||||||
|
|
||||||
|
shMem[2*threadIdx.x] = static_cast<Z>(xTad[xOffset] * xTad[xOffset]); // for norm
|
||||||
|
shMem[2*threadIdx.x + 1] = static_cast<Z>(xTad[xOffset] * yTad[yOffset]); // for input * gradO
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// accumulate sum per block
|
||||||
|
for (uint activeThreads = blockDim.x / 2; activeThreads > 0; activeThreads /= 2) {
|
||||||
|
|
||||||
|
if (threadIdx.x < activeThreads && i + activeThreads < tadLen) {
|
||||||
|
|
||||||
|
shMem[2*threadIdx.x] += shMem[2*(threadIdx.x + activeThreads)];
|
||||||
|
shMem[2*threadIdx.x + 1] += shMem[2*(threadIdx.x + activeThreads) + 1];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
norm += shMem[0];
|
||||||
|
sumOfProd += shMem[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
// *** SECOND STAGE - GRADIENT CALCULATION *** //
|
||||||
|
|
||||||
|
norm = math::nd4j_sqrt<Z,Z>(norm);
|
||||||
|
|
||||||
|
for (uint i = threadIdx.x; i < tadLen; i += blockDim.x) {
|
||||||
|
|
||||||
|
const auto yOffset = shape::getIndexOffset(i, yTadShapeInfo, tadLen);
|
||||||
|
const auto zOffset = shape::getIndexOffset(i, zTadShapeInfo, tadLen);
|
||||||
|
|
||||||
|
if(norm > clipNormVal) {
|
||||||
|
|
||||||
|
const auto xOffset = shape::getIndexOffset(i, xTadShapeInfo, tadLen);
|
||||||
|
|
||||||
|
const Z factor1 = static_cast<Z>(1) / norm; // 1 / norm
|
||||||
|
const Z factor2 = factor1 / (norm * norm); // 1 / (norm * norm * norm)
|
||||||
|
|
||||||
|
zTad[zOffset] = clipNormVal * (factor1 * yTad[yOffset] - factor2 * sumOfProd * xTad[xOffset]);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
zTad[zOffset] = yTad[yOffset];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template<typename X, typename Z>
|
||||||
|
static void clipByNormBPCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream,
|
||||||
|
const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets,
|
||||||
|
const void* vy, const Nd4jLong* yShapeInfo, const Nd4jLong* yTadOffsets,
|
||||||
|
void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets,
|
||||||
|
void* vreducBuff, const double clipNormVal) {
|
||||||
|
|
||||||
|
if(xTadOffsets == nullptr) { // means whole array
|
||||||
|
clipByNormBPWholeArrCuda<X,Z><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vreducBuff, static_cast<Z>(clipNormVal));
|
||||||
|
clipByNormBPCalcGradCuda<X,Z><<<blocksPerGrid, threadsPerBlock, 256, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vreducBuff, static_cast<Z>(clipNormVal));
|
||||||
|
}
|
||||||
|
else // means tads using
|
||||||
|
clipByNormBPTadsCuda<X,Z><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vx, xShapeInfo, xTadOffsets, vy, yShapeInfo, yTadOffsets, vz, zShapeInfo, zTadOffsets, static_cast<Z>(clipNormVal));
|
||||||
|
}
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template void clipByNormBPCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong* xTadOffsets, const void *vy, const Nd4jLong *yShapeInfo, const Nd4jLong* yTadOffsets, void *vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets, void* vreducBuff, const double clipNormVal), LIBND4J_TYPES, FLOAT_TYPES);
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
void clipByNormBP(nd4j::LaunchContext* context, const NDArray& input, const NDArray& gradO, NDArray& gradI /*output*/, const std::vector<int>& dimensions, const NDArray& clipNorm) {
|
||||||
|
|
||||||
|
PointersManager manager(context, "clipByNormBP");
|
||||||
|
|
||||||
|
const double clipNormVal = clipNorm.e<double>(0);
|
||||||
|
|
||||||
|
const auto xType = input.dataType();
|
||||||
|
const auto zType = gradI.dataType();
|
||||||
|
|
||||||
|
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||||
|
const int sharedMem = threadsPerBlock * 2 * input.sizeOfT() + 128;
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&gradI}, {&input, &gradO});
|
||||||
|
|
||||||
|
|
||||||
|
if(dimensions.empty() || dimensions.size() == input.rankOf()) { // means whole array
|
||||||
|
|
||||||
|
const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
BUILD_DOUBLE_SELECTOR(xType, zType, clipByNormBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), nullptr, gradO.getSpecialBuffer(), gradO.getSpecialShapeInfo(), nullptr, gradI.getSpecialBuffer(), gradI.getSpecialShapeInfo(), nullptr, context->getReductionPointer(), clipNormVal), LIBND4J_TYPES, FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
else { // means tads using
|
||||||
|
|
||||||
|
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), dimensions);
|
||||||
|
auto packY = ConstantTadHelper::getInstance()->tadForDimensions(gradO.getShapeInfo(), dimensions);
|
||||||
|
auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(gradI.getShapeInfo(), dimensions);
|
||||||
|
|
||||||
|
const int blocksPerGrid = packX.numberOfTads();
|
||||||
|
BUILD_DOUBLE_SELECTOR(xType, zType, clipByNormBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.getSpecialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), gradO.getSpecialBuffer(), packY.platformShapeInfo(), packY.platformOffsets(), gradI.getSpecialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), nullptr, clipNormVal), LIBND4J_TYPES, FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({&gradI}, {&input, &gradO});
|
||||||
|
|
||||||
|
manager.synchronize();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -697,18 +1089,6 @@ void scatterUpdate(nd4j::LaunchContext* context, NDArray& input, NDArray& update
|
|||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void randomShuffle_, (nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::graph::RandomGenerator& rng, const bool isInplace), LIBND4J_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void randomShuffle_, (nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::graph::RandomGenerator& rng, const bool isInplace), LIBND4J_TYPES);
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
|
||||||
template<typename T>
|
|
||||||
static void gatherND_(nd4j::LaunchContext * context, NDArray& input, NDArray& indices, NDArray& output) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
void gatherND(nd4j::LaunchContext * context, NDArray& input, NDArray& indices, NDArray& output) {
|
|
||||||
BUILD_SINGLE_SELECTOR(input.dataType(), gatherND_, (context, input, indices, output), LIBND4J_TYPES);
|
|
||||||
}
|
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void gatherND_, (nd4j::LaunchContext * context, NDArray& input, NDArray& indices, NDArray& output), LIBND4J_TYPES);
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
@ -1037,18 +1417,6 @@ void eye(nd4j::LaunchContext * context, NDArray& output) {
|
|||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void clipByGlobalNorm_, (nd4j::LaunchContext * context, std::vector<NDArray*> const& inputs, double clipNorm, nd4j::memory::Workspace* workspace, std::vector<NDArray*>& outputs, bool isInplace), FLOAT_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void clipByGlobalNorm_, (nd4j::LaunchContext * context, std::vector<NDArray*> const& inputs, double clipNorm, nd4j::memory::Workspace* workspace, std::vector<NDArray*>& outputs, bool isInplace), FLOAT_TYPES);
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
|
||||||
template<typename T>
|
|
||||||
static void clipByNormBP_(nd4j::LaunchContext * context, const NDArray& input, const NDArray& gradO, NDArray& gradI /*output*/, const std::vector<int>& dimensions, const NDArray& clipNorm) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
void clipByNormBP(nd4j::LaunchContext * context, const NDArray& input, const NDArray& gradO, NDArray& gradI /*output*/, const std::vector<int>& dimensions, const NDArray& clipNorm) {
|
|
||||||
BUILD_SINGLE_SELECTOR(gradI.dataType(), clipByNormBP_, (context, input, gradO, gradI, dimensions, clipNorm), FLOAT_TYPES);
|
|
||||||
}
|
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void clipByNormBP_, (nd4j::LaunchContext * context, const NDArray& input, const NDArray& gradO, NDArray& gradI /*output*/, const std::vector<int>& dimensions, const NDArray& clipNorm), FLOAT_TYPES);
|
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
template<typename T>
|
template<typename T>
|
||||||
@ -1374,8 +1742,7 @@ void concat(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs,
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void concatCudaLauncher, (const int numOfArrs, const cudaStream_t *stream, void* pVx, void* pxShapeInfo, void* pVz, void* pzShapeInfo), LIBND4J_TYPES);
|
|
||||||
BUILD_DOUBLE_TEMPLATE(template void padCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const int mode, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, const void* vPadVal), LIBND4J_TYPES, INTEGER_TYPES);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
37
libnd4j/include/ops/declarable/helpers/gradient.h
Normal file
37
libnd4j/include/ops/declarable/helpers/gradient.h
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author sgazeos@gmail.com
|
||||||
|
//
|
||||||
|
#ifndef __GRADIENT_H_HELPERS__
|
||||||
|
#define __GRADIENT_H_HELPERS__
|
||||||
|
#include <op_boilerplate.h>
|
||||||
|
#include <NDArray.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
/*
|
||||||
|
* applyGradientDescent: calculate z = x - y * w.
|
||||||
|
* */
|
||||||
|
void applyGradientDescent(nd4j::LaunchContext* context, NDArray* input, NDArray* step, double weight, NDArray* output);
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
@ -45,29 +45,29 @@ namespace nd4j {
|
|||||||
const NDArray* iSeq, const NDArray* cSeq, const NDArray* fSeq, const NDArray* oSeq, const NDArray* zSeq,
|
const NDArray* iSeq, const NDArray* cSeq, const NDArray* fSeq, const NDArray* oSeq, const NDArray* zSeq,
|
||||||
const NDArray* hSeq, const NDArray* ySeq, const std::vector<double>& params, const int dataFormat){
|
const NDArray* hSeq, const NDArray* ySeq, const std::vector<double>& params, const int dataFormat){
|
||||||
|
|
||||||
int seqLen, mb, inSize, outSize;
|
int seqLen, bS, nIn, nOut;
|
||||||
|
|
||||||
if(dataFormat == 0) {
|
if(dataFormat == 0) {
|
||||||
seqLen = xSeq->sizeAt(0);
|
seqLen = xSeq->sizeAt(0);
|
||||||
mb = xSeq->sizeAt(1);
|
bS = xSeq->sizeAt(1);
|
||||||
inSize = xSeq->sizeAt(2);
|
nIn = xSeq->sizeAt(2);
|
||||||
outSize = iSeq->sizeAt(2);
|
nOut = iSeq->sizeAt(2);
|
||||||
}
|
}
|
||||||
else if(dataFormat == 1) {
|
else if(dataFormat == 1) {
|
||||||
seqLen = xSeq->sizeAt(2);
|
seqLen = xSeq->sizeAt(2);
|
||||||
mb = xSeq->sizeAt(0);
|
bS = xSeq->sizeAt(0);
|
||||||
inSize = xSeq->sizeAt(1);
|
nIn = xSeq->sizeAt(1);
|
||||||
outSize = iSeq->sizeAt(1);
|
nOut = iSeq->sizeAt(1);
|
||||||
}
|
}
|
||||||
else if(dataFormat == 2) {
|
else if(dataFormat == 2) {
|
||||||
seqLen = xSeq->sizeAt(1);
|
seqLen = xSeq->sizeAt(1);
|
||||||
mb = xSeq->sizeAt(0);
|
bS = xSeq->sizeAt(0);
|
||||||
inSize = xSeq->sizeAt(2);
|
nIn = xSeq->sizeAt(2);
|
||||||
outSize = iSeq->sizeAt(2);
|
nOut = iSeq->sizeAt(2);
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<Nd4jLong> inSliceShape({mb,inSize});
|
const std::vector<Nd4jLong> inSliceShape({bS,nIn});
|
||||||
const std::vector<Nd4jLong> outSliceShape({mb,outSize});
|
const std::vector<Nd4jLong> outSliceShape({bS,nOut});
|
||||||
|
|
||||||
auto c_t1 = const_cast<NDArray*>(c0);
|
auto c_t1 = const_cast<NDArray*>(c0);
|
||||||
auto y_t1 = const_cast<NDArray*>(y0);
|
auto y_t1 = const_cast<NDArray*>(y0);
|
||||||
@ -105,11 +105,11 @@ namespace nd4j {
|
|||||||
void lstmTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* c0, const NDArray* Wx, const NDArray* Wh, const NDArray* Wc, const NDArray* Wp, const NDArray* b,
|
void lstmTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* c0, const NDArray* Wx, const NDArray* Wh, const NDArray* Wc, const NDArray* Wp, const NDArray* b,
|
||||||
NDArray* h, NDArray* c, const std::vector<double>& params) {
|
NDArray* h, NDArray* c, const std::vector<double>& params) {
|
||||||
|
|
||||||
// x input [time x bS x inSize]
|
// x input [time x bS x nIn]
|
||||||
// h0 initial cell output (at time step = 0) [bS x numProj], in case of projection=false -> numProj == numUnits !!!
|
// h0 initial cell output (at time step = 0) [bS x numProj], in case of projection=false -> numProj == numUnits !!!
|
||||||
// c0 initial cell state (at time step = 0) [bS x numUnits],
|
// c0 initial cell state (at time step = 0) [bS x numUnits],
|
||||||
|
|
||||||
// Wx input-to-hidden weights, [inSize x 4*numUnits]
|
// Wx input-to-hidden weights, [nIn x 4*numUnits]
|
||||||
// Wh hidden-to-hidden weights, [numProj x 4*numUnits]
|
// Wh hidden-to-hidden weights, [numProj x 4*numUnits]
|
||||||
// Wc diagonal weights for peephole connections [3*numUnits]
|
// Wc diagonal weights for peephole connections [3*numUnits]
|
||||||
// Wp projection weights [numUnits x numProj]
|
// Wp projection weights [numUnits x numProj]
|
||||||
|
@ -47,30 +47,17 @@ namespace helpers {
|
|||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
static NDArray timeSubset(const NDArray* arr, const int t, const int dataFormat){
|
static NDArray timeSubset(const NDArray* arr, const int t, const int dataFormat){
|
||||||
if(dataFormat == 0){
|
|
||||||
//TNS: shape [timeLength, numExamples, inOutSize]
|
|
||||||
auto x = (*arr)({t,t+1, 0,0, 0,0});
|
|
||||||
const std::vector<Nd4jLong> newShape({arr->sizeAt(1),arr->sizeAt(2)});
|
|
||||||
return x.reshape(arr->ordering(), newShape);
|
|
||||||
} else if(dataFormat == 1){
|
|
||||||
//NST: shape [numExamples, inOutSize, timeLength]
|
|
||||||
auto x = (*arr)({0,0, 0,0, t,t+1});
|
|
||||||
const std::vector<Nd4jLong> newShape({arr->sizeAt(0),arr->sizeAt(1)});
|
|
||||||
return x.reshape(arr->ordering(), newShape);
|
|
||||||
} else {
|
|
||||||
//NTS: shape [numExamples, timeLength, inOutSize] - TF "time_major=false" layout
|
|
||||||
auto x = (*arr)({0,0, t,t+1, 0,0});
|
|
||||||
const std::vector<Nd4jLong> newShape({arr->sizeAt(0),arr->sizeAt(2)});
|
|
||||||
return x.reshape(arr->ordering(), newShape);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
if(dataFormat == 0) { // TNS: shape [timeLength, numExamples, inOutSize]
|
||||||
template <typename T>
|
return (*arr)({t,t+1, 0,0, 0,0});
|
||||||
static FORCEINLINE void clipping(NDArray* arr, T limit) {
|
}
|
||||||
arr->applyScalar(scalar::LstmClip, limit);
|
else if(dataFormat == 1) { //NST: shape [numExamples, inOutSize, timeLength]
|
||||||
|
return (*arr)({0,0, 0,0, t,t+1});
|
||||||
|
}
|
||||||
|
else { //NTS: shape [numExamples, timeLength, inOutSize] - TF "time_major=false" layout
|
||||||
|
return (*arr)({0,0, t,t+1, 0,0});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void lstmCell(nd4j::LaunchContext * context, const NDArray* xt, const NDArray* ht_1, const NDArray* ct_1, const NDArray* Wx, const NDArray* Wh, const NDArray* Wc, const NDArray* Wp, const NDArray* b,
|
void lstmCell(nd4j::LaunchContext * context, const NDArray* xt, const NDArray* ht_1, const NDArray* ct_1, const NDArray* Wx, const NDArray* Wh, const NDArray* Wc, const NDArray* Wp, const NDArray* b,
|
||||||
NDArray* ht, NDArray* ct, const std::vector<double>& params);
|
NDArray* ht, NDArray* ct, const std::vector<double>& params);
|
||||||
|
@ -26,6 +26,12 @@
|
|||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
|
|
||||||
|
void batchToSpace(nd4j::LaunchContext* context, const NDArray& input, NDArray& output, const uint cropBottom, const uint cropTop, const uint cropLeft, const uint cropRight, const uint blockSize);
|
||||||
|
|
||||||
|
void spaceToBatch(nd4j::LaunchContext* context, const NDArray& input, NDArray& output, const uint padBottom, const uint padTop, const uint padLeft, const uint padRight, const uint blockSize);
|
||||||
|
|
||||||
|
/*
|
||||||
// this method MUST be platform-specific
|
// this method MUST be platform-specific
|
||||||
|
|
||||||
template <typename T, int NUM_BLOCK_DIMS, bool B2S>
|
template <typename T, int NUM_BLOCK_DIMS, bool B2S>
|
||||||
@ -76,6 +82,7 @@ namespace helpers {
|
|||||||
Nd4jStatus _spaceToBatch(nd4j::LaunchContext * context, int internal_block_dims, NDArray *input, NDArray *output, std::vector<Nd4jLong> &internal_input_shape, std::vector<Nd4jLong> &internal_output_shape, Nd4jLong *block_shape, Nd4jLong *paddings);
|
Nd4jStatus _spaceToBatch(nd4j::LaunchContext * context, int internal_block_dims, NDArray *input, NDArray *output, std::vector<Nd4jLong> &internal_input_shape, std::vector<Nd4jLong> &internal_output_shape, Nd4jLong *block_shape, Nd4jLong *paddings);
|
||||||
|
|
||||||
Nd4jStatus _batchToSpace(nd4j::LaunchContext * context, int internal_block_dims, NDArray *input, NDArray *output, std::vector<Nd4jLong> &internal_input_shape, std::vector<Nd4jLong> &internal_output_shape, Nd4jLong *block_shape, Nd4jLong *crops);
|
Nd4jStatus _batchToSpace(nd4j::LaunchContext * context, int internal_block_dims, NDArray *input, NDArray *output, std::vector<Nd4jLong> &internal_input_shape, std::vector<Nd4jLong> &internal_output_shape, Nd4jLong *block_shape, Nd4jLong *crops);
|
||||||
|
*/
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -28,7 +28,7 @@ namespace nd4j {
|
|||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
|
|
||||||
void stack(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray* outArr, const int dim);
|
void stack(nd4j::LaunchContext * context, const std::vector<const NDArray*>& inArrs, NDArray* outArr, const int dim);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -114,6 +114,9 @@ namespace nd4j {
|
|||||||
auto var = block->isFastPath() ? block->fastpath_out()[0] : block->variable(p)->getNDArray();
|
auto var = block->isFastPath() ? block->fastpath_out()[0] : block->variable(p)->getNDArray();
|
||||||
var->p(Nd4jLong(0), status == ND4J_STATUS_TRUE ? 1.0f : 0.0f);
|
var->p(Nd4jLong(0), status == ND4J_STATUS_TRUE ? 1.0f : 0.0f);
|
||||||
|
|
||||||
|
// for CPU backend that's nop, but for CUDA-like archs this will update special buffer
|
||||||
|
var->syncToDevice();
|
||||||
|
|
||||||
if (status == ND4J_STATUS_FALSE || status == ND4J_STATUS_TRUE)
|
if (status == ND4J_STATUS_FALSE || status == ND4J_STATUS_TRUE)
|
||||||
return ND4J_STATUS_OK;
|
return ND4J_STATUS_OK;
|
||||||
|
|
||||||
|
@ -777,7 +777,7 @@ namespace nd4j {
|
|||||||
block.setDataType(0, type);
|
block.setDataType(0, type);
|
||||||
block.fillInputs(in);
|
block.fillInputs(in);
|
||||||
block.markInplace(isInplace);
|
block.markInplace(isInplace);
|
||||||
block.setRNG(ProviderRNG::getInstance().getRNG());
|
// block.setRNG(ProviderRNG::getInstance().getRNG());
|
||||||
|
|
||||||
for (int e = 0; e < tArgs.size(); e++)
|
for (int e = 0; e < tArgs.size(); e++)
|
||||||
block.getTArguments()->emplace_back(tArgs.at(e));
|
block.getTArguments()->emplace_back(tArgs.at(e));
|
||||||
|
@ -33,7 +33,6 @@ public:
|
|||||||
TEST_F(ArrayOptionsTests, TestShape_Basic_0) {
|
TEST_F(ArrayOptionsTests, TestShape_Basic_0) {
|
||||||
shape[5] = 1;
|
shape[5] = 1;
|
||||||
|
|
||||||
|
|
||||||
ASSERT_TRUE(ArrayOptions::isNewFormat(shape));
|
ASSERT_TRUE(ArrayOptions::isNewFormat(shape));
|
||||||
ASSERT_FALSE(ArrayOptions::isSparseArray(shape));
|
ASSERT_FALSE(ArrayOptions::isSparseArray(shape));
|
||||||
}
|
}
|
||||||
|
@ -289,8 +289,8 @@ TEST_F(ContextTests, test_short_context_1) {
|
|||||||
auto array1 = NDArrayFactory::create<float>('c', {3, 2}, {-1.f, -2.f, -3.f, -4.f, -5.f, -6.f});
|
auto array1 = NDArrayFactory::create<float>('c', {3, 2}, {-1.f, -2.f, -3.f, -4.f, -5.f, -6.f});
|
||||||
Context ctx(1);
|
Context ctx(1);
|
||||||
|
|
||||||
ctx.setInputArray(0, array0.buffer(), array0.shapeInfo(), nullptr, nullptr);
|
ctx.setInputArray(0, array0.buffer(), array0.shapeInfo(), array0.specialBuffer(), array0.specialShapeInfo());
|
||||||
ctx.setInputArray(1, array1.buffer(), array1.shapeInfo(), nullptr, nullptr);
|
ctx.setInputArray(1, array1.buffer(), array1.shapeInfo(), array1.specialBuffer(), array1.specialShapeInfo());
|
||||||
|
|
||||||
ASSERT_EQ(2, ctx.width());
|
ASSERT_EQ(2, ctx.width());
|
||||||
|
|
||||||
@ -303,8 +303,14 @@ TEST_F(ContextTests, test_short_context_1) {
|
|||||||
ASSERT_TRUE(input0->buffer() == array0.buffer());
|
ASSERT_TRUE(input0->buffer() == array0.buffer());
|
||||||
ASSERT_TRUE(input0->shapeInfo() == array0.shapeInfo());
|
ASSERT_TRUE(input0->shapeInfo() == array0.shapeInfo());
|
||||||
|
|
||||||
|
ASSERT_TRUE(input0->specialBuffer() == array0.specialBuffer());
|
||||||
|
ASSERT_TRUE(input0->specialShapeInfo() == array0.specialShapeInfo());
|
||||||
|
|
||||||
ASSERT_TRUE(input1->buffer() == array1.buffer());
|
ASSERT_TRUE(input1->buffer() == array1.buffer());
|
||||||
ASSERT_TRUE(input1->shapeInfo() == array1.shapeInfo());
|
ASSERT_TRUE(input1->shapeInfo() == array1.shapeInfo());
|
||||||
|
|
||||||
|
ASSERT_TRUE(input1->specialBuffer() == array1.specialBuffer());
|
||||||
|
ASSERT_TRUE(input1->specialShapeInfo() == array1.specialShapeInfo());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ContextTests, test_short_context_2) {
|
TEST_F(ContextTests, test_short_context_2) {
|
||||||
@ -315,9 +321,9 @@ TEST_F(ContextTests, test_short_context_2) {
|
|||||||
auto exp = NDArrayFactory::create<float>('c', {3, 2}, {2.f, 4.f, 6.f, 8.f, 10.f, 12.f});
|
auto exp = NDArrayFactory::create<float>('c', {3, 2}, {2.f, 4.f, 6.f, 8.f, 10.f, 12.f});
|
||||||
Context ctx(1);
|
Context ctx(1);
|
||||||
|
|
||||||
ctx.setInputArray(0, array0.buffer(), array0.shapeInfo(), nullptr, nullptr);
|
ctx.setInputArray(0, array0.buffer(), array0.shapeInfo(), array0.specialBuffer(), array0.specialShapeInfo());
|
||||||
ctx.setInputArray(1, array1.buffer(), array1.shapeInfo(), nullptr, nullptr);
|
ctx.setInputArray(1, array1.buffer(), array1.shapeInfo(), array1.specialBuffer(), array1.specialShapeInfo());
|
||||||
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), nullptr, nullptr);
|
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
||||||
|
|
||||||
ASSERT_EQ(2, ctx.width());
|
ASSERT_EQ(2, ctx.width());
|
||||||
|
|
||||||
@ -334,8 +340,8 @@ TEST_F(ContextTests, test_short_context_3) {
|
|||||||
auto exp = NDArrayFactory::create<float>('c', {3, 2}, {2.f, 4.f, 6.f, 8.f, 10.f, 12.f});
|
auto exp = NDArrayFactory::create<float>('c', {3, 2}, {2.f, 4.f, 6.f, 8.f, 10.f, 12.f});
|
||||||
Context ctx(1);
|
Context ctx(1);
|
||||||
|
|
||||||
ctx.setInputArray(0, array0.buffer(), array0.shapeInfo(), nullptr, nullptr);
|
ctx.setInputArray(0, array0.buffer(), array0.shapeInfo(), array0.specialBuffer(), array0.specialShapeInfo());
|
||||||
ctx.setInputArray(1, array1.buffer(), array1.shapeInfo(), nullptr, nullptr);
|
ctx.setInputArray(1, array1.buffer(), array1.shapeInfo(), array1.specialBuffer(), array1.specialShapeInfo());
|
||||||
|
|
||||||
ASSERT_EQ(2, ctx.width());
|
ASSERT_EQ(2, ctx.width());
|
||||||
|
|
||||||
|
@ -87,6 +87,9 @@ TEST_F(CudaBasicsTests1, TestPairwise_1) {
|
|||||||
cudaError_t dZ = cudaStreamCreate(reinterpret_cast<cudaStream_t *>(&nativeStream));
|
cudaError_t dZ = cudaStreamCreate(reinterpret_cast<cudaStream_t *>(&nativeStream));
|
||||||
auto stream = reinterpret_cast<cudaStream_t *>(&nativeStream);
|
auto stream = reinterpret_cast<cudaStream_t *>(&nativeStream);
|
||||||
|
|
||||||
|
x.dataBuffer()->allocatePrimary();
|
||||||
|
x.syncToHost();
|
||||||
|
|
||||||
cudaMemcpyAsync(devBufferPtrX, x.buffer(), x.lengthOf() * x.sizeOfT(), cudaMemcpyHostToDevice, *stream);
|
cudaMemcpyAsync(devBufferPtrX, x.buffer(), x.lengthOf() * x.sizeOfT(), cudaMemcpyHostToDevice, *stream);
|
||||||
cudaMemcpyAsync(devShapePtrX, x.shapeInfo(), shape::shapeInfoByteLength(x.shapeInfo()), cudaMemcpyHostToDevice, *stream);
|
cudaMemcpyAsync(devShapePtrX, x.shapeInfo(), shape::shapeInfoByteLength(x.shapeInfo()), cudaMemcpyHostToDevice, *stream);
|
||||||
|
|
||||||
@ -95,6 +98,8 @@ TEST_F(CudaBasicsTests1, TestPairwise_1) {
|
|||||||
res = cudaStreamSynchronize(*stream);
|
res = cudaStreamSynchronize(*stream);
|
||||||
ASSERT_EQ(0, res);
|
ASSERT_EQ(0, res);
|
||||||
|
|
||||||
|
z.dataBuffer()->allocatePrimary();
|
||||||
|
|
||||||
cudaMemcpyAsync(z.buffer(), devBufferPtrZ, z.lengthOf() * x.sizeOfT(), cudaMemcpyDeviceToHost, *stream);
|
cudaMemcpyAsync(z.buffer(), devBufferPtrZ, z.lengthOf() * x.sizeOfT(), cudaMemcpyDeviceToHost, *stream);
|
||||||
res = cudaStreamSynchronize(*stream);
|
res = cudaStreamSynchronize(*stream);
|
||||||
ASSERT_EQ(0, res);
|
ASSERT_EQ(0, res);
|
||||||
@ -103,6 +108,9 @@ TEST_F(CudaBasicsTests1, TestPairwise_1) {
|
|||||||
cudaFree(devBufferPtrZ);
|
cudaFree(devBufferPtrZ);
|
||||||
cudaFree(devShapePtrX);
|
cudaFree(devShapePtrX);
|
||||||
|
|
||||||
|
// needed due to memcpy
|
||||||
|
z.tickWriteHost();
|
||||||
|
|
||||||
for (int e = 0; e < z.lengthOf(); e++) {
|
for (int e = 0; e < z.lengthOf(); e++) {
|
||||||
ASSERT_NEAR(exp.e<double>(e), z.e<double>(e), 1e-5);
|
ASSERT_NEAR(exp.e<double>(e), z.e<double>(e), 1e-5);
|
||||||
}
|
}
|
||||||
@ -116,11 +124,11 @@ TEST_F(CudaBasicsTests1, execIndexReduceScalar_1) {
|
|||||||
NDArray x2('c', {2,2}, {0.5, 1.5, -4.5, 3.5}, nd4j::DataType::BFLOAT16);
|
NDArray x2('c', {2,2}, {0.5, 1.5, -4.5, 3.5}, nd4j::DataType::BFLOAT16);
|
||||||
NDArray x3('c', {2,2}, {0, -1, 0, 1}, nd4j::DataType::BOOL);
|
NDArray x3('c', {2,2}, {0, -1, 0, 1}, nd4j::DataType::BOOL);
|
||||||
|
|
||||||
NDArray scalar('c', {0}, {0}, nd4j::DataType::INT64);
|
NDArray scalar('c', {}, {0}, nd4j::DataType::INT64);
|
||||||
|
|
||||||
NDArray exp1('c', {0}, {3}, nd4j::DataType::INT64);
|
NDArray exp1('c', {}, {3}, nd4j::DataType::INT64);
|
||||||
NDArray exp2('c', {0}, {2}, nd4j::DataType::INT64);
|
NDArray exp2('c', {}, {2}, nd4j::DataType::INT64);
|
||||||
NDArray exp3('c', {0}, {1}, nd4j::DataType::INT64);
|
NDArray exp3('c', {}, {1}, nd4j::DataType::INT64);
|
||||||
|
|
||||||
void *dX1, *dX2, *dX3, *dZ;
|
void *dX1, *dX2, *dX3, *dZ;
|
||||||
Nd4jLong *dX1ShapeInfo, *dX2ShapeInfo, *dX3ShapeInfo, *dZShapeInfo;
|
Nd4jLong *dX1ShapeInfo, *dX2ShapeInfo, *dX3ShapeInfo, *dZShapeInfo;
|
||||||
@ -140,6 +148,11 @@ TEST_F(CudaBasicsTests1, execIndexReduceScalar_1) {
|
|||||||
cudaResult = cudaStreamCreate(&stream);
|
cudaResult = cudaStreamCreate(&stream);
|
||||||
ASSERT_EQ(0, cudaResult);
|
ASSERT_EQ(0, cudaResult);
|
||||||
|
|
||||||
|
x1.syncToHost();
|
||||||
|
x2.syncToHost();
|
||||||
|
x3.syncToHost();
|
||||||
|
scalar.syncToHost();
|
||||||
|
|
||||||
cudaMemcpyAsync(dX1, x1.buffer(), x1.lengthOf() * x1.sizeOfT(), cudaMemcpyHostToDevice, stream);
|
cudaMemcpyAsync(dX1, x1.buffer(), x1.lengthOf() * x1.sizeOfT(), cudaMemcpyHostToDevice, stream);
|
||||||
cudaMemcpyAsync(dX2, x2.buffer(), x2.lengthOf() * x2.sizeOfT(), cudaMemcpyHostToDevice, stream);
|
cudaMemcpyAsync(dX2, x2.buffer(), x2.lengthOf() * x2.sizeOfT(), cudaMemcpyHostToDevice, stream);
|
||||||
cudaMemcpyAsync(dX3, x3.buffer(), x3.lengthOf() * x3.sizeOfT(), cudaMemcpyHostToDevice, stream);
|
cudaMemcpyAsync(dX3, x3.buffer(), x3.lengthOf() * x3.sizeOfT(), cudaMemcpyHostToDevice, stream);
|
||||||
@ -152,7 +165,7 @@ TEST_F(CudaBasicsTests1, execIndexReduceScalar_1) {
|
|||||||
cudaResult = cudaMalloc(reinterpret_cast<void **>(&reductionPointer), 1024*1024);
|
cudaResult = cudaMalloc(reinterpret_cast<void **>(&reductionPointer), 1024*1024);
|
||||||
ASSERT_EQ(0, cudaResult);
|
ASSERT_EQ(0, cudaResult);
|
||||||
|
|
||||||
LaunchContext lc(&stream, reductionPointer);
|
LaunchContext lc(&stream, LaunchContext::defaultContext()->getReductionPointer(), LaunchContext::defaultContext()->getScalarPointer(), LaunchContext::defaultContext()->getAllocationPointer());
|
||||||
|
|
||||||
/***************************************/
|
/***************************************/
|
||||||
|
|
||||||
@ -172,6 +185,8 @@ TEST_F(CudaBasicsTests1, execIndexReduceScalar_1) {
|
|||||||
cudaResult = cudaStreamSynchronize(stream);
|
cudaResult = cudaStreamSynchronize(stream);
|
||||||
ASSERT_EQ(0, cudaResult);
|
ASSERT_EQ(0, cudaResult);
|
||||||
|
|
||||||
|
scalar.tickWriteHost();
|
||||||
|
|
||||||
ASSERT_NEAR(exp1.e<float>(0), scalar.e<float>(0), 1e-5);
|
ASSERT_NEAR(exp1.e<float>(0), scalar.e<float>(0), 1e-5);
|
||||||
|
|
||||||
/***************************************/
|
/***************************************/
|
||||||
@ -236,11 +251,12 @@ TEST_F(CudaBasicsTests1, execReduce3Scalar_1) {
|
|||||||
NDArray x2('c', {2,2}, {-1,-2,-3,-4}, nd4j::DataType::INT32);
|
NDArray x2('c', {2,2}, {-1,-2,-3,-4}, nd4j::DataType::INT32);
|
||||||
NDArray x3('c', {2,2}, {1.5,1.5,1.5,1.5}, nd4j::DataType::DOUBLE);
|
NDArray x3('c', {2,2}, {1.5,1.5,1.5,1.5}, nd4j::DataType::DOUBLE);
|
||||||
NDArray x4('c', {2,2}, {1,2,3,4}, nd4j::DataType::DOUBLE);
|
NDArray x4('c', {2,2}, {1,2,3,4}, nd4j::DataType::DOUBLE);
|
||||||
NDArray exp1('c', {0}, {-30}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray exp2('c', {0}, {15}, nd4j::DataType::DOUBLE);
|
|
||||||
|
|
||||||
NDArray scalar1('c', {0}, {100}, nd4j::DataType::FLOAT32);
|
NDArray exp1('c', {}, {-30.f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray scalar2('c', {0}, {100}, nd4j::DataType::DOUBLE);
|
NDArray exp2('c', {}, {15.}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
|
NDArray scalar1('c', {}, {100.f}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray scalar2('c', {}, {100.}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
void *dX1, *dX2, *dX3, *dX4, *dZ1, *dZ2;
|
void *dX1, *dX2, *dX3, *dX4, *dZ1, *dZ2;
|
||||||
Nd4jLong *dX1ShapeInfo, *dX3ShapeInfo, *dZ1ShapeInfo, *dZ2ShapeInfo;
|
Nd4jLong *dX1ShapeInfo, *dX3ShapeInfo, *dZ1ShapeInfo, *dZ2ShapeInfo;
|
||||||
@ -262,6 +278,13 @@ TEST_F(CudaBasicsTests1, execReduce3Scalar_1) {
|
|||||||
cudaResult = cudaStreamCreate(&stream);
|
cudaResult = cudaStreamCreate(&stream);
|
||||||
ASSERT_EQ(0, cudaResult);
|
ASSERT_EQ(0, cudaResult);
|
||||||
|
|
||||||
|
x1.syncToHost();
|
||||||
|
x2.syncToHost();
|
||||||
|
x3.syncToHost();
|
||||||
|
x4.syncToHost();
|
||||||
|
scalar1.syncToHost();
|
||||||
|
scalar2.syncToHost();
|
||||||
|
|
||||||
cudaMemcpyAsync(dX1, x1.buffer(), x1.lengthOf() * x1.sizeOfT(), cudaMemcpyHostToDevice, stream);
|
cudaMemcpyAsync(dX1, x1.buffer(), x1.lengthOf() * x1.sizeOfT(), cudaMemcpyHostToDevice, stream);
|
||||||
cudaMemcpyAsync(dX2, x2.buffer(), x2.lengthOf() * x2.sizeOfT(), cudaMemcpyHostToDevice, stream);
|
cudaMemcpyAsync(dX2, x2.buffer(), x2.lengthOf() * x2.sizeOfT(), cudaMemcpyHostToDevice, stream);
|
||||||
cudaMemcpyAsync(dX3, x3.buffer(), x3.lengthOf() * x3.sizeOfT(), cudaMemcpyHostToDevice, stream);
|
cudaMemcpyAsync(dX3, x3.buffer(), x3.lengthOf() * x3.sizeOfT(), cudaMemcpyHostToDevice, stream);
|
||||||
@ -288,6 +311,9 @@ TEST_F(CudaBasicsTests1, execReduce3Scalar_1) {
|
|||||||
cudaResult = cudaStreamSynchronize(stream);
|
cudaResult = cudaStreamSynchronize(stream);
|
||||||
ASSERT_EQ(0, cudaResult);
|
ASSERT_EQ(0, cudaResult);
|
||||||
|
|
||||||
|
scalar1.tickWriteHost();
|
||||||
|
scalar2.tickWriteHost();
|
||||||
|
|
||||||
cudaMemcpyAsync(scalar1.buffer(), dZ1, scalar1.lengthOf() * scalar1.sizeOfT(), cudaMemcpyDeviceToHost, stream);
|
cudaMemcpyAsync(scalar1.buffer(), dZ1, scalar1.lengthOf() * scalar1.sizeOfT(), cudaMemcpyDeviceToHost, stream);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream);
|
cudaResult = cudaStreamSynchronize(stream);
|
||||||
@ -327,11 +353,16 @@ TEST_F(CudaBasicsTests1, execReduce3_1) {
|
|||||||
NDArray x('c', {2,2}, {1,2,3,4}, nd4j::DataType::INT32);
|
NDArray x('c', {2,2}, {1,2,3,4}, nd4j::DataType::INT32);
|
||||||
NDArray y('c', {2,2}, {-1,-2,-3,-4}, nd4j::DataType::INT32);
|
NDArray y('c', {2,2}, {-1,-2,-3,-4}, nd4j::DataType::INT32);
|
||||||
|
|
||||||
NDArray exp('c', {0}, {-30}, nd4j::DataType::FLOAT32);
|
NDArray exp('c', {}, {-30.f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray z('c', {0}, {100}, nd4j::DataType::FLOAT32);
|
NDArray z('c', {}, {100.f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
std::vector<int> dimensions = {0, 1};
|
std::vector<int> dimensions = {0, 1};
|
||||||
|
|
||||||
|
x.syncToHost();
|
||||||
|
y.syncToHost();
|
||||||
|
z.syncToHost();
|
||||||
|
|
||||||
|
|
||||||
std::vector<std::pair<void*,size_t>> hostData;
|
std::vector<std::pair<void*,size_t>> hostData;
|
||||||
hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions
|
hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions
|
||||||
std::vector<void*> devicePtrs(hostData.size(), nullptr);
|
std::vector<void*> devicePtrs(hostData.size(), nullptr);
|
||||||
@ -354,7 +385,7 @@ TEST_F(CudaBasicsTests1, execReduce3_1) {
|
|||||||
nullptr, nullptr, nullptr, nullptr);
|
nullptr, nullptr, nullptr, nullptr);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -374,8 +405,8 @@ TEST_F(CudaBasicsTests1, execReduce3_2) {
|
|||||||
NDArray x('c', {2,2}, {1.5,1.5,1.5,1.5}, nd4j::DataType::DOUBLE);
|
NDArray x('c', {2,2}, {1.5,1.5,1.5,1.5}, nd4j::DataType::DOUBLE);
|
||||||
NDArray y('c', {2,2}, {1,2,3,4}, nd4j::DataType::DOUBLE);
|
NDArray y('c', {2,2}, {1,2,3,4}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
NDArray exp('c', {0}, {15}, nd4j::DataType::DOUBLE);
|
NDArray exp('c', {}, {15.}, nd4j::DataType::DOUBLE);
|
||||||
NDArray z('c', {0}, {100}, nd4j::DataType::DOUBLE);
|
NDArray z('c', {}, {100.}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
std::vector<int> dimensions = {0, 1};
|
std::vector<int> dimensions = {0, 1};
|
||||||
|
|
||||||
@ -404,7 +435,7 @@ TEST_F(CudaBasicsTests1, execReduce3_2) {
|
|||||||
|
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -470,7 +501,7 @@ TEST_F(CudaBasicsTests1, execReduce3_3) {
|
|||||||
(Nd4jLong*)devicePtrs[3], (Nd4jLong*)devicePtrs[4]);
|
(Nd4jLong*)devicePtrs[3], (Nd4jLong*)devicePtrs[4]);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -536,7 +567,7 @@ TEST_F(CudaBasicsTests1, execReduce3_4) {
|
|||||||
(Nd4jLong*)devicePtrs[3], (Nd4jLong*)devicePtrs[4]);
|
(Nd4jLong*)devicePtrs[3], (Nd4jLong*)devicePtrs[4]);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -602,7 +633,7 @@ TEST_F(CudaBasicsTests1, execReduce3_5) {
|
|||||||
(Nd4jLong*)devicePtrs[3], (Nd4jLong*)devicePtrs[4]);
|
(Nd4jLong*)devicePtrs[3], (Nd4jLong*)devicePtrs[4]);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -668,7 +699,7 @@ TEST_F(CudaBasicsTests1, execReduce3All_1) {
|
|||||||
(Nd4jLong*)devicePtrs[3], (Nd4jLong*)devicePtrs[4]);
|
(Nd4jLong*)devicePtrs[3], (Nd4jLong*)devicePtrs[4]);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -734,7 +765,7 @@ TEST_F(CudaBasicsTests1, execReduce3All_2) {
|
|||||||
(Nd4jLong*)devicePtrs[3], (Nd4jLong*)devicePtrs[4]);
|
(Nd4jLong*)devicePtrs[3], (Nd4jLong*)devicePtrs[4]);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -789,7 +820,7 @@ TEST_F(CudaBasicsTests1, execIndexReduce_1) {
|
|||||||
(Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]);
|
(Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -851,7 +882,7 @@ TEST_F(CudaBasicsTests1, execIndexReduce_2) {
|
|||||||
(Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]);
|
(Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -912,7 +943,7 @@ TEST_F(CudaBasicsTests1, execIndexReduce_3) {
|
|||||||
(Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]);
|
(Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -934,7 +965,7 @@ TEST_F(CudaBasicsTests1, execScalar_1) {
|
|||||||
|
|
||||||
NDArray x('c', {2,3}, {0,1,2,3,4,5}, nd4j::DataType::INT64);
|
NDArray x('c', {2,3}, {0,1,2,3,4,5}, nd4j::DataType::INT64);
|
||||||
NDArray exp('c',{2,3}, {0,0,1,1,2,2}, nd4j::DataType::INT64);
|
NDArray exp('c',{2,3}, {0,0,1,1,2,2}, nd4j::DataType::INT64);
|
||||||
NDArray scalar('c',{0}, {2}, nd4j::DataType::FLOAT32);
|
NDArray scalar('c',{}, {2.f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray z('c', {2,3}, {100,100,100,100,100,100}, nd4j::DataType::INT64);
|
NDArray z('c', {2,3}, {100,100,100,100,100,100}, nd4j::DataType::INT64);
|
||||||
|
|
||||||
// create cuda stream and LaunchContext
|
// create cuda stream and LaunchContext
|
||||||
@ -951,7 +982,7 @@ TEST_F(CudaBasicsTests1, execScalar_1) {
|
|||||||
nullptr);
|
nullptr);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -969,7 +1000,7 @@ TEST_F(CudaBasicsTests1, execScalar_2) {
|
|||||||
|
|
||||||
NDArray x('c', {2,3}, {-1,-2,-3,-4,-5,-6}, nd4j::DataType::INT64);
|
NDArray x('c', {2,3}, {-1,-2,-3,-4,-5,-6}, nd4j::DataType::INT64);
|
||||||
NDArray exp('c',{2,3}, {10,10,10,10,10,10}, nd4j::DataType::FLOAT32);
|
NDArray exp('c',{2,3}, {10,10,10,10,10,10}, nd4j::DataType::FLOAT32);
|
||||||
NDArray scalar('c',{0}, {10}, nd4j::DataType::FLOAT32);
|
NDArray scalar('c',{}, {10.f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray z('c', {2,3}, {100,100,100,100,100,100}, nd4j::DataType::FLOAT32);
|
NDArray z('c', {2,3}, {100,100,100,100,100,100}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
// create cuda stream and LaunchContext
|
// create cuda stream and LaunchContext
|
||||||
@ -986,7 +1017,7 @@ TEST_F(CudaBasicsTests1, execScalar_2) {
|
|||||||
nullptr);
|
nullptr);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -1044,7 +1075,7 @@ TEST_F(CudaBasicsTests1, execScalar_3) {
|
|||||||
nullptr, nullptr);
|
nullptr, nullptr);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -1062,7 +1093,7 @@ TEST_F(CudaBasicsTests1, execScalar_3) {
|
|||||||
TEST_F(CudaBasicsTests1, execScalarBool_1) {
|
TEST_F(CudaBasicsTests1, execScalarBool_1) {
|
||||||
|
|
||||||
NDArray x('c', {2,3}, {-1,-2,0,1,2,3}, nd4j::DataType::BFLOAT16);
|
NDArray x('c', {2,3}, {-1,-2,0,1,2,3}, nd4j::DataType::BFLOAT16);
|
||||||
NDArray scalar('c',{0}, {0}, nd4j::DataType::BFLOAT16);
|
NDArray scalar('c',{}, {0}, nd4j::DataType::BFLOAT16);
|
||||||
NDArray exp('c',{2,3}, {0,0,0,1,1,1}, nd4j::DataType::BOOL);
|
NDArray exp('c',{2,3}, {0,0,0,1,1,1}, nd4j::DataType::BOOL);
|
||||||
NDArray z('c', {2,3}, {100,100,100,100,100,100,}, nd4j::DataType::BOOL);
|
NDArray z('c', {2,3}, {100,100,100,100,100,100,}, nd4j::DataType::BOOL);
|
||||||
|
|
||||||
@ -1081,7 +1112,7 @@ TEST_F(CudaBasicsTests1, execScalarBool_1) {
|
|||||||
nullptr);
|
nullptr);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -1134,7 +1165,7 @@ TEST_F(CudaBasicsTests1, execScalarBool_2) {
|
|||||||
nullptr, nullptr);
|
nullptr, nullptr);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -1194,7 +1225,7 @@ TEST_F(CudaBasicsTests1, execBroadcast_1) {
|
|||||||
nullptr, nullptr);
|
nullptr, nullptr);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -1254,7 +1285,7 @@ TEST_F(CudaBasicsTests1, execBroadcast_2) {
|
|||||||
nullptr, nullptr);
|
nullptr, nullptr);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -1311,7 +1342,7 @@ TEST_F(CudaBasicsTests1, execBroadcastBool_1) {
|
|||||||
nullptr, nullptr);
|
nullptr, nullptr);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -1369,7 +1400,7 @@ TEST_F(CudaBasicsTests1, execBroadcastBool_2) {
|
|||||||
nullptr, nullptr);
|
nullptr, nullptr);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -1410,7 +1441,7 @@ TEST_F(CudaBasicsTests1, execPairwiseTransform_1) {
|
|||||||
nullptr);
|
nullptr);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -1444,7 +1475,7 @@ TEST_F(CudaBasicsTests1, execPairwiseBoolTransform_1) {
|
|||||||
nullptr);
|
nullptr);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -1477,7 +1508,7 @@ TEST_F(CudaBasicsTests1, execTransformFloat_1) {
|
|||||||
nullptr, nullptr, nullptr);
|
nullptr, nullptr, nullptr);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -1507,7 +1538,7 @@ TEST_F(CudaBasicsTests1, execTransformFloat_2) {
|
|||||||
nullptr, nullptr, nullptr);
|
nullptr, nullptr, nullptr);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -1524,7 +1555,6 @@ TEST_F(CudaBasicsTests1, execTransformAny_1) {
|
|||||||
NDArray z('c', {4,1}, {100,100,100,100}, nd4j::DataType::INT32);
|
NDArray z('c', {4,1}, {100,100,100,100}, nd4j::DataType::INT32);
|
||||||
NDArray exp('c', {4,1}, {0, 2, 6, 12}, nd4j::DataType::INT32);
|
NDArray exp('c', {4,1}, {0, 2, 6, 12}, nd4j::DataType::INT32);
|
||||||
x.permutei({1,0});
|
x.permutei({1,0});
|
||||||
x.syncShape();
|
|
||||||
|
|
||||||
// create cuda stream and LaunchContext
|
// create cuda stream and LaunchContext
|
||||||
cudaError_t cudaResult;
|
cudaError_t cudaResult;
|
||||||
@ -1539,7 +1569,7 @@ TEST_F(CudaBasicsTests1, execTransformAny_1) {
|
|||||||
nullptr, nullptr, nullptr);
|
nullptr, nullptr, nullptr);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -1569,7 +1599,7 @@ TEST_F(CudaBasicsTests1, execTransformAny_2) {
|
|||||||
nullptr, nullptr, nullptr);
|
nullptr, nullptr, nullptr);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -1586,7 +1616,6 @@ TEST_F(CudaBasicsTests1, execTransformStrict_1) {
|
|||||||
NDArray z('c', {3,2}, {100,100,100,100,100,100}, nd4j::DataType::DOUBLE);
|
NDArray z('c', {3,2}, {100,100,100,100,100,100}, nd4j::DataType::DOUBLE);
|
||||||
NDArray exp('c', {3,2}, {0, 3, 12, 27, 48, 75}, nd4j::DataType::DOUBLE);
|
NDArray exp('c', {3,2}, {0, 3, 12, 27, 48, 75}, nd4j::DataType::DOUBLE);
|
||||||
x.permutei({1,0});
|
x.permutei({1,0});
|
||||||
x.syncShape();
|
|
||||||
|
|
||||||
// create cuda stream and LaunchContext
|
// create cuda stream and LaunchContext
|
||||||
cudaError_t cudaResult;
|
cudaError_t cudaResult;
|
||||||
@ -1601,7 +1630,7 @@ TEST_F(CudaBasicsTests1, execTransformStrict_1) {
|
|||||||
nullptr, nullptr, nullptr);
|
nullptr, nullptr, nullptr);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -1631,7 +1660,7 @@ TEST_F(CudaBasicsTests1, execTransformStrict_2) {
|
|||||||
nullptr, nullptr, nullptr);
|
nullptr, nullptr, nullptr);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -1648,7 +1677,6 @@ TEST_F(CudaBasicsTests1, execTransformSame_1) {
|
|||||||
NDArray z('c', {1,6}, {100,100,100,100,100,100}, nd4j::DataType::DOUBLE);
|
NDArray z('c', {1,6}, {100,100,100,100,100,100}, nd4j::DataType::DOUBLE);
|
||||||
NDArray exp('c', {1,6}, {0,2.25,6.25,12.25,20.25,30.25}, nd4j::DataType::DOUBLE);
|
NDArray exp('c', {1,6}, {0,2.25,6.25,12.25,20.25,30.25}, nd4j::DataType::DOUBLE);
|
||||||
x.permutei({1,0});
|
x.permutei({1,0});
|
||||||
x.syncShape();
|
|
||||||
|
|
||||||
// create cuda stream and LaunchContext
|
// create cuda stream and LaunchContext
|
||||||
cudaError_t cudaResult;
|
cudaError_t cudaResult;
|
||||||
@ -1663,7 +1691,7 @@ TEST_F(CudaBasicsTests1, execTransformSame_1) {
|
|||||||
nullptr, nullptr, nullptr);
|
nullptr, nullptr, nullptr);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -1693,7 +1721,7 @@ TEST_F(CudaBasicsTests1, execTransformSame_2) {
|
|||||||
nullptr, nullptr, nullptr);
|
nullptr, nullptr, nullptr);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -1710,7 +1738,6 @@ TEST_F(CudaBasicsTests1, execTransformBool_1) {
|
|||||||
NDArray z('c', {1,6}, {100,100,100,100,100,100}, nd4j::DataType::BOOL);
|
NDArray z('c', {1,6}, {100,100,100,100,100,100}, nd4j::DataType::BOOL);
|
||||||
NDArray exp('c', {1,6}, {0,0,1,0,1,0}, nd4j::DataType::BOOL);
|
NDArray exp('c', {1,6}, {0,0,1,0,1,0}, nd4j::DataType::BOOL);
|
||||||
x.permutei({1,0});
|
x.permutei({1,0});
|
||||||
x.syncShape();
|
|
||||||
|
|
||||||
// create cuda stream and LaunchContext
|
// create cuda stream and LaunchContext
|
||||||
cudaError_t cudaResult;
|
cudaError_t cudaResult;
|
||||||
@ -1725,7 +1752,7 @@ TEST_F(CudaBasicsTests1, execTransformBool_1) {
|
|||||||
nullptr, nullptr, nullptr);
|
nullptr, nullptr, nullptr);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -1755,7 +1782,7 @@ TEST_F(CudaBasicsTests1, execTransformBool_2) {
|
|||||||
nullptr, nullptr, nullptr);
|
nullptr, nullptr, nullptr);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -1772,7 +1799,6 @@ TEST_F(CudaBasicsTests1, execReduceFloat_1) {
|
|||||||
NDArray z('c', {3}, {100,100,100}, nd4j::DataType::FLOAT32);
|
NDArray z('c', {3}, {100,100,100}, nd4j::DataType::FLOAT32);
|
||||||
NDArray exp('c', {3}, {2.5, 6.5, 10.5}, nd4j::DataType::FLOAT32);
|
NDArray exp('c', {3}, {2.5, 6.5, 10.5}, nd4j::DataType::FLOAT32);
|
||||||
x.permutei({2,1,0});
|
x.permutei({2,1,0});
|
||||||
x.syncShape();
|
|
||||||
|
|
||||||
std::vector<int> dimensions = {0,2};
|
std::vector<int> dimensions = {0,2};
|
||||||
|
|
||||||
@ -1807,7 +1833,7 @@ TEST_F(CudaBasicsTests1, execReduceFloat_1) {
|
|||||||
(Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]);
|
(Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -1861,7 +1887,7 @@ TEST_F(CudaBasicsTests1, execReduceFloat_2) {
|
|||||||
(Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]);
|
(Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -1882,7 +1908,6 @@ TEST_F(CudaBasicsTests1, execReduceSame_1) {
|
|||||||
NDArray z('c', {3}, {100,100,100}, nd4j::DataType::INT32);
|
NDArray z('c', {3}, {100,100,100}, nd4j::DataType::INT32);
|
||||||
NDArray exp('c', {3}, {20, 52, 84}, nd4j::DataType::INT32);
|
NDArray exp('c', {3}, {20, 52, 84}, nd4j::DataType::INT32);
|
||||||
x.permutei({2,1,0});
|
x.permutei({2,1,0});
|
||||||
x.syncShape();
|
|
||||||
|
|
||||||
std::vector<int> dimensions = {0,2};
|
std::vector<int> dimensions = {0,2};
|
||||||
|
|
||||||
@ -1917,7 +1942,7 @@ TEST_F(CudaBasicsTests1, execReduceSame_1) {
|
|||||||
(Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]);
|
(Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -1971,7 +1996,7 @@ TEST_F(CudaBasicsTests1, execReduceSame_2) {
|
|||||||
(Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]);
|
(Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -1992,7 +2017,7 @@ TEST_F(CudaBasicsTests1, execReduceBool_1) {
|
|||||||
NDArray z('c', {3}, {100,100,100}, nd4j::DataType::BOOL);
|
NDArray z('c', {3}, {100,100,100}, nd4j::DataType::BOOL);
|
||||||
NDArray exp('c', {3}, {0, 1, 1}, nd4j::DataType::BOOL);
|
NDArray exp('c', {3}, {0, 1, 1}, nd4j::DataType::BOOL);
|
||||||
x.permutei({2,1,0});
|
x.permutei({2,1,0});
|
||||||
x.syncShape();
|
|
||||||
|
|
||||||
std::vector<int> dimensions = {0,2};
|
std::vector<int> dimensions = {0,2};
|
||||||
|
|
||||||
@ -2027,7 +2052,7 @@ TEST_F(CudaBasicsTests1, execReduceBool_1) {
|
|||||||
(Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]);
|
(Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -2081,7 +2106,7 @@ TEST_F(CudaBasicsTests1, execReduceBool_2) {
|
|||||||
(Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]);
|
(Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -2102,7 +2127,6 @@ TEST_F(CudaBasicsTests1, execReduceLong_1) {
|
|||||||
NDArray z('c', {3}, {100,100,100}, nd4j::DataType::INT64);
|
NDArray z('c', {3}, {100,100,100}, nd4j::DataType::INT64);
|
||||||
NDArray exp('c', {3}, {5,6,6}, nd4j::DataType::INT64);
|
NDArray exp('c', {3}, {5,6,6}, nd4j::DataType::INT64);
|
||||||
x.permutei({2,1,0});
|
x.permutei({2,1,0});
|
||||||
x.syncShape();
|
|
||||||
|
|
||||||
std::vector<int> dimensions = {0,2};
|
std::vector<int> dimensions = {0,2};
|
||||||
|
|
||||||
@ -2137,7 +2161,7 @@ TEST_F(CudaBasicsTests1, execReduceLong_1) {
|
|||||||
(Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]);
|
(Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -2191,7 +2215,7 @@ TEST_F(CudaBasicsTests1, execReduceLong_2) {
|
|||||||
(Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]);
|
(Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -2209,10 +2233,9 @@ TEST_F(CudaBasicsTests1, execReduceLong_2) {
|
|||||||
TEST_F(CudaBasicsTests1, execReduceFloatScalar_1) {
|
TEST_F(CudaBasicsTests1, execReduceFloatScalar_1) {
|
||||||
|
|
||||||
NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, nd4j::DataType::INT32);
|
NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, nd4j::DataType::INT32);
|
||||||
NDArray z('c', {0}, {100}, nd4j::DataType::FLOAT32);
|
NDArray z('c', {}, {100}, nd4j::DataType::FLOAT32);
|
||||||
NDArray exp('c', {0}, {6.5}, nd4j::DataType::FLOAT32);
|
NDArray exp('c', {}, {6.5}, nd4j::DataType::FLOAT32);
|
||||||
x.permutei({2,1,0});
|
x.permutei({2,1,0});
|
||||||
x.syncShape();
|
|
||||||
|
|
||||||
// create cuda stream and LaunchContext
|
// create cuda stream and LaunchContext
|
||||||
cudaError_t cudaResult;
|
cudaError_t cudaResult;
|
||||||
@ -2233,7 +2256,7 @@ TEST_F(CudaBasicsTests1, execReduceFloatScalar_1) {
|
|||||||
nullptr, z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
nullptr, z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -2247,8 +2270,8 @@ TEST_F(CudaBasicsTests1, execReduceFloatScalar_1) {
|
|||||||
TEST_F(CudaBasicsTests1, execReduceFloatScalar_2) {
|
TEST_F(CudaBasicsTests1, execReduceFloatScalar_2) {
|
||||||
|
|
||||||
NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, nd4j::DataType::INT32);
|
NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, nd4j::DataType::INT32);
|
||||||
NDArray z('c', {0}, {100}, nd4j::DataType::DOUBLE);
|
NDArray z('c', {}, {100}, nd4j::DataType::DOUBLE);
|
||||||
NDArray exp('c', {0}, {6.5}, nd4j::DataType::DOUBLE);
|
NDArray exp('c', {}, {6.5}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
// create cuda stream and LaunchContext
|
// create cuda stream and LaunchContext
|
||||||
cudaError_t cudaResult;
|
cudaError_t cudaResult;
|
||||||
@ -2269,7 +2292,7 @@ TEST_F(CudaBasicsTests1, execReduceFloatScalar_2) {
|
|||||||
nullptr, z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
nullptr, z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -2283,10 +2306,9 @@ TEST_F(CudaBasicsTests1, execReduceFloatScalar_2) {
|
|||||||
TEST_F(CudaBasicsTests1, execReduceSameScalar_1) {
|
TEST_F(CudaBasicsTests1, execReduceSameScalar_1) {
|
||||||
|
|
||||||
NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, nd4j::DataType::INT32);
|
NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, nd4j::DataType::INT32);
|
||||||
NDArray z('c', {0}, {100}, nd4j::DataType::INT32);
|
NDArray z('c', {}, {100}, nd4j::DataType::INT32);
|
||||||
NDArray exp('c', {0}, {156}, nd4j::DataType::INT32);
|
NDArray exp('c', {}, {156}, nd4j::DataType::INT32);
|
||||||
x.permutei({2,1,0});
|
x.permutei({2,1,0});
|
||||||
x.syncShape();
|
|
||||||
|
|
||||||
// create cuda stream and LaunchContext
|
// create cuda stream and LaunchContext
|
||||||
cudaError_t cudaResult;
|
cudaError_t cudaResult;
|
||||||
@ -2307,7 +2329,7 @@ TEST_F(CudaBasicsTests1, execReduceSameScalar_1) {
|
|||||||
nullptr, z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
nullptr, z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -2321,8 +2343,8 @@ TEST_F(CudaBasicsTests1, execReduceSameScalar_1) {
|
|||||||
TEST_F(CudaBasicsTests1, execReduceSameScalar_2) {
|
TEST_F(CudaBasicsTests1, execReduceSameScalar_2) {
|
||||||
|
|
||||||
NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, nd4j::DataType::DOUBLE);
|
NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, nd4j::DataType::DOUBLE);
|
||||||
NDArray z('c', {0}, {100}, nd4j::DataType::DOUBLE);
|
NDArray z('c', {}, {100}, nd4j::DataType::DOUBLE);
|
||||||
NDArray exp('c', {0}, {156}, nd4j::DataType::DOUBLE);
|
NDArray exp('c', {}, {156}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
// create cuda stream and LaunchContext
|
// create cuda stream and LaunchContext
|
||||||
cudaError_t cudaResult;
|
cudaError_t cudaResult;
|
||||||
@ -2343,7 +2365,7 @@ TEST_F(CudaBasicsTests1, execReduceSameScalar_2) {
|
|||||||
nullptr, z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
nullptr, z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -2357,8 +2379,8 @@ TEST_F(CudaBasicsTests1, execReduceSameScalar_2) {
|
|||||||
TEST_F(CudaBasicsTests1, execReduceBoolScalar_1) {
|
TEST_F(CudaBasicsTests1, execReduceBoolScalar_1) {
|
||||||
|
|
||||||
NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,-7,-8,-9,-10,-11,-12,-13,-14,-15,-16,-17,-18}, nd4j::DataType::INT32);
|
NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,-7,-8,-9,-10,-11,-12,-13,-14,-15,-16,-17,-18}, nd4j::DataType::INT32);
|
||||||
NDArray z('c', {0}, {100}, nd4j::DataType::BOOL);
|
NDArray z('c', {}, {100}, nd4j::DataType::BOOL);
|
||||||
NDArray exp('c', {0}, {1}, nd4j::DataType::BOOL);
|
NDArray exp('c', {}, {1}, nd4j::DataType::BOOL);
|
||||||
x.permutei({2,1,0});
|
x.permutei({2,1,0});
|
||||||
x.syncShape();
|
x.syncShape();
|
||||||
|
|
||||||
@ -2381,7 +2403,7 @@ TEST_F(CudaBasicsTests1, execReduceBoolScalar_1) {
|
|||||||
nullptr, z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
nullptr, z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -2395,8 +2417,8 @@ TEST_F(CudaBasicsTests1, execReduceBoolScalar_1) {
|
|||||||
TEST_F(CudaBasicsTests1, execReduceBoolScalar_2) {
|
TEST_F(CudaBasicsTests1, execReduceBoolScalar_2) {
|
||||||
|
|
||||||
NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,-7,-8,-9,-10,-11,-12,-13,-14,-15,-16,-17,-18}, nd4j::DataType::DOUBLE);
|
NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,-7,-8,-9,-10,-11,-12,-13,-14,-15,-16,-17,-18}, nd4j::DataType::DOUBLE);
|
||||||
NDArray z('c', {0}, {100}, nd4j::DataType::BOOL);
|
NDArray z('c', {}, {100}, nd4j::DataType::BOOL);
|
||||||
NDArray exp('c', {0}, {1}, nd4j::DataType::BOOL);
|
NDArray exp('c', {}, {1}, nd4j::DataType::BOOL);
|
||||||
|
|
||||||
// create cuda stream and LaunchContext
|
// create cuda stream and LaunchContext
|
||||||
cudaError_t cudaResult;
|
cudaError_t cudaResult;
|
||||||
@ -2417,7 +2439,7 @@ TEST_F(CudaBasicsTests1, execReduceBoolScalar_2) {
|
|||||||
nullptr, z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
nullptr, z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -2431,8 +2453,8 @@ TEST_F(CudaBasicsTests1, execReduceBoolScalar_2) {
|
|||||||
TEST_F(CudaBasicsTests1, execReduceLongScalar_1) {
|
TEST_F(CudaBasicsTests1, execReduceLongScalar_1) {
|
||||||
|
|
||||||
NDArray x('c', {2,3,4}, {-5,0,-3,0,-1,0,1,2,3,4,5,6,7,0,9,10,11,0,13,14,0,16,0,18}, nd4j::DataType::INT32);
|
NDArray x('c', {2,3,4}, {-5,0,-3,0,-1,0,1,2,3,4,5,6,7,0,9,10,11,0,13,14,0,16,0,18}, nd4j::DataType::INT32);
|
||||||
NDArray z('c', {0}, {100}, nd4j::DataType::INT64);
|
NDArray z('c', {}, {100}, nd4j::DataType::INT64);
|
||||||
NDArray exp('c', {0}, {17}, nd4j::DataType::INT64);
|
NDArray exp('c', {}, {17}, nd4j::DataType::INT64);
|
||||||
x.permutei({2,1,0});
|
x.permutei({2,1,0});
|
||||||
x.syncShape();
|
x.syncShape();
|
||||||
|
|
||||||
@ -2455,7 +2477,7 @@ TEST_F(CudaBasicsTests1, execReduceLongScalar_1) {
|
|||||||
nullptr, z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
nullptr, z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -2469,8 +2491,8 @@ TEST_F(CudaBasicsTests1, execReduceLongScalar_1) {
|
|||||||
TEST_F(CudaBasicsTests1, execReduceLongScalar_2) {
|
TEST_F(CudaBasicsTests1, execReduceLongScalar_2) {
|
||||||
|
|
||||||
NDArray x('c', {2,3,4}, {-5,0,-3,0,-1,0,1,2,3,4,5,6,7,0,9,10,11,0,13,14,0,16,0,18}, nd4j::DataType::DOUBLE);
|
NDArray x('c', {2,3,4}, {-5,0,-3,0,-1,0,1,2,3,4,5,6,7,0,9,10,11,0,13,14,0,16,0,18}, nd4j::DataType::DOUBLE);
|
||||||
NDArray z('c', {0}, {100}, nd4j::DataType::INT64);
|
NDArray z('c', {}, {100}, nd4j::DataType::INT64);
|
||||||
NDArray exp('c', {0}, {17}, nd4j::DataType::INT64);
|
NDArray exp('c', {}, {17}, nd4j::DataType::INT64);
|
||||||
|
|
||||||
// create cuda stream and LaunchContext
|
// create cuda stream and LaunchContext
|
||||||
cudaError_t cudaResult;
|
cudaError_t cudaResult;
|
||||||
@ -2491,7 +2513,7 @@ TEST_F(CudaBasicsTests1, execReduceLongScalar_2) {
|
|||||||
nullptr, z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
nullptr, z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -2534,6 +2556,9 @@ TEST_F(CudaBasicsTests1, execReduce3TAD_1) {
|
|||||||
|
|
||||||
cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult);
|
cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult);
|
||||||
|
|
||||||
|
x.syncToDevice();
|
||||||
|
y.syncToDevice();
|
||||||
|
|
||||||
// call cuda kernel which calculates result
|
// call cuda kernel which calculates result
|
||||||
NativeOpExecutioner::execReduce3TAD(&lc, nd4j::reduce3::Dot,
|
NativeOpExecutioner::execReduce3TAD(&lc, nd4j::reduce3::Dot,
|
||||||
nullptr, x.getShapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
nullptr, x.getShapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
||||||
@ -2544,7 +2569,7 @@ TEST_F(CudaBasicsTests1, execReduce3TAD_1) {
|
|||||||
(Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]);
|
(Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -2600,7 +2625,7 @@ TEST_F(CudaBasicsTests1, execReduce3TAD_2) {
|
|||||||
(Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]);
|
(Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -2656,7 +2681,7 @@ TEST_F(CudaBasicsTests1, execReduce3TAD_3) {
|
|||||||
(Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]);
|
(Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -2674,8 +2699,8 @@ TEST_F(CudaBasicsTests1, execReduce3TAD_4) {
|
|||||||
|
|
||||||
NDArray x('c', {2,2,3}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6}, nd4j::DataType::DOUBLE);
|
NDArray x('c', {2,2,3}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6}, nd4j::DataType::DOUBLE);
|
||||||
NDArray y('c', {2,2,3}, {10,20,30,40,50,60,70,80,90,100,110,120}, nd4j::DataType::DOUBLE);
|
NDArray y('c', {2,2,3}, {10,20,30,40,50,60,70,80,90,100,110,120}, nd4j::DataType::DOUBLE);
|
||||||
NDArray exp('c', {0}, {1820}, nd4j::DataType::FLOAT32);
|
NDArray exp('c', {}, {1820}, nd4j::DataType::FLOAT32);
|
||||||
NDArray z('c', {0}, {100}, nd4j::DataType::FLOAT32);
|
NDArray z('c', {}, {100}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
std::vector<int> dimensions = {0,1,2};
|
std::vector<int> dimensions = {0,1,2};
|
||||||
|
|
||||||
@ -2711,7 +2736,7 @@ TEST_F(CudaBasicsTests1, execReduce3TAD_4) {
|
|||||||
(Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]);
|
(Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -2728,8 +2753,8 @@ TEST_F(CudaBasicsTests1, execReduce3TAD_4) {
|
|||||||
TEST_F(CudaBasicsTests1, execSummaryStats_1) {
|
TEST_F(CudaBasicsTests1, execSummaryStats_1) {
|
||||||
|
|
||||||
NDArray x('c', {2,2,3}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6}, nd4j::DataType::INT64);
|
NDArray x('c', {2,2,3}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6}, nd4j::DataType::INT64);
|
||||||
NDArray exp('c', {0}, {3.605551}, nd4j::DataType::FLOAT32);
|
NDArray exp('c', {}, {3.605551}, nd4j::DataType::FLOAT32);
|
||||||
NDArray z('c', {0}, {100}, nd4j::DataType::FLOAT32);
|
NDArray z('c', {}, {100}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
// create cuda stream and LaunchContext
|
// create cuda stream and LaunchContext
|
||||||
cudaError_t cudaResult;
|
cudaError_t cudaResult;
|
||||||
@ -2748,7 +2773,7 @@ TEST_F(CudaBasicsTests1, execSummaryStats_1) {
|
|||||||
true);
|
true);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -2799,7 +2824,7 @@ TEST_F(CudaBasicsTests1, execSummaryStats_2) {
|
|||||||
true);
|
true);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -2853,7 +2878,7 @@ TEST_F(CudaBasicsTests1, execSummaryStats_3) {
|
|||||||
true);
|
true);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -2870,8 +2895,8 @@ TEST_F(CudaBasicsTests1, execSummaryStats_3) {
|
|||||||
TEST_F(CudaBasicsTests1, execSummaryStatsScalar_1) {
|
TEST_F(CudaBasicsTests1, execSummaryStatsScalar_1) {
|
||||||
|
|
||||||
NDArray x('c', {2,2,3}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6}, nd4j::DataType::INT64);
|
NDArray x('c', {2,2,3}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6}, nd4j::DataType::INT64);
|
||||||
NDArray exp('c', {0}, {3.605551}, nd4j::DataType::FLOAT32);
|
NDArray exp('c', {}, {3.605551}, nd4j::DataType::FLOAT32);
|
||||||
NDArray z('c', {0}, {100}, nd4j::DataType::FLOAT32);
|
NDArray z('c', {}, {100}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
// create cuda stream and LaunchContext
|
// create cuda stream and LaunchContext
|
||||||
cudaError_t cudaResult;
|
cudaError_t cudaResult;
|
||||||
@ -2890,7 +2915,7 @@ TEST_F(CudaBasicsTests1, execSummaryStatsScalar_1) {
|
|||||||
true);
|
true);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -2932,7 +2957,7 @@ TEST_F(CudaBasicsTests1, execRandom_1) {
|
|||||||
devicePtrs[0]);
|
devicePtrs[0]);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -2977,7 +3002,7 @@ TEST_F(CudaBasicsTests1, execRandom_2) {
|
|||||||
devicePtrs[0]);
|
devicePtrs[0]);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -3020,7 +3045,7 @@ TEST_F(CudaBasicsTests1, execRandom_3) {
|
|||||||
devicePtrs[0]);
|
devicePtrs[0]);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
@ -3064,7 +3089,7 @@ TEST_F(CudaBasicsTests1, execRandom_4) {
|
|||||||
devicePtrs[0]);
|
devicePtrs[0]);
|
||||||
|
|
||||||
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult);
|
||||||
z.syncToHost();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
for (int e = 0; e < z.lengthOf(); e++)
|
for (int e = 0; e < z.lengthOf(); e++)
|
||||||
|
@ -406,7 +406,7 @@ TEST_F(CudaBasicsTests2, mmulMxM_18) {
|
|||||||
|
|
||||||
nd4j::MmulHelper::mmul(&a, &b, &c, 1., 0.);
|
nd4j::MmulHelper::mmul(&a, &b, &c, 1., 0.);
|
||||||
|
|
||||||
ASSERT_TRUE(c.equalsTo(&exp));
|
ASSERT_TRUE(c.equalsTo(&exp, 1e-1));
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
@ -428,7 +428,7 @@ TEST_F(CudaBasicsTests2, mmulMxM_19) {
|
|||||||
|
|
||||||
nd4j::MmulHelper::mmul(&a, &b, &c, 1., 0.);
|
nd4j::MmulHelper::mmul(&a, &b, &c, 1., 0.);
|
||||||
|
|
||||||
ASSERT_TRUE(c.equalsTo(&exp));
|
ASSERT_TRUE(c.equalsTo(&exp, 1e-1));
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
@ -450,7 +450,7 @@ TEST_F(CudaBasicsTests2, mmulMxM_20) {
|
|||||||
|
|
||||||
nd4j::MmulHelper::mmul(&a, &b, &c, 1., 0.);
|
nd4j::MmulHelper::mmul(&a, &b, &c, 1., 0.);
|
||||||
|
|
||||||
ASSERT_TRUE(c.equalsTo(&exp));
|
ASSERT_TRUE(c.equalsTo(&exp, 1e-1));
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
@ -467,7 +467,6 @@ TEST_F(CudaBasicsTests2, mmulMxM_21) {
|
|||||||
NDArray exp('c', {M,N}, {-45., 43., -49., 53., -50., -97., 79., -101., 113., -90., -149., 115., -153., 173., -130.}, nd4j::DataType::DOUBLE);
|
NDArray exp('c', {M,N}, {-45., 43., -49., 53., -50., -97., 79., -101., 113., -90., -149., 115., -153., 173., -130.}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
nd4j::MmulHelper::mmul(&a, &b, &c, 1., 0.);
|
nd4j::MmulHelper::mmul(&a, &b, &c, 1., 0.);
|
||||||
// c.printBuffer();
|
|
||||||
|
|
||||||
ASSERT_TRUE(c.equalsTo(&exp));
|
ASSERT_TRUE(c.equalsTo(&exp));
|
||||||
}
|
}
|
||||||
@ -552,6 +551,7 @@ TEST_F(CudaBasicsTests2, mmulMxM_26) {
|
|||||||
const Nd4jLong K = 4;
|
const Nd4jLong K = 4;
|
||||||
const Nd4jLong N = 5;
|
const Nd4jLong N = 5;
|
||||||
|
|
||||||
|
// 3x4 * 4x5 = 3x5
|
||||||
NDArray a('c', {M,K}, {1.,2,3,4,5,6,7,8,9,10,11,12}, nd4j::DataType::INT64);
|
NDArray a('c', {M,K}, {1.,2,3,4,5,6,7,8,9,10,11,12}, nd4j::DataType::INT64);
|
||||||
NDArray b('c', {K,N}, {-2,-3,0,1,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, nd4j::DataType::FLOAT32);
|
NDArray b('c', {K,N}, {-2,-3,0,1,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, nd4j::DataType::FLOAT32);
|
||||||
NDArray c('c', {M,N}, nd4j::DataType::DOUBLE);
|
NDArray c('c', {M,N}, nd4j::DataType::DOUBLE);
|
||||||
@ -1097,7 +1097,7 @@ TEST_F(CudaBasicsTests2, mmulDot_1) {
|
|||||||
NDArray y('f', {N}, {0.1, 0.2, 0.3, 0.4}, nd4j::DataType::FLOAT32);
|
NDArray y('f', {N}, {0.1, 0.2, 0.3, 0.4}, nd4j::DataType::FLOAT32);
|
||||||
NDArray z(nd4j::DataType::DOUBLE);
|
NDArray z(nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
NDArray exp('c', {0}, {3}, nd4j::DataType::DOUBLE);
|
NDArray exp('c', {}, {3}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
nd4j::MmulHelper::mmul(&x, &y, &z);
|
nd4j::MmulHelper::mmul(&x, &y, &z);
|
||||||
ASSERT_TRUE(z.equalsTo(&exp));
|
ASSERT_TRUE(z.equalsTo(&exp));
|
||||||
@ -1112,7 +1112,7 @@ TEST_F(CudaBasicsTests2, mmulDot_2) {
|
|||||||
NDArray y('f', {1,1,N,1,1,1}, {0.1, 0.2, 0.3, 0.4}, nd4j::DataType::FLOAT32);
|
NDArray y('f', {1,1,N,1,1,1}, {0.1, 0.2, 0.3, 0.4}, nd4j::DataType::FLOAT32);
|
||||||
NDArray z(nd4j::DataType::DOUBLE);
|
NDArray z(nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
NDArray exp('c', {0}, {3}, nd4j::DataType::DOUBLE);
|
NDArray exp('c', {}, {3}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
nd4j::MmulHelper::mmul(&x, &y, &z);
|
nd4j::MmulHelper::mmul(&x, &y, &z);
|
||||||
ASSERT_TRUE(z.equalsTo(&exp));
|
ASSERT_TRUE(z.equalsTo(&exp));
|
||||||
@ -1129,7 +1129,7 @@ TEST_F(CudaBasicsTests2, mmulDot_3) {
|
|||||||
NDArray y = yBig(0, {1}, true);
|
NDArray y = yBig(0, {1}, true);
|
||||||
NDArray z(nd4j::DataType::DOUBLE);
|
NDArray z(nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
NDArray exp('c', {0}, {3}, nd4j::DataType::DOUBLE);
|
NDArray exp('c', {}, {3}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
nd4j::MmulHelper::mmul(&x, &y, &z);
|
nd4j::MmulHelper::mmul(&x, &y, &z);
|
||||||
ASSERT_TRUE(z.equalsTo(&exp));
|
ASSERT_TRUE(z.equalsTo(&exp));
|
||||||
@ -1146,7 +1146,7 @@ TEST_F(CudaBasicsTests2, mmulDot_4) {
|
|||||||
NDArray y = yBig(0, {1});
|
NDArray y = yBig(0, {1});
|
||||||
NDArray z(nd4j::DataType::DOUBLE);
|
NDArray z(nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
NDArray exp('c', {0}, {3}, nd4j::DataType::DOUBLE);
|
NDArray exp('c', {}, {3}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
nd4j::MmulHelper::mmul(&x, &y, &z);
|
nd4j::MmulHelper::mmul(&x, &y, &z);
|
||||||
ASSERT_TRUE(z.equalsTo(&exp));
|
ASSERT_TRUE(z.equalsTo(&exp));
|
||||||
|
@ -154,6 +154,78 @@ TEST_F(DeclarableOpsTests1, BasicInitialization2) {
|
|||||||
ASSERT_EQ(1, op->getOpDescriptor()->getNumberOfOutputs());
|
ASSERT_EQ(1, op->getOpDescriptor()->getNumberOfOutputs());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests1, ApplyGradientDescent_1) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {3,4}, {1,2,3,4,5,6,7,8,9,10,11,12});
|
||||||
|
auto y = NDArrayFactory::create<double>('c', {3,4}, {0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0,1.1,1.2});
|
||||||
|
auto exp = NDArrayFactory::create<double>('c', {3,4});
|
||||||
|
exp.linspace(0.9, 0.9);
|
||||||
|
nd4j::ops::apply_sgd op;
|
||||||
|
auto result = op.execute({&x, &y}, {1.}, {}, {}, false, nd4j::DataType::DOUBLE);
|
||||||
|
ASSERT_EQ(result->status(), ND4J_STATUS_OK);
|
||||||
|
auto z = result->at(0);
|
||||||
|
// result->at(0)->printIndexedBuffer("OUTPUT");
|
||||||
|
// result->at(0)->printShapeInfo("OUTPUT Shape");
|
||||||
|
// exp.printIndexedBuffer("EXPECT");
|
||||||
|
ASSERT_TRUE(z->equalsTo(exp));
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests1, AssignBroadcastTest_1) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {3,4}, {1,2,3,4,5,6,7,8,9,10,11,12});
|
||||||
|
auto y = NDArrayFactory::create<double>('c', {1,4}, {0.1,0.2,0.3,0.4});
|
||||||
|
auto exp = NDArrayFactory::create<double>('c', {3,4}, {0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4});
|
||||||
|
nd4j::ops::assign op;
|
||||||
|
auto result = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
|
||||||
|
ASSERT_EQ(result->status(), ND4J_STATUS_OK);
|
||||||
|
auto z = result->at(0);
|
||||||
|
// result->at(0)->printIndexedBuffer("OUTPUT");
|
||||||
|
// result->at(0)->printShapeInfo("OUTPUT Shape");
|
||||||
|
// exp.printIndexedBuffer("EXPECT");
|
||||||
|
ASSERT_TRUE(z->equalsTo(exp));
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests1, AssignBroadcastTest_2) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {3,4}, {1,2,3,4,5,6,7,8,9,10,11,12});
|
||||||
|
auto y = NDArrayFactory::create<double>('c', {1,4}, {0.1,0.2,0.3,0.4});
|
||||||
|
auto eps = NDArrayFactory::create<double>('c', {3,4}, {1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4});
|
||||||
|
auto exp1 = NDArrayFactory::create<double>('c', {3,4}); // zero
|
||||||
|
auto exp2 = NDArrayFactory::create<double>('c', {1,4}, {3, 6, 9, 12});
|
||||||
|
nd4j::ops::assign_bp op;
|
||||||
|
auto result = op.execute({&x, &y, &eps}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
|
||||||
|
ASSERT_EQ(result->status(), ND4J_STATUS_OK);
|
||||||
|
auto z1 = result->at(0);
|
||||||
|
auto z2 = result->at(1);
|
||||||
|
// z1->printIndexedBuffer("OUTPUT");
|
||||||
|
// z2->printIndexedBuffer("OUTPUT");
|
||||||
|
//
|
||||||
|
// exp1.printIndexedBuffer("EXPECT");
|
||||||
|
// exp2.printIndexedBuffer("EXPECT");
|
||||||
|
|
||||||
|
ASSERT_TRUE(z1->equalsTo(exp1));
|
||||||
|
ASSERT_TRUE(z2->equalsTo(exp2));
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests1, AXpY_Test_1) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {3,4}, {1,2,3,4,5,6,7,8,9,10,11,12});
|
||||||
|
auto y = NDArrayFactory::create<double>('c', {3,4}, {1,2,3,4,5,6,7,8,9,10,11,12});
|
||||||
|
auto exp = NDArrayFactory::create<double>('c', {3,4});
|
||||||
|
exp.linspace(3, 3);
|
||||||
|
nd4j::ops::axpy op;
|
||||||
|
auto result = op.execute({&x, &y}, {2.}, {}, {}, false, nd4j::DataType::DOUBLE);
|
||||||
|
ASSERT_EQ(result->status(), ND4J_STATUS_OK);
|
||||||
|
auto z = result->at(0);
|
||||||
|
// result->at(0)->printIndexedBuffer("OUTPUT");
|
||||||
|
// result->at(0)->printShapeInfo("OUTPUT Shape");
|
||||||
|
// exp.printIndexedBuffer("EXPECT");
|
||||||
|
ASSERT_TRUE(z->equalsTo(exp));
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests1, BasicInitialization3) {
|
TEST_F(DeclarableOpsTests1, BasicInitialization3) {
|
||||||
auto op1 = nd4j::ops::OpRegistrator::getInstance()->getOperation("concat");
|
auto op1 = nd4j::ops::OpRegistrator::getInstance()->getOperation("concat");
|
||||||
|
@ -1995,6 +1995,30 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_4) {
|
|||||||
delete results;
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests10, Image_CropAndResize_5) {
|
||||||
|
|
||||||
|
NDArray images('c', {1, 100, 100, 3});
|
||||||
|
NDArray boxes('c', {1,4}, {0,0,1,1}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray boxI('c', {2}, {1,1}, nd4j::DataType::INT32);
|
||||||
|
NDArray cropSize = NDArrayFactory::create<int>({10, 10});
|
||||||
|
|
||||||
|
//NDArray<float> ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f});
|
||||||
|
NDArray expected('c', {1, 10, 10,3}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
nd4j::ops::crop_and_resize op;
|
||||||
|
auto results = op.execute({&images, &boxes, &boxI, &cropSize}, {}, {1});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto result = results->at(0);
|
||||||
|
result->printShapeInfo("Cropped and Resized");
|
||||||
|
ASSERT_TRUE(expected.isSameShapeStrict(result));
|
||||||
|
//ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_1) {
|
TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_1) {
|
||||||
|
|
||||||
|
@ -879,7 +879,7 @@ TEST_F(DeclarableOpsTests12, pullRows_1) {
|
|||||||
Nd4jPointer nativeStart[2];
|
Nd4jPointer nativeStart[2];
|
||||||
|
|
||||||
#ifdef __CUDABLAS__
|
#ifdef __CUDABLAS__
|
||||||
nativeStart[1] = *(x.getContext()->getCudaStream());
|
nativeStart[1] = (x.getContext()->getCudaStream());
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
pullRows(nativeStart, x.buffer(), x.getShapeInfo(), x.getSpecialBuffer(), x.getSpecialShapeInfo(),
|
pullRows(nativeStart, x.buffer(), x.getShapeInfo(), x.getSpecialBuffer(), x.getSpecialShapeInfo(),
|
||||||
@ -913,7 +913,7 @@ TEST_F(DeclarableOpsTests12, pullRows_2) {
|
|||||||
|
|
||||||
Nd4jPointer nativeStart[2];
|
Nd4jPointer nativeStart[2];
|
||||||
#ifdef __CUDABLAS__
|
#ifdef __CUDABLAS__
|
||||||
nativeStart[1] = *(x.getContext()->getCudaStream());
|
nativeStart[1] = (x.getContext()->getCudaStream());
|
||||||
#endif
|
#endif
|
||||||
pullRows(nativeStart, x.buffer(), x.getShapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
pullRows(nativeStart, x.buffer(), x.getShapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
||||||
z.buffer(), z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
z.buffer(), z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
||||||
|
@ -161,11 +161,12 @@ TEST_F(DeclarableOpsTests13, test_eval_reduction_shape_1) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests13, test_or_1) {
|
TEST_F(DeclarableOpsTests13, test_or_1) {
|
||||||
auto x = NDArrayFactory::create<bool>('c', {4}, {false, true, false, true});
|
|
||||||
auto y = NDArrayFactory::create<bool>('c', {4}, {false, false, true, true});
|
|
||||||
auto e = NDArrayFactory::create<bool>('c', {4}, {false, true, true, true});
|
|
||||||
|
|
||||||
auto z = NDArrayFactory::create<bool>('c', {4});
|
NDArray x('c', {4}, {false, true, false, true}, nd4j::DataType::BOOL);
|
||||||
|
NDArray y('c', {4}, {false, false, true, true}, nd4j::DataType::BOOL);
|
||||||
|
NDArray e('c', {4}, {false, true, true, true}, nd4j::DataType::BOOL);
|
||||||
|
|
||||||
|
NDArray z('c', {4}, nd4j::DataType::BOOL);
|
||||||
|
|
||||||
x.applyPairwiseTransform(pairwise::Or, &y, &z, nullptr);
|
x.applyPairwiseTransform(pairwise::Or, &y, &z, nullptr);
|
||||||
|
|
||||||
@ -292,7 +293,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_3) {
|
|||||||
auto cols = NDArrayFactory::create<int>({4, 3, 10, 8, 6, 7, 1, 5, 9, 4, 9, 8, 10, 2, 0, 6, 7, 3, 6, 8, 3, 9, 10, 1, 4, 0, 5, 10, 0, 4, 6, 8, 9, 2, 5, 7, 0, 10, 3, 1, 8, 9, 6, 7, 2, 7, 9, 3, 10, 0, 4, 2, 8, 1, 2, 8, 3, 10, 0, 4, 9, 1, 5, 5, 9, 0, 3, 10, 4, 8, 1, 2, 6, 2, 0, 3, 4, 1, 10, 9, 7, 10, 1, 3, 7, 4, 5, 2, 8, 6, 3, 4, 0, 9, 6, 5, 8, 7, 1});
|
auto cols = NDArrayFactory::create<int>({4, 3, 10, 8, 6, 7, 1, 5, 9, 4, 9, 8, 10, 2, 0, 6, 7, 3, 6, 8, 3, 9, 10, 1, 4, 0, 5, 10, 0, 4, 6, 8, 9, 2, 5, 7, 0, 10, 3, 1, 8, 9, 6, 7, 2, 7, 9, 3, 10, 0, 4, 2, 8, 1, 2, 8, 3, 10, 0, 4, 9, 1, 5, 5, 9, 0, 3, 10, 4, 8, 1, 2, 6, 2, 0, 3, 4, 1, 10, 9, 7, 10, 1, 3, 7, 4, 5, 2, 8, 6, 3, 4, 0, 9, 6, 5, 8, 7, 1});
|
||||||
auto vals = NDArrayFactory::create<double>({0.6199614579042966, 0.19644097697184246, 0.13824979367331638, 0.01949900138247239, 0.008923198738222747, 0.008392793826291798, 0.0033348224714784204, 0.0026246189757042166, 0.0025733360563748838, 0.5877136110798608, 0.28250257562439585, 0.08098135424273815, 0.014862718272075049, 0.01219187321450782, 0.01152346362368888, 0.004243137936786281, 0.0034626999030188577, 0.0025185661029283168, 0.6777005651521399, 0.18321248222489303, 0.04018202465629351, 0.02941935889988646, 0.02164146250842832, 0.019898422145651618, 0.011683461395713935, 0.008439076090480863, 0.007823146926512332, 0.6770900431883232, 0.16617511239723026, 0.06039349887686468, 0.04650913399744179, 0.016886531410284355, 0.014591049666869658, 0.006407638669806174, 0.006074413005122801, 0.0058725787880570205, 0.6278185083409108, 0.235127797795446, 0.07023700015217448, 0.030885483448633774, 0.01229522088606573, 0.009238279699136107, 0.008219511168822047, 0.004303744819835723, 0.0018744536889749907, 0.7122603898978483, 0.07862620103245824, 0.07061257369349086, 0.06721483653169834, 0.028957853952131768, 0.01778978123182596, 0.01481713955181034, 0.005492728917348627, 0.0042284951913875955, 0.5266844101016999, 0.3304104787383107, 0.10930017433210941, 0.018514917515240075, 0.006969360999637938, 0.0063776901975396, 0.0010590388116165708, 6.526830884629785E-4, 3.1246215383067865E-5, 0.7176179284835663, 0.08741734015883978, 0.05927699083866909, 0.04663169573956976, 0.03287576269194147, 0.02993912340339554, 0.013365238657916641, 0.010616858763291145, 0.002259061262810172, 0.6891905160321706, 0.1397658294110526, 0.05438284759722162, 0.05437184733708826, 0.028683289714498808, 0.020986120697576355, 0.007218358114741088, 0.0032834770669826364, 0.002117714028667893, 0.6823873496503976, 0.1345267083671607, 0.08712863515505885, 0.04286621088946242, 0.02544804597749639, 0.01689343932533317, 0.007219134659004873, 0.0019232929717404616, 0.0016071830043453991, 0.6425809622897437, 0.18474464886441516, 0.10897036475298316, 0.03466939253836615, 0.013288054277817787, 0.005149178177380355, 0.0037974063158903518, 0.0037851733015991287, 0.0030148194818042273});
|
auto vals = NDArrayFactory::create<double>({0.6199614579042966, 0.19644097697184246, 0.13824979367331638, 0.01949900138247239, 0.008923198738222747, 0.008392793826291798, 0.0033348224714784204, 0.0026246189757042166, 0.0025733360563748838, 0.5877136110798608, 0.28250257562439585, 0.08098135424273815, 0.014862718272075049, 0.01219187321450782, 0.01152346362368888, 0.004243137936786281, 0.0034626999030188577, 0.0025185661029283168, 0.6777005651521399, 0.18321248222489303, 0.04018202465629351, 0.02941935889988646, 0.02164146250842832, 0.019898422145651618, 0.011683461395713935, 0.008439076090480863, 0.007823146926512332, 0.6770900431883232, 0.16617511239723026, 0.06039349887686468, 0.04650913399744179, 0.016886531410284355, 0.014591049666869658, 0.006407638669806174, 0.006074413005122801, 0.0058725787880570205, 0.6278185083409108, 0.235127797795446, 0.07023700015217448, 0.030885483448633774, 0.01229522088606573, 0.009238279699136107, 0.008219511168822047, 0.004303744819835723, 0.0018744536889749907, 0.7122603898978483, 0.07862620103245824, 0.07061257369349086, 0.06721483653169834, 0.028957853952131768, 0.01778978123182596, 0.01481713955181034, 0.005492728917348627, 0.0042284951913875955, 0.5266844101016999, 0.3304104787383107, 0.10930017433210941, 0.018514917515240075, 0.006969360999637938, 0.0063776901975396, 0.0010590388116165708, 6.526830884629785E-4, 3.1246215383067865E-5, 0.7176179284835663, 0.08741734015883978, 0.05927699083866909, 0.04663169573956976, 0.03287576269194147, 0.02993912340339554, 0.013365238657916641, 0.010616858763291145, 0.002259061262810172, 0.6891905160321706, 0.1397658294110526, 0.05438284759722162, 0.05437184733708826, 0.028683289714498808, 0.020986120697576355, 0.007218358114741088, 0.0032834770669826364, 0.002117714028667893, 0.6823873496503976, 0.1345267083671607, 0.08712863515505885, 0.04286621088946242, 0.02544804597749639, 0.01689343932533317, 0.007219134659004873, 0.0019232929717404616, 0.0016071830043453991, 0.6425809622897437, 0.18474464886441516, 0.10897036475298316, 0.03466939253836615, 0.013288054277817787, 0.005149178177380355, 0.0037974063158903518, 0.0037851733015991287, 0.0030148194818042273});
|
||||||
//auto buf = NDArrayFactory::create<double>('c', {4});
|
//auto buf = NDArrayFactory::create<double>('c', {4});
|
||||||
auto exp = NDArrayFactory::create<double>('c', {11, 5}, {-0.08182565030695285, -0.10231113628399446, 0.016815534365147027, 0.16174900250174604, -0.2069849599383414, -0.12637623042629828, 0.10991761828249218, 0.13982379012581797, -0.09160092204813117, 0.09219561399020912, 0.14251517534534905, 0.014713798033084492, 0.1978861666999472, -0.25244458878217496, -0.0183980318957791, 0.13649108861674678, 0.07642892434591711, -0.07614804543349199, 0.12922677007082004, -0.19230554501535452, -0.1125183370752973, -0.0959552766032053, 0.014909543143622344, 0.018856765542142554, 0.19992319593641855, 0.3024116381613982, -0.18827088592810545, 0.10219412073880345, -0.09701789574674309, -0.00327343943101904, 0.15208047075206382, -0.024040804138184012, -0.13907523297518, 0.30082909806368757, 0.17454904535785062, -0.315679421513792, 0.1422789919281073, -0.08984704554824749, 0.011189110407313801, -0.1073173666189425, -0.24925394718449986, 0.10762857570027974, 0.034332424313159, 3.347586409494324E-4, -0.17491784809038768, 0.0711742983812613, 0.15171952090213991, -0.0888982509512525, -0.20577777883552498, 0.02762112109359763, 0.08098091525002123, -0.19210932623853155, -0.11199300447832489, 0.02465568018591474, 0.20890821461174836});
|
auto exp = NDArrayFactory::create<double>('c', {11, 5}, {-0.080205, -0.085862, 0.024045, 0.133551, -0.199896, -0.170597, 0.187301, 0.205824, -0.165268, 0.131228, 0.155135, 0.021446, 0.217583, -0.262873, -0.021075, 0.114537, 0.088023, -0.039205, 0.087984, -0.179565, -0.132683, 0.003677, 0.072081, -0.068737, 0.204481, 0.287223, -0.193989, 0.104569, -0.123401, -0.036368, 0.086745, 0.002961, -0.091327, 0.234853, 0.120270, -0.304006, 0.128305, -0.084867, -0.017550, -0.130837, -0.288569, 0.124679, 0.054078, -0.034187, -0.192599, 0.033196, 0.228182, -0.044972, -0.314217, 0.020287, 0.054427, -0.078887, -0.078246, -0.104543, 0.169803});
|
||||||
//auto exp2 = NDArrayFactory::create<double>({-4., -4., -4., -4.
|
//auto exp2 = NDArrayFactory::create<double>({-4., -4., -4., -4.
|
||||||
//std::vector<NDArray*> exp({&exp1, &exp2});
|
//std::vector<NDArray*> exp({&exp1, &exp2});
|
||||||
//data.assign(1.0); //linspace(1);
|
//data.assign(1.0); //linspace(1);
|
||||||
@ -308,7 +309,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_3) {
|
|||||||
result->at(0)->printBuffer("Output");
|
result->at(0)->printBuffer("Output");
|
||||||
exp.printBuffer("Expect");
|
exp.printBuffer("Expect");
|
||||||
//result->at(0)->printShapeInfo("Shape output");
|
//result->at(0)->printShapeInfo("Shape output");
|
||||||
//ASSERT_TRUE(exp.equalsTo(result->at(0)));
|
ASSERT_TRUE(exp.equalsTo(result->at(0)));
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -265,8 +265,7 @@ TEST_F(DeclarableOpsTests14, test_empty_stack_1) {
|
|||||||
auto res2 = sumOp.execute({&e}, {1.}, {1});
|
auto res2 = sumOp.execute({&e}, {1.}, {1});
|
||||||
ASSERT_EQ(res2->status(), Status::OK());
|
ASSERT_EQ(res2->status(), Status::OK());
|
||||||
auto out = res2->at(0);
|
auto out = res2->at(0);
|
||||||
out->printShapeInfo("ReduceSum empty shape with keep dims");
|
|
||||||
out->printIndexedBuffer("ReduceSum scalar");
|
|
||||||
ASSERT_EQ(out->e<float>(0), DataTypeUtils::infOrMax<float>());
|
ASSERT_EQ(out->e<float>(0), DataTypeUtils::infOrMax<float>());
|
||||||
delete res2;
|
delete res2;
|
||||||
delete result;
|
delete result;
|
||||||
|
@ -72,9 +72,9 @@ TEST_F(DeclarableOpsTests15, Test_Half_assign_1) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests15, test_avgpooling_edge_1) {
|
TEST_F(DeclarableOpsTests15, test_avgpooling_edge_1) {
|
||||||
int inOutH = 35;
|
int inOutH = 5;// 35;
|
||||||
int inOutW = 35;
|
int inOutW = 5;// 35;
|
||||||
int inOutC = 192;
|
int inOutC = 10;// 192;
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {1, inOutH, inOutW, inOutC});
|
auto x = NDArrayFactory::create<double>('c', {1, inOutH, inOutW, inOutC});
|
||||||
x.linspace(1.0);
|
x.linspace(1.0);
|
||||||
@ -273,10 +273,12 @@ TEST_F(DeclarableOpsTests15, test_hashCode_1) {
|
|||||||
y.linspace(2.);
|
y.linspace(2.);
|
||||||
|
|
||||||
nd4j::ops::hashcode op;
|
nd4j::ops::hashcode op;
|
||||||
auto resultA0 = op.execute({&x}, {}, {});
|
auto resultA0 = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::INT64);
|
||||||
auto resultA1 = op.execute({&x}, {}, {});
|
auto resultA1 = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::INT64);
|
||||||
auto resultB0 = op.execute({&y}, {}, {});
|
auto resultB0 = op.execute({&y}, {}, {}, {}, false, nd4j::DataType::INT64);
|
||||||
|
// resultA0->at(0)->printIndexedBuffer("A0");
|
||||||
|
// resultA1->at(0)->printIndexedBuffer("A1");
|
||||||
|
// resultB0->at(0)->printIndexedBuffer("B0");
|
||||||
ASSERT_EQ(*resultA0->at(0), *resultA1->at(0));
|
ASSERT_EQ(*resultA0->at(0), *resultA1->at(0));
|
||||||
ASSERT_NE(*resultA0->at(0), *resultB0->at(0));
|
ASSERT_NE(*resultA0->at(0), *resultB0->at(0));
|
||||||
|
|
||||||
@ -293,9 +295,13 @@ TEST_F(DeclarableOpsTests15, test_hashCode_2) {
|
|||||||
y.linspace(2.);
|
y.linspace(2.);
|
||||||
|
|
||||||
nd4j::ops::hashcode op;
|
nd4j::ops::hashcode op;
|
||||||
auto resultA0 = op.execute({&x}, {}, {});
|
auto resultA0 = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::INT64);
|
||||||
auto resultA1 = op.execute({&x}, {}, {});
|
auto resultA1 = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::INT64);
|
||||||
auto resultB0 = op.execute({&y}, {}, {});
|
auto resultB0 = op.execute({&y}, {}, {}, {}, false, nd4j::DataType::INT64);
|
||||||
|
|
||||||
|
// resultA0->at(0)->printIndexedBuffer("A0");
|
||||||
|
// resultA1->at(0)->printIndexedBuffer("A1");
|
||||||
|
// resultB0->at(0)->printIndexedBuffer("B0");
|
||||||
|
|
||||||
ASSERT_EQ(*resultA0->at(0), *resultA1->at(0));
|
ASSERT_EQ(*resultA0->at(0), *resultA1->at(0));
|
||||||
ASSERT_NE(*resultA0->at(0), *resultB0->at(0));
|
ASSERT_NE(*resultA0->at(0), *resultB0->at(0));
|
||||||
@ -374,25 +380,25 @@ TEST_F(DeclarableOpsTests15, test_lstmBlock_1) {
|
|||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
|
||||||
z->printIndexedBuffer("Z");
|
// z->printIndexedBuffer("Z");
|
||||||
|
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests15, test_lstmBlock_2) {
|
TEST_F(DeclarableOpsTests15, test_lstmBlock_2) {
|
||||||
int seqLength = 32;
|
int seqLen = 32;
|
||||||
int m = 64;
|
int bS = 64;
|
||||||
int n = 32;
|
int nIn = 32;
|
||||||
|
|
||||||
auto x0 = NDArrayFactory::create<Nd4jLong>(5);
|
auto x0 = NDArrayFactory::create<Nd4jLong>(5);
|
||||||
auto x1 = NDArrayFactory::create<float>('f', {m, n, seqLength});
|
auto x1 = NDArrayFactory::create<float>('f', {bS, nIn, seqLen});
|
||||||
auto x2 = NDArrayFactory::create<float>('f', {m, n});
|
auto x2 = NDArrayFactory::create<float>('f', {bS, nIn}); // nIn == nOut
|
||||||
auto x3 = NDArrayFactory::create<float>('f', {m, n});
|
auto x3 = NDArrayFactory::create<float>('f', {bS, nIn});
|
||||||
auto x4 = NDArrayFactory::create<float>('f', {2 * n, 4 * n});
|
auto x4 = NDArrayFactory::create<float>('f', {2 * nIn, 4 * nIn});
|
||||||
auto x5 = NDArrayFactory::create<float>('f', {n});
|
auto x5 = NDArrayFactory::create<float>('f', {nIn});
|
||||||
auto x6 = NDArrayFactory::create<float>('f', {n});
|
auto x6 = NDArrayFactory::create<float>('f', {nIn});
|
||||||
auto x7 = NDArrayFactory::create<float>('f', {n});
|
auto x7 = NDArrayFactory::create<float>('f', {nIn});
|
||||||
auto x8 = NDArrayFactory::create<float>('f', {4 * n});
|
auto x8 = NDArrayFactory::create<float>('f', {4 * nIn});
|
||||||
|
|
||||||
nd4j::ops::lstmBlock op;
|
nd4j::ops::lstmBlock op;
|
||||||
auto result = op.execute({&x0, &x1, &x2, &x3, &x4, &x5, &x6, &x7, &x8}, {1.0, 0.0}, {0, 1});
|
auto result = op.execute({&x0, &x1, &x2, &x3, &x4, &x5, &x6, &x7, &x8}, {1.0, 0.0}, {0, 1});
|
||||||
@ -402,3 +408,29 @@ TEST_F(DeclarableOpsTests15, test_lstmBlock_2) {
|
|||||||
|
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests15, test_lstmBlock_3) {
|
||||||
|
|
||||||
|
int seqLen = 3;
|
||||||
|
int bS = 2;
|
||||||
|
int nIn = 4;
|
||||||
|
|
||||||
|
NDArray f('f', {bS, nIn, seqLen}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray cLast('f', {bS, nIn}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
f = 2;
|
||||||
|
cLast = 3;
|
||||||
|
|
||||||
|
for (int t = 0; t < seqLen; ++t) {
|
||||||
|
|
||||||
|
//section 1
|
||||||
|
//auto ft = f({0,0, 0,0, t,t+1});
|
||||||
|
//auto temp = ft * cLast;
|
||||||
|
|
||||||
|
|
||||||
|
// section 2
|
||||||
|
auto ft = f({0,0, 0,0, t,t+1});
|
||||||
|
auto temp1 = ft.reshape('f', {bS, nIn});
|
||||||
|
auto temp2 = temp1 * cLast;
|
||||||
|
}
|
||||||
|
}
|
@ -306,6 +306,73 @@ TEST_F(DeclarableOpsTests2, gather_13) {
|
|||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests2, BroadcastGradientArgs_1) {
|
||||||
|
|
||||||
|
NDArray input ('c', {3,3,4}, {1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12, 13,14,15,16,17,18,19,20,21,22,23,24, 25,26,27,28,29,30,31,32,33,34,35,36}, nd4j::DataType::INT32);
|
||||||
|
NDArray indices ('c', {2,3}, {0, 1, 2, 2, 1,2}, nd4j::DataType::INT32);
|
||||||
|
|
||||||
|
nd4j::ops::broadcastgradientargs op;
|
||||||
|
|
||||||
|
auto result = op.execute({&input, &indices}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_KERNEL_FAILURE, result->status());
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests2, NLP_Cbow_Test_1) {
|
||||||
|
auto exp0 = NDArrayFactory::create<double>('c', {1, 10});
|
||||||
|
auto exp1 = NDArrayFactory::create<double>('c', {1, 10});
|
||||||
|
auto exp2 = NDArrayFactory::create<double>('c', {1, 10});
|
||||||
|
|
||||||
|
exp0.assign(0.0095);
|
||||||
|
exp1.assign(0.019875);
|
||||||
|
exp2.assign(0.02);
|
||||||
|
|
||||||
|
auto target = NDArrayFactory::create<int>(0);
|
||||||
|
auto ngStarter = NDArrayFactory::empty<int>();
|
||||||
|
auto context = NDArrayFactory::create<int>('c', {3}, {0, 1, 2});
|
||||||
|
auto locked = NDArrayFactory::create<int>('c', {3});
|
||||||
|
auto indices = NDArrayFactory::create<int>('c', {2}, {4, 5});
|
||||||
|
auto codes = NDArrayFactory::create<int8_t>('c', {2}, {1, 1});
|
||||||
|
auto syn0 = NDArrayFactory::create<double>('c', {100, 10});
|
||||||
|
auto syn1 = NDArrayFactory::create<double>('c', {100, 10});
|
||||||
|
auto syn1Neg = NDArrayFactory::empty<double>();
|
||||||
|
auto expTable = NDArrayFactory::create<double>('c', {10000});
|
||||||
|
auto negTable = NDArrayFactory::empty<double>();
|
||||||
|
auto numWords = NDArrayFactory::create<int>('c', {1}, {1});
|
||||||
|
|
||||||
|
syn0.assign(0.01);
|
||||||
|
syn1.assign(0.02);
|
||||||
|
expTable.assign(0.5);
|
||||||
|
|
||||||
|
auto alpha = NDArrayFactory::create<double>(0.025);
|
||||||
|
auto randomValue = NDArrayFactory::create<Nd4jLong>(2L);
|
||||||
|
auto inferenceVector = NDArrayFactory::empty<double>();
|
||||||
|
|
||||||
|
nd4j::ops::cbow op;
|
||||||
|
auto result = op.execute({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {}, {true}, true);
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto row_s0_0 = syn0({0,1, 0,0}, true);
|
||||||
|
auto row_s0_1 = syn0({1,2, 0,0}, true);
|
||||||
|
auto row_s0_2 = syn0({2,3, 0,0}, true);
|
||||||
|
|
||||||
|
auto row_s1_4 = syn1({4,5, 0,0}, true);
|
||||||
|
auto row_s1_5 = syn1({5,6, 0,0}, true);
|
||||||
|
auto row_s1_6 = syn1({6,7, 0,0}, true);
|
||||||
|
|
||||||
|
ASSERT_EQ(exp0, row_s0_0);
|
||||||
|
ASSERT_EQ(exp0, row_s0_1);
|
||||||
|
ASSERT_EQ(exp0, row_s0_2);
|
||||||
|
|
||||||
|
ASSERT_EQ(exp1, row_s1_4);
|
||||||
|
ASSERT_EQ(exp1, row_s1_5);
|
||||||
|
ASSERT_EQ(exp2, row_s1_6);
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests2, Test_Concat_3D_1) {
|
TEST_F(DeclarableOpsTests2, Test_Concat_3D_1) {
|
||||||
auto x0 = NDArrayFactory::create<double>('c', {1, 100, 150});
|
auto x0 = NDArrayFactory::create<double>('c', {1, 100, 150});
|
||||||
auto x1 = NDArrayFactory::create<double>('c', {1, 100, 150});
|
auto x1 = NDArrayFactory::create<double>('c', {1, 100, 150});
|
||||||
@ -412,6 +479,48 @@ TEST_F(DeclarableOpsTests2, Test_FloorMod_1) {
|
|||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests2, Test_FloorDiv_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {1, 3}, {3.0, 6.0, -3.0});
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {1, 3}, {-2.0, 2.0, -2.0});
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {1, 3}, {-2., 3., 1.,});
|
||||||
|
|
||||||
|
nd4j::ops::floordiv op;
|
||||||
|
|
||||||
|
auto result = op.execute({&x, &y}, {}, {});
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
// z->printShapeInfo("FloorDiv1 shape");
|
||||||
|
// z->printIndexedBuffer("FloorDiv1");
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests2, Test_FloorDiv_2) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {1, 3}, {3.0, 6.0, -3.0});
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {1, 3}, {-2.0, 2.0, -2.0});
|
||||||
|
auto eps = NDArrayFactory::create<float>('c', {1, 3}, {1, 2, 3});
|
||||||
|
auto exp1 = NDArrayFactory::create<float>('c', {1, 3}, {1, 2., 3});
|
||||||
|
auto exp2 = NDArrayFactory::create<float>('c', {1, 3}, {-0, -2., 3});
|
||||||
|
|
||||||
|
nd4j::ops::floordiv_bp op;
|
||||||
|
|
||||||
|
auto result = op.execute({&x, &y, &eps}, {}, {});
|
||||||
|
ASSERT_EQ(result->status(), Status::OK());
|
||||||
|
auto z1 = result->at(0);
|
||||||
|
auto z2 = result->at(1);
|
||||||
|
// z->printShapeInfo("FloorDiv1 shape");
|
||||||
|
// z1->printIndexedBuffer("FloorDiv2_1");
|
||||||
|
// z2->printIndexedBuffer("FloorDiv2_2");
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp1.equalsTo(z1));
|
||||||
|
ASSERT_TRUE(exp2.equalsTo(z2));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests2, Test_CRelu_1) {
|
TEST_F(DeclarableOpsTests2, Test_CRelu_1) {
|
||||||
auto x = NDArrayFactory::create<float>('c', {2, 2}, {1.0, 2.0, 3.0, 4.0});
|
auto x = NDArrayFactory::create<float>('c', {2, 2}, {1.0, 2.0, 3.0, 4.0});
|
||||||
auto exp = NDArrayFactory::create<float>('c', {2, 4}, {1.0, 2.0, 0, 0, 3.0, 4.0, 0, 0});
|
auto exp = NDArrayFactory::create<float>('c', {2, 4}, {1.0, 2.0, 0, 0, 3.0, 4.0, 0, 0});
|
||||||
|
@ -298,8 +298,6 @@ TEST_F(DeclarableOpsTests3, Test_ClipByNorm_1) {
|
|||||||
auto result = op.execute({&x}, {4.0}, {});
|
auto result = op.execute({&x}, {4.0}, {});
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
z->printIndexedBuffer("CBN1");
|
|
||||||
exp.printIndexedBuffer("EXP1");
|
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
@ -315,7 +313,6 @@ TEST_F(DeclarableOpsTests3, Test_ClipByNorm_2) {
|
|||||||
auto result = op.execute({&x}, {6.0}, {});
|
auto result = op.execute({&x}, {6.0}, {});
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
z->printIndexedBuffer("CBN2");
|
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
@ -323,6 +320,38 @@ TEST_F(DeclarableOpsTests3, Test_ClipByNorm_2) {
|
|||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests3, Test_ClipByNorm_3) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {3, 5});
|
||||||
|
auto unities = NDArrayFactory::create<double>('c', {3, 1}, {1., 1., 1.});
|
||||||
|
auto scale = NDArrayFactory::create<double>('c', {3, 1}, {1.1, 1., 0.9});
|
||||||
|
|
||||||
|
x.linspace(100.);
|
||||||
|
|
||||||
|
auto xNorm1 = x.reduceAlongDims(reduce::Norm2, {1}, true);
|
||||||
|
x /= xNorm1;
|
||||||
|
xNorm1 = x.reduceAlongDims(reduce::Norm2,{1}, true);
|
||||||
|
|
||||||
|
ASSERT_TRUE(unities.isSameShape(xNorm1));
|
||||||
|
ASSERT_TRUE(unities.equalsTo(xNorm1));
|
||||||
|
|
||||||
|
x *= scale;
|
||||||
|
xNorm1 = x.reduceAlongDims(reduce::Norm2, {1}, true);
|
||||||
|
|
||||||
|
nd4j::ops::clipbynorm op;
|
||||||
|
auto result = op.execute({&x}, {1.0}, {1}, {}, false, nd4j::DataType::DOUBLE);
|
||||||
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
auto zNorm1 = z->reduceAlongDims(reduce::Norm2, {1}, true);
|
||||||
|
auto exp = NDArrayFactory::create<double>('c', {3, 1}, {1., 1., xNorm1.e<double>(2)});
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(&zNorm1));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(&zNorm1));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests3, Test_ListDiff_1) {
|
TEST_F(DeclarableOpsTests3, Test_ListDiff_1) {
|
||||||
auto x= NDArrayFactory::create<float>('c', {6}, {1, 2, 3, 4, 5, 6});
|
auto x= NDArrayFactory::create<float>('c', {6}, {1, 2, 3, 4, 5, 6});
|
||||||
auto y= NDArrayFactory::create<float>('c', {3}, {1, 3, 5});
|
auto y= NDArrayFactory::create<float>('c', {3}, {1, 3, 5});
|
||||||
@ -338,6 +367,9 @@ TEST_F(DeclarableOpsTests3, Test_ListDiff_1) {
|
|||||||
auto z0 = result->at(0);
|
auto z0 = result->at(0);
|
||||||
auto z1 = result->at(1);
|
auto z1 = result->at(1);
|
||||||
|
|
||||||
|
z0->getDataBuffer()->syncToSpecial(true); // force sync
|
||||||
|
z1->getDataBuffer()->syncToSpecial(true); // force sync
|
||||||
|
|
||||||
ASSERT_TRUE(exp0.isSameShape(z0));
|
ASSERT_TRUE(exp0.isSameShape(z0));
|
||||||
ASSERT_TRUE(exp0.equalsTo(z0));
|
ASSERT_TRUE(exp0.equalsTo(z0));
|
||||||
|
|
||||||
@ -2746,12 +2778,277 @@ TEST_F(DeclarableOpsTests3, svd_test11) {
|
|||||||
delete results;
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests3, elu_test1) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {3,3}, {0.1, .2, .3, -.4,-.5,-.6, .7, .8, .9});
|
||||||
|
// auto expS = NDArrayFactory::create<double>('c', {3});
|
||||||
|
// auto expU = NDArrayFactory::create<double>('c', {3,3});
|
||||||
|
auto exp = NDArrayFactory::create<double>('c', {3,3}, {.1, .2, .3, -0.32968, -0.393469, -0.451188, .7, .8, .9});
|
||||||
|
|
||||||
|
nd4j::ops::elu op;
|
||||||
|
auto results = op.execute({&x}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto s = results->at(0);
|
||||||
|
// auto u = results->at(1);
|
||||||
|
// auto v = results->at(2);
|
||||||
|
// s->printIndexedBuffer("ELU");
|
||||||
|
ASSERT_TRUE(exp.equalsTo(s));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests3, elu_test2) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {3, 3}, {0.1, .2, .3, -.4, -.5, -.6, .7, .8, .9});
|
||||||
|
auto eps = NDArrayFactory::create<double>('c', {3,3});
|
||||||
|
eps.assign(2.);
|
||||||
|
// auto expU = NDArrayFactory::create<double>('c', {3,3});
|
||||||
|
auto exp = NDArrayFactory::create<double>('c', {3, 3}, {2, 2, 2, 1.34064, 1.213061, 1.097623, 2, 2, 2});
|
||||||
|
|
||||||
|
nd4j::ops::elu_bp op;
|
||||||
|
auto results = op.execute({ &x, &eps }, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto s = results->at(0);
|
||||||
|
// auto u = results->at(1);
|
||||||
|
// auto v = results->at(2);
|
||||||
|
// s->printIndexedBuffer("ELU_BP");
|
||||||
|
ASSERT_TRUE(exp.equalsTo(s));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests3, lrelu_test1) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {3,3}, {1, 2, 3, -4,-5,-6, 7, 8, 9});
|
||||||
|
// auto expS = NDArrayFactory::create<double>('c', {3});
|
||||||
|
// auto expU = NDArrayFactory::create<double>('c', {3,3});
|
||||||
|
auto exp = NDArrayFactory::create<double>('c', {3,3}, {1, 2, 3, -0.8, -1., -1.2, 7, 8, 9});
|
||||||
|
|
||||||
|
nd4j::ops::lrelu op;
|
||||||
|
auto results = op.execute({&x}, {0.2}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto s = results->at(0);
|
||||||
|
// auto u = results->at(1);
|
||||||
|
// auto v = results->at(2);
|
||||||
|
// s->printIndexedBuffer("LRELU");
|
||||||
|
ASSERT_TRUE(exp.equalsTo(s));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests3, lrelu_test2) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {3,3}, {1, 2, 3, -4,-5,-6, 7, 8, 9});
|
||||||
|
// auto expS = NDArrayFactory::create<double>('c', {3});
|
||||||
|
auto eps = NDArrayFactory::create<double>('c', {3,3}, {2,2,2,2,2,2,2, 2,2});
|
||||||
|
auto exp = NDArrayFactory::create<double>('c', {3,3}, {2, 2, 2, 0, 0, 0, 2, 2, 2});
|
||||||
|
|
||||||
|
nd4j::ops::lrelu_bp op;
|
||||||
|
auto results = op.execute({&x, &eps}, {0.2}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto s = results->at(0);
|
||||||
|
// auto u = results->at(1);
|
||||||
|
// auto v = results->at(2);
|
||||||
|
// s->printIndexedBuffer("LRELU_BP");
|
||||||
|
ASSERT_TRUE(exp.equalsTo(s));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests3, selu_test1) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {3,3}, {1, 2, 3, -4,-5,-6, 7, 8, 9});
|
||||||
|
// auto expS = NDArrayFactory::create<double>('c', {3});
|
||||||
|
// auto expU = NDArrayFactory::create<double>('c', {3,3});
|
||||||
|
auto exp = NDArrayFactory::create<double>('c', {3,3}, {1.050701, 2.101402, 3.152103, -1.725899, -1.746253, -1.753742, 7.354907, 8.405608, 9.456309});
|
||||||
|
|
||||||
|
nd4j::ops::selu op;
|
||||||
|
auto results = op.execute({&x}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto s = results->at(0);
|
||||||
|
// s->printIndexedBuffer("SELU");
|
||||||
|
ASSERT_TRUE(exp.equalsTo(s));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests3, selu_test2) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {3,3}, {1, 2, 3, -4,-5,-6, 7, 8, 9});
|
||||||
|
// auto expS = NDArrayFactory::create<double>('c', {3});
|
||||||
|
auto eps = NDArrayFactory::create<double>('c', {3,3}, {2,2,2,2,2,2,2, 2,2});
|
||||||
|
auto exp = NDArrayFactory::create<double>('c', {3,3}, {2.101401, 2.101402, 2.101402, 0.064401, 0.023692, 0.008716, 2.101402, 2.101402, 2.101402});
|
||||||
|
|
||||||
|
nd4j::ops::selu_bp op;
|
||||||
|
auto results = op.execute({&x, &eps}, {0.2}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto s = results->at(0);
|
||||||
|
// auto u = results->at(1);
|
||||||
|
// auto v = results->at(2);
|
||||||
|
// s->printIndexedBuffer("SELU_BP");
|
||||||
|
ASSERT_TRUE(exp.equalsTo(s));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests3, EQScalarTests_1) {
|
||||||
|
Graph graph;
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create(1.0f);
|
||||||
|
auto scalar = NDArrayFactory::create(1.0f);
|
||||||
|
|
||||||
|
nd4j::ops::eq_scalar op;
|
||||||
|
auto res = op.evaluate({&x, &scalar});
|
||||||
|
ASSERT_TRUE(res);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests3, EQScalarTests_2) {
|
||||||
|
Graph graph;
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create(2.0f);
|
||||||
|
auto scalar = NDArrayFactory::create(1.0f);
|
||||||
|
|
||||||
|
nd4j::ops::eq_scalar op;
|
||||||
|
auto res = op.evaluate({&x, &scalar});
|
||||||
|
ASSERT_FALSE(res);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests3, GTScalarTests_1) {
|
||||||
|
Graph graph;
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create(1.0f);
|
||||||
|
auto scalar = NDArrayFactory::create(1.0f);
|
||||||
|
|
||||||
|
nd4j::ops::gt_scalar op;
|
||||||
|
auto res = op.evaluate({&x, &scalar});
|
||||||
|
ASSERT_FALSE(res);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests3, GTScalarTests_2) {
|
||||||
|
Graph graph;
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create(2.0f);
|
||||||
|
auto scalar = NDArrayFactory::create(1.0f);
|
||||||
|
|
||||||
|
nd4j::ops::gt_scalar op;
|
||||||
|
auto res = op.evaluate({&x, &scalar});
|
||||||
|
ASSERT_TRUE(res);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests3, GTEScalarTests_1) {
|
||||||
|
Graph graph;
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create(1.0f);
|
||||||
|
auto scalar = NDArrayFactory::create(1.0f);
|
||||||
|
|
||||||
|
nd4j::ops::gte_scalar op;
|
||||||
|
auto res = op.evaluate({&x, &scalar});
|
||||||
|
ASSERT_TRUE(res);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests3, GTEScalarTests_2) {
|
||||||
|
Graph graph;
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create(2.0f);
|
||||||
|
auto scalar = NDArrayFactory::create(1.0f);
|
||||||
|
|
||||||
|
nd4j::ops::gte_scalar op;
|
||||||
|
auto res = op.evaluate({&x, &scalar});
|
||||||
|
ASSERT_TRUE(res);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests3, GTEScalarTests_3) {
|
||||||
|
Graph graph;
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create(1.0f);
|
||||||
|
auto scalar = NDArrayFactory::create(2.0f);
|
||||||
|
|
||||||
|
nd4j::ops::gte_scalar op;
|
||||||
|
auto res = op.evaluate({&x, &scalar});
|
||||||
|
ASSERT_FALSE(res);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests3, LTEScalarTests_1) {
|
||||||
|
Graph graph;
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create(1.0f);
|
||||||
|
auto scalar = NDArrayFactory::create(1.0f);
|
||||||
|
|
||||||
|
nd4j::ops::lte_scalar op;
|
||||||
|
auto res = op.evaluate({&x, &scalar});
|
||||||
|
ASSERT_TRUE(res);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests3, LTEScalarTests_2) {
|
||||||
|
Graph graph;
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create(2.0f);
|
||||||
|
auto scalar = NDArrayFactory::create(1.0f);
|
||||||
|
|
||||||
|
nd4j::ops::lte_scalar op;
|
||||||
|
auto res = op.evaluate({&x, &scalar});
|
||||||
|
ASSERT_FALSE(res);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests3, LTEScalarTests_3) {
|
||||||
|
Graph graph;
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create(1.0f);
|
||||||
|
auto scalar = NDArrayFactory::create(2.0f);
|
||||||
|
|
||||||
|
nd4j::ops::lte_scalar op;
|
||||||
|
auto res = op.evaluate({&x, &scalar});
|
||||||
|
ASSERT_TRUE(res);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests3, NEQScalarTests_1) {
|
||||||
|
Graph graph;
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create(1.0f);
|
||||||
|
auto scalar = NDArrayFactory::create(1.0f);
|
||||||
|
|
||||||
|
nd4j::ops::neq_scalar op;
|
||||||
|
auto res = op.evaluate({&x, &scalar});
|
||||||
|
ASSERT_FALSE(res);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests3, NEQScalarTests_2) {
|
||||||
|
Graph graph;
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create(2.0f);
|
||||||
|
auto scalar = NDArrayFactory::create(1.0f);
|
||||||
|
|
||||||
|
nd4j::ops::neq_scalar op;
|
||||||
|
auto res = op.evaluate({&x, &scalar});
|
||||||
|
ASSERT_TRUE(res);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests3, NOOPTests_1) {
|
||||||
|
Graph graph;
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create(2.0f);
|
||||||
|
auto scalar = NDArrayFactory::create(1.0f);
|
||||||
|
|
||||||
|
nd4j::ops::noop op;
|
||||||
|
auto res = op.execute({&x, &scalar}, {}, {});
|
||||||
|
ASSERT_TRUE(res->status() == nd4j::Status::OK());
|
||||||
|
delete res;
|
||||||
|
}
|
||||||
|
@ -298,6 +298,86 @@ TEST_F(DeclarableOpsTests4, Test_Fill_1) {
|
|||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests4, Test_FirasSparce_1) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {1, 81});
|
||||||
|
auto exp = NDArrayFactory::create<double>('c', {1, 2}, {0, 1});
|
||||||
|
|
||||||
|
x.p(51, 1);
|
||||||
|
x.p(52, 0);
|
||||||
|
x.p(60, 1);
|
||||||
|
x.p(61, 0);
|
||||||
|
nd4j::ops::firas_sparse op;
|
||||||
|
auto result = op.execute({&x}, {}, {0, 1});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
// z->printIndexedBuffer("FIRAS");
|
||||||
|
// z->printShapeInfo("OUTSHAPE");
|
||||||
|
// ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests4, Test_FlattenTests_1) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {3, 3, 3, 3});
|
||||||
|
auto exp = NDArrayFactory::create<double>('c', {81});
|
||||||
|
|
||||||
|
x.linspace(1);
|
||||||
|
exp.linspace(1);
|
||||||
|
nd4j::ops::flatten op;
|
||||||
|
auto result = op.execute({&x}, {}, {'c'});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
// z->printIndexedBuffer("Flatten1");
|
||||||
|
// z->printShapeInfo("Flatten1 shape");
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests4, Test_FlattenTests_2) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {3, 3, 3, 3});
|
||||||
|
auto y = NDArrayFactory::create<double>('c', {3, 3});
|
||||||
|
auto exp = NDArrayFactory::create<double>('c', {90});
|
||||||
|
|
||||||
|
x.linspace(1);
|
||||||
|
y.linspace(82);
|
||||||
|
exp.linspace(1);
|
||||||
|
nd4j::ops::flatten op;
|
||||||
|
auto result = op.execute({&x, &y}, {}, {'c'});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
// z->printIndexedBuffer("Flatten2");
|
||||||
|
// z->printShapeInfo("Flatten2 shape");
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests4, Test_FloorTests_1) {
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {3, 3}, {1.5, 2.3, 3.4, 4.3, 5.9, 6.1, 7.2, 8.9, 9.7});
|
||||||
|
auto exp = NDArrayFactory::create<double>('c', {3,3});
|
||||||
|
|
||||||
|
exp.linspace(1);
|
||||||
|
nd4j::ops::Floor op;
|
||||||
|
auto result = op.execute({&x}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
// z->printIndexedBuffer("Flatten1");
|
||||||
|
// z->printShapeInfo("Flatten1 shape");
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests4, Test_Reshape_Again) {
|
TEST_F(DeclarableOpsTests4, Test_Reshape_Again) {
|
||||||
auto x = NDArrayFactory::create<double>('c', {4, 3});
|
auto x = NDArrayFactory::create<double>('c', {4, 3});
|
||||||
auto exp = NDArrayFactory::create<double>('c', {4, 3});
|
auto exp = NDArrayFactory::create<double>('c', {4, 3});
|
||||||
|
@ -224,49 +224,33 @@ TEST_F(DeclarableOpsTests5, Test_Boolean_diff_1) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests5, Test_SpaceToBatch_1) {
|
TEST_F(DeclarableOpsTests5, Test_SpaceToBatch_1) {
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {1, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
auto x = NDArrayFactory::create<double>('c', {1, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||||
auto blocks = NDArrayFactory::create<double>('c', {2}, {2, 2});
|
|
||||||
auto paddings = NDArrayFactory::create<double>('c', {2, 2}, {0, 0, 0, 0});
|
|
||||||
|
|
||||||
auto exp = NDArrayFactory::create<double>('c', {4, 1, 1, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
auto exp = NDArrayFactory::create<double>('c', {4, 1, 1, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||||
|
|
||||||
nd4j::ops::space_to_batch op;
|
|
||||||
auto result = op.execute({&x, &blocks, &paddings}, {}, {});
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
|
||||||
|
|
||||||
auto z = result->at(0);
|
|
||||||
|
|
||||||
//z->printIndexedBuffer("z");
|
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests5, Test_SpaceToBatch_int_1) {
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {1, 2, 2, 1}, {1, 2, 3, 4});
|
|
||||||
auto blocks = NDArrayFactory::create<int>('c', {2, 1}, {2, 2});
|
|
||||||
auto paddings = NDArrayFactory::create<int>('c', {2, 2}, {0, 0, 0, 0});
|
auto paddings = NDArrayFactory::create<int>('c', {2, 2}, {0, 0, 0, 0});
|
||||||
|
|
||||||
auto exp = NDArrayFactory::create<double>('c', {4, 1, 1, 1}, {1, 2, 3, 4});
|
|
||||||
auto z = NDArrayFactory::create<double>('c', {4, 1, 1, 1});
|
|
||||||
|
|
||||||
nd4j::ops::space_to_batch op;
|
nd4j::ops::space_to_batch op;
|
||||||
auto result = op.execute({&x, &blocks, &paddings}, {&z}, {}, {}, {});
|
auto result = op.execute({&x, &paddings}, {}, {2});
|
||||||
ASSERT_EQ(Status::OK(), result);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests5, Test_SpaceToBatch_1_1) {
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {1, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
|
||||||
auto exp = NDArrayFactory::create<double>('c', {4, 1, 1, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
|
||||||
|
|
||||||
nd4j::ops::space_to_batch op;
|
|
||||||
auto result = op.execute({&x}, {}, {2, 2, 0, 0, 0, 0});
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
// z->printIndexedBuffer();
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
@ -274,15 +258,14 @@ TEST_F(DeclarableOpsTests5, Test_SpaceToBatch_1_1) {
|
|||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests5, Test_SpaceToBatch_2) {
|
TEST_F(DeclarableOpsTests5, Test_SpaceToBatch_2) {
|
||||||
auto x = NDArrayFactory::create<double>('c', {1, 2, 2, 1}, {1, 2, 3, 4});
|
auto x = NDArrayFactory::create<double>('c', {1, 2, 2, 1}, {1, 2, 3, 4});
|
||||||
auto blocks = NDArrayFactory::create<double>('c', {2}, {2, 2});
|
|
||||||
auto paddings = NDArrayFactory::create<double>('c', {2, 2}, {0, 0, 0, 0});
|
|
||||||
|
|
||||||
auto exp = NDArrayFactory::create<double>('c', {4, 1, 1, 1}, {1, 2, 3, 4});
|
auto exp = NDArrayFactory::create<double>('c', {4, 1, 1, 1}, {1, 2, 3, 4});
|
||||||
|
auto paddings = NDArrayFactory::create<int>('c', {2, 2}, {0, 0, 0, 0});
|
||||||
|
|
||||||
nd4j::ops::space_to_batch op;
|
nd4j::ops::space_to_batch op;
|
||||||
auto result = op.execute({&x, &blocks, &paddings}, {}, {});
|
auto result = op.execute({&x, &paddings}, {}, {2});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
@ -295,17 +278,17 @@ TEST_F(DeclarableOpsTests5, Test_SpaceToBatch_2) {
|
|||||||
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests5, Test_SpaceToBatch_3) {
|
TEST_F(DeclarableOpsTests5, Test_SpaceToBatch_3) {
|
||||||
auto x = NDArrayFactory::create<double>('c', {2, 2, 4, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
|
||||||
auto blocks = NDArrayFactory::create<double>('c', {2}, {2, 2});
|
|
||||||
auto paddings = NDArrayFactory::create<double>('c', {2, 2}, {0, 0, 2, 0});
|
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {2, 2, 4, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||||
|
auto paddings = NDArrayFactory::create<int>('c', {2, 2}, {0, 0, 2, 0});
|
||||||
auto exp = NDArrayFactory::create<double>('c', {8, 1, 3, 1}, {0, 1, 3, 0, 9, 11,0, 2, 4, 0, 10, 12,0, 5, 7, 0, 13, 15,0, 6, 8, 0, 14, 16});
|
auto exp = NDArrayFactory::create<double>('c', {8, 1, 3, 1}, {0, 1, 3, 0, 9, 11,0, 2, 4, 0, 10, 12,0, 5, 7, 0, 13, 15,0, 6, 8, 0, 14, 16});
|
||||||
|
|
||||||
nd4j::ops::space_to_batch op;
|
nd4j::ops::space_to_batch op;
|
||||||
auto result = op.execute({&x, &blocks, &paddings}, {}, {});
|
auto result = op.execute({&x, &paddings}, {}, {2});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
// z->printIndexedBuffer();
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
@ -313,53 +296,46 @@ TEST_F(DeclarableOpsTests5, Test_SpaceToBatch_3) {
|
|||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests5, Test_SpaceToBatch_3_1) {
|
//////////////////////////////////////////////////////////////////////
|
||||||
auto x = NDArrayFactory::create<double>('c', {2, 2, 4, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
TEST_F(DeclarableOpsTests5, Test_SpaceToBatch_4) {
|
||||||
auto exp = NDArrayFactory::create<double>('c', {8, 1, 3, 1}, {0, 1, 3, 0, 9, 11, 0, 2, 4, 0, 10, 12, 0, 5, 7, 0, 13, 15, 0, 6, 8, 0, 14, 16});
|
|
||||||
|
const int blockSize = 2;
|
||||||
|
NDArray x('c', {3, 3*blockSize - 1 - 2, 4*blockSize - 2 - 3, 2}, {147, 148, 219, 220, 149, 150, 11, 12, 83, 84, 13, 14, 155, 156, 227, 228, 157, 158, 171, 172, 243, 244, 173, 174, 35, 36, 107, 108, 37, 38, 179, 180, 251, 252, 181, 182, 195, 196, 267, 268, 197, 198, 59, 60, 131, 132, 61, 62, 203, 204, 275, 276, 205, 206}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray paddings = NDArrayFactory::create<int>('c', {2, 2}, {1, 2, 2, 3});
|
||||||
|
|
||||||
|
NDArray exp('c', {3*blockSize*blockSize, 3, 4, 2}, {0,0, 0,0, 0,0, 0,0, 0,0, 11,12, 13,14, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0,
|
||||||
|
0,0, 0,0, 0,0, 35,36, 37,38, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 59,60, 61,62, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0,
|
||||||
|
0,0, 0,0, 0,0, 0,0, 83,84, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 107, 108, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0,
|
||||||
|
0,0, 0,0, 0,0, 0,0, 0,0, 131, 132, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 147, 148, 149, 150, 0,0, 0,0, 155, 156, 157, 158,
|
||||||
|
0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 171, 172, 173, 174, 0,0, 0,0, 179, 180, 181, 182, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 195, 196,
|
||||||
|
197, 198, 0,0, 0,0, 203, 204, 205, 206, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 219, 220, 0,0, 0,0, 0,0, 227, 228, 0,0, 0,0, 0,0,
|
||||||
|
0,0, 0,0, 0,0, 0,0, 243, 244, 0,0, 0,0, 0,0, 251, 252, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 267, 268, 0,0, 0,0, 0,0, 275,
|
||||||
|
276, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::space_to_batch op;
|
nd4j::ops::space_to_batch op;
|
||||||
auto result = op.execute({&x}, {}, {2, 2, 0, 0, 2, 0});
|
auto result = op.execute({&x, &paddings}, {}, {blockSize});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
// z->printIndexedBuffer();
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
TEST_F(DeclarableOpsTests5, Test_SpaceToBatch_3_2) {
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {1, 2, 5, 5});
|
|
||||||
auto blocks = NDArrayFactory::create<int>({1,1,2});
|
|
||||||
auto paddings = NDArrayFactory::create<int>('c', {3,2}, {0,0,0,0,1,2});
|
|
||||||
auto exp = NDArrayFactory::create<double>('c', {2, 2, 5, 4}, {0., 2., 4., 0., 0., 7., 9., 0., 0., 12., 14., 0., 0., 17., 19., 0., 0., 22., 24., 0., 0., 27., 29., 0., 0., 32., 34., 0., 0., 37., 39., 0., 0., 42., 44., 0., 0., 47., 49., 0.,
|
|
||||||
1., 3., 5., 0., 6., 8., 10., 0., 11., 13., 15., 0., 16., 18., 20., 0., 21., 23., 25., 0., 26., 28., 30., 0., 31., 33., 35., 0., 36., 38., 40., 0., 41., 43., 45., 0., 46., 48., 50., 0.});
|
|
||||||
x.linspace(1);
|
|
||||||
nd4j::ops::space_to_batch op;
|
|
||||||
auto result = op.execute({&x, &blocks, &paddings}, {}, {});
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
|
||||||
|
|
||||||
auto z = result->at(0);
|
|
||||||
//z->printIndexedBuffer("Space to Batch Out");
|
|
||||||
//z->printShapeInfo("Space to Batch Out shape");
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests5, Test_BatchToSpace_1) {
|
TEST_F(DeclarableOpsTests5, Test_BatchToSpace_1) {
|
||||||
auto x = NDArrayFactory::create<double>('c', {4, 1, 1, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
auto x = NDArrayFactory::create<double>('c', {4, 1, 1, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||||
auto exp = NDArrayFactory::create<double>('c', {1, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
auto exp = NDArrayFactory::create<double>('c', {1, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||||
auto blocks = NDArrayFactory::create<double>('c', {2}, {2, 2});
|
auto crops = NDArrayFactory::create<int>('c', {2, 2}, {0, 0, 0, 0});
|
||||||
auto crops = NDArrayFactory::create<double>('c', {2, 2}, {0, 0, 0, 0});
|
|
||||||
|
|
||||||
nd4j::ops::batch_to_space op;
|
nd4j::ops::batch_to_space op;
|
||||||
auto result = op.execute({&x, &blocks, &crops}, {}, {});
|
auto result = op.execute({&x, &crops}, {}, {2});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
// z->printIndexedBuffer();
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
@ -367,31 +343,13 @@ TEST_F(DeclarableOpsTests5, Test_BatchToSpace_1) {
|
|||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests5, Test_BatchToSpace_1_1) {
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {4, 1, 1, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
|
||||||
auto exp = NDArrayFactory::create<double>('c', {1, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
|
||||||
|
|
||||||
nd4j::ops::batch_to_space op;
|
|
||||||
auto result = op.execute({&x}, {}, {2, 2, 0, 0, 0, 0});
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
|
||||||
|
|
||||||
auto z = result->at(0);
|
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests5, Test_BatchToSpace_2) {
|
TEST_F(DeclarableOpsTests5, Test_BatchToSpace_2) {
|
||||||
auto x = NDArrayFactory::create<double>('c', {4, 1, 1, 1}, {1, 2, 3, 4});
|
auto x = NDArrayFactory::create<double>('c', {4, 1, 1, 1}, {1, 2, 3, 4});
|
||||||
auto exp = NDArrayFactory::create<double>('c', {1, 2, 2, 1}, {1, 2, 3, 4});
|
auto exp = NDArrayFactory::create<double>('c', {1, 2, 2, 1}, {1, 2, 3, 4});
|
||||||
auto blocks = NDArrayFactory::create<double>('c', {2}, {2, 2});
|
auto crops = NDArrayFactory::create<int>('c', {2, 2}, {0, 0, 0, 0});
|
||||||
auto crops = NDArrayFactory::create<double>('c', {2, 2}, {0, 0, 0, 0});
|
|
||||||
|
|
||||||
nd4j::ops::batch_to_space op;
|
nd4j::ops::batch_to_space op;
|
||||||
auto result = op.execute({&x, &blocks, &crops}, {}, {});
|
auto result = op.execute({&x, &crops}, {}, {2});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
@ -404,13 +362,15 @@ TEST_F(DeclarableOpsTests5, Test_BatchToSpace_2) {
|
|||||||
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests5, Test_BatchToSpace_3) {
|
TEST_F(DeclarableOpsTests5, Test_BatchToSpace_3) {
|
||||||
auto x = NDArrayFactory::create<double>('c', {8, 1, 3, 1}, {0, 1, 3, 0, 9, 11, 0, 2, 4, 0, 10, 12, 0, 5, 7, 0, 13, 15, 0, 6, 8, 0, 14, 16});
|
auto x = NDArrayFactory::create<double>('c', {8, 1, 3, 1}, {0, 1, 3, 0, 9, 11,
|
||||||
|
0, 2, 4, 0, 10, 12,
|
||||||
|
0, 5, 7, 0, 13, 15,
|
||||||
|
0, 6, 8, 0, 14, 16});
|
||||||
auto exp = NDArrayFactory::create<double>('c', {2, 2, 4, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
auto exp = NDArrayFactory::create<double>('c', {2, 2, 4, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||||
auto blocks = NDArrayFactory::create<double>('c', {2}, {2, 2});
|
auto crops = NDArrayFactory::create<int>('c', {2, 2}, {0, 0, 2, 0});
|
||||||
auto crops = NDArrayFactory::create<double>('c', {2, 2}, {0, 0, 2, 0});
|
|
||||||
|
|
||||||
nd4j::ops::batch_to_space op;
|
nd4j::ops::batch_to_space op;
|
||||||
auto result = op.execute({&x, &blocks, &crops}, {}, {});
|
auto result = op.execute({&x, &crops}, {}, {2});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
@ -421,13 +381,18 @@ TEST_F(DeclarableOpsTests5, Test_BatchToSpace_3) {
|
|||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests5, Test_BatchToSpace_4) {
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests5, Test_BatchToSpace_3_1) {
|
const int blockSize = 2;
|
||||||
auto x = NDArrayFactory::create<double>('c', {8, 1, 3, 1}, {0, 1, 3, 0, 9, 11, 0, 2, 4, 0, 10, 12, 0, 5, 7, 0, 13, 15, 0, 6, 8, 0, 14, 16});
|
NDArray x('c', {3*blockSize*blockSize, 3, 4, 2}, nd4j::DataType::FLOAT32);
|
||||||
auto exp = NDArrayFactory::create<double>('c', {2, 2, 4, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
x.linspace(1, 1);
|
||||||
|
NDArray crops = NDArrayFactory::create<int>('c', {2, 2}, {1, 2, 2, 3});
|
||||||
|
|
||||||
|
NDArray exp('c', {3, 3*blockSize - 1 - 2, 4*blockSize - 2 - 3, 2}, {147, 148, 219, 220, 149, 150, 11, 12, 83, 84, 13, 14, 155, 156, 227, 228, 157, 158, 171, 172, 243, 244, 173, 174, 35, 36, 107, 108, 37, 38, 179, 180, 251, 252, 181, 182, 195, 196, 267, 268, 197, 198, 59, 60, 131, 132, 61, 62, 203, 204, 275, 276, 205, 206}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::batch_to_space op;
|
nd4j::ops::batch_to_space op;
|
||||||
auto result = op.execute({&x}, {}, {2, 2, 0, 0, 2, 0});
|
auto result = op.execute({&x, &crops}, {}, {blockSize});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
@ -631,6 +596,25 @@ TEST_F(DeclarableOpsTests5, gatherNd_test6) {
|
|||||||
delete results;
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests5, gatherNd_test7) {
|
||||||
|
|
||||||
|
auto input = NDArrayFactory::create<double>('c', {4, 4});
|
||||||
|
input.linspace(1);
|
||||||
|
auto indices = NDArrayFactory::create<int>('c', {3,3,2}, {0,2,1, 0,1,0, 1,3,1, 0,2,1, 0,1,0, 1,3,1});
|
||||||
|
auto expected = NDArrayFactory::create<double>('c', {3,3}, {3,5,5,8,5,10,2,2,14});
|
||||||
|
|
||||||
|
nd4j::ops::gather_nd op;
|
||||||
|
auto results = op.execute({&input, &indices}, {}, {});
|
||||||
|
auto output = results->at(0);
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests5, reverse_sequense_test1) {
|
TEST_F(DeclarableOpsTests5, reverse_sequense_test1) {
|
||||||
|
|
||||||
|
@ -279,58 +279,6 @@ TEST_F(DeclarableOpsTests6, Test_gatherNd_Edge_1) {
|
|||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests6, Test_StB_1) {
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {4, 64, 64, 4});
|
|
||||||
auto blocks = NDArrayFactory::create<double>('c', {2}, {8, 8});
|
|
||||||
auto paddings = NDArrayFactory::create<double>('c', {2, 2}, {12, 12, 16, 16});
|
|
||||||
|
|
||||||
x.assign(1.0f);
|
|
||||||
|
|
||||||
nd4j::ops::space_to_batch op;
|
|
||||||
auto result = op.execute({&x, &blocks, &paddings}, {}, {});
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
|
||||||
|
|
||||||
auto z = result->at(0);
|
|
||||||
|
|
||||||
//nd4j_printf("Mean: %f\n", z->meanNumber());
|
|
||||||
|
|
||||||
delete result;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests6, Test_StB_2) {
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {2, 6, 6, 2});
|
|
||||||
auto blocks = NDArrayFactory::create<double>('c', {2}, {2, 2});
|
|
||||||
auto paddings = NDArrayFactory::create<double>('c', {2, 2}, {2, 2, 2, 2});
|
|
||||||
|
|
||||||
x.assign(1.0f);
|
|
||||||
|
|
||||||
nd4j::ops::space_to_batch op;
|
|
||||||
auto result = op.execute({&x, &blocks, &paddings}, {}, {});
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
|
||||||
|
|
||||||
auto z = result->at(0);
|
|
||||||
|
|
||||||
delete result;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests6, Test_BtS_1) {
|
|
||||||
auto x = NDArrayFactory::create<double>('f', {256, 8, 8, 2});
|
|
||||||
auto blocks = NDArrayFactory::create<double>('c',{2}, {8, 8});
|
|
||||||
auto crops = NDArrayFactory::create<double>('c', {2, 2});
|
|
||||||
|
|
||||||
nd4j::ops::batch_to_space op;
|
|
||||||
auto result = op.execute({&x, &blocks, &crops}, {}, {});
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
|
||||||
|
|
||||||
auto z = result->at(0);
|
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests6, Test_Order_1) {
|
TEST_F(DeclarableOpsTests6, Test_Order_1) {
|
||||||
auto x = NDArrayFactory::create<double>('f', {2, 3});
|
auto x = NDArrayFactory::create<double>('f', {2, 3});
|
||||||
auto exp = NDArrayFactory::create<double>('c', {2, 3});
|
auto exp = NDArrayFactory::create<double>('c', {2, 3});
|
||||||
@ -1532,8 +1480,8 @@ TEST_F(DeclarableOpsTests6, LogDet_1) {
|
|||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
//z->printIndexedBuffer("Output ");
|
z->printIndexedBuffer("LogDet Output1 ");
|
||||||
//exp.printIndexedBuffer("Expected ");
|
exp.printIndexedBuffer("LogDet Expected1 ");
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
@ -1554,9 +1502,32 @@ TEST_F(DeclarableOpsTests6, LogDet_2) {
|
|||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
// z->printIndexedBuffer("Output ");
|
z->printIndexedBuffer("LogDet Output2 ");
|
||||||
// z->printShapeInfo("Shape");
|
// z->printShapeInfo("Shape");
|
||||||
//exp.printIndexedBuffer("Expected ");
|
exp.printIndexedBuffer("LogDet Expected2 ");
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests6, LogDet_3) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {3, 3}, {4,12,-16,12,37,-43,-16,-43,98});
|
||||||
|
auto exp = NDArrayFactory::create<double>( 3.5835189);
|
||||||
|
|
||||||
|
//x.printIndexedBuffer("Input");
|
||||||
|
nd4j::ops::logdet op;
|
||||||
|
auto result = op.execute({&x}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
z->printIndexedBuffer("LogDet Output3 ");
|
||||||
|
// z->printShapeInfo("Shape");
|
||||||
|
exp.printIndexedBuffer("LogDet Expected3 ");
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
@ -3644,38 +3644,6 @@ TEST_F(DeclarableOpsTests7, fill_test3) {
|
|||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests7, clipbynorm_test3) {
|
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {3, 5});
|
|
||||||
auto unities = NDArrayFactory::create<double>('c', {3, 1}, {1., 1., 1.});
|
|
||||||
auto scale = NDArrayFactory::create<double>('c', {3, 1}, {1.1, 1., 0.9});
|
|
||||||
|
|
||||||
x.linspace(100.);
|
|
||||||
|
|
||||||
auto xNorm1 = x.reduceAlongDims(reduce::Norm2, {1}, true);
|
|
||||||
x /= xNorm1;
|
|
||||||
xNorm1 = x.reduceAlongDims(reduce::Norm2,{1}, true);
|
|
||||||
|
|
||||||
ASSERT_TRUE(unities.isSameShape(xNorm1));
|
|
||||||
ASSERT_TRUE(unities.equalsTo(xNorm1));
|
|
||||||
|
|
||||||
x *= scale;
|
|
||||||
xNorm1 = x.reduceAlongDims(reduce::Norm2, {1}, true);
|
|
||||||
|
|
||||||
nd4j::ops::clipbynorm op;
|
|
||||||
auto result = op.execute({&x}, {1.0}, {1}, {}, false, nd4j::DataType::DOUBLE);
|
|
||||||
auto z = result->at(0);
|
|
||||||
|
|
||||||
auto zNorm1 = z->reduceAlongDims(reduce::Norm2, {1}, true);
|
|
||||||
auto exp = NDArrayFactory::create<double>('c', {3, 1}, {1., 1., xNorm1.e<double>(2)});
|
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(&zNorm1));
|
|
||||||
ASSERT_TRUE(exp.equalsTo(&zNorm1));
|
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests7, mirrorPad_test1) {
|
TEST_F(DeclarableOpsTests7, mirrorPad_test1) {
|
||||||
|
|
||||||
|
@ -184,7 +184,7 @@ TEST_F(DeclarableOpsTests9, exponentialDistribution_test1) {
|
|||||||
destroyRandom((Nd4jPointer) rng);
|
destroyRandom((Nd4jPointer) rng);
|
||||||
delete[] buffer;
|
delete[] buffer;
|
||||||
}
|
}
|
||||||
*/
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests9, exponentialDistribution_test2) {
|
TEST_F(DeclarableOpsTests9, exponentialDistribution_test2) {
|
||||||
@ -221,6 +221,7 @@ TEST_F(DeclarableOpsTests9, exponentialDistribution_test2) {
|
|||||||
|
|
||||||
delete[] buffer;
|
delete[] buffer;
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests9, ScalarOpTest_MixedOrders_1) {
|
TEST_F(DeclarableOpsTests9, ScalarOpTest_MixedOrders_1) {
|
||||||
auto x = NDArrayFactory::create<double>('f', {2, 2}, {1.0, 3.0, 2.0, 4.0});
|
auto x = NDArrayFactory::create<double>('f', {2, 2}, {1.0, 3.0, 2.0, 4.0});
|
||||||
@ -1586,7 +1587,6 @@ TEST_F(DeclarableOpsTests9, clipbynorm_bp_test1) {
|
|||||||
|
|
||||||
const int bS = 2;
|
const int bS = 2;
|
||||||
const int nOut = 3;
|
const int nOut = 3;
|
||||||
const int axis = 0;
|
|
||||||
const double clip = 0.7;
|
const double clip = 0.7;
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1]
|
auto x = NDArrayFactory::create<double>('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1]
|
||||||
|
@ -84,7 +84,7 @@ TEST_F(FlatBuffersTest, BasicTest1) {
|
|||||||
delete gB;
|
delete gB;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
TEST_F(FlatBuffersTest, FlatGraphTest1) {
|
TEST_F(FlatBuffersTest, FlatGraphTest1) {
|
||||||
flatbuffers::FlatBufferBuilder builder(4096);
|
flatbuffers::FlatBufferBuilder builder(4096);
|
||||||
|
|
||||||
@ -205,7 +205,7 @@ TEST_F(FlatBuffersTest, FlatGraphTest1) {
|
|||||||
delete var0;
|
delete var0;
|
||||||
delete resultWrapper;
|
delete resultWrapper;
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
TEST_F(FlatBuffersTest, ExecutionTest1) {
|
TEST_F(FlatBuffersTest, ExecutionTest1) {
|
||||||
auto gA = new Node(OpType_TRANSFORM_SAME);
|
auto gA = new Node(OpType_TRANSFORM_SAME);
|
||||||
|
|
||||||
|
@ -97,6 +97,7 @@ TEST_F(GraphStateTests, Basic_Tests_2) {
|
|||||||
deleteGraphState(state);
|
deleteGraphState(state);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
TEST_F(GraphStateTests, Stateful_Execution_1) {
|
TEST_F(GraphStateTests, Stateful_Execution_1) {
|
||||||
auto state = getGraphState(117L);
|
auto state = getGraphState(117L);
|
||||||
|
|
||||||
@ -117,16 +118,13 @@ TEST_F(GraphStateTests, Stateful_Execution_2) {
|
|||||||
|
|
||||||
Nd4jLong scopes[] = {22, 33};
|
Nd4jLong scopes[] = {22, 33};
|
||||||
auto status = execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0);
|
auto status = execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0);
|
||||||
|
|
||||||
// it's no-op: just LogicScope
|
// it's no-op: just LogicScope
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
deleteGraphState(state);
|
deleteGraphState(state);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
// This test checks WHILE loop
|
||||||
* This test checks WHILE loop
|
|
||||||
*/
|
|
||||||
TEST_F(GraphStateTests, Stateful_Execution_3) {
|
TEST_F(GraphStateTests, Stateful_Execution_3) {
|
||||||
auto var0 = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
|
auto var0 = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
|
||||||
auto var1 = NDArrayFactory::create<float>(11.0f);
|
auto var1 = NDArrayFactory::create<float>(11.0f);
|
||||||
@ -193,10 +191,7 @@ TEST_F(GraphStateTests, Stateful_Execution_3) {
|
|||||||
// now we check provided result array
|
// now we check provided result array
|
||||||
float sum = res0.reduceNumber(reduce::Sum).e<float>(0);
|
float sum = res0.reduceNumber(reduce::Sum).e<float>(0);
|
||||||
|
|
||||||
/*
|
// Expected result is {1, 2, 3, 4} + {2} elementwise + {2} elementwise, which gives { 5, 6, 7, 8}, and sum should be 26
|
||||||
* Expected result is {1, 2, 3, 4} + {2} elementwise + {2} elementwise, which gives { 5, 6, 7, 8}, and sum should be 26
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
ASSERT_NEAR(26.0f, sum, 1e-5);
|
ASSERT_NEAR(26.0f, sum, 1e-5);
|
||||||
|
|
||||||
// nd4j_printf("0 ------------------\n","");
|
// nd4j_printf("0 ------------------\n","");
|
||||||
@ -206,9 +201,7 @@ TEST_F(GraphStateTests, Stateful_Execution_3) {
|
|||||||
// nd4j_printf("1 ------------------\n","");
|
// nd4j_printf("1 ------------------\n","");
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
// This test checks CONDITIONAL execution for FALSE
|
||||||
* This test checks CONDITIONAL execution for FALSE
|
|
||||||
*/
|
|
||||||
TEST_F(GraphStateTests, Stateful_Execution_4) {
|
TEST_F(GraphStateTests, Stateful_Execution_4) {
|
||||||
auto var0 = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
|
auto var0 = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
|
||||||
auto var1 = NDArrayFactory::create<float>(5.0f);
|
auto var1 = NDArrayFactory::create<float>(5.0f);
|
||||||
@ -282,9 +275,7 @@ TEST_F(GraphStateTests, Stateful_Execution_4) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
// This test checks CONDITIONAL execution for TRUE
|
||||||
* This test checks CONDITIONAL execution for TRUE
|
|
||||||
*/
|
|
||||||
TEST_F(GraphStateTests, Stateful_Execution_5) {
|
TEST_F(GraphStateTests, Stateful_Execution_5) {
|
||||||
auto var0 = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
|
auto var0 = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
|
||||||
auto var1 = NDArrayFactory::create<float>(5.0f);
|
auto var1 = NDArrayFactory::create<float>(5.0f);
|
||||||
@ -355,3 +346,4 @@ TEST_F(GraphStateTests, Stateful_Execution_5) {
|
|||||||
|
|
||||||
deleteGraphState(state);
|
deleteGraphState(state);
|
||||||
}
|
}
|
||||||
|
*/
|
@ -104,8 +104,8 @@ TEST_F(JavaInteropTests, TestShapeExposure3) {
|
|||||||
sub1.assign(1.0f);
|
sub1.assign(1.0f);
|
||||||
sub2.assign(2.0f);
|
sub2.assign(2.0f);
|
||||||
|
|
||||||
Nd4jPointer inputBuffers[] = {x.buffer(), sizes.buffer()};
|
Nd4jPointer inputBuffers[] = {x.buffer(), sizes.buffer(), x.getSpecialBuffer(), sizes.getSpecialBuffer()};
|
||||||
Nd4jPointer inputShapes[] = {x.shapeInfo(), sizes.shapeInfo()};
|
Nd4jPointer inputShapes[] = {x.shapeInfo(), sizes.shapeInfo(), x.getSpecialShapeInfo(), sizes.getSpecialShapeInfo()};
|
||||||
|
|
||||||
nd4j::ops::split_v op;
|
nd4j::ops::split_v op;
|
||||||
|
|
||||||
@ -130,13 +130,11 @@ TEST_F(JavaInteropTests, Test_Squeeze_1) {
|
|||||||
|
|
||||||
nd4j::ops::squeeze op;
|
nd4j::ops::squeeze op;
|
||||||
|
|
||||||
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer()};
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), x.getSpecialBuffer()};
|
||||||
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo()};
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), x.getSpecialShapeInfo()};
|
||||||
|
|
||||||
|
|
||||||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
|
|
||||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
|
|
||||||
|
|
||||||
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer(), z.getSpecialBuffer()};
|
||||||
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo(), z.getSpecialShapeInfo()};
|
||||||
auto status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
auto status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
@ -149,16 +147,19 @@ TEST_F(JavaInteropTests, Test_RDiv_1) {
|
|||||||
auto z = NDArrayFactory::create<double>('c', {3});
|
auto z = NDArrayFactory::create<double>('c', {3});
|
||||||
auto e = NDArrayFactory::create<double>('c', {3}, {2, 3, 4});
|
auto e = NDArrayFactory::create<double>('c', {3}, {2, 3, 4});
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&z}, {&x, &y});
|
||||||
|
|
||||||
nd4j::ops::reversedivide op;
|
nd4j::ops::reversedivide op;
|
||||||
|
|
||||||
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), (Nd4jPointer) y.getBuffer()};
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), (Nd4jPointer) y.getBuffer(), x.getSpecialBuffer(), y.getSpecialBuffer()};
|
||||||
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), (Nd4jPointer) y.getShapeInfo()};
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), (Nd4jPointer) y.getShapeInfo(), x.getSpecialShapeInfo(), y.getSpecialShapeInfo()};
|
||||||
|
|
||||||
|
|
||||||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer(), z.getSpecialBuffer()};
|
||||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo(), z.getSpecialShapeInfo()};
|
||||||
|
|
||||||
auto status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
auto status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({&z}, {&x, &y});
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
ASSERT_EQ(e, z);
|
ASSERT_EQ(e, z);
|
||||||
@ -183,13 +184,14 @@ TEST_F(JavaInteropTests, TestSconv2d_1) {
|
|||||||
|
|
||||||
nd4j::ops::sconv2d op;
|
nd4j::ops::sconv2d op;
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&output}, {&input, &weightsD, &weightsP, &bias});
|
||||||
|
|
||||||
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.getBuffer(), (Nd4jPointer) weightsD.getBuffer(), (Nd4jPointer) weightsP.getBuffer(), (Nd4jPointer) bias.getBuffer()};
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.getBuffer(), (Nd4jPointer) weightsD.getBuffer(), (Nd4jPointer) weightsP.getBuffer(), (Nd4jPointer) bias.getBuffer(), (Nd4jPointer) input.getSpecialBuffer(), (Nd4jPointer) weightsD.getSpecialBuffer(), (Nd4jPointer) weightsP.getSpecialBuffer(), (Nd4jPointer) bias.getSpecialBuffer()};
|
||||||
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.getShapeInfo(), (Nd4jPointer) weightsD.getShapeInfo(), (Nd4jPointer) weightsP.getShapeInfo(), (Nd4jPointer) bias.getShapeInfo()};
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.getShapeInfo(), (Nd4jPointer) weightsD.getShapeInfo(), (Nd4jPointer) weightsP.getShapeInfo(), (Nd4jPointer) bias.getShapeInfo(), (Nd4jPointer) input.getSpecialShapeInfo(), (Nd4jPointer) weightsD.getSpecialShapeInfo(), (Nd4jPointer) weightsP.getSpecialShapeInfo(), (Nd4jPointer) bias.getSpecialShapeInfo()};
|
||||||
|
|
||||||
|
|
||||||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer()};
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer(), (Nd4jPointer) output.getSpecialBuffer()};
|
||||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo()};
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo(), (Nd4jPointer) output.getSpecialShapeInfo()};
|
||||||
|
|
||||||
Nd4jLong exp[] = {1, 1, 1, 1, 0, 0, 1, 1, 0, 0};
|
Nd4jLong exp[] = {1, 1, 1, 1, 0, 0, 1, 1, 0, 0};
|
||||||
|
|
||||||
@ -197,6 +199,7 @@ TEST_F(JavaInteropTests, TestSconv2d_1) {
|
|||||||
nullptr, 0, exp, 9, nullptr, 0, false);
|
nullptr, 0, exp, 9, nullptr, 0, false);
|
||||||
|
|
||||||
//output.printBuffer("output");
|
//output.printBuffer("output");
|
||||||
|
NDArray::registerSpecialUse({&output}, {&input, &weightsD, &weightsP, &bias});
|
||||||
|
|
||||||
ASSERT_NEAR(1423, output.e<float>(0), 1e-5);
|
ASSERT_NEAR(1423, output.e<float>(0), 1e-5);
|
||||||
//nd4j_printf("Iter %i passed...\n", e);
|
//nd4j_printf("Iter %i passed...\n", e);
|
||||||
@ -216,19 +219,20 @@ TEST_F(JavaInteropTests, TestSconv2d_2) {
|
|||||||
|
|
||||||
nd4j::ops::sconv2d op;
|
nd4j::ops::sconv2d op;
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&output}, {&input, &weightsD});
|
||||||
|
|
||||||
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.getBuffer(), (Nd4jPointer) weightsD.getBuffer()};
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.getBuffer(), (Nd4jPointer) weightsD.getBuffer(), input.getSpecialBuffer(), weightsD.getSpecialBuffer()};
|
||||||
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.getShapeInfo(), (Nd4jPointer) weightsD.getShapeInfo()};
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.getShapeInfo(), (Nd4jPointer) weightsD.getShapeInfo(), input.getSpecialShapeInfo(), weightsD.getSpecialShapeInfo()};
|
||||||
|
|
||||||
|
|
||||||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer()};
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer(), output.getSpecialBuffer()};
|
||||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo()};
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo(), output.getSpecialShapeInfo()};
|
||||||
|
|
||||||
Nd4jLong exp[] = {1, 1, 1, 1, 0, 0, 1, 1, 0};
|
Nd4jLong exp[] = {1, 1, 1, 1, 0, 0, 1, 1, 0};
|
||||||
|
|
||||||
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 9, nullptr, 0, false);
|
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 9, nullptr, 0, false);
|
||||||
|
|
||||||
//output.printBuffer("output");
|
NDArray::registerSpecialUse({&output}, {&input, &weightsD});
|
||||||
|
|
||||||
ASSERT_NEAR(1, output.e<float>(0), 1e-5);
|
ASSERT_NEAR(1, output.e<float>(0), 1e-5);
|
||||||
}
|
}
|
||||||
@ -239,18 +243,21 @@ TEST_F(JavaInteropTests, TestMaxPooling2d_1) {
|
|||||||
auto output = NDArrayFactory::create<float>('c', {1, 2, 4, 5});
|
auto output = NDArrayFactory::create<float>('c', {1, 2, 4, 5});
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&output}, {&input});
|
||||||
|
|
||||||
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.getBuffer()};
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.getBuffer(), input.getSpecialBuffer()};
|
||||||
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.getShapeInfo()};
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.getShapeInfo(), input.getSpecialShapeInfo()};
|
||||||
|
|
||||||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer()};
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer(), output.getSpecialBuffer()};
|
||||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo()};
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo(), output.getSpecialShapeInfo()};
|
||||||
|
|
||||||
std::vector<Nd4jLong> iArgs({2, 2, 1, 1, 0, 0, 1, 1, 1});
|
std::vector<Nd4jLong> iArgs({2, 2, 1, 1, 0, 0, 1, 1, 1});
|
||||||
|
|
||||||
nd4j::ops::maxpool2d op;
|
nd4j::ops::maxpool2d op;
|
||||||
|
|
||||||
Nd4jStatus status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs.data(), 9, nullptr, 0, false);
|
Nd4jStatus status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs.data(), 9, nullptr, 0, false);
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({&output}, {&input});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
|
||||||
}
|
}
|
||||||
@ -267,12 +274,13 @@ TEST_F(JavaInteropTests, TestCol2Im_1) {
|
|||||||
auto output = NDArrayFactory::create<float>('c', {1, 2, 4, 5});
|
auto output = NDArrayFactory::create<float>('c', {1, 2, 4, 5});
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
|
|
||||||
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.getBuffer()};
|
NDArray::prepareSpecialUse({&output}, {&input});
|
||||||
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.getShapeInfo()};
|
|
||||||
|
|
||||||
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.getBuffer(), input.getSpecialBuffer()};
|
||||||
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.getShapeInfo(), input.getSpecialShapeInfo()};
|
||||||
|
|
||||||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer()};
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer(), output.getSpecialBuffer()};
|
||||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo()};
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo(), output.getSpecialShapeInfo()};
|
||||||
|
|
||||||
nd4j::ops::col2im op;
|
nd4j::ops::col2im op;
|
||||||
|
|
||||||
@ -282,6 +290,8 @@ TEST_F(JavaInteropTests, TestCol2Im_1) {
|
|||||||
|
|
||||||
execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 9, nullptr, 0, false);
|
execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 9, nullptr, 0, false);
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({&output}, {&input});
|
||||||
|
|
||||||
ASSERT_TRUE(output.meanNumber().e<float>(0) > 0.0f);
|
ASSERT_TRUE(output.meanNumber().e<float>(0) > 0.0f);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -300,18 +310,23 @@ TEST_F(JavaInteropTests, TestPNorm_1) {
|
|||||||
auto output = NDArrayFactory::create<float>('c', {1, 3, 3, 3});
|
auto output = NDArrayFactory::create<float>('c', {1, 3, 3, 3});
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&output}, {&input});
|
||||||
|
|
||||||
nd4j::ops::pnormpool2d op;
|
nd4j::ops::pnormpool2d op;
|
||||||
|
|
||||||
Nd4jLong exp[] = {2, 2, 1, 1, 0, 0, 1, 1, 0, 2, 0, 0};
|
Nd4jLong exp[] = {2, 2, 1, 1, 0, 0, 1, 1, 0, 2, 0, 0};
|
||||||
|
|
||||||
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.getBuffer()};
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.getBuffer(), input.getSpecialBuffer()};
|
||||||
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.getShapeInfo()};
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.getShapeInfo(), input.getSpecialShapeInfo()};
|
||||||
|
|
||||||
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer(), output.getSpecialBuffer()};
|
||||||
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo(), output.getSpecialShapeInfo()};
|
||||||
|
|
||||||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer()};
|
|
||||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo()};
|
|
||||||
|
|
||||||
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 11, nullptr, 0, false);
|
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 11, nullptr, 0, false);
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({&output}, {&input});
|
||||||
|
|
||||||
ASSERT_TRUE(output.meanNumber().e<double>(0) > 0.0);
|
ASSERT_TRUE(output.meanNumber().e<double>(0) > 0.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -321,16 +336,20 @@ TEST_F(JavaInteropTests, TestInplace_1) {
|
|||||||
//auto exp('c', {10, 10});
|
//auto exp('c', {10, 10});
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({}, {&input});
|
||||||
|
|
||||||
nd4j::ops::clipbyvalue op;
|
nd4j::ops::clipbyvalue op;
|
||||||
|
|
||||||
double extras[] = {-1.0f, 1.0f};
|
double extras[] = {-1.0f, 1.0f};
|
||||||
|
|
||||||
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.getBuffer()};
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.getBuffer(), input.getSpecialBuffer()};
|
||||||
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.getShapeInfo()};
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.getShapeInfo(), input.getSpecialShapeInfo()};
|
||||||
|
|
||||||
|
|
||||||
Nd4jStatus result = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, nullptr, nullptr, 0, extras, 2, nullptr, 0, nullptr, 0, true);
|
Nd4jStatus result = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, nullptr, nullptr, 0, extras, 2, nullptr, 0, nullptr, 0, true);
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({}, {&input});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result);
|
ASSERT_EQ(ND4J_STATUS_OK, result);
|
||||||
|
|
||||||
ASSERT_NEAR(1.0, input.meanNumber().e<float>(0), 1e-5);
|
ASSERT_NEAR(1.0, input.meanNumber().e<float>(0), 1e-5);
|
||||||
@ -381,6 +400,7 @@ TEST_F(JavaInteropTests, Test_Synonyms_3) {
|
|||||||
ASSERT_EQ(nameRef, name);
|
ASSERT_EQ(nameRef, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
TEST_F(JavaInteropTests, test_avgpooling_edge_1) {
|
TEST_F(JavaInteropTests, test_avgpooling_edge_1) {
|
||||||
int inOutH = 35;
|
int inOutH = 35;
|
||||||
int inOutW = 35;
|
int inOutW = 35;
|
||||||
@ -391,19 +411,23 @@ TEST_F(JavaInteropTests, test_avgpooling_edge_1) {
|
|||||||
x.linspace(1.0);
|
x.linspace(1.0);
|
||||||
z.linspace(1.0);
|
z.linspace(1.0);
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&z}, {&x});
|
||||||
|
|
||||||
nd4j::ops::avgpool2d op;
|
nd4j::ops::avgpool2d op;
|
||||||
//auto result = op.execute({&x}, {}, {3,3, 1,1, 0,0, 1,1, 1, 0, 1});
|
//auto result = op.execute({&x}, {}, {3,3, 1,1, 0,0, 1,1, 1, 0, 1});
|
||||||
|
|
||||||
Nd4jLong exp[] = {3,3, 1,1, 0,0, 1,1, 1, 0, 1};
|
Nd4jLong exp[] = {3,3, 1,1, 0,0, 1,1, 1, 0, 1};
|
||||||
|
|
||||||
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer()};
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), x.getSpecialBuffer()};
|
||||||
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo()};
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), x.getSpecialShapeInfo()};
|
||||||
|
|
||||||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer(), z.getSpecialBuffer()};
|
||||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo(), z.getSpecialShapeInfo()};
|
||||||
|
|
||||||
auto result = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 11, nullptr, 0, false);
|
auto result = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 11, nullptr, 0, false);
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({&z}, {&x});
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result);
|
ASSERT_EQ(Status::OK(), result);
|
||||||
|
|
||||||
int totalPadHeight = (inOutH - 1) * 1 + 3 - inOutH;
|
int totalPadHeight = (inOutH - 1) * 1 + 3 - inOutH;
|
||||||
@ -469,7 +493,7 @@ TEST_F(JavaInteropTests, test_avgpooling_edge_1) {
|
|||||||
ASSERT_EQ(m, z);
|
ASSERT_EQ(m, z);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
|
||||||
TEST_F(JavaInteropTests, Test_GraphReuse_1) {
|
TEST_F(JavaInteropTests, Test_GraphReuse_1) {
|
||||||
uint8_t* data = nd4j::graph::readFlatBuffers("./resources/reduce_dim_false.fb");
|
uint8_t* data = nd4j::graph::readFlatBuffers("./resources/reduce_dim_false.fb");
|
||||||
|
|
||||||
@ -577,18 +601,19 @@ TEST_F(JavaInteropTests, Test_Greater_1) {
|
|||||||
|
|
||||||
auto exp = NDArrayFactory::create<bool>('c', {2, 2}, {0, 0, 1, 1});
|
auto exp = NDArrayFactory::create<bool>('c', {2, 2}, {0, 0, 1, 1});
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&o}, {&x, &y});
|
||||||
|
|
||||||
nd4j::ops::greater op;
|
nd4j::ops::greater op;
|
||||||
|
|
||||||
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), (Nd4jPointer) y.getBuffer(), x.getSpecialBuffer(), y.getSpecialBuffer()};
|
||||||
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), (Nd4jPointer) y.getShapeInfo(), x.getSpecialShapeInfo(), y.getSpecialShapeInfo()};
|
||||||
|
|
||||||
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), (Nd4jPointer) y.getBuffer()};
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.getBuffer(), o.getSpecialBuffer()};
|
||||||
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), (Nd4jPointer) y.getShapeInfo()};
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.getShapeInfo(), o.getSpecialShapeInfo()};
|
||||||
|
|
||||||
|
|
||||||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.getBuffer()};
|
|
||||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.getShapeInfo()};
|
|
||||||
|
|
||||||
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||||
o.printIndexedBuffer("Greater JIT");
|
|
||||||
|
NDArray::registerSpecialUse({&o}, {&x, &y});
|
||||||
ASSERT_TRUE(exp.equalsTo(&o));
|
ASSERT_TRUE(exp.equalsTo(&o));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -602,35 +627,41 @@ TEST_F(JavaInteropTests, Test_Greater_2) {
|
|||||||
|
|
||||||
nd4j::ops::greater op;
|
nd4j::ops::greater op;
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&o}, {&x, &y});
|
||||||
|
|
||||||
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), (Nd4jPointer) y.getBuffer()};
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), (Nd4jPointer) y.getBuffer(), x.getSpecialBuffer(), y.getSpecialBuffer()};
|
||||||
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), (Nd4jPointer) y.getShapeInfo()};
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), (Nd4jPointer) y.getShapeInfo(), x.getSpecialShapeInfo(), y.getSpecialShapeInfo()};
|
||||||
|
|
||||||
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.getBuffer(), o.getSpecialBuffer()};
|
||||||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.getBuffer()};
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.getShapeInfo(), o.getSpecialShapeInfo()};
|
||||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.getShapeInfo()};
|
|
||||||
|
|
||||||
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({&o}, {&x, &y});
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(&o));
|
ASSERT_TRUE(exp.equalsTo(&o));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(JavaInteropTests, Test_Boolean_Op_1) {
|
TEST_F(JavaInteropTests, Test_Boolean_Op_1) {
|
||||||
|
|
||||||
nd4j::ops::is_non_decreasing op;
|
nd4j::ops::is_non_decreasing op;
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<float>('c', {5}, {1, 2, 3, 4, 5});
|
auto x = NDArrayFactory::create<float>('c', {5}, {1, 2, 3, 4, 5});
|
||||||
auto o = NDArrayFactory::create<bool>(false);
|
auto o = NDArrayFactory::create<bool>(false);
|
||||||
auto exp = NDArrayFactory::create<bool>(1);
|
auto exp = NDArrayFactory::create<bool>(1);
|
||||||
|
|
||||||
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer()};
|
NDArray::prepareSpecialUse({&o}, {&x});
|
||||||
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo()};
|
|
||||||
|
|
||||||
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), x.getSpecialBuffer()};
|
||||||
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), x.getSpecialShapeInfo()};
|
||||||
|
|
||||||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.getBuffer()};
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.getBuffer(), o.getSpecialBuffer()};
|
||||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.getShapeInfo()};
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.getShapeInfo(), o.getSpecialShapeInfo()};
|
||||||
|
|
||||||
auto hash = op.getOpHash();
|
auto hash = op.getOpHash();
|
||||||
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({&o}, {&x});
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(&o));
|
ASSERT_TRUE(exp.equalsTo(&o));
|
||||||
@ -644,15 +675,18 @@ TEST_F(JavaInteropTests, Test_Inplace_Outputs_1) {
|
|||||||
|
|
||||||
nd4j::ops::test_output_reshape op;
|
nd4j::ops::test_output_reshape op;
|
||||||
|
|
||||||
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer()};
|
NDArray::prepareSpecialUse({&z}, {&x});
|
||||||
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo()};
|
|
||||||
|
|
||||||
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), x.getSpecialBuffer()};
|
||||||
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), x.getSpecialShapeInfo()};
|
||||||
|
|
||||||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer(), z.getSpecialBuffer()};
|
||||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo(), z.getSpecialShapeInfo()};
|
||||||
|
|
||||||
auto hash = op.getOpHash();
|
auto hash = op.getOpHash();
|
||||||
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({&z}, {&x});
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
@ -669,14 +703,18 @@ TEST_F(JavaInteropTests, Test_Inplace_Outputs_2) {
|
|||||||
|
|
||||||
nd4j::ops::add op;
|
nd4j::ops::add op;
|
||||||
|
|
||||||
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), (Nd4jPointer) y.getBuffer()};
|
NDArray::prepareSpecialUse({&z}, {&x, &y});
|
||||||
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), (Nd4jPointer) y.getShapeInfo()};
|
|
||||||
|
|
||||||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), (Nd4jPointer) y.getBuffer(), x.getSpecialBuffer(), y.getSpecialBuffer()};
|
||||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), (Nd4jPointer) y.getShapeInfo(), x.getSpecialShapeInfo(), y.getSpecialShapeInfo()};
|
||||||
|
|
||||||
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer(), z.getSpecialBuffer()};
|
||||||
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo(), z.getSpecialShapeInfo()};
|
||||||
|
|
||||||
auto hash = op.getOpHash();
|
auto hash = op.getOpHash();
|
||||||
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&z}, {&x, &y});
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
ASSERT_TRUE(e.isSameShape(z));
|
ASSERT_TRUE(e.isSameShape(z));
|
||||||
@ -692,16 +730,20 @@ TEST_F(JavaInteropTests, Test_Inplace_Outputs_3) {
|
|||||||
|
|
||||||
nd4j::ops::gather op;
|
nd4j::ops::gather op;
|
||||||
|
|
||||||
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.getBuffer(), (Nd4jPointer) indices.getBuffer()};
|
NDArray::prepareSpecialUse({&output}, {&input, &indices});
|
||||||
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.getShapeInfo(), (Nd4jPointer) indices.getShapeInfo()};
|
|
||||||
|
|
||||||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer()};
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.getBuffer(), (Nd4jPointer) indices.getBuffer(), input.getSpecialBuffer(), indices.getSpecialBuffer()};
|
||||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo()};
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.getShapeInfo(), (Nd4jPointer) indices.getShapeInfo(), input.getSpecialShapeInfo(), input.getSpecialShapeInfo()};
|
||||||
|
|
||||||
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer(), output.getSpecialBuffer()};
|
||||||
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo(), output.getSpecialShapeInfo()};
|
||||||
|
|
||||||
Nd4jLong iArgs[] = {1};
|
Nd4jLong iArgs[] = {1};
|
||||||
|
|
||||||
auto hash = op.getOpHash();
|
auto hash = op.getOpHash();
|
||||||
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 1, nullptr, 0, false);
|
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 1, nullptr, 0, false);
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({&output}, {&input, &indices});
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
ASSERT_TRUE(e.isSameShape(output));
|
ASSERT_TRUE(e.isSameShape(output));
|
||||||
@ -715,12 +757,31 @@ TEST_F(JavaInteropTests, Test_Reduce3_EdgeCase) {
|
|||||||
auto z = NDArrayFactory::create<float>('c', {5});
|
auto z = NDArrayFactory::create<float>('c', {5});
|
||||||
|
|
||||||
auto dims = NDArrayFactory::create<int>('c', {2}, {0, 1});
|
auto dims = NDArrayFactory::create<int>('c', {2}, {0, 1});
|
||||||
|
dims.syncToHost();
|
||||||
|
|
||||||
execReduce3Tad(nullptr, 2, x.buffer(), x.shapeInfo(), nullptr, nullptr, nullptr,
|
nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext();
|
||||||
y.buffer(), y.shapeInfo(), nullptr, nullptr,
|
|
||||||
z.buffer(), z.shapeInfo(), nullptr, nullptr,
|
Nd4jPointer* extraPointers = nullptr;
|
||||||
dims.buffer(), dims.shapeInfo(), dims.specialBuffer(), dims.specialShapeInfo(), nullptr, nullptr, nullptr, nullptr);
|
#ifdef __CUDABLAS__
|
||||||
|
extraPointers = new Nd4jPointer[6] {nullptr, context->getCudaStream(), context->getScalarPointer(), nullptr, context->getCudaSpecialStream(), context->getReductionPointer()};
|
||||||
|
#endif
|
||||||
|
|
||||||
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.getShapeInfo(), {0,1});
|
||||||
|
auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.getShapeInfo(), {0,1});
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&z}, {&x, &y, &dims});
|
||||||
|
|
||||||
|
execReduce3Tad(extraPointers, 2, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
||||||
|
nullptr,
|
||||||
|
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
||||||
|
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
||||||
|
dims.buffer(), dims.shapeInfo(), dims.specialBuffer(), dims.specialShapeInfo(), packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets());
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({&z}, {&x, &y, &dims});
|
||||||
|
|
||||||
|
delete []extraPointers;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
TEST_F(JavaInteropTests, Test_SimpleIf_Output) {
|
TEST_F(JavaInteropTests, Test_SimpleIf_Output) {
|
||||||
Environment::getInstance()->setDebug(true);
|
Environment::getInstance()->setDebug(true);
|
||||||
@ -745,17 +806,21 @@ TEST_F(JavaInteropTests, Test_AveragePooling_FF_TF_double) {
|
|||||||
|
|
||||||
nd4j::ops::avgpool2d op;
|
nd4j::ops::avgpool2d op;
|
||||||
|
|
||||||
Nd4jPointer ptrsInBuffer[] = {reinterpret_cast<Nd4jPointer>(input.buffer())};
|
NDArray::prepareSpecialUse({&z}, {&input});
|
||||||
Nd4jPointer ptrsInShapes[] = {reinterpret_cast<Nd4jPointer>(input.shapeInfo())};
|
|
||||||
|
Nd4jPointer ptrsInBuffer[] = {reinterpret_cast<Nd4jPointer>(input.buffer()), input.getSpecialBuffer()};
|
||||||
|
Nd4jPointer ptrsInShapes[] = {reinterpret_cast<Nd4jPointer>(input.shapeInfo()), input.getSpecialShapeInfo()};
|
||||||
|
|
||||||
|
Nd4jPointer ptrsOutBuffers[] = {reinterpret_cast<Nd4jPointer>(z.buffer()), z.getSpecialBuffer()};
|
||||||
|
Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo()), z.getSpecialShapeInfo()};
|
||||||
|
|
||||||
Nd4jPointer ptrsOutBuffers[] = {reinterpret_cast<Nd4jPointer>(z.buffer())};
|
|
||||||
Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo())};
|
|
||||||
Nd4jLong iArgs[] = {3,3, 3,3, 0,0, 1,1,1, 0,1};
|
Nd4jLong iArgs[] = {3,3, 3,3, 0,0, 1,1,1, 0,1};
|
||||||
|
|
||||||
auto hash = op.getOpHash();
|
auto hash = op.getOpHash();
|
||||||
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false);
|
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false);
|
||||||
ASSERT_EQ(Status::OK(), status);
|
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({&z}, {&input});
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
@ -767,11 +832,13 @@ TEST_F(JavaInteropTests, Test_MaxPool2D_float_1) {
|
|||||||
|
|
||||||
input.linspace(1.0);
|
input.linspace(1.0);
|
||||||
|
|
||||||
Nd4jPointer ptrsInBuffer[] = {reinterpret_cast<Nd4jPointer>(input.buffer())};
|
NDArray::prepareSpecialUse({&z}, {&input});
|
||||||
Nd4jPointer ptrsInShapes[] = {reinterpret_cast<Nd4jPointer>(input.shapeInfo())};
|
|
||||||
|
|
||||||
Nd4jPointer ptrsOutBuffers[] = {reinterpret_cast<Nd4jPointer>(z.buffer())};
|
Nd4jPointer ptrsInBuffer[] = {reinterpret_cast<Nd4jPointer>(input.buffer()), input.getSpecialBuffer()};
|
||||||
Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo())};
|
Nd4jPointer ptrsInShapes[] = {reinterpret_cast<Nd4jPointer>(input.shapeInfo()), input.getSpecialShapeInfo()};
|
||||||
|
|
||||||
|
Nd4jPointer ptrsOutBuffers[] = {reinterpret_cast<Nd4jPointer>(z.buffer()), z.getSpecialBuffer()};
|
||||||
|
Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo()), z.getSpecialShapeInfo()};
|
||||||
|
|
||||||
Nd4jLong iArgs[] = {2,2, 1,1, 1,1, 2,2,1, 0,0};
|
Nd4jLong iArgs[] = {2,2, 1,1, 1,1, 2,2,1, 0,0};
|
||||||
|
|
||||||
@ -779,6 +846,8 @@ TEST_F(JavaInteropTests, Test_MaxPool2D_float_1) {
|
|||||||
|
|
||||||
auto hash = op.getOpHash();
|
auto hash = op.getOpHash();
|
||||||
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false);
|
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false);
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({&z}, {&input});
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -791,11 +860,13 @@ TEST_F(JavaInteropTests, Test_Unstack_1) {
|
|||||||
auto z3 = NDArrayFactory::create<double>('c',{5});
|
auto z3 = NDArrayFactory::create<double>('c',{5});
|
||||||
auto z4 = NDArrayFactory::create<double>('c',{5});
|
auto z4 = NDArrayFactory::create<double>('c',{5});
|
||||||
|
|
||||||
Nd4jPointer ptrsInBuffer[] = {reinterpret_cast<Nd4jPointer>(x.buffer())};
|
NDArray::prepareSpecialUse({&z0, &z1, &z2, &z3, &z4}, {&x});
|
||||||
Nd4jPointer ptrsInShapes[] = {reinterpret_cast<Nd4jPointer>(x.shapeInfo())};
|
|
||||||
|
|
||||||
Nd4jPointer ptrsOutBuffers[] = {z0.buffer(), z1.buffer(), z2.buffer(), z3.buffer(), z4.buffer()};
|
Nd4jPointer ptrsInBuffer[] = {reinterpret_cast<Nd4jPointer>(x.buffer()), x.getSpecialBuffer()};
|
||||||
Nd4jPointer ptrsOutShapes[] = {z0.shapeInfo(), z1.shapeInfo(), z2.shapeInfo(), z3.shapeInfo(), z4.shapeInfo()};
|
Nd4jPointer ptrsInShapes[] = {reinterpret_cast<Nd4jPointer>(x.shapeInfo()), x.getSpecialShapeInfo()};
|
||||||
|
|
||||||
|
Nd4jPointer ptrsOutBuffers[] = {z0.buffer(), z1.buffer(), z2.buffer(), z3.buffer(), z4.buffer(), z0.getSpecialBuffer(), z1.getSpecialBuffer(), z2.getSpecialBuffer(), z3.getSpecialBuffer(), z4.getSpecialBuffer()};
|
||||||
|
Nd4jPointer ptrsOutShapes[] = {z0.shapeInfo(), z1.shapeInfo(), z2.shapeInfo(), z3.shapeInfo(), z4.shapeInfo(), z0.getSpecialShapeInfo(), z1.getSpecialShapeInfo(), z2.getSpecialShapeInfo(), z3.getSpecialShapeInfo(), z4.getSpecialShapeInfo()};
|
||||||
|
|
||||||
Nd4jLong iArgs[] = {0};
|
Nd4jLong iArgs[] = {0};
|
||||||
|
|
||||||
@ -803,6 +874,8 @@ TEST_F(JavaInteropTests, Test_Unstack_1) {
|
|||||||
|
|
||||||
auto hash = op.getOpHash();
|
auto hash = op.getOpHash();
|
||||||
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 5, nullptr, 0, iArgs, 1, nullptr, 0, false);
|
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 5, nullptr, 0, iArgs, 1, nullptr, 0, false);
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({&z0, &z1, &z2, &z3, &z4}, {&x});
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -814,17 +887,20 @@ TEST_F(JavaInteropTests, Test_AveragePooling_FF_TF_float) {
|
|||||||
|
|
||||||
nd4j::ops::avgpool2d op;
|
nd4j::ops::avgpool2d op;
|
||||||
|
|
||||||
Nd4jPointer ptrsInBuffer[] = {reinterpret_cast<Nd4jPointer>(input.buffer())};
|
NDArray::prepareSpecialUse({&z}, {&input});
|
||||||
Nd4jPointer ptrsInShapes[] = {reinterpret_cast<Nd4jPointer>(input.shapeInfo())};
|
|
||||||
|
|
||||||
Nd4jPointer ptrsOutBuffers[] = {reinterpret_cast<Nd4jPointer>(z.buffer())};
|
Nd4jPointer ptrsInBuffer[] = {reinterpret_cast<Nd4jPointer>(input.buffer()), input.getSpecialBuffer()};
|
||||||
Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo())};
|
Nd4jPointer ptrsInShapes[] = {reinterpret_cast<Nd4jPointer>(input.shapeInfo()), input.getSpecialShapeInfo()};
|
||||||
|
|
||||||
|
Nd4jPointer ptrsOutBuffers[] = {reinterpret_cast<Nd4jPointer>(z.buffer()), z.getSpecialBuffer()};
|
||||||
|
Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo()), z.getSpecialShapeInfo()};
|
||||||
Nd4jLong iArgs[] = {3,3, 3,3, 0,0, 1,1,1, 0,1};
|
Nd4jLong iArgs[] = {3,3, 3,3, 0,0, 1,1,1, 0,1};
|
||||||
|
|
||||||
auto hash = op.getOpHash();
|
auto hash = op.getOpHash();
|
||||||
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false);
|
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false);
|
||||||
ASSERT_EQ(Status::OK(), status);
|
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({&z}, {&input});
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
@ -839,12 +915,16 @@ TEST_F(JavaInteropTests, Test_Mixed_Add_1) {
|
|||||||
auto arrayZ = NDArrayFactory::create<double>({0, 0, 0, 0});
|
auto arrayZ = NDArrayFactory::create<double>({0, 0, 0, 0});
|
||||||
auto arrayE = NDArrayFactory::create<double>({2, 4, 6, 8});
|
auto arrayE = NDArrayFactory::create<double>({2, 4, 6, 8});
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&arrayZ}, {&arrayX, &arrayY});
|
||||||
|
|
||||||
execPairwiseTransform(nullptr, pairwise::Add,
|
execPairwiseTransform(nullptr, pairwise::Add,
|
||||||
arrayX.buffer(), arrayX.shapeInfo(), nullptr, nullptr,
|
arrayX.buffer(), arrayX.shapeInfo(), arrayX.getSpecialBuffer(), arrayX.getSpecialShapeInfo(),
|
||||||
arrayY.buffer(), arrayY.shapeInfo(), nullptr, nullptr,
|
arrayY.buffer(), arrayY.shapeInfo(), arrayY.getSpecialBuffer(), arrayY.getSpecialShapeInfo(),
|
||||||
arrayZ.buffer(), arrayZ.shapeInfo(), nullptr, nullptr,
|
arrayZ.buffer(), arrayZ.shapeInfo(), arrayZ.getSpecialBuffer(), arrayZ.getSpecialShapeInfo(),
|
||||||
nullptr);
|
nullptr);
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({&arrayZ}, {&arrayX, &arrayY});
|
||||||
|
|
||||||
ASSERT_EQ(arrayE, arrayZ);
|
ASSERT_EQ(arrayE, arrayZ);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -853,16 +933,20 @@ TEST_F(JavaInteropTests, Test_Add_1) {
|
|||||||
auto y = NDArrayFactory::create<int>('c', {5}, {1, 1, 1, 1, 1});
|
auto y = NDArrayFactory::create<int>('c', {5}, {1, 1, 1, 1, 1});
|
||||||
auto e = NDArrayFactory::create<int>('c', {5}, {2, 2, 2, 2, 2});
|
auto e = NDArrayFactory::create<int>('c', {5}, {2, 2, 2, 2, 2});
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&x}, {&x, &y});
|
||||||
|
|
||||||
nd4j::ops::add op;
|
nd4j::ops::add op;
|
||||||
|
|
||||||
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), y.getBuffer()};
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), y.getBuffer(), x.getSpecialBuffer(), y.getSpecialBuffer()};
|
||||||
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), y.getShapeInfo()};
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), y.getShapeInfo(), x.getSpecialShapeInfo(), y.getSpecialShapeInfo(),};
|
||||||
|
|
||||||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) x.getBuffer()};
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) x.getBuffer(), x.getSpecialBuffer()};
|
||||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) x.getShapeInfo()};
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) x.getShapeInfo(), x.getSpecialShapeInfo()};
|
||||||
|
|
||||||
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({&x}, {&x, &y});
|
||||||
|
|
||||||
ASSERT_EQ(e, x);
|
ASSERT_EQ(e, x);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -876,14 +960,18 @@ TEST_F(JavaInteropTests, zeta_test10) {
|
|||||||
|
|
||||||
nd4j::ops::zeta op;
|
nd4j::ops::zeta op;
|
||||||
|
|
||||||
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), q.getBuffer()};
|
NDArray::prepareSpecialUse({&z}, {&x, &q});
|
||||||
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), q.getShapeInfo()};
|
|
||||||
|
|
||||||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), q.getBuffer(), x.getSpecialBuffer(), q.getSpecialBuffer()};
|
||||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.getShapeInfo(), q.getShapeInfo(), x.specialShapeInfo(), q.getSpecialShapeInfo()};
|
||||||
|
|
||||||
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer(), z.getSpecialBuffer()};
|
||||||
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo(), z.getSpecialShapeInfo()};
|
||||||
|
|
||||||
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({&z}, {&x, &q});
|
||||||
|
|
||||||
ASSERT_EQ(e, z);
|
ASSERT_EQ(e, z);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -892,12 +980,23 @@ TEST_F(JavaInteropTests, Test_Is_Max_1) {
|
|||||||
auto arrayZ = NDArrayFactory::create<bool>({0, 0, 0, 0});
|
auto arrayZ = NDArrayFactory::create<bool>({0, 0, 0, 0});
|
||||||
auto arrayE = NDArrayFactory::create<bool>({0, 1, 0, 0});
|
auto arrayE = NDArrayFactory::create<bool>({0, 1, 0, 0});
|
||||||
|
|
||||||
execTransformAny(nullptr, transform::IsMax,
|
nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext();
|
||||||
arrayX.buffer(), arrayX.shapeInfo(), nullptr, nullptr,
|
|
||||||
arrayZ.buffer(), arrayZ.shapeInfo(), nullptr, nullptr,
|
Nd4jPointer* extraPointers = nullptr;
|
||||||
|
#ifdef __CUDABLAS__
|
||||||
|
extraPointers = new Nd4jPointer[7] {nullptr, context->getCudaStream(), context->getScalarPointer(), nullptr, context->getCudaSpecialStream(), context->getReductionPointer(), context->getAllocationPointer()};
|
||||||
|
#endif
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&arrayZ}, {&arrayX});
|
||||||
|
execTransformAny(extraPointers, transform::IsMax,
|
||||||
|
arrayX.buffer(), arrayX.shapeInfo(), arrayX.getSpecialBuffer(), arrayX.getSpecialShapeInfo(),
|
||||||
|
arrayZ.buffer(), arrayZ.shapeInfo(), arrayZ.getSpecialBuffer(), arrayZ.getSpecialShapeInfo(),
|
||||||
nullptr);
|
nullptr);
|
||||||
//arrayZ.printIndexedBuffer("JAVA ISMAX1");
|
|
||||||
|
NDArray::registerSpecialUse({&arrayZ}, {&arrayX});
|
||||||
ASSERT_EQ(arrayE, arrayZ);
|
ASSERT_EQ(arrayE, arrayZ);
|
||||||
|
|
||||||
|
delete []extraPointers;
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(JavaInteropTests, Test_Is_Max_1_2) {
|
TEST_F(JavaInteropTests, Test_Is_Max_1_2) {
|
||||||
@ -905,12 +1004,22 @@ TEST_F(JavaInteropTests, Test_Is_Max_1_2) {
|
|||||||
auto arrayZ = NDArrayFactory::create<float>({0, 0, 0, 0});
|
auto arrayZ = NDArrayFactory::create<float>({0, 0, 0, 0});
|
||||||
auto arrayE = NDArrayFactory::create<float>({0, 1, 0, 0});
|
auto arrayE = NDArrayFactory::create<float>({0, 1, 0, 0});
|
||||||
|
|
||||||
execTransformAny(nullptr, transform::IsMax,
|
nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext();
|
||||||
arrayX.buffer(), arrayX.shapeInfo(), nullptr, nullptr,
|
|
||||||
arrayZ.buffer(), arrayZ.shapeInfo(), nullptr, nullptr,
|
Nd4jPointer* extraPointers = nullptr;
|
||||||
|
#ifdef __CUDABLAS__
|
||||||
|
extraPointers = new Nd4jPointer[7] {nullptr, context->getCudaStream(), context->getScalarPointer(), nullptr, context->getCudaSpecialStream(), context->getReductionPointer(), context->getAllocationPointer()};
|
||||||
|
#endif
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&arrayZ}, {&arrayX});
|
||||||
|
execTransformAny(extraPointers, transform::IsMax,
|
||||||
|
arrayX.buffer(), arrayX.shapeInfo(), arrayX.getSpecialBuffer(), arrayX.getSpecialShapeInfo(),
|
||||||
|
arrayZ.buffer(), arrayZ.shapeInfo(), arrayZ.getSpecialBuffer(), arrayZ.getSpecialShapeInfo(),
|
||||||
nullptr);
|
nullptr);
|
||||||
//arrayZ.printIndexedBuffer("JAVA ISMAX1");
|
//arrayZ.printIndexedBuffer("JAVA ISMAX1");
|
||||||
|
NDArray::registerSpecialUse({&arrayZ}, {&arrayX});
|
||||||
ASSERT_EQ(arrayE, arrayZ);
|
ASSERT_EQ(arrayE, arrayZ);
|
||||||
|
delete []extraPointers;
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(JavaInteropTests, Test_Is_Max_2) {
|
TEST_F(JavaInteropTests, Test_Is_Max_2) {
|
||||||
@ -921,10 +1030,12 @@ TEST_F(JavaInteropTests, Test_Is_Max_2) {
|
|||||||
Nd4jLong *ex[] = {tad, off};
|
Nd4jLong *ex[] = {tad, off};
|
||||||
float ea[] = {2, 1, 2};
|
float ea[] = {2, 1, 2};
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&arrayZ}, {&arrayX});
|
||||||
execTransformBool(reinterpret_cast<void **>(ex), transform::IsMax,
|
execTransformBool(reinterpret_cast<void **>(ex), transform::IsMax,
|
||||||
arrayX.buffer(), arrayX.shapeInfo(), nullptr, nullptr,
|
arrayX.buffer(), arrayX.shapeInfo(), arrayX.getSpecialBuffer(), arrayX.getSpecialShapeInfo(),
|
||||||
arrayZ.buffer(), arrayZ.shapeInfo(), nullptr, nullptr,
|
arrayZ.buffer(), arrayZ.shapeInfo(), arrayZ.getSpecialBuffer(), arrayZ.getSpecialShapeInfo(),
|
||||||
ea);
|
ea);
|
||||||
|
NDArray::registerSpecialUse({&arrayZ}, {&arrayX});
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(JavaInteropTests, Test_IAMax_1) {
|
TEST_F(JavaInteropTests, Test_IAMax_1) {
|
||||||
@ -939,14 +1050,13 @@ TEST_F(JavaInteropTests, Test_Boolean_Broadcastables_1) {
|
|||||||
auto arrayX = NDArrayFactory::create<double>('c', {10, 10});
|
auto arrayX = NDArrayFactory::create<double>('c', {10, 10});
|
||||||
auto arrayY = NDArrayFactory::create<double>('c', {10, 10});
|
auto arrayY = NDArrayFactory::create<double>('c', {10, 10});
|
||||||
|
|
||||||
Nd4jPointer ptrsInBuffer[] = {reinterpret_cast<Nd4jPointer>(arrayX.buffer()), reinterpret_cast<Nd4jPointer>(arrayY.buffer())};
|
Nd4jPointer ptrsInBuffer[] = {reinterpret_cast<Nd4jPointer>(arrayX.buffer()), reinterpret_cast<Nd4jPointer>(arrayY.buffer()), arrayX.getSpecialBuffer(), arrayY.getSpecialBuffer()};
|
||||||
Nd4jPointer ptrsInShapes[] = {reinterpret_cast<Nd4jPointer>(arrayX.shapeInfo()), reinterpret_cast<Nd4jPointer>(arrayY.shapeInfo())};
|
Nd4jPointer ptrsInShapes[] = {reinterpret_cast<Nd4jPointer>(arrayX.shapeInfo()), reinterpret_cast<Nd4jPointer>(arrayY.shapeInfo()), arrayX.getSpecialShapeInfo(), arrayY.getSpecialShapeInfo()};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({}, {&arrayX, &arrayY});
|
||||||
nd4j::ops::greater_equal op;
|
nd4j::ops::greater_equal op;
|
||||||
auto shapeList = calculateOutputShapes2(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, nullptr, 0, nullptr, 0, nullptr, 0);
|
auto shapeList = calculateOutputShapes2(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, nullptr, 0, nullptr, 0, nullptr, 0);
|
||||||
|
NDArray::registerSpecialUse({}, {&arrayX, &arrayY});
|
||||||
delete shapeList;
|
delete shapeList;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -955,17 +1065,19 @@ TEST_F(JavaInteropTests, Test_L2_Loss_3) {
|
|||||||
auto e = NDArrayFactory::create<double>(0.303254);
|
auto e = NDArrayFactory::create<double>(0.303254);
|
||||||
auto z = NDArrayFactory::create<double>(0.0);
|
auto z = NDArrayFactory::create<double>(0.0);
|
||||||
|
|
||||||
Nd4jPointer ptrsInBuffer[] = {reinterpret_cast<Nd4jPointer>(x.buffer())};
|
NDArray::prepareSpecialUse({&z}, {&x});
|
||||||
Nd4jPointer ptrsInShapes[] = {reinterpret_cast<Nd4jPointer>(x.shapeInfo())};
|
|
||||||
|
|
||||||
Nd4jPointer ptrsOutBuffer[] = {reinterpret_cast<Nd4jPointer>(z.buffer())};
|
Nd4jPointer ptrsInBuffer[] = {reinterpret_cast<Nd4jPointer>(x.buffer()), x.getSpecialBuffer()};
|
||||||
Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo())};
|
Nd4jPointer ptrsInShapes[] = {reinterpret_cast<Nd4jPointer>(x.shapeInfo()), x.getSpecialShapeInfo()};
|
||||||
|
|
||||||
|
Nd4jPointer ptrsOutBuffer[] = {reinterpret_cast<Nd4jPointer>(z.buffer()), z.getSpecialBuffer()};
|
||||||
|
Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo()), z.getSpecialShapeInfo()};
|
||||||
|
|
||||||
nd4j::ops::l2_loss op;
|
nd4j::ops::l2_loss op;
|
||||||
auto status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffer, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
auto status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffer, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
z.printIndexedBuffer("z");
|
NDArray::registerSpecialUse({&z}, {&x});
|
||||||
|
|
||||||
ASSERT_EQ(e, z);
|
ASSERT_EQ(e, z);
|
||||||
}
|
}
|
||||||
@ -978,15 +1090,19 @@ TEST_F(JavaInteropTests, Test_Fastpath_3) {
|
|||||||
auto exp = NDArrayFactory::create<float>('c', {3, 2}, {2.f, 4.f, 6.f, 8.f, 10.f, 12.f});
|
auto exp = NDArrayFactory::create<float>('c', {3, 2}, {2.f, 4.f, 6.f, 8.f, 10.f, 12.f});
|
||||||
Context ctx(1);
|
Context ctx(1);
|
||||||
|
|
||||||
ctx.setInputArray(0, array0.buffer(), array0.shapeInfo(), nullptr, nullptr);
|
NDArray::prepareSpecialUse({&z}, {&array0, &array1});
|
||||||
ctx.setInputArray(1, array1.buffer(), array1.shapeInfo(), nullptr, nullptr);
|
|
||||||
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), nullptr, nullptr);
|
ctx.setInputArray(0, array0.buffer(), array0.shapeInfo(), array0.getSpecialBuffer(), array0.getSpecialShapeInfo());
|
||||||
|
ctx.setInputArray(1, array1.buffer(), array1.shapeInfo(), array1.getSpecialBuffer(), array1.getSpecialShapeInfo());
|
||||||
|
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.getSpecialBuffer(), z.getSpecialShapeInfo());
|
||||||
|
|
||||||
ASSERT_EQ(2, ctx.width());
|
ASSERT_EQ(2, ctx.width());
|
||||||
|
|
||||||
nd4j::ops::add op;
|
nd4j::ops::add op;
|
||||||
execCustomOp2(nullptr, op.getOpHash(), &ctx);
|
execCustomOp2(nullptr, op.getOpHash(), &ctx);
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({&z}, {&array0, &array1});
|
||||||
|
|
||||||
ASSERT_EQ(exp, z);
|
ASSERT_EQ(exp, z);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -996,6 +1112,9 @@ TEST_F(JavaInteropTests, Test_Fastpath_4) {
|
|||||||
auto z = NDArrayFactory::create<double>('c', {3, 5});
|
auto z = NDArrayFactory::create<double>('c', {3, 5});
|
||||||
Nd4jLong iArgs[] = {3, 5, 2};
|
Nd4jLong iArgs[] = {3, 5, 2};
|
||||||
|
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&z}, {});
|
||||||
|
|
||||||
Context ctx(1);
|
Context ctx(1);
|
||||||
|
|
||||||
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
||||||
@ -1004,6 +1123,8 @@ TEST_F(JavaInteropTests, Test_Fastpath_4) {
|
|||||||
nd4j::ops::tri op;
|
nd4j::ops::tri op;
|
||||||
execCustomOp2(nullptr, op.getOpHash(), &ctx);
|
execCustomOp2(nullptr, op.getOpHash(), &ctx);
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({&z}, {});
|
||||||
|
|
||||||
ASSERT_EQ(exp, z);
|
ASSERT_EQ(exp, z);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1014,6 +1135,8 @@ TEST_F(JavaInteropTests, Test_Fastpath_5) {
|
|||||||
a.linspace(1.0);
|
a.linspace(1.0);
|
||||||
b.linspace(1.0);
|
b.linspace(1.0);
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&c}, {&b, &c});
|
||||||
|
|
||||||
Context ctx(1);
|
Context ctx(1);
|
||||||
|
|
||||||
ctx.setInputArray(0, a.buffer(), a.shapeInfo(), a.specialBuffer(), a.specialShapeInfo());
|
ctx.setInputArray(0, a.buffer(), a.shapeInfo(), a.specialBuffer(), a.specialShapeInfo());
|
||||||
@ -1023,6 +1146,8 @@ TEST_F(JavaInteropTests, Test_Fastpath_5) {
|
|||||||
nd4j::ops::matmul op;
|
nd4j::ops::matmul op;
|
||||||
auto status = execCustomOp2(nullptr, op.getOpHash(), &ctx);
|
auto status = execCustomOp2(nullptr, op.getOpHash(), &ctx);
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({&c}, {&b, &c});
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1037,6 +1162,8 @@ TEST_F(JavaInteropTests, Test_Fastpath_6) {
|
|||||||
b.linspace(1.0);
|
b.linspace(1.0);
|
||||||
gI.linspace(1.0);
|
gI.linspace(1.0);
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&gA, &gB}, {&a, &b, &gI});
|
||||||
|
|
||||||
Context ctx(1);
|
Context ctx(1);
|
||||||
Nd4jLong iArgs[] = {0L, 0L, 0L};
|
Nd4jLong iArgs[] = {0L, 0L, 0L};
|
||||||
|
|
||||||
@ -1052,6 +1179,8 @@ TEST_F(JavaInteropTests, Test_Fastpath_6) {
|
|||||||
nd4j::ops::matmul_bp op;
|
nd4j::ops::matmul_bp op;
|
||||||
auto status = execCustomOp2(nullptr, op.getOpHash(), &ctx);
|
auto status = execCustomOp2(nullptr, op.getOpHash(), &ctx);
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({&gA, &gB}, {&a, &b, &gI});
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1061,6 +1190,8 @@ TEST_F(JavaInteropTests, Test_Fastpath_7) {
|
|||||||
auto z = NDArrayFactory::create<float>('c', {3});
|
auto z = NDArrayFactory::create<float>('c', {3});
|
||||||
auto e = NDArrayFactory::create<float>('c', {3}, {1.f, 2.f, 3.f});
|
auto e = NDArrayFactory::create<float>('c', {3}, {1.f, 2.f, 3.f});
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&z}, {&a, &b});
|
||||||
|
|
||||||
Context ctx(1);
|
Context ctx(1);
|
||||||
Nd4jLong iArgs[] = {0L, 0L, 0L};
|
Nd4jLong iArgs[] = {0L, 0L, 0L};
|
||||||
|
|
||||||
@ -1074,6 +1205,8 @@ TEST_F(JavaInteropTests, Test_Fastpath_7) {
|
|||||||
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
||||||
|
|
||||||
auto status = execCustomOp2(nullptr, op.getOpHash(), &ctx);
|
auto status = execCustomOp2(nullptr, op.getOpHash(), &ctx);
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({&z}, {&a, &b});
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
ASSERT_EQ(e, z);
|
ASSERT_EQ(e, z);
|
||||||
|
@ -172,7 +172,7 @@ TEST_F(LegacyOpsTests, ReduceTests_1) {
|
|||||||
ASSERT_EQ(1, result->size());
|
ASSERT_EQ(1, result->size());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
z->printBuffer("ReduceTest1");
|
// z->printBuffer("ReduceTest1");
|
||||||
ASSERT_TRUE(z->isScalar());
|
ASSERT_TRUE(z->isScalar());
|
||||||
ASSERT_NEAR(x.sumNumber().e<float>(0), z->e<float>(0), 1e-5f);
|
ASSERT_NEAR(x.sumNumber().e<float>(0), z->e<float>(0), 1e-5f);
|
||||||
|
|
||||||
@ -232,10 +232,10 @@ TEST_F(LegacyOpsTests, ReduceTests_4) {
|
|||||||
auto result = op.execute({&x, &indices}, {}, {}, {true});
|
auto result = op.execute({&x, &indices}, {}, {}, {true});
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
auto exp = x.reduceAlongDims(reduce::Sum, {1}, true);
|
auto exp = x.reduceAlongDims(reduce::Sum, {1}, true);
|
||||||
indices.printShapeInfo("Indices shape");
|
// indices.printShapeInfo("Indices shape");
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
z->printIndexedBuffer("Output reduce 4");
|
// z->printIndexedBuffer("Output reduce 4");
|
||||||
exp.printIndexedBuffer("Expected reduce 4");
|
// exp.printIndexedBuffer("Expected reduce 4");
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
@ -253,7 +253,7 @@ TEST_F(LegacyOpsTests, ReduceTests_5) {
|
|||||||
ASSERT_EQ(1, result->size());
|
ASSERT_EQ(1, result->size());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
z->printBuffer("ReduceTest1");
|
// z->printBuffer("ReduceTest1");
|
||||||
ASSERT_TRUE(z->isScalar());
|
ASSERT_TRUE(z->isScalar());
|
||||||
ASSERT_NEAR(x.meanNumber().e<float>(0), z->e<float>(0), 1e-5f);
|
ASSERT_NEAR(x.meanNumber().e<float>(0), z->e<float>(0), 1e-5f);
|
||||||
|
|
||||||
@ -315,9 +315,9 @@ TEST_F(LegacyOpsTests, ReduceTests_8) {
|
|||||||
auto exp = x.reduceAlongDims(reduce::Mean, {1}, true);
|
auto exp = x.reduceAlongDims(reduce::Mean, {1}, true);
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
z->printIndexedBuffer("Reduce8 output");
|
// z->printIndexedBuffer("Reduce8 output");
|
||||||
z->printShapeInfo("Reduce8 shape");
|
// z->printShapeInfo("Reduce8 shape");
|
||||||
exp.printShapeInfo("Reduce8 expected shape");
|
// exp.printShapeInfo("Reduce8 expected shape");
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
@ -356,7 +356,7 @@ TEST_F(LegacyOpsTests, IndexReduceTests_2) {
|
|||||||
ASSERT_EQ(1, result->size());
|
ASSERT_EQ(1, result->size());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
z->printIndexedBuffer("Hello indexreduce2");
|
// z->printIndexedBuffer("Hello indexreduce2");
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
//ASSERT_EQ(4, z->e<int>(0));
|
//ASSERT_EQ(4, z->e<int>(0));
|
||||||
//ASSERT_EQ(4, z->e<int>(1));
|
//ASSERT_EQ(4, z->e<int>(1));
|
||||||
@ -423,8 +423,8 @@ TEST_F(LegacyOpsTests, BroadcastingTests_1) {
|
|||||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
|
||||||
auto list = x.allTensorsAlongDimension({1});
|
auto list = x.allTensorsAlongDimension({1});
|
||||||
x.printIndexedBuffer("Output broadcast");
|
// x.printIndexedBuffer("Output broadcast");
|
||||||
list->at(0)->printIndexedBuffer("Column 0:");
|
// list->at(0)->printIndexedBuffer("Column 0:");
|
||||||
for (int e = 0; e < list->size(); e++)
|
for (int e = 0; e < list->size(); e++)
|
||||||
ASSERT_TRUE(row.equalsTo(list->at(e)));
|
ASSERT_TRUE(row.equalsTo(list->at(e)));
|
||||||
|
|
||||||
@ -439,14 +439,15 @@ TEST_F(LegacyOpsTests, BroadcastingTests_2) {
|
|||||||
e.assign(4.0);
|
e.assign(4.0);
|
||||||
|
|
||||||
int axis = 1;
|
int axis = 1;
|
||||||
shape::TAD tad;
|
|
||||||
tad.init(y.shapeInfo(), &axis, 1);
|
|
||||||
tad.createTadOnlyShapeInfo();
|
|
||||||
tad.createOffsets();
|
|
||||||
|
|
||||||
shape::printShapeInfoLinear("tad shape", tad.tadOnlyShapeInfo);
|
// shape::printShapeInfoLinear("tad shape", tad.tadOnlyShapeInfo);
|
||||||
|
auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.shapeInfo(), {axis});
|
||||||
|
|
||||||
NativeOpExecutioner::execInverseBroadcast(LaunchContext::defaultContext(), broadcast::Add, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), &axis, 1, tad.tadOnlyShapeInfo, tad.tadOffsets, tad.tadOnlyShapeInfo, tad.tadOffsets);
|
NDArray::prepareSpecialUse({&y}, {&x});
|
||||||
|
|
||||||
|
NativeOpExecutioner::execInverseBroadcast(LaunchContext::defaultContext(), broadcast::Add, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), &axis, 1, packY.platformShapeInfo(), packY.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets());
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({&y}, {&x});
|
||||||
|
|
||||||
ASSERT_EQ(e, y);
|
ASSERT_EQ(e, y);
|
||||||
}
|
}
|
||||||
@ -500,9 +501,30 @@ TEST_F(LegacyOpsTests, Reduce3_2) {
|
|||||||
auto z = NDArrayFactory::create<float>('c', {5});
|
auto z = NDArrayFactory::create<float>('c', {5});
|
||||||
|
|
||||||
auto dim = NDArrayFactory::create<int>('c', {1}, {1});
|
auto dim = NDArrayFactory::create<int>('c', {1}, {1});
|
||||||
|
dim.syncToHost();
|
||||||
|
|
||||||
execReduce3Tad(nullptr, reduce3::CosineSimilarity, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(),
|
nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext();
|
||||||
nullptr, nullptr, nullptr, nullptr);
|
|
||||||
|
Nd4jPointer* extraPointers = nullptr;
|
||||||
|
#ifdef __CUDABLAS__
|
||||||
|
extraPointers = new Nd4jPointer[7] {nullptr, context->getCudaStream(), context->getScalarPointer(), nullptr, context->getCudaSpecialStream(), context->getReductionPointer(), context->getAllocationPointer()};
|
||||||
|
#endif
|
||||||
|
|
||||||
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.getShapeInfo(), {1});
|
||||||
|
auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.getShapeInfo(), {1});
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&z}, {&x, &y, &dim});
|
||||||
|
|
||||||
|
execReduce3Tad(extraPointers, reduce3::CosineSimilarity,
|
||||||
|
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
||||||
|
nullptr, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
||||||
|
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
||||||
|
dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(),
|
||||||
|
packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets());
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({&z}, {&x, &y, &dim});
|
||||||
|
|
||||||
|
delete []extraPointers;
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(LegacyOpsTests, Reduce3_3) {
|
TEST_F(LegacyOpsTests, Reduce3_3) {
|
||||||
@ -515,19 +537,31 @@ TEST_F(LegacyOpsTests, Reduce3_3) {
|
|||||||
auto z = NDArrayFactory::create<double>('c', {3});
|
auto z = NDArrayFactory::create<double>('c', {3});
|
||||||
|
|
||||||
auto dim = NDArrayFactory::create<int>('c', {1}, {1});
|
auto dim = NDArrayFactory::create<int>('c', {1}, {1});
|
||||||
|
dim.syncToHost();
|
||||||
|
|
||||||
|
nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext();
|
||||||
|
|
||||||
|
Nd4jPointer* extraPointers = nullptr;
|
||||||
|
#ifdef __CUDABLAS__
|
||||||
|
extraPointers = new Nd4jPointer[7] {nullptr, context->getCudaStream(), context->getScalarPointer(), nullptr, context->getCudaSpecialStream(), context->getReductionPointer(), context->getAllocationPointer()};
|
||||||
|
#endif
|
||||||
|
|
||||||
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.getShapeInfo(), {1});
|
||||||
|
auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.getShapeInfo(), {1});
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&z}, {&x, &y, &dim});
|
||||||
|
|
||||||
|
|
||||||
execReduce3Tad(nullptr, reduce3::CosineDistance,
|
execReduce3Tad(extraPointers, reduce3::CosineDistance,
|
||||||
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
||||||
nullptr,
|
nullptr,
|
||||||
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
||||||
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
||||||
dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(),
|
dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(),
|
||||||
nullptr, nullptr, nullptr, nullptr);
|
packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets());
|
||||||
|
|
||||||
// z.printIndexedBuffer("z");
|
|
||||||
|
|
||||||
ASSERT_EQ(e, z);
|
ASSERT_EQ(e, z);
|
||||||
|
NDArray::registerSpecialUse({&z}, {&x, &y, &dim});
|
||||||
|
delete []extraPointers;
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(LegacyOpsTests, Reduce3_4) {
|
TEST_F(LegacyOpsTests, Reduce3_4) {
|
||||||
@ -540,19 +574,33 @@ TEST_F(LegacyOpsTests, Reduce3_4) {
|
|||||||
auto z = NDArrayFactory::create<double>('c', {1, 3});
|
auto z = NDArrayFactory::create<double>('c', {1, 3});
|
||||||
|
|
||||||
auto dim = NDArrayFactory::create<int>('c', {1}, {1});
|
auto dim = NDArrayFactory::create<int>('c', {1}, {1});
|
||||||
|
dim.syncToHost();
|
||||||
|
|
||||||
|
nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext();
|
||||||
|
|
||||||
|
Nd4jPointer* extraPointers = nullptr;
|
||||||
|
#ifdef __CUDABLAS__
|
||||||
|
extraPointers = new Nd4jPointer[7] {nullptr, context->getCudaStream(), context->getScalarPointer(), nullptr, context->getCudaSpecialStream(), context->getReductionPointer(), context->getAllocationPointer()};
|
||||||
|
#endif
|
||||||
|
|
||||||
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.getShapeInfo(), {1});
|
||||||
|
auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.getShapeInfo(), {1});
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&z}, {&x, &y, &dim});
|
||||||
|
|
||||||
|
|
||||||
execReduce3Tad(nullptr, reduce3::CosineDistance,
|
execReduce3Tad(extraPointers, reduce3::CosineDistance,
|
||||||
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
||||||
nullptr,
|
nullptr,
|
||||||
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
||||||
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
||||||
dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(),
|
dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(),
|
||||||
nullptr, nullptr, nullptr, nullptr);
|
packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets());
|
||||||
|
|
||||||
// z.printIndexedBuffer("z");
|
// z.printIndexedBuffer("z");
|
||||||
|
NDArray::registerSpecialUse({&z}, {&x, &y, &dim});
|
||||||
ASSERT_EQ(e, z);
|
ASSERT_EQ(e, z);
|
||||||
|
delete []extraPointers;
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(LegacyOpsTests, Reduce3_5) {
|
TEST_F(LegacyOpsTests, Reduce3_5) {
|
||||||
@ -565,19 +613,32 @@ TEST_F(LegacyOpsTests, Reduce3_5) {
|
|||||||
auto z = NDArrayFactory::create<double>('c', {1, 3});
|
auto z = NDArrayFactory::create<double>('c', {1, 3});
|
||||||
|
|
||||||
auto dim = NDArrayFactory::create<int>('c', {1}, {1});
|
auto dim = NDArrayFactory::create<int>('c', {1}, {1});
|
||||||
|
dim.syncToHost();
|
||||||
|
|
||||||
|
nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext();
|
||||||
|
|
||||||
|
Nd4jPointer* extraPointers = nullptr;
|
||||||
|
#ifdef __CUDABLAS__
|
||||||
|
extraPointers = new Nd4jPointer[7] {nullptr, context->getCudaStream(), context->getScalarPointer(), nullptr, context->getCudaSpecialStream(), context->getReductionPointer(), context->getAllocationPointer()};
|
||||||
|
#endif
|
||||||
|
|
||||||
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.getShapeInfo(), {1});
|
||||||
|
auto packY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.getShapeInfo(), {1});
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&z}, {&x, &y, &dim});
|
||||||
|
|
||||||
|
|
||||||
execReduce3Tad(nullptr, reduce3::CosineDistance,
|
execReduce3Tad(extraPointers, reduce3::CosineDistance,
|
||||||
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
||||||
nullptr,
|
nullptr,
|
||||||
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
||||||
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
||||||
dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(),
|
dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(),
|
||||||
nullptr, nullptr, nullptr, nullptr);
|
packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets());
|
||||||
|
|
||||||
z.printIndexedBuffer("z");
|
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({&z}, {&x, &y, &dim});
|
||||||
ASSERT_EQ(e, z);
|
ASSERT_EQ(e, z);
|
||||||
|
delete []extraPointers;
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(LegacyOpsTests, test_Reduce3_All_1) {
|
TEST_F(LegacyOpsTests, test_Reduce3_All_1) {
|
||||||
@ -589,12 +650,25 @@ TEST_F(LegacyOpsTests, test_Reduce3_All_1) {
|
|||||||
auto tadPackX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), -1);
|
auto tadPackX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), -1);
|
||||||
auto tadPackY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.shapeInfo(), -1);
|
auto tadPackY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.shapeInfo(), -1);
|
||||||
|
|
||||||
execReduce3All(nullptr, reduce3::EuclideanDistance, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext();
|
||||||
|
|
||||||
|
Nd4jPointer* extraPointers = nullptr;
|
||||||
|
#ifdef __CUDABLAS__
|
||||||
|
extraPointers = new Nd4jPointer[7] {nullptr, context->getCudaStream(), context->getScalarPointer(), nullptr, context->getCudaSpecialStream(), context->getReductionPointer(), context->getAllocationPointer()};
|
||||||
|
#endif
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&z}, {&x, &y});
|
||||||
|
|
||||||
|
execReduce3All(extraPointers, reduce3::EuclideanDistance, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
||||||
nullptr, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
nullptr, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
||||||
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
||||||
dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(),
|
dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(),
|
||||||
tadPackX.platformShapeInfo(), tadPackX.platformOffsets(),
|
tadPackX.platformShapeInfo(), tadPackX.platformOffsets(),
|
||||||
tadPackY.platformShapeInfo(), tadPackY.platformOffsets());
|
tadPackY.platformShapeInfo(), tadPackY.platformOffsets());
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({&z}, {&x, &y});
|
||||||
|
|
||||||
|
delete []extraPointers;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -634,7 +708,7 @@ TEST_F(LegacyOpsTests, test_inverse_broadcast_2) {
|
|||||||
|
|
||||||
auto tadPackY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.shapeInfo(), 1);
|
auto tadPackY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.shapeInfo(), 1);
|
||||||
|
|
||||||
y.tickWriteDevice();
|
z.tickWriteDevice();
|
||||||
|
|
||||||
NativeOpExecutioner::execInverseBroadcastBool(LaunchContext::defaultContext(), broadcast::BoolOps::EqualTo,
|
NativeOpExecutioner::execInverseBroadcastBool(LaunchContext::defaultContext(), broadcast::BoolOps::EqualTo,
|
||||||
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
||||||
@ -657,7 +731,11 @@ TEST_F(LegacyOpsTests, test_legacy_reduce_empty_1) {
|
|||||||
|
|
||||||
int dim = 1;
|
int dim = 1;
|
||||||
|
|
||||||
NativeOpExecutioner::execReduceSame(LaunchContext::defaultContext(), reduce::SameOps::Sum, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &dim, 1, x.shapeInfo(), nullptr);
|
NativeOpExecutioner::execReduceSame(LaunchContext::defaultContext(), reduce::SameOps::Sum,
|
||||||
|
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
||||||
|
nullptr,
|
||||||
|
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
||||||
|
&dim, 1, x.getPlatformShapeInfo(), nullptr);
|
||||||
|
|
||||||
ASSERT_EQ(e, z);
|
ASSERT_EQ(e, z);
|
||||||
}
|
}
|
||||||
@ -670,7 +748,7 @@ TEST_F(LegacyOpsTests, test_legacy_reduce_empty_2) {
|
|||||||
|
|
||||||
int dim = 1;
|
int dim = 1;
|
||||||
|
|
||||||
NativeOpExecutioner::execReduceSame(LaunchContext::defaultContext(), reduce::SameOps::Min, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &dim, 1, x.shapeInfo(), nullptr);
|
NativeOpExecutioner::execReduceSame(LaunchContext::defaultContext(), reduce::SameOps::Min, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &dim, 1, x.getPlatformShapeInfo(), nullptr);
|
||||||
|
|
||||||
ASSERT_EQ(e, z);
|
ASSERT_EQ(e, z);
|
||||||
}
|
}
|
||||||
@ -683,7 +761,7 @@ TEST_F(LegacyOpsTests, test_legacy_reduce_empty_3) {
|
|||||||
|
|
||||||
int dim = 1;
|
int dim = 1;
|
||||||
|
|
||||||
NativeOpExecutioner::execReduceSame(LaunchContext::defaultContext(), reduce::SameOps::Max, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &dim, 1, x.shapeInfo(), nullptr);
|
NativeOpExecutioner::execReduceSame(LaunchContext::defaultContext(), reduce::SameOps::Max, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &dim, 1, x.getPlatformShapeInfo(), nullptr);
|
||||||
|
|
||||||
ASSERT_EQ(e, z);
|
ASSERT_EQ(e, z);
|
||||||
}
|
}
|
||||||
|
@ -33,6 +33,10 @@ public:
|
|||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(MmapTests, Test_Basic_Mmap_1) {
|
TEST_F(MmapTests, Test_Basic_Mmap_1) {
|
||||||
|
// FIXME: we must adopt this for CUDA as well
|
||||||
|
if (!Environment::getInstance()->isCPU())
|
||||||
|
return;
|
||||||
|
|
||||||
// just 10GB
|
// just 10GB
|
||||||
Nd4jLong size = 100000L;
|
Nd4jLong size = 100000L;
|
||||||
|
|
||||||
|
@ -1901,14 +1901,14 @@ TEST_F(MultiDataTypeTests, Test_Cast_1) {
|
|||||||
|
|
||||||
asBool.assign(first);
|
asBool.assign(first);
|
||||||
|
|
||||||
asBool.printIndexedBuffer("asBool");
|
// asBool.printIndexedBuffer("asBool");
|
||||||
asBool.applyScalar(scalar::Not, false, &_not);
|
asBool.applyScalar(scalar::Not, false, &_not);
|
||||||
|
|
||||||
_not.printIndexedBuffer("_not");
|
// _not.printIndexedBuffer("_not");
|
||||||
|
|
||||||
asFloat.assign(_not);
|
asFloat.assign(_not);
|
||||||
|
|
||||||
asFloat.printIndexedBuffer("asFloat");
|
// asFloat.printIndexedBuffer("asFloat");
|
||||||
ASSERT_EQ(exp, asFloat);
|
ASSERT_EQ(exp, asFloat);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1922,14 +1922,14 @@ TEST_F(MultiDataTypeTests, Test_Cast_2) {
|
|||||||
|
|
||||||
asBool.assign(first);
|
asBool.assign(first);
|
||||||
|
|
||||||
asBool.printIndexedBuffer("asBool");
|
// asBool.printIndexedBuffer("asBool");
|
||||||
asBool.applyTransform(transform::Not, &_not);
|
asBool.applyTransform(transform::Not, &_not);
|
||||||
|
|
||||||
_not.printIndexedBuffer("_not");
|
// _not.printIndexedBuffer("_not");
|
||||||
|
|
||||||
asFloat.assign(_not);
|
asFloat.assign(_not);
|
||||||
|
|
||||||
asFloat.printIndexedBuffer("asFloat");
|
// asFloat.printIndexedBuffer("asFloat");
|
||||||
ASSERT_EQ(exp, asFloat);
|
ASSERT_EQ(exp, asFloat);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1945,7 +1945,7 @@ TEST_F(MultiDataTypeTests, divide_bool_test1) {
|
|||||||
NDArray x3 = x1 / x2;
|
NDArray x3 = x1 / x2;
|
||||||
}
|
}
|
||||||
catch (std::exception& message) {
|
catch (std::exception& message) {
|
||||||
printf("%s\n", message.what());
|
// printf("%s\n", message.what());
|
||||||
ASSERT_TRUE(1);
|
ASSERT_TRUE(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1953,7 +1953,7 @@ TEST_F(MultiDataTypeTests, divide_bool_test1) {
|
|||||||
x1 /= x2;
|
x1 /= x2;
|
||||||
}
|
}
|
||||||
catch (std::exception& message) {
|
catch (std::exception& message) {
|
||||||
printf("%s\n", message.what());
|
// printf("%s\n", message.what());
|
||||||
ASSERT_TRUE(1);
|
ASSERT_TRUE(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1961,7 +1961,7 @@ TEST_F(MultiDataTypeTests, divide_bool_test1) {
|
|||||||
NDArray x3 = 150. / x2;
|
NDArray x3 = 150. / x2;
|
||||||
}
|
}
|
||||||
catch (std::exception& message) {
|
catch (std::exception& message) {
|
||||||
printf("%s\n", message.what());
|
// printf("%s\n", message.what());
|
||||||
ASSERT_TRUE(1);
|
ASSERT_TRUE(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1969,7 +1969,7 @@ TEST_F(MultiDataTypeTests, divide_bool_test1) {
|
|||||||
x1.divRowVector(&x4, &x3);
|
x1.divRowVector(&x4, &x3);
|
||||||
}
|
}
|
||||||
catch (std::exception& message) {
|
catch (std::exception& message) {
|
||||||
printf("%s\n", message.what());
|
// printf("%s\n", message.what());
|
||||||
ASSERT_TRUE(1);
|
ASSERT_TRUE(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1977,7 +1977,7 @@ TEST_F(MultiDataTypeTests, divide_bool_test1) {
|
|||||||
x1.applyBroadcast(nd4j::broadcast::FloorDiv, {1}, &x4, &x3);
|
x1.applyBroadcast(nd4j::broadcast::FloorDiv, {1}, &x4, &x3);
|
||||||
}
|
}
|
||||||
catch (std::exception& message) {
|
catch (std::exception& message) {
|
||||||
printf("%s\n", message.what());
|
// printf("%s\n", message.what());
|
||||||
ASSERT_TRUE(1);
|
ASSERT_TRUE(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1985,7 +1985,7 @@ TEST_F(MultiDataTypeTests, divide_bool_test1) {
|
|||||||
x1.applyTrueBroadcast(BROADCAST(FloorMod), &x2, &x3, true);
|
x1.applyTrueBroadcast(BROADCAST(FloorMod), &x2, &x3, true);
|
||||||
}
|
}
|
||||||
catch (std::exception& message) {
|
catch (std::exception& message) {
|
||||||
printf("%s\n", message.what());
|
// printf("%s\n", message.what());
|
||||||
ASSERT_TRUE(1);
|
ASSERT_TRUE(1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1998,13 +1998,13 @@ TEST_F(MultiDataTypeTests, aaa) {
|
|||||||
z.permutei({1,0});
|
z.permutei({1,0});
|
||||||
|
|
||||||
nd4j::graph::RandomGenerator gen(119,5);
|
nd4j::graph::RandomGenerator gen(119,5);
|
||||||
std::vector<double> extraArguments = {1.5, 2.5};
|
ExtraArguments extras({1.5, 2.5});
|
||||||
|
|
||||||
NativeOpExecutioner::execRandom(nullptr, nd4j::random::UniformDistribution,
|
NativeOpExecutioner::execRandom(LaunchContext::defaultContext(), nd4j::random::UniformDistribution,
|
||||||
&gen,
|
&gen,
|
||||||
z.buffer(), z.getShapeInfo(), nullptr, nullptr,
|
z.buffer(), z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
||||||
extraArguments.data());
|
extras.argumentsAsT<double>());
|
||||||
z.printIndexedBuffer();
|
// z.printIndexedBuffer();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2016,7 +2016,7 @@ TEST_F(MultiDataTypeTests, assign_2)
|
|||||||
NDArray expected('c', {4}, {1,2,3,4}, nd4j::DataType::INT32);
|
NDArray expected('c', {4}, {1,2,3,4}, nd4j::DataType::INT32);
|
||||||
|
|
||||||
y.assign(x);
|
y.assign(x);
|
||||||
y.printBuffer();
|
// y.printBuffer();
|
||||||
|
|
||||||
ASSERT_TRUE(expected.equalsTo(&y));
|
ASSERT_TRUE(expected.equalsTo(&y));
|
||||||
}
|
}
|
||||||
|
@ -99,14 +99,14 @@ TEST_F(NDArrayConstructorsTests, test_constructor_4) {
|
|||||||
TEST_F(NDArrayConstructorsTests, test_constructor_5) {
|
TEST_F(NDArrayConstructorsTests, test_constructor_5) {
|
||||||
auto x = NDArrayFactory::create<double>('c',{2, 2}, {1, 2, 3, 4});
|
auto x = NDArrayFactory::create<double>('c',{2, 2}, {1, 2, 3, 4});
|
||||||
|
|
||||||
ASSERT_FALSE(x.buffer() == nullptr);
|
ASSERT_TRUE(x.buffer() == nullptr);
|
||||||
ASSERT_FALSE(x.specialBuffer() == nullptr);
|
ASSERT_FALSE(x.specialBuffer() == nullptr);
|
||||||
|
|
||||||
ASSERT_FALSE(x.shapeInfo() == nullptr);
|
ASSERT_FALSE(x.shapeInfo() == nullptr);
|
||||||
ASSERT_FALSE(x.specialShapeInfo() == nullptr);
|
ASSERT_FALSE(x.specialShapeInfo() == nullptr);
|
||||||
|
|
||||||
ASSERT_TRUE(x.isActualOnDeviceSide());
|
ASSERT_TRUE(x.isActualOnDeviceSide());
|
||||||
ASSERT_TRUE(x.isActualOnHostSide());
|
ASSERT_FALSE(x.isActualOnHostSide());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NDArrayConstructorsTests, test_constructor_6) {
|
TEST_F(NDArrayConstructorsTests, test_constructor_6) {
|
||||||
@ -139,14 +139,14 @@ TEST_F(NDArrayConstructorsTests, test_constructor_7) {
|
|||||||
TEST_F(NDArrayConstructorsTests, test_constructor_8) {
|
TEST_F(NDArrayConstructorsTests, test_constructor_8) {
|
||||||
auto x = NDArrayFactory::create_<double>('c',{2, 2}, {1, 2, 3, 4});
|
auto x = NDArrayFactory::create_<double>('c',{2, 2}, {1, 2, 3, 4});
|
||||||
|
|
||||||
ASSERT_FALSE(x->buffer() == nullptr);
|
ASSERT_TRUE(x->buffer() == nullptr);
|
||||||
ASSERT_FALSE(x->specialBuffer() == nullptr);
|
ASSERT_FALSE(x->specialBuffer() == nullptr);
|
||||||
|
|
||||||
ASSERT_FALSE(x->shapeInfo() == nullptr);
|
ASSERT_FALSE(x->shapeInfo() == nullptr);
|
||||||
ASSERT_FALSE(x->specialShapeInfo() == nullptr);
|
ASSERT_FALSE(x->specialShapeInfo() == nullptr);
|
||||||
|
|
||||||
ASSERT_TRUE(x->isActualOnDeviceSide());
|
ASSERT_TRUE(x->isActualOnDeviceSide());
|
||||||
ASSERT_TRUE(x->isActualOnHostSide());
|
ASSERT_FALSE(x->isActualOnHostSide());
|
||||||
|
|
||||||
delete x;
|
delete x;
|
||||||
}
|
}
|
||||||
@ -184,7 +184,7 @@ TEST_F(NDArrayConstructorsTests, test_linspace_1) {
|
|||||||
TEST_F(NDArrayConstructorsTests, test_constructor_10) {
|
TEST_F(NDArrayConstructorsTests, test_constructor_10) {
|
||||||
|
|
||||||
NDArray scalar1(nd4j::DataType::DOUBLE); // scalar1 = 0
|
NDArray scalar1(nd4j::DataType::DOUBLE); // scalar1 = 0
|
||||||
NDArray scalar2('c', {0}, {0});
|
NDArray scalar2('c', {}, {0});
|
||||||
|
|
||||||
ASSERT_TRUE(scalar1.isActualOnDeviceSide());
|
ASSERT_TRUE(scalar1.isActualOnDeviceSide());
|
||||||
ASSERT_TRUE(!scalar1.isActualOnHostSide());
|
ASSERT_TRUE(!scalar1.isActualOnHostSide());
|
||||||
|
@ -70,7 +70,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_Registration_1) {
|
|||||||
auto y = NDArrayFactory::create<int>('c', {5}, {5, 4, 3, 2, 1});
|
auto y = NDArrayFactory::create<int>('c', {5}, {5, 4, 3, 2, 1});
|
||||||
|
|
||||||
ASSERT_TRUE(x.isActualOnDeviceSide());
|
ASSERT_TRUE(x.isActualOnDeviceSide());
|
||||||
ASSERT_TRUE(x.isActualOnHostSide());
|
ASSERT_FALSE(x.isActualOnHostSide());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NDArrayCudaBasicsTests, Test_Registration_2) {
|
TEST_F(NDArrayCudaBasicsTests, Test_Registration_2) {
|
||||||
@ -86,7 +86,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_Registration_3) {
|
|||||||
auto y = NDArrayFactory::create<int>('c', {5}, {5, 4, 3, 2, 1});
|
auto y = NDArrayFactory::create<int>('c', {5}, {5, 4, 3, 2, 1});
|
||||||
|
|
||||||
ASSERT_TRUE(x.isActualOnDeviceSide());
|
ASSERT_TRUE(x.isActualOnDeviceSide());
|
||||||
ASSERT_TRUE(x.isActualOnHostSide());
|
ASSERT_FALSE(x.isActualOnHostSide());
|
||||||
|
|
||||||
NDArray::registerSpecialUse({&x}, {&y});
|
NDArray::registerSpecialUse({&x}, {&y});
|
||||||
|
|
||||||
@ -94,7 +94,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_Registration_3) {
|
|||||||
ASSERT_FALSE(x.isActualOnHostSide());
|
ASSERT_FALSE(x.isActualOnHostSide());
|
||||||
|
|
||||||
ASSERT_TRUE(y.isActualOnDeviceSide());
|
ASSERT_TRUE(y.isActualOnDeviceSide());
|
||||||
ASSERT_TRUE(y.isActualOnHostSide());
|
ASSERT_FALSE(y.isActualOnHostSide());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NDArrayCudaBasicsTests, Test_Registration_01) {
|
TEST_F(NDArrayCudaBasicsTests, Test_Registration_01) {
|
||||||
@ -102,7 +102,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_Registration_01) {
|
|||||||
auto y = NDArrayFactory::create_<int>('c', {5}, {5, 4, 3, 2, 1});
|
auto y = NDArrayFactory::create_<int>('c', {5}, {5, 4, 3, 2, 1});
|
||||||
|
|
||||||
ASSERT_TRUE(x->isActualOnDeviceSide());
|
ASSERT_TRUE(x->isActualOnDeviceSide());
|
||||||
ASSERT_TRUE(x->isActualOnHostSide());
|
ASSERT_FALSE(x->isActualOnHostSide());
|
||||||
delete x;
|
delete x;
|
||||||
delete y;
|
delete y;
|
||||||
}
|
}
|
||||||
@ -122,7 +122,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_Registration_03) {
|
|||||||
auto y = NDArrayFactory::create_<int>('c', {5}, {5, 4, 3, 2, 1});
|
auto y = NDArrayFactory::create_<int>('c', {5}, {5, 4, 3, 2, 1});
|
||||||
|
|
||||||
ASSERT_TRUE(x->isActualOnDeviceSide());
|
ASSERT_TRUE(x->isActualOnDeviceSide());
|
||||||
ASSERT_TRUE(x->isActualOnHostSide());
|
ASSERT_FALSE(x->isActualOnHostSide());
|
||||||
|
|
||||||
NDArray::registerSpecialUse({y}, {x});
|
NDArray::registerSpecialUse({y}, {x});
|
||||||
x->applyTransform(transform::Neg, y, nullptr);
|
x->applyTransform(transform::Neg, y, nullptr);
|
||||||
@ -142,7 +142,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_Cosine_1) {
|
|||||||
auto y = NDArrayFactory::create_<double>('c', {5}, {5, 4, 3, 2, 1});
|
auto y = NDArrayFactory::create_<double>('c', {5}, {5, 4, 3, 2, 1});
|
||||||
|
|
||||||
ASSERT_TRUE(x->isActualOnDeviceSide());
|
ASSERT_TRUE(x->isActualOnDeviceSide());
|
||||||
ASSERT_TRUE(x->isActualOnHostSide());
|
ASSERT_FALSE(x->isActualOnHostSide());
|
||||||
|
|
||||||
NDArray::registerSpecialUse({y}, {x});
|
NDArray::registerSpecialUse({y}, {x});
|
||||||
x->applyTransform(transform::Cosine, y, nullptr);
|
x->applyTransform(transform::Cosine, y, nullptr);
|
||||||
@ -552,7 +552,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_PrimitiveNeg_2) {
|
|||||||
auto y = NDArrayFactory::create<double>('c', {5});
|
auto y = NDArrayFactory::create<double>('c', {5});
|
||||||
|
|
||||||
ASSERT_TRUE(x.isActualOnDeviceSide());
|
ASSERT_TRUE(x.isActualOnDeviceSide());
|
||||||
ASSERT_TRUE(x.isActualOnHostSide());
|
ASSERT_FALSE(x.isActualOnHostSide());
|
||||||
|
|
||||||
x.applyTransform(transform::Neg, &y, nullptr);
|
x.applyTransform(transform::Neg, &y, nullptr);
|
||||||
//ASSERT_TRUE(x->isActualOnDeviceSide());
|
//ASSERT_TRUE(x->isActualOnDeviceSide());
|
||||||
@ -572,7 +572,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_PrimitiveSqrt_1) { // strict
|
|||||||
auto y = NDArrayFactory::create<double>('c', {5});
|
auto y = NDArrayFactory::create<double>('c', {5});
|
||||||
auto exp = NDArrayFactory::create<double>({1.000000, 1.414214, 1.732051, 2.000000, 2.236068});
|
auto exp = NDArrayFactory::create<double>({1.000000, 1.414214, 1.732051, 2.000000, 2.236068});
|
||||||
ASSERT_TRUE(x.isActualOnDeviceSide());
|
ASSERT_TRUE(x.isActualOnDeviceSide());
|
||||||
ASSERT_TRUE(x.isActualOnHostSide());
|
ASSERT_FALSE(x.isActualOnHostSide());
|
||||||
|
|
||||||
x.applyTransform(transform::Sqrt, &y, nullptr);
|
x.applyTransform(transform::Sqrt, &y, nullptr);
|
||||||
//ASSERT_TRUE(x->isActualOnDeviceSide());
|
//ASSERT_TRUE(x->isActualOnDeviceSide());
|
||||||
@ -619,7 +619,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_PrimitiveCosine_1) { // strict
|
|||||||
auto exp = NDArrayFactory::create<double>('c', {5}, {0.540302, -0.416147, -0.989992, -0.653644, 0.283662});
|
auto exp = NDArrayFactory::create<double>('c', {5}, {0.540302, -0.416147, -0.989992, -0.653644, 0.283662});
|
||||||
|
|
||||||
ASSERT_TRUE(x.isActualOnDeviceSide());
|
ASSERT_TRUE(x.isActualOnDeviceSide());
|
||||||
ASSERT_TRUE(x.isActualOnHostSide());
|
ASSERT_FALSE(x.isActualOnHostSide());
|
||||||
|
|
||||||
x.applyTransform(transform::Cosine, &y, nullptr);
|
x.applyTransform(transform::Cosine, &y, nullptr);
|
||||||
//ASSERT_TRUE(x->isActualOnDeviceSide());
|
//ASSERT_TRUE(x->isActualOnDeviceSide());
|
||||||
@ -642,7 +642,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_PrimitiveCosine_2) {
|
|||||||
auto exp = NDArrayFactory::create<double>('c', {5}, {0.540302, -0.416147, -0.989992, -0.653644, 0.283662});
|
auto exp = NDArrayFactory::create<double>('c', {5}, {0.540302, -0.416147, -0.989992, -0.653644, 0.283662});
|
||||||
|
|
||||||
ASSERT_TRUE(x.isActualOnDeviceSide());
|
ASSERT_TRUE(x.isActualOnDeviceSide());
|
||||||
ASSERT_TRUE(x.isActualOnHostSide());
|
ASSERT_FALSE(x.isActualOnHostSide());
|
||||||
x.applyTransform(transform::Cosine, &y, nullptr);
|
x.applyTransform(transform::Cosine, &y, nullptr);
|
||||||
//ASSERT_TRUE(x->isActualOnDeviceSide());
|
//ASSERT_TRUE(x->isActualOnDeviceSide());
|
||||||
//ASSERT_FALSE(x->isActualOnHostSide());
|
//ASSERT_FALSE(x->isActualOnHostSide());
|
||||||
@ -671,7 +671,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_PrimitiveCosine_3) {
|
|||||||
auto exp = NDArrayFactory::create<double>({0.540302, -0.416147, -0.989992, -0.653644, 0.283662});
|
auto exp = NDArrayFactory::create<double>({0.540302, -0.416147, -0.989992, -0.653644, 0.283662});
|
||||||
|
|
||||||
ASSERT_TRUE(x.isActualOnDeviceSide());
|
ASSERT_TRUE(x.isActualOnDeviceSide());
|
||||||
ASSERT_TRUE(x.isActualOnHostSide());
|
ASSERT_FALSE(x.isActualOnHostSide());
|
||||||
x.applyTransform(transform::Cosine, &y, nullptr);
|
x.applyTransform(transform::Cosine, &y, nullptr);
|
||||||
//ASSERT_TRUE(x->isActualOnDeviceSide());
|
//ASSERT_TRUE(x->isActualOnDeviceSide());
|
||||||
//ASSERT_FALSE(x->isActualOnHostSide());
|
//ASSERT_FALSE(x->isActualOnHostSide());
|
||||||
@ -1263,8 +1263,8 @@ TEST_F(NDArrayCudaBasicsTests, applyReduce3_3) {
|
|||||||
NDArray x3('c', {3,2}, {1.5,1.5,1.5,1.5,1.5,1.5}, nd4j::DataType::DOUBLE);
|
NDArray x3('c', {3,2}, {1.5,1.5,1.5,1.5,1.5,1.5}, nd4j::DataType::DOUBLE);
|
||||||
NDArray x4('c', {3,2}, {1,2,3,4,5,6}, nd4j::DataType::DOUBLE);
|
NDArray x4('c', {3,2}, {1,2,3,4,5,6}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
NDArray exp1('c', {0}, {-204}, nd4j::DataType::FLOAT32);
|
NDArray exp1('c', {}, {-204}, nd4j::DataType::FLOAT32);
|
||||||
NDArray exp2('c', {0}, {31.5}, nd4j::DataType::DOUBLE);
|
NDArray exp2('c', {}, {31.5}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
|
|
||||||
auto z = x1.applyReduce3(reduce3::Dot, &x2);
|
auto z = x1.applyReduce3(reduce3::Dot, &x2);
|
||||||
@ -1340,15 +1340,15 @@ TEST_F(NDArrayCudaBasicsTests, applyIndexReduce_test1) {
|
|||||||
|
|
||||||
NDArray x('c', {2,3}, {0, 10, 1, 2, 2.5,-4}, nd4j::DataType::DOUBLE);
|
NDArray x('c', {2,3}, {0, 10, 1, 2, 2.5,-4}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
NDArray scalar('c', {0}, {100}, nd4j::DataType::INT64);
|
NDArray scalar('c', {}, {100}, nd4j::DataType::INT64);
|
||||||
NDArray vec1('c', {2}, {100,100}, nd4j::DataType::INT64);
|
NDArray vec1('c', {2}, {100,100}, nd4j::DataType::INT64);
|
||||||
NDArray vec2('c', {3}, {100,100,100}, nd4j::DataType::INT64);
|
NDArray vec2('c', {3}, {100,100,100}, nd4j::DataType::INT64);
|
||||||
|
|
||||||
NDArray exp1('c', {0}, {1}, nd4j::DataType::INT64);
|
NDArray exp1('c', {}, {1}, nd4j::DataType::INT64);
|
||||||
NDArray exp2('c', {2}, {1,1}, nd4j::DataType::INT64);
|
NDArray exp2('c', {2}, {1,1}, nd4j::DataType::INT64);
|
||||||
NDArray exp3('c', {3}, {1,0,0}, nd4j::DataType::INT64);
|
NDArray exp3('c', {3}, {1,0,0}, nd4j::DataType::INT64);
|
||||||
|
|
||||||
NDArray exp4('c', {0}, {2}, nd4j::DataType::INT64);
|
NDArray exp4('c', {}, {2}, nd4j::DataType::INT64);
|
||||||
NDArray exp5('c', {2}, {1,1}, nd4j::DataType::INT64);
|
NDArray exp5('c', {2}, {1,1}, nd4j::DataType::INT64);
|
||||||
NDArray exp6('c', {3}, {1,0,0}, nd4j::DataType::INT64);
|
NDArray exp6('c', {3}, {1,0,0}, nd4j::DataType::INT64);
|
||||||
|
|
||||||
@ -1379,11 +1379,11 @@ TEST_F(NDArrayCudaBasicsTests, applyIndexReduce_test2) {
|
|||||||
|
|
||||||
NDArray x('c', {2,3}, {0, 10, 1, 2, 2.5,-4}, nd4j::DataType::DOUBLE);
|
NDArray x('c', {2,3}, {0, 10, 1, 2, 2.5,-4}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
NDArray exp1('c', {0}, {1}, nd4j::DataType::INT64);
|
NDArray exp1('c', {}, {1}, nd4j::DataType::INT64);
|
||||||
NDArray exp2('c', {2}, {1,1}, nd4j::DataType::INT64);
|
NDArray exp2('c', {2}, {1,1}, nd4j::DataType::INT64);
|
||||||
NDArray exp3('c', {3}, {1,0,0}, nd4j::DataType::INT64);
|
NDArray exp3('c', {3}, {1,0,0}, nd4j::DataType::INT64);
|
||||||
|
|
||||||
NDArray exp4('c', {0}, {2}, nd4j::DataType::INT64);
|
NDArray exp4('c', {}, {2}, nd4j::DataType::INT64);
|
||||||
NDArray exp5('c', {2}, {1,1}, nd4j::DataType::INT64);
|
NDArray exp5('c', {2}, {1,1}, nd4j::DataType::INT64);
|
||||||
NDArray exp6('c', {3}, {1,0,0}, nd4j::DataType::INT64);
|
NDArray exp6('c', {3}, {1,0,0}, nd4j::DataType::INT64);
|
||||||
|
|
||||||
@ -1419,13 +1419,13 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_float_test1) {
|
|||||||
|
|
||||||
NDArray x('c', {2,3,2}, {1,2,3,4,5,6,7,8,-1,-2,-3,-4,}, nd4j::DataType::INT32);
|
NDArray x('c', {2,3,2}, {1,2,3,4,5,6,7,8,-1,-2,-3,-4,}, nd4j::DataType::INT32);
|
||||||
|
|
||||||
NDArray z1('c', {0}, {100}, nd4j::DataType::DOUBLE);
|
NDArray z1('c', {}, {100}, nd4j::DataType::DOUBLE);
|
||||||
NDArray z2('c', {2,2}, {100,100,100,100}, nd4j::DataType::FLOAT32);
|
NDArray z2('c', {2,2}, {100,100,100,100}, nd4j::DataType::FLOAT32);
|
||||||
NDArray z3('c', {3}, {100,100,100}, nd4j::DataType::DOUBLE);
|
NDArray z3('c', {3}, {100,100,100}, nd4j::DataType::DOUBLE);
|
||||||
NDArray z4('c', {3,2}, {100,100,100,100,100,100}, nd4j::DataType::FLOAT32);
|
NDArray z4('c', {3,2}, {100,100,100,100,100,100}, nd4j::DataType::FLOAT32);
|
||||||
NDArray z5('c', {2}, {100,100}, nd4j::DataType::FLOAT32);
|
NDArray z5('c', {2}, {100,100}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
NDArray exp1('c', {0}, {2.166667}, nd4j::DataType::DOUBLE);
|
NDArray exp1('c', {}, {2.166667}, nd4j::DataType::DOUBLE);
|
||||||
NDArray exp2('c', {2,2}, {3,4,1,0.666667}, nd4j::DataType::FLOAT32);
|
NDArray exp2('c', {2,2}, {3,4,1,0.666667}, nd4j::DataType::FLOAT32);
|
||||||
NDArray exp3('c', {3}, {4.5,1,1}, nd4j::DataType::DOUBLE);
|
NDArray exp3('c', {3}, {4.5,1,1}, nd4j::DataType::DOUBLE);
|
||||||
NDArray exp4('c', {3,2}, {4,5,1,1,1,1}, nd4j::DataType::FLOAT32);
|
NDArray exp4('c', {3,2}, {4,5,1,1,1,1}, nd4j::DataType::FLOAT32);
|
||||||
@ -1457,7 +1457,7 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_float_test2) {
|
|||||||
|
|
||||||
NDArray x('c', {2,3,2}, {1,2,3,4,5,6,7,8,-1,-2,-3,-4,}, nd4j::DataType::DOUBLE);
|
NDArray x('c', {2,3,2}, {1,2,3,4,5,6,7,8,-1,-2,-3,-4,}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
NDArray exp1('c', {0}, {2.166667}, nd4j::DataType::DOUBLE);
|
NDArray exp1('c', {}, {2.166667}, nd4j::DataType::DOUBLE);
|
||||||
NDArray exp2('c', {2,2}, {3,4,1,0.666667}, nd4j::DataType::DOUBLE);
|
NDArray exp2('c', {2,2}, {3,4,1,0.666667}, nd4j::DataType::DOUBLE);
|
||||||
NDArray exp3('c', {3}, {4.5,1,1}, nd4j::DataType::DOUBLE);
|
NDArray exp3('c', {3}, {4.5,1,1}, nd4j::DataType::DOUBLE);
|
||||||
NDArray exp4('c', {3,2}, {4,5,1,1,1,1}, nd4j::DataType::DOUBLE);
|
NDArray exp4('c', {3,2}, {4,5,1,1,1,1}, nd4j::DataType::DOUBLE);
|
||||||
@ -1535,13 +1535,13 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_same_test1) {
|
|||||||
|
|
||||||
NDArray x('c', {2,3,2}, {1.5,2,3,4,5,6,7.5,8,-1,-2,-3.5,-4,}, nd4j::DataType::FLOAT32);
|
NDArray x('c', {2,3,2}, {1.5,2,3,4,5,6,7.5,8,-1,-2,-3.5,-4,}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
NDArray z1('c', {0}, {100}, nd4j::DataType::FLOAT32);
|
NDArray z1('c', {}, {100}, nd4j::DataType::FLOAT32);
|
||||||
NDArray z2('c', {2,2}, {100,100,100,100}, nd4j::DataType::FLOAT32);
|
NDArray z2('c', {2,2}, {100,100,100,100}, nd4j::DataType::FLOAT32);
|
||||||
NDArray z3('c', {3}, {100,100,100}, nd4j::DataType::FLOAT32);
|
NDArray z3('c', {3}, {100,100,100}, nd4j::DataType::FLOAT32);
|
||||||
NDArray z4('c', {3,2}, {100,100,100,100,100,100}, nd4j::DataType::FLOAT32);
|
NDArray z4('c', {3,2}, {100,100,100,100,100,100}, nd4j::DataType::FLOAT32);
|
||||||
NDArray z5('c', {2}, {100,100}, nd4j::DataType::FLOAT32);
|
NDArray z5('c', {2}, {100,100}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
NDArray exp1('c', {0}, {26.5}, nd4j::DataType::FLOAT32);
|
NDArray exp1('c', {}, {26.5}, nd4j::DataType::FLOAT32);
|
||||||
NDArray exp2('c', {2,2}, {9.5,12,3,2}, nd4j::DataType::FLOAT32);
|
NDArray exp2('c', {2,2}, {9.5,12,3,2}, nd4j::DataType::FLOAT32);
|
||||||
NDArray exp3('c', {3}, {19,4,3.5}, nd4j::DataType::FLOAT32);
|
NDArray exp3('c', {3}, {19,4,3.5}, nd4j::DataType::FLOAT32);
|
||||||
NDArray exp4('c', {3,2}, {9,10,2,2,1.5,2}, nd4j::DataType::FLOAT32);
|
NDArray exp4('c', {3,2}, {9,10,2,2,1.5,2}, nd4j::DataType::FLOAT32);
|
||||||
@ -1573,7 +1573,7 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_same_test2) {
|
|||||||
|
|
||||||
NDArray x('c', {2,3,2}, {1.5,2,3,4,5,6,7.5,8,-1,-2,-3.5,-4,}, nd4j::DataType::INT64);
|
NDArray x('c', {2,3,2}, {1.5,2,3,4,5,6,7.5,8,-1,-2,-3.5,-4,}, nd4j::DataType::INT64);
|
||||||
|
|
||||||
NDArray exp1('c', {0}, {26}, nd4j::DataType::INT64);
|
NDArray exp1('c', {}, {26}, nd4j::DataType::INT64);
|
||||||
NDArray exp2('c', {2,2}, {9,12,3,2}, nd4j::DataType::INT64);
|
NDArray exp2('c', {2,2}, {9,12,3,2}, nd4j::DataType::INT64);
|
||||||
NDArray exp3('c', {3}, {18,4,4}, nd4j::DataType::INT64);
|
NDArray exp3('c', {3}, {18,4,4}, nd4j::DataType::INT64);
|
||||||
NDArray exp4('c', {3,2}, {8,10,2,2,2,2}, nd4j::DataType::INT64);
|
NDArray exp4('c', {3,2}, {8,10,2,2,2,2}, nd4j::DataType::INT64);
|
||||||
@ -1605,13 +1605,13 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_bool_test1) {
|
|||||||
|
|
||||||
NDArray x('c', {2,3,2}, {0.5,2,3,-4,5,6,-7.5,8,-1,-0.5,-3.5,4}, nd4j::DataType::DOUBLE);
|
NDArray x('c', {2,3,2}, {0.5,2,3,-4,5,6,-7.5,8,-1,-0.5,-3.5,4}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
NDArray z1('c', {0}, {100}, nd4j::DataType::BOOL);
|
NDArray z1('c', {}, {100}, nd4j::DataType::BOOL);
|
||||||
NDArray z2('c', {2,2}, {100,100,100,100}, nd4j::DataType::BOOL);
|
NDArray z2('c', {2,2}, {100,100,100,100}, nd4j::DataType::BOOL);
|
||||||
NDArray z3('c', {3}, {100,100,100}, nd4j::DataType::BOOL);
|
NDArray z3('c', {3}, {100,100,100}, nd4j::DataType::BOOL);
|
||||||
NDArray z4('c', {3,2}, {100,100,100,100,100,100}, nd4j::DataType::BOOL);
|
NDArray z4('c', {3,2}, {100,100,100,100,100,100}, nd4j::DataType::BOOL);
|
||||||
NDArray z5('c', {2}, {100,100}, nd4j::DataType::BOOL);
|
NDArray z5('c', {2}, {100,100}, nd4j::DataType::BOOL);
|
||||||
|
|
||||||
NDArray exp1('c', {0}, {1}, nd4j::DataType::BOOL);
|
NDArray exp1('c', {}, {1}, nd4j::DataType::BOOL);
|
||||||
NDArray exp2('c', {2,2}, {1,1,0,1}, nd4j::DataType::BOOL);
|
NDArray exp2('c', {2,2}, {1,1,0,1}, nd4j::DataType::BOOL);
|
||||||
NDArray exp3('c', {3}, {1,1,1}, nd4j::DataType::BOOL);
|
NDArray exp3('c', {3}, {1,1,1}, nd4j::DataType::BOOL);
|
||||||
NDArray exp4('c', {3,2}, {1,1,1,0,1,1}, nd4j::DataType::BOOL);
|
NDArray exp4('c', {3,2}, {1,1,1,0,1,1}, nd4j::DataType::BOOL);
|
||||||
@ -1643,7 +1643,7 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_bool_test2) {
|
|||||||
|
|
||||||
NDArray x('c', {2,3,2}, {0.5,2,3,-4,5,6,-7.5,8,-1,-0.5,-3.5,4}, nd4j::DataType::INT32);
|
NDArray x('c', {2,3,2}, {0.5,2,3,-4,5,6,-7.5,8,-1,-0.5,-3.5,4}, nd4j::DataType::INT32);
|
||||||
|
|
||||||
NDArray exp1('c', {0}, {1}, nd4j::DataType::BOOL);
|
NDArray exp1('c', {}, {1}, nd4j::DataType::BOOL);
|
||||||
NDArray exp2('c', {2,2}, {1,1,0,1}, nd4j::DataType::BOOL);
|
NDArray exp2('c', {2,2}, {1,1,0,1}, nd4j::DataType::BOOL);
|
||||||
NDArray exp3('c', {3}, {1,1,1}, nd4j::DataType::BOOL);
|
NDArray exp3('c', {3}, {1,1,1}, nd4j::DataType::BOOL);
|
||||||
NDArray exp4('c', {3,2}, {0,1,1,0,1,1}, nd4j::DataType::BOOL);
|
NDArray exp4('c', {3,2}, {0,1,1,0,1,1}, nd4j::DataType::BOOL);
|
||||||
@ -1675,13 +1675,13 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_long_test1) {
|
|||||||
|
|
||||||
NDArray x('c', {2,3,2}, {0.5,2,3,-0,5,6,-7.5,0,-1,-0.5,-3.5,4}, nd4j::DataType::FLOAT32);
|
NDArray x('c', {2,3,2}, {0.5,2,3,-0,5,6,-7.5,0,-1,-0.5,-3.5,4}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
NDArray z1('c', {0}, {100}, nd4j::DataType::INT64);
|
NDArray z1('c', {}, {100}, nd4j::DataType::INT64);
|
||||||
NDArray z2('c', {2,2}, {100,100,100,100}, nd4j::DataType::INT64);
|
NDArray z2('c', {2,2}, {100,100,100,100}, nd4j::DataType::INT64);
|
||||||
NDArray z3('c', {3}, {100,100,100}, nd4j::DataType::INT64);
|
NDArray z3('c', {3}, {100,100,100}, nd4j::DataType::INT64);
|
||||||
NDArray z4('c', {3,2}, {100,100,100,100,100,100}, nd4j::DataType::INT64);
|
NDArray z4('c', {3,2}, {100,100,100,100,100,100}, nd4j::DataType::INT64);
|
||||||
NDArray z5('c', {2}, {100,100}, nd4j::DataType::INT64);
|
NDArray z5('c', {2}, {100,100}, nd4j::DataType::INT64);
|
||||||
|
|
||||||
NDArray exp1('c', {0}, {2}, nd4j::DataType::INT64);
|
NDArray exp1('c', {}, {2}, nd4j::DataType::INT64);
|
||||||
NDArray exp2('c', {2,2}, {0,1,0,1}, nd4j::DataType::INT64);
|
NDArray exp2('c', {2,2}, {0,1,0,1}, nd4j::DataType::INT64);
|
||||||
NDArray exp3('c', {3}, {1,1,0}, nd4j::DataType::INT64);
|
NDArray exp3('c', {3}, {1,1,0}, nd4j::DataType::INT64);
|
||||||
NDArray exp4('c', {3,2}, {0,1,0,1,0,0}, nd4j::DataType::INT64);
|
NDArray exp4('c', {3,2}, {0,1,0,1,0,0}, nd4j::DataType::INT64);
|
||||||
@ -1713,7 +1713,7 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_long_test2) {
|
|||||||
|
|
||||||
NDArray x('c', {2,3,2}, {0.5,2,3,-0,5,6,-7.5,0,-1,-0.5,-3.5,4}, nd4j::DataType::INT32);
|
NDArray x('c', {2,3,2}, {0.5,2,3,-0,5,6,-7.5,0,-1,-0.5,-3.5,4}, nd4j::DataType::INT32);
|
||||||
|
|
||||||
NDArray exp1('c', {0}, {4}, nd4j::DataType::INT64);
|
NDArray exp1('c', {}, {4}, nd4j::DataType::INT64);
|
||||||
NDArray exp2('c', {2,2}, {1,1,0,2}, nd4j::DataType::INT64);
|
NDArray exp2('c', {2,2}, {1,1,0,2}, nd4j::DataType::INT64);
|
||||||
NDArray exp3('c', {3}, {2,2,0}, nd4j::DataType::INT64);
|
NDArray exp3('c', {3}, {2,2,0}, nd4j::DataType::INT64);
|
||||||
NDArray exp4('c', {3,2}, {1,1,0,2,0,0}, nd4j::DataType::INT64);
|
NDArray exp4('c', {3,2}, {1,1,0,2,0,0}, nd4j::DataType::INT64);
|
||||||
@ -1792,8 +1792,8 @@ TEST_F(NDArrayCudaBasicsTests, TestFloat16_1) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NDArrayCudaBasicsTests, TestFloat16_2) {
|
TEST_F(NDArrayCudaBasicsTests, TestFloat16_2) {
|
||||||
auto x = NDArrayFactory::create<float16>('c', {9}, {1,2,3,4,5,7,8,9});
|
auto x = NDArrayFactory::create<float16>('c', {9}, {1,2,3,4,5,6,7,8,9});
|
||||||
auto y = NDArrayFactory::create<float16>('c', {9}, {1,2,3,4,5,7,8,9});
|
auto y = NDArrayFactory::create<float16>('c', {9}, {1,2,3,4,5,6,7,8,9});
|
||||||
ASSERT_TRUE(x.equalsTo(y));
|
ASSERT_TRUE(x.equalsTo(y));
|
||||||
//for (int e = 0; e < x.lengthOf(); e++)
|
//for (int e = 0; e < x.lengthOf(); e++)
|
||||||
// ASSERT_NEAR(x.e<float16>(e), y.e<float16>(e), 1.e-5f);
|
// ASSERT_NEAR(x.e<float16>(e), y.e<float16>(e), 1.e-5f);
|
||||||
@ -1812,14 +1812,14 @@ TEST_F(NDArrayCudaBasicsTests, TestFloat_4) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NDArrayCudaBasicsTests, TestFloat_5) {
|
TEST_F(NDArrayCudaBasicsTests, TestFloat_5) {
|
||||||
auto x = NDArrayFactory::create<float>('c', {3,3}, {1,2,3,4,5,7,8,9});
|
auto x = NDArrayFactory::create<float>('c', {3,3}, {1,2,3,4,5,6,7,8,9});
|
||||||
auto y = NDArrayFactory::create<float>('c', {3,3}, {2,4,5,5,6,7,8,9});
|
auto y = NDArrayFactory::create<float>('c', {3,3}, {2,4,5,5,6,7,8,9, 10});
|
||||||
ASSERT_FALSE(x.equalsTo(&y));
|
ASSERT_FALSE(x.equalsTo(&y));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NDArrayCudaBasicsTests, TestFloat_6) {
|
TEST_F(NDArrayCudaBasicsTests, TestFloat_6) {
|
||||||
auto x = NDArrayFactory::create<float>('f', {3,3}, {1,2,3,4,5,7,8,9});
|
auto x = NDArrayFactory::create<float>('f', {3,3}, {1,2,3,4,5,6,7,8,9});
|
||||||
auto y = NDArrayFactory::create<float>('f', {3,3}, {2,4,5,5,6,7,8,9});
|
auto y = NDArrayFactory::create<float>('f', {3,3}, {2,4,5,5,6,7,8,9,10});
|
||||||
ASSERT_FALSE(x.equalsTo(&y));
|
ASSERT_FALSE(x.equalsTo(&y));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2257,9 +2257,12 @@ TEST_F(NDArrayCudaBasicsTests, Test_Empty_4) {
|
|||||||
// cudaStreamSynchronize(*stream);
|
// cudaStreamSynchronize(*stream);
|
||||||
|
|
||||||
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_1) {
|
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_1) {
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<float>('c', {5,2}, {0,1,2,3,4,5,6,7,8,9});
|
auto x = NDArrayFactory::create<float>('c', {5,2}, {0,1,2,3,4,5,6,7,8,9});
|
||||||
|
x.syncToHost();
|
||||||
auto z = NDArrayFactory::create<float>('c', {5, 8});
|
auto z = NDArrayFactory::create<float>('c', {5, 8});
|
||||||
auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
z.syncToHost();
|
||||||
|
|
||||||
std::vector<void*> buffers(4);
|
std::vector<void*> buffers(4);
|
||||||
std::vector<Nd4jLong*> shapes(4);
|
std::vector<Nd4jLong*> shapes(4);
|
||||||
std::vector<Nd4jLong*> hostShapes(4);
|
std::vector<Nd4jLong*> hostShapes(4);
|
||||||
@ -2270,193 +2273,166 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_1) {
|
|||||||
hostShapes[i] = x.shapeInfo();
|
hostShapes[i] = x.shapeInfo();
|
||||||
}
|
}
|
||||||
Nd4jPointer extra[2];
|
Nd4jPointer extra[2];
|
||||||
extra[1] = *stream;
|
extra[1] = x.getContext()->getCudaStream();
|
||||||
::concat(extra, 1, 4, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
::concat(extra, 1, 4, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||||
z.syncToHost();
|
|
||||||
z.printIndexedBuffer("Concat result");
|
|
||||||
z.printBuffer("C Concat result linear");
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_2) {
|
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_2) {
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<float>('c', {5,2}, {0,1,2,3,4,5,6,7,8,9});
|
auto x = NDArrayFactory::create<float>('c', {5,2}, {0,1,2,3,4,5,6,7,8,9});
|
||||||
auto z = NDArrayFactory::create<float>('f', {5, 8});
|
auto z = NDArrayFactory::create<float>('f', {5, 8});
|
||||||
auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
|
||||||
std::vector<void*> buffers(4);
|
std::vector<void*> buffers(4);
|
||||||
std::vector<Nd4jLong*> shapes(4);
|
std::vector<Nd4jLong*> shapes(4);
|
||||||
std::vector<Nd4jLong*> hostShapes(4);
|
std::vector<Nd4jLong*> hostShapes(4);
|
||||||
|
|
||||||
|
x.syncToHost();
|
||||||
|
z.syncToHost();
|
||||||
|
|
||||||
for (size_t i = 0; i < buffers.size(); i++) {
|
for (size_t i = 0; i < buffers.size(); i++) {
|
||||||
buffers[i] = x.specialBuffer();
|
buffers[i] = x.specialBuffer();
|
||||||
shapes[i] = x.specialShapeInfo();
|
shapes[i] = x.specialShapeInfo();
|
||||||
hostShapes[i] = x.shapeInfo();
|
hostShapes[i] = x.shapeInfo();
|
||||||
}
|
}
|
||||||
Nd4jPointer extra[2];
|
|
||||||
extra[1] = *stream;
|
|
||||||
::concat(extra, 1, 4, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
|
||||||
z.syncToHost();
|
|
||||||
z.printIndexedBuffer("Concat result");
|
|
||||||
z.printBuffer("F Concat result linear");
|
|
||||||
|
|
||||||
|
Nd4jPointer extra[2];
|
||||||
|
extra[1] = x.getContext()->getCudaStream();
|
||||||
|
|
||||||
|
::concat(extra, 1, 4, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_3) {
|
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_3) {
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<float>('c', {2,3}, {1,2,3,4,5,6});
|
auto x = NDArrayFactory::create<float>('c', {2,3}, {1,2,3,4,5,6});
|
||||||
auto y = NDArrayFactory::create<float>('c', {1,3}, {7,8,9});
|
auto y = NDArrayFactory::create<float>('c', {1,3}, {7,8,9});
|
||||||
auto z = NDArrayFactory::create<float>('f', {3, 3});
|
auto z = NDArrayFactory::create<float>('f', {3, 3});
|
||||||
auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
|
||||||
|
|
||||||
std::vector<void*> buffers(2);
|
std::vector<void*> buffers(2);
|
||||||
std::vector<Nd4jLong*> shapes(2);
|
std::vector<Nd4jLong*> shapes(2);
|
||||||
std::vector<Nd4jLong*> hostShapes(2);
|
std::vector<Nd4jLong*> hostShapes(2);
|
||||||
|
|
||||||
//for (size_t i = 0; i < buffers.size(); i++) {
|
x.syncToHost();
|
||||||
buffers[0] = x.specialBuffer();
|
y.syncToHost();
|
||||||
shapes[0] = x.specialShapeInfo();
|
|
||||||
hostShapes[0] = x.shapeInfo();
|
|
||||||
buffers[1] = y.specialBuffer();
|
|
||||||
shapes[1] = y.specialShapeInfo();
|
|
||||||
hostShapes[1] = y.shapeInfo();
|
|
||||||
//}
|
|
||||||
Nd4jPointer extra[2];
|
|
||||||
extra[1] = *stream;
|
|
||||||
::concat(extra, 0, 2, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
|
||||||
z.syncToHost();
|
z.syncToHost();
|
||||||
z.printIndexedBuffer("Concat result");
|
|
||||||
z.printBuffer("F Concat result linear");
|
|
||||||
|
|
||||||
|
buffers[0] = x.specialBuffer(); shapes[0] = x.specialShapeInfo(); hostShapes[0] = x.shapeInfo();
|
||||||
|
buffers[1] = y.specialBuffer(); shapes[1] = y.specialShapeInfo(); hostShapes[1] = y.shapeInfo();
|
||||||
|
|
||||||
|
Nd4jPointer extra[2];
|
||||||
|
extra[1] = x.getContext()->getCudaStream();
|
||||||
|
|
||||||
|
::concat(extra, 0, 2, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_4) {
|
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_4) {
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<float>('c', {2,3}, {1,2,3,4,5,6});
|
auto x = NDArrayFactory::create<float>('c', {2,3}, {1,2,3,4,5,6});
|
||||||
auto y = NDArrayFactory::create<float>('c', {1,3}, {7,8,9});
|
auto y = NDArrayFactory::create<float>('c', {1,3}, {7,8,9});
|
||||||
auto z = NDArrayFactory::create<float>('c', {3, 3});
|
auto z = NDArrayFactory::create<float>('c', {3, 3});
|
||||||
auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
|
||||||
|
x.syncToHost();
|
||||||
|
y.syncToHost();
|
||||||
|
z.syncToHost();
|
||||||
|
|
||||||
std::vector<void*> buffers(2);
|
std::vector<void*> buffers(2);
|
||||||
std::vector<Nd4jLong*> shapes(2);
|
std::vector<Nd4jLong*> shapes(2);
|
||||||
std::vector<Nd4jLong*> hostShapes(2);
|
std::vector<Nd4jLong*> hostShapes(2);
|
||||||
|
|
||||||
//for (size_t i = 0; i < buffers.size(); i++) {
|
buffers[0] = x.specialBuffer(); shapes[0] = x.specialShapeInfo(); hostShapes[0] = x.shapeInfo();
|
||||||
buffers[0] = x.specialBuffer();
|
buffers[1] = y.specialBuffer(); shapes[1] = y.specialShapeInfo(); hostShapes[1] = y.shapeInfo();
|
||||||
shapes[0] = x.specialShapeInfo();
|
|
||||||
hostShapes[0] = x.shapeInfo();
|
|
||||||
buffers[1] = y.specialBuffer();
|
|
||||||
shapes[1] = y.specialShapeInfo();
|
|
||||||
hostShapes[1] = y.shapeInfo();
|
|
||||||
//}
|
|
||||||
Nd4jPointer extra[2];
|
|
||||||
extra[1] = *stream;
|
|
||||||
::concat(extra, 0, 2, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
|
||||||
z.syncToHost();
|
|
||||||
z.printIndexedBuffer("Concat result");
|
|
||||||
z.printBuffer("C Concat result linear");
|
|
||||||
|
|
||||||
|
Nd4jPointer extra[2];
|
||||||
|
extra[1] = x.getContext()->getCudaStream();
|
||||||
|
|
||||||
|
::concat(extra, 0, 2, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_5) {
|
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_5) {
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<float>('c', {1,2,3}, {1,2,3,4,5,6});
|
auto x = NDArrayFactory::create<float>('c', {1,2,3}, {1,2,3,4,5,6});
|
||||||
auto y = NDArrayFactory::create<float>('c', {1,2,3}, {7,8,9,10,11, 12});
|
auto y = NDArrayFactory::create<float>('c', {1,2,3}, {7,8,9,10,11, 12});
|
||||||
|
|
||||||
auto z = NDArrayFactory::create<float>('c', {2, 2, 3});
|
auto z = NDArrayFactory::create<float>('c', {2, 2, 3});
|
||||||
auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
||||||
std::vector<void*> buffers(2);
|
std::vector<void*> buffers(2);
|
||||||
std::vector<Nd4jLong*> shapes(2);
|
std::vector<Nd4jLong*> shapes(2);
|
||||||
std::vector<Nd4jLong*> hostShapes(2);
|
std::vector<Nd4jLong*> hostShapes(2);
|
||||||
|
|
||||||
//for (size_t i = 0; i < buffers.size(); i++) {
|
buffers[0] = x.specialBuffer(); shapes[0] = x.specialShapeInfo(); hostShapes[0] = x.shapeInfo();
|
||||||
buffers[0] = x.specialBuffer();
|
buffers[1] = y.specialBuffer(); shapes[1] = y.specialShapeInfo(); hostShapes[1] = y.shapeInfo();
|
||||||
shapes[0] = x.specialShapeInfo();
|
|
||||||
hostShapes[0] = x.shapeInfo();
|
|
||||||
buffers[1] = y.specialBuffer();
|
|
||||||
shapes[1] = y.specialShapeInfo();
|
|
||||||
hostShapes[1] = y.shapeInfo();
|
|
||||||
//}
|
|
||||||
Nd4jPointer extra[2];
|
|
||||||
extra[1] = *stream;
|
|
||||||
::concat(extra, 0, 2, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
|
||||||
z.syncToHost();
|
|
||||||
z.printIndexedBuffer("Concat result");
|
|
||||||
z.printBuffer("C Concat result linear");
|
|
||||||
|
|
||||||
|
Nd4jPointer extra[2];
|
||||||
|
extra[1] = x.getContext()->getCudaStream();
|
||||||
|
|
||||||
|
::concat(extra, 0, 2, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_6) {
|
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_6) {
|
||||||
|
|
||||||
auto x1 = NDArrayFactory::create<float>('c', {2,2,3}, {1,2,3,4,5,6,7,8, 9, 10,11,12});
|
auto x1 = NDArrayFactory::create<float>('c', {2,2,3}, {1,2,3,4,5,6,7,8, 9, 10,11,12});
|
||||||
auto x2 = NDArrayFactory::create<float>('c', {1,2,3}, {13,14,15,16,17, 18});
|
auto x2 = NDArrayFactory::create<float>('c', {1,2,3}, {13,14,15,16,17, 18});
|
||||||
auto x3 = NDArrayFactory::create<float>('c', {1,2,3}, {19,20,21,22,23, 24});
|
auto x3 = NDArrayFactory::create<float>('c', {1,2,3}, {19,20,21,22,23, 24});
|
||||||
|
|
||||||
|
x1.syncToHost();
|
||||||
|
x2.syncToHost();
|
||||||
|
x3.syncToHost();
|
||||||
|
|
||||||
auto z = NDArrayFactory::create<float>('c', {4, 2, 3});
|
auto z = NDArrayFactory::create<float>('c', {4, 2, 3});
|
||||||
auto stream = x1.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
|
||||||
std::vector<void*> buffers(3);
|
std::vector<void*> buffers(3);
|
||||||
std::vector<Nd4jLong*> shapes(3);
|
std::vector<Nd4jLong*> shapes(3);
|
||||||
std::vector<Nd4jLong*> hostShapes(3);
|
std::vector<Nd4jLong*> hostShapes(3);
|
||||||
|
|
||||||
//for (size_t i = 0; i < buffers.size(); i++) {
|
buffers[0] = x1.specialBuffer(); shapes[0] = x1.specialShapeInfo(); hostShapes[0] = x1.shapeInfo();
|
||||||
buffers[0] = x1.specialBuffer();
|
buffers[1] = x2.specialBuffer(); shapes[1] = x2.specialShapeInfo(); hostShapes[1] = x2.shapeInfo();
|
||||||
shapes[0] = x1.specialShapeInfo();
|
buffers[2] = x3.specialBuffer(); shapes[2] = x3.specialShapeInfo(); hostShapes[2] = x3.shapeInfo();
|
||||||
hostShapes[0] = x1.shapeInfo();
|
|
||||||
buffers[1] = x2.specialBuffer();
|
|
||||||
shapes[1] = x2.specialShapeInfo();
|
|
||||||
hostShapes[1] = x2.shapeInfo();
|
|
||||||
buffers[2] = x3.specialBuffer();
|
|
||||||
shapes[2] = x3.specialShapeInfo();
|
|
||||||
hostShapes[2] = x3.shapeInfo();
|
|
||||||
//}
|
|
||||||
printf("The third array is %p\n", buffers[2]);
|
|
||||||
Nd4jPointer extra[2];
|
|
||||||
extra[1] = *stream;
|
|
||||||
::concat(extra, 0, 3, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
|
||||||
z.syncToHost();
|
|
||||||
z.printIndexedBuffer("Concat result");
|
|
||||||
z.printBuffer("C Concat3D result linear");
|
|
||||||
|
|
||||||
|
Nd4jPointer extra[2];
|
||||||
|
extra[1] = x1.getContext()->getCudaStream();
|
||||||
|
|
||||||
|
::concat(extra, 0, 3, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_7) {
|
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_7) {
|
||||||
|
|
||||||
auto x1 = NDArrayFactory::create<float>(1);
|
auto x1 = NDArrayFactory::create<float>(1);
|
||||||
auto x2 = NDArrayFactory::create<float>(2);
|
auto x2 = NDArrayFactory::create<float>(2);
|
||||||
auto x3 = NDArrayFactory::create<float>(3);
|
auto x3 = NDArrayFactory::create<float>(3);
|
||||||
|
|
||||||
auto z = NDArrayFactory::create<float>('c', {3}, {1,2,3});
|
auto z = NDArrayFactory::create<float>('c', {3}, {1,2,3});
|
||||||
auto stream = x1.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
|
||||||
|
x1.syncToHost();
|
||||||
|
x2.syncToHost();
|
||||||
|
x3.syncToHost();
|
||||||
|
|
||||||
std::vector<void*> buffers(3);
|
std::vector<void*> buffers(3);
|
||||||
std::vector<Nd4jLong*> shapes(3);
|
std::vector<Nd4jLong*> shapes(3);
|
||||||
std::vector<Nd4jLong*> hostShapes(3);
|
std::vector<Nd4jLong*> hostShapes(3);
|
||||||
|
|
||||||
//for (size_t i = 0; i < buffers.size(); i++) {
|
buffers[0] = x1.specialBuffer(); shapes[0] = x1.specialShapeInfo(); hostShapes[0] = x1.shapeInfo();
|
||||||
buffers[0] = x1.specialBuffer();
|
buffers[1] = x2.specialBuffer(); shapes[1] = x2.specialShapeInfo(); hostShapes[1] = x2.shapeInfo();
|
||||||
shapes[0] = x1.specialShapeInfo();
|
buffers[2] = x3.specialBuffer(); shapes[2] = x3.specialShapeInfo(); hostShapes[2] = x3.shapeInfo();
|
||||||
hostShapes[0] = x1.shapeInfo();
|
|
||||||
buffers[1] = x2.specialBuffer();
|
|
||||||
shapes[1] = x2.specialShapeInfo();
|
|
||||||
hostShapes[1] = x2.shapeInfo();
|
|
||||||
buffers[2] = x3.specialBuffer();
|
|
||||||
shapes[2] = x3.specialShapeInfo();
|
|
||||||
hostShapes[2] = x3.shapeInfo();
|
|
||||||
//}
|
|
||||||
printf("The third array is %p\n", buffers[2]);
|
|
||||||
Nd4jPointer extra[2];
|
|
||||||
extra[1] = *stream;
|
|
||||||
::concat(extra, 0, 3, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
|
||||||
z.syncToHost();
|
|
||||||
z.printIndexedBuffer("Concat result");
|
|
||||||
z.printBuffer("C Concat scalar result linear");
|
|
||||||
|
|
||||||
|
Nd4jPointer extra[2];
|
||||||
|
extra[1] = x1.getContext()->getCudaStream();
|
||||||
|
|
||||||
|
::concat(extra, 0, 3, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_8) {
|
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_8) {
|
||||||
// public void testLargeConcat() {
|
|
||||||
// val list = new ArrayList<INDArray>();
|
|
||||||
//
|
|
||||||
// for (int e = 0; e < 100000; e++)
|
|
||||||
// list.add(Nd4j.create(1, 300));
|
|
||||||
//
|
|
||||||
// val result = Nd4j.concat(0, list.toArray(new INDArray[list.size()]));
|
|
||||||
// }
|
|
||||||
auto totalCount = 1000;
|
auto totalCount = 1000;
|
||||||
auto width = 300;
|
auto width = 300;
|
||||||
std::vector<NDArray> lx;//(totalCount);
|
std::vector<NDArray> lx(totalCount);
|
||||||
for (int i = 0; i < totalCount; i++) {
|
for (int i = 0; i < totalCount; i++) {
|
||||||
lx.emplace_back(NDArrayFactory::create<float>('c', {1, width}));
|
lx[i] = NDArrayFactory::create<float>('c', {1, width});
|
||||||
lx[i].assign(i);
|
lx[i].assign(i);
|
||||||
|
lx[i].syncToHost();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto z = NDArrayFactory::create<float>('c', {totalCount, width});
|
auto z = NDArrayFactory::create<float>('c', {totalCount, width});
|
||||||
auto stream = nd4j::LaunchContext ::defaultContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
|
||||||
std::vector<void*> buffers(totalCount);
|
std::vector<void*> buffers(totalCount);
|
||||||
std::vector<Nd4jLong*> shapes(totalCount);
|
std::vector<Nd4jLong*> shapes(totalCount);
|
||||||
std::vector<Nd4jLong*> hostShapes(totalCount);
|
std::vector<Nd4jLong*> hostShapes(totalCount);
|
||||||
@ -2467,16 +2443,10 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_8) {
|
|||||||
hostShapes[i] = lx[i].shapeInfo();
|
hostShapes[i] = lx[i].shapeInfo();
|
||||||
}
|
}
|
||||||
|
|
||||||
printf("The third array is %p\n", buffers[2]);
|
|
||||||
Nd4jPointer extra[2];
|
Nd4jPointer extra[2];
|
||||||
extra[1] = *stream;
|
extra[1] = nd4j::LaunchContext::defaultContext()->getCudaStream();
|
||||||
|
|
||||||
::concat(extra, 0, totalCount, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
::concat(extra, 0, totalCount, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||||
z.syncToHost();
|
|
||||||
nd4j_printf("%f %f %f\n", z.e<float>(0), z.e<float>(width * totalCount / 2), z.e<float>(width * (totalCount - 1)));
|
|
||||||
//z.printIndexedBuffer("Concat result");
|
|
||||||
//z.printBuffer("C Concat scalar result linear");
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NDArrayCudaBasicsTests, TestTear_1) {
|
TEST_F(NDArrayCudaBasicsTests, TestTear_1) {
|
||||||
@ -2489,9 +2459,8 @@ TEST_F(NDArrayCudaBasicsTests, TestTear_1) {
|
|||||||
}
|
}
|
||||||
auto z = NDArrayFactory::create<float>('c', {total, 10, 10});
|
auto z = NDArrayFactory::create<float>('c', {total, 10, 10});
|
||||||
|
|
||||||
auto stream = input.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
Nd4jPointer extra[1];
|
||||||
Nd4jPointer extra[2];
|
extra[1] = input.getContext()->getCudaStream();
|
||||||
extra[1] = *stream;
|
|
||||||
|
|
||||||
std::vector<void*> buffers(total);
|
std::vector<void*> buffers(total);
|
||||||
std::vector<Nd4jLong*> shapes(total);
|
std::vector<Nd4jLong*> shapes(total);
|
||||||
@ -2520,17 +2489,20 @@ TEST_F(NDArrayCudaBasicsTests, TestTear_1) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(NDArrayCudaBasicsTests, TestTear_2) {
|
TEST_F(NDArrayCudaBasicsTests, TestTear_2) {
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<float>('c', {1, 10, 10});
|
auto input = NDArrayFactory::create<float>('c', {1, 10, 10});
|
||||||
|
|
||||||
std::vector<NDArray> arrays; // = {NDArrayFactory::create<float>('c', {1, 10, 10}), NDArrayFactory::create<float>('c', {1, 10, 10}), NDArrayFactory::create<float>('c', {1, 10, 10}), NDArrayFactory::create<float>('c', {1, 10, 10}), NDArrayFactory::create<float>('c', {1, 10, 10})};
|
std::vector<NDArray> arrays; // = {NDArrayFactory::create<float>('c', {1, 10, 10}), NDArrayFactory::create<float>('c', {1, 10, 10}), NDArrayFactory::create<float>('c', {1, 10, 10}), NDArrayFactory::create<float>('c', {1, 10, 10}), NDArrayFactory::create<float>('c', {1, 10, 10})};
|
||||||
for (int e = 0; e < 10; e++) {
|
for (int e = 0; e < 10; e++) {
|
||||||
input.assign(e);
|
input.assign(e);
|
||||||
arrays.emplace_back(input);
|
arrays.emplace_back(input);
|
||||||
|
arrays[e].syncToHost();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto z = NDArrayFactory::create<float>('c', {10, 10, 10});
|
auto z = NDArrayFactory::create<float>('c', {10, 10, 10});
|
||||||
|
|
||||||
auto stream = input.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
|
||||||
Nd4jPointer extra[2];
|
Nd4jPointer extra[2];
|
||||||
extra[1] = *stream;
|
extra[1] = input.getContext()->getCudaStream();
|
||||||
|
|
||||||
std::vector<void*> buffers(10);
|
std::vector<void*> buffers(10);
|
||||||
std::vector<Nd4jLong*> shapes(10);
|
std::vector<Nd4jLong*> shapes(10);
|
||||||
@ -2541,12 +2513,12 @@ TEST_F(NDArrayCudaBasicsTests, TestTear_2) {
|
|||||||
shapes[i] = arrays[i].specialShapeInfo();
|
shapes[i] = arrays[i].specialShapeInfo();
|
||||||
hostShapes[i] = arrays[i].shapeInfo();
|
hostShapes[i] = arrays[i].shapeInfo();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int> dimsToExclude({1,2});
|
std::vector<int> dimsToExclude({1,2});
|
||||||
|
|
||||||
|
|
||||||
::concat(extra, 0, 10, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
::concat(extra, 0, 10, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||||
// z.syncToHost();
|
|
||||||
// z.printBuffer("Pile OK");
|
|
||||||
// z.printIndexedBuffer("Pile 10x10");
|
|
||||||
// z.printIndexedBuffer("Pile 10x10");
|
|
||||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), dimsToExclude);
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), dimsToExclude);
|
||||||
//std::vector<void*> arraysData(arrays.size());
|
//std::vector<void*> arraysData(arrays.size());
|
||||||
Nd4jPointer* arraysData;
|
Nd4jPointer* arraysData;
|
||||||
@ -2561,7 +2533,7 @@ TEST_F(NDArrayCudaBasicsTests, TestTear_2) {
|
|||||||
}
|
}
|
||||||
::tear(extra, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), arraysData, input.specialShapeInfo(), packX.specialShapeInfo(), packX.specialOffsets());
|
::tear(extra, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), arraysData, input.specialShapeInfo(), packX.specialShapeInfo(), packX.specialOffsets());
|
||||||
// auto result = op.execute({&z}, {}, {1, 2});
|
// auto result = op.execute({&z}, {}, {1, 2});
|
||||||
// nd4j_printf("Result count is %lu\n", result->size());
|
|
||||||
//ASSERT_EQ(10, result->size());
|
//ASSERT_EQ(10, result->size());
|
||||||
err = cudaFree(arraysData);
|
err = cudaFree(arraysData);
|
||||||
if (err != 0) {
|
if (err != 0) {
|
||||||
@ -2569,11 +2541,6 @@ TEST_F(NDArrayCudaBasicsTests, TestTear_2) {
|
|||||||
ASSERT_TRUE(false);
|
ASSERT_TRUE(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (size_t e = 0; e < arrays.size(); e++) {
|
|
||||||
arrays[e].syncToHost();
|
|
||||||
arrays[e].printBuffer("Output list at");
|
|
||||||
//result->at(e)->printBuffer("OUtput TEAR at");
|
|
||||||
}
|
|
||||||
// ASSERT_TRUE(tads->at(e)->equalsTo(result->at(e)));
|
// ASSERT_TRUE(tads->at(e)->equalsTo(result->at(e)));
|
||||||
|
|
||||||
// delete result;
|
// delete result;
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user