* initial commit

Signed-off-by: raver119 <raver119@gmail.com>

* bunch of tweaks

Signed-off-by: raver119 <raver119@gmail.com>

* hamming distance nullification

Signed-off-by: raver119 <raver119@gmail.com>

* Add output array value assignment for testing/debugging

Signed-off-by: Alex Black <blacka101@gmail.com>

* don't assign empty arrays

Signed-off-by: raver119 <raver119@gmail.com>

* conv2d/conv3d/depthwise2d nullified

Signed-off-by: raver119 <raver119@gmail.com>

* conv2d/conv3d/depthwise2d nullified

Signed-off-by: raver119 <raver119@gmail.com>

* conv2d/conv3d/depthwise2d nullified

Signed-off-by: raver119 <raver119@gmail.com>

* few more fixes

Signed-off-by: raver119 <raver119@gmail.com>

* im2col

Signed-off-by: raver119 <raver119@gmail.com>

* pooling?

Signed-off-by: raver119 <raver119@gmail.com>

* more nullified

Signed-off-by: raver119 <raver119@gmail.com>

* ismax nullified

Signed-off-by: raver119 <raver119@gmail.com>

* rollback ismax nullification

Signed-off-by: raver119 <raver119@gmail.com>

* synchronized cublas handle use on per-device basis

Signed-off-by: raver119 <raver119@gmail.com>

* hiding method from jcpp

Signed-off-by: raver119 <raver119@gmail.com>

* get rid of test assigns in DeclarableOp

Signed-off-by: raver119 <raver119@gmail.com>

* get rid of assigns

Signed-off-by: raver119 <raver119@gmail.com>

* proper deviceId is back

Signed-off-by: raver119 <raver119@gmail.com>

* include fixed

Signed-off-by: raver119 <raver119@gmail.com>

Co-authored-by: Alex Black <blacka101@gmail.com>
master
raver119 2020-03-20 08:49:28 +03:00 committed by GitHub
parent 30a28fae45
commit 7a2ac800dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
57 changed files with 229 additions and 152 deletions

View File

@ -277,13 +277,13 @@ namespace sd {
/** /**
* constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently * constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently
*/ */
NDArray(Nd4jLong* shapeInfo, const bool copyStrides = false, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); NDArray(Nd4jLong* shapeInfo, const bool copyStrides = false, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), const bool nullify = true);
/** /**
* constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to be zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently * constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to be zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently
* set dtype as array type * set dtype as array type
*/ */
NDArray(Nd4jLong* shapeInfo, const sd::DataType dtype, const bool copyStrides = false, sd::LaunchContext* context = sd::LaunchContext::defaultContext()); NDArray(Nd4jLong* shapeInfo, const sd::DataType dtype, const bool copyStrides = false, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), const bool nullify = true);
/** /**
* this constructor creates new array using shape information contained in vector argument * this constructor creates new array using shape information contained in vector argument

View File

@ -143,7 +143,7 @@ NDArray::NDArray(void* buffer, const char order, const std::vector<Nd4jLong> &sh
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
// creates new NDArray using shape information from "shapeInfo" array, set all elements in new array to be zeros // creates new NDArray using shape information from "shapeInfo" array, set all elements in new array to be zeros
NDArray::NDArray(Nd4jLong* shapeInfo, const sd::DataType dtype, const bool copyStrides, sd::LaunchContext * context) { NDArray::NDArray(Nd4jLong* shapeInfo, const sd::DataType dtype, const bool copyStrides, sd::LaunchContext * context, const bool nullify) {
if (shapeInfo == nullptr) if (shapeInfo == nullptr)
throw std::runtime_error("NDArray constructor: can't be initalized without shapeinfo"); throw std::runtime_error("NDArray constructor: can't be initalized without shapeinfo");
@ -161,7 +161,9 @@ NDArray::NDArray(Nd4jLong* shapeInfo, const sd::DataType dtype, const bool copyS
if (!isEmpty()) { if (!isEmpty()) {
_buffer = std::make_shared<DataBuffer>(lengthOf() * sizeOfT(), dtype, getContext()->getWorkspace()); _buffer = std::make_shared<DataBuffer>(lengthOf() * sizeOfT(), dtype, getContext()->getWorkspace());
_buffer->setToZeroBuffers();
if (nullify)
_buffer->setToZeroBuffers();
} }
} }
@ -213,7 +215,7 @@ NDArray::NDArray(sd::LaunchContext * context) {
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
// creates new NDArray using shape information from "shapeInfo" array, set all elements in new array to be zeros, set dtype as array type // creates new NDArray using shape information from "shapeInfo" array, set all elements in new array to be zeros, set dtype as array type
NDArray::NDArray(Nd4jLong* shapeInfo, const bool copyStrides, sd::LaunchContext * context): NDArray::NDArray(Nd4jLong* shapeInfo, const bool copyStrides, sd::LaunchContext * context, const bool nullify):
NDArray(shapeInfo, ArrayOptions::dataType(shapeInfo), copyStrides, context) { NDArray(shapeInfo, ArrayOptions::dataType(shapeInfo), copyStrides, context) {
} }
@ -3339,9 +3341,6 @@ void NDArray::nullify() {
if (isEmpty()) if (isEmpty())
return; return;
if (isS())
throw std::runtime_error("NDArray::nullify: can't nullify string array");
if (isView() || ews() != 1) if (isView() || ews() != 1)
assign(0); assign(0);
else else

View File

@ -54,6 +54,8 @@ class ND4J_EXPORT LaunchContext {
static std::vector<std::shared_ptr<LaunchContext>> _contexts; static std::vector<std::shared_ptr<LaunchContext>> _contexts;
static std::mutex _mutex; static std::mutex _mutex;
static MAP_IMPL<int, std::mutex*> _deviceMutexes;
// used for MKLDNN // used for MKLDNN
void *_engine = nullptr; void *_engine = nullptr;
@ -93,7 +95,6 @@ class ND4J_EXPORT LaunchContext {
void setCudaSpecialStream(cudaStream_t* cudaStream); void setCudaSpecialStream(cudaStream_t* cudaStream);
void setCublasHandle(void *handle); void setCublasHandle(void *handle);
#endif // JCPP #endif // JCPP
#endif // CUDA #endif // CUDA
@ -111,6 +112,12 @@ class ND4J_EXPORT LaunchContext {
void setDeviceID(int deviceID) { _deviceID = deviceID; } void setDeviceID(int deviceID) { _deviceID = deviceID; }
sd::ErrorReference* errorReference(); sd::ErrorReference* errorReference();
#ifndef __JAVACPP_HACK__
// this method returns mutex shared between all threads that use the same device
static std::mutex* deviceMutex();
#endif
static bool isInitialized(); static bool isInitialized();
static void releaseBuffers(); static void releaseBuffers();

View File

@ -19,6 +19,7 @@
// //
#include <execution/LaunchContext.h> #include <execution/LaunchContext.h>
#include <execution/AffinityManager.h>
#include <helpers/logger.h> #include <helpers/logger.h>
#include <exceptions/cuda_exception.h> #include <exceptions/cuda_exception.h>
#include <thread> #include <thread>
@ -42,6 +43,7 @@ namespace sd {
} }
std::vector<std::shared_ptr<LaunchContext>> LaunchContext::_contexts = std::vector<std::shared_ptr<LaunchContext>>(); std::vector<std::shared_ptr<LaunchContext>> LaunchContext::_contexts = std::vector<std::shared_ptr<LaunchContext>>();
MAP_IMPL<int, std::mutex*> LaunchContext::_deviceMutexes;
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
LaunchContext::LaunchContext() { LaunchContext::LaunchContext() {
@ -49,6 +51,8 @@ namespace sd {
_workspace = nullptr; _workspace = nullptr;
_deviceID = 0; _deviceID = 0;
_deviceMutexes[_deviceID] = new std::mutex();
#ifdef HAVE_MKLDNN #ifdef HAVE_MKLDNN
_engine = new dnnl::engine(dnnl::engine::kind::cpu, 0); _engine = new dnnl::engine(dnnl::engine::kind::cpu, 0);
#endif #endif
@ -68,6 +72,11 @@ namespace sd {
return LaunchContext::_contexts[0].get(); return LaunchContext::_contexts[0].get();
} }
std::mutex* LaunchContext::deviceMutex() {
auto deviceId = AffinityManager::currentDeviceId();
return _deviceMutexes[deviceId];
}
void LaunchContext::swapContextBuffers(ContextBuffers &buffers) { void LaunchContext::swapContextBuffers(ContextBuffers &buffers) {
// //
} }

View File

@ -31,6 +31,7 @@ namespace sd {
std::vector<std::shared_ptr<LaunchContext>> LaunchContext::_contexts = std::vector<std::shared_ptr<LaunchContext>>(); std::vector<std::shared_ptr<LaunchContext>> LaunchContext::_contexts = std::vector<std::shared_ptr<LaunchContext>>();
std::mutex LaunchContext::_mutex; std::mutex LaunchContext::_mutex;
MAP_IMPL<int, std::mutex*> LaunchContext::_deviceMutexes;
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
LaunchContext::LaunchContext(cudaStream_t *cudaStream, cudaStream_t& specialCudaStream, void* reductionPointer, void* scalarPointer, int* allocationPointer) { LaunchContext::LaunchContext(cudaStream_t *cudaStream, cudaStream_t& specialCudaStream, void* reductionPointer, void* scalarPointer, int* allocationPointer) {
@ -44,6 +45,11 @@ LaunchContext::LaunchContext(cudaStream_t *cudaStream, cudaStream_t& specialCuda
_isAllocated = false; _isAllocated = false;
} }
std::mutex* LaunchContext::deviceMutex() {
auto deviceId = AffinityManager::currentDeviceId();
return _deviceMutexes[deviceId];
}
LaunchContext::~LaunchContext() { LaunchContext::~LaunchContext() {
if (_isAllocated) { if (_isAllocated) {
@ -85,6 +91,8 @@ LaunchContext::LaunchContext() {
_contexts.resize(numDevices); _contexts.resize(numDevices);
for (int e = 0; e < numDevices; e++) { for (int e = 0; e < numDevices; e++) {
_deviceMutexes[e] = new std::mutex();
AffinityManager::setCurrentNativeDevice(e); AffinityManager::setCurrentNativeDevice(e);
LaunchContext::_contexts[e] = std::make_shared<LaunchContext>(); LaunchContext::_contexts[e] = std::make_shared<LaunchContext>();

View File

@ -252,6 +252,8 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
const bool typeIntFloat = AB && aType == DataType::INT8 && cType == DataType::FLOAT32 && major >= 6; const bool typeIntFloat = AB && aType == DataType::INT8 && cType == DataType::FLOAT32 && major >= 6;
const bool typeHalfFloat = AB && aType == DataType::HALF && cType == DataType::FLOAT32 && major >= 6; const bool typeHalfFloat = AB && aType == DataType::HALF && cType == DataType::FLOAT32 && major >= 6;
std::lock_guard<std::mutex> lock(*LaunchContext::deviceMutex());
auto handle = reinterpret_cast<cublasHandle_t *>(A->getContext()->getCublasHandle()); auto handle = reinterpret_cast<cublasHandle_t *>(A->getContext()->getCublasHandle());
auto stream = A->getContext()->getCudaStream(); auto stream = A->getContext()->getCudaStream();
@ -394,6 +396,8 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, sd::NDArray* Y,
const bool typeDouble = AXY && aType == DataType::DOUBLE; const bool typeDouble = AXY && aType == DataType::DOUBLE;
const bool typeFloat = AXY && aType == DataType::FLOAT32; const bool typeFloat = AXY && aType == DataType::FLOAT32;
std::lock_guard<std::mutex> lock(*LaunchContext::deviceMutex());
auto handle = reinterpret_cast<cublasHandle_t *>(A->getContext()->getCublasHandle()); auto handle = reinterpret_cast<cublasHandle_t *>(A->getContext()->getCublasHandle());
auto stream = A->getContext()->getCudaStream(); auto stream = A->getContext()->getCudaStream();

View File

@ -106,6 +106,7 @@ namespace sd {
void storeResult(Context &block, int outputNumber, NDArray& array); void storeResult(Context &block, int outputNumber, NDArray& array);
void storeResult(Context &block, int outputNumber, NDArray* array); void storeResult(Context &block, int outputNumber, NDArray* array);
sd::NDArray* getZ(Context& block, int inputId = 0); sd::NDArray* getZ(Context& block, int inputId = 0);
sd::NDArray* getNullifiedZ(Context& block, int inputId = 0);
/** /**
* This method pre-allocates NDArrays for Op output, in case they are not available at op execution time * This method pre-allocates NDArrays for Op output, in case they are not available at op execution time

View File

@ -77,7 +77,15 @@ namespace sd {
* @param inputId * @param inputId
* @return * @return
*/ */
sd::NDArray *getZ(graph::Context &ctx, int inputId); sd::NDArray* getZ(graph::Context &ctx, int inputId);
/**
* Helper method, needed for compatibility with DeclarableOp macros
* @param ctx
* @param inputId
* @return
*/
sd::NDArray* getNullifiedZ(graph::Context &ctx, int inputId);
}; };
} }
} }

View File

@ -30,7 +30,7 @@ namespace sd {
CUSTOM_OP_IMPL(bits_hamming_distance, 2, 1, true, 0, 0) { CUSTOM_OP_IMPL(bits_hamming_distance, 2, 1, true, 0, 0) {
auto x = INPUT_VARIABLE(0); auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1); auto y = INPUT_VARIABLE(1);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_NULLIFIED(0);
REQUIRE_TRUE(x->lengthOf() == y->lengthOf(), 0, "bits_hamming_distance: both arguments must have the same length"); REQUIRE_TRUE(x->lengthOf() == y->lengthOf(), 0, "bits_hamming_distance: both arguments must have the same length");
REQUIRE_TRUE(x->dataType() == y->dataType(), 0, "bits_hamming_distance: both arguments must have the same data type"); REQUIRE_TRUE(x->dataType() == y->dataType(), 0, "bits_hamming_distance: both arguments must have the same data type");

View File

@ -32,7 +32,7 @@ namespace sd {
auto values = INPUT_VARIABLE(2); auto values = INPUT_VARIABLE(2);
NDArray *def = nullptr; NDArray *def = nullptr;
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_NULLIFIED(0);
if (block.width() > 3) if (block.width() > 3)
def = INPUT_VARIABLE(3); def = INPUT_VARIABLE(3);

View File

@ -30,7 +30,7 @@ namespace sd {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto delim = INPUT_VARIABLE(1); auto delim = INPUT_VARIABLE(1);
auto indices = OUTPUT_VARIABLE(0); auto indices = OUTPUT_NULLIFIED(0);
auto values = OUTPUT_VARIABLE(1); auto values = OUTPUT_VARIABLE(1);
auto d = delim->e<std::string>(0); auto d = delim->e<std::string>(0);

View File

@ -28,7 +28,7 @@ namespace sd {
namespace ops { namespace ops {
CUSTOM_OP_IMPL(col2im, 1, 1, false, 0, 9) { CUSTOM_OP_IMPL(col2im, 1, 1, false, 0, 9) {
auto x = INPUT_VARIABLE(0); auto x = INPUT_VARIABLE(0);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_NULLIFIED(0);
REQUIRE_TRUE(x->rankOf() == 6, 0, "col2im input should be 6D, but got %i instead", x->rankOf()); REQUIRE_TRUE(x->rankOf() == 6, 0, "col2im input should be 6D, but got %i instead", x->rankOf());
REQUIRE_TRUE(z->rankOf() == 4, 0, "col2im output should be 4D, but got %i instead", z->rankOf()); REQUIRE_TRUE(z->rankOf() == 4, 0, "col2im output should be 4D, but got %i instead", z->rankOf());
@ -45,8 +45,6 @@ namespace sd {
LaunchContext* ctx = block.launchContext(); LaunchContext* ctx = block.launchContext();
helpers::col2im(*ctx, *x, *z, strideY, strideX, padHeight, padWidth, imgHeight, imgWidth, dY, dX); helpers::col2im(*ctx, *x, *z, strideY, strideX, padHeight, padWidth, imgHeight, imgWidth, dY, dX);
STORE_RESULT(*z);
return ND4J_STATUS_OK; return ND4J_STATUS_OK;
} }
DECLARE_SHAPE_FN(col2im) { DECLARE_SHAPE_FN(col2im) {

View File

@ -37,7 +37,7 @@ CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 5) {
auto weights = INPUT_VARIABLE(1); // [kW, iC, oC] always auto weights = INPUT_VARIABLE(1); // [kW, iC, oC] always
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto output = OUTPUT_VARIABLE(0); // [bS, oW, oC] (NWC) or [bS, oC, oW] (NCW) auto output = OUTPUT_NULLIFIED(0); // [bS, oW, oC] (NWC) or [bS, oC, oW] (NCW)
int kW = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) width int kW = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) width
int sW = INT_ARG(1); // strides width int sW = INT_ARG(1); // strides width
@ -167,9 +167,9 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) {
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oW, oC] (NWC) or [bS, oC, oW] (NCW), epsilon_next auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oW, oC] (NWC) or [bS, oC, oW] (NCW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW), epsilon auto gradI = OUTPUT_NULLIFIED(0); // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW), epsilon
auto gradW = OUTPUT_VARIABLE(1); // [kW, iC, oC] always auto gradW = OUTPUT_NULLIFIED(1); // [kW, iC, oC] always
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC]
int kW = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) width int kW = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) width
int sW = INT_ARG(1); // strides width int sW = INT_ARG(1); // strides width

View File

@ -40,7 +40,7 @@ CUSTOM_OP_IMPL(conv2d, 2, 1, false, 0, 9) {
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) auto output = OUTPUT_NULLIFIED(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
int sH = INT_ARG(2); // strides height int sH = INT_ARG(2); // strides height
int sW = INT_ARG(3); // strides width int sW = INT_ARG(3); // strides width
@ -161,9 +161,9 @@ CUSTOM_OP_IMPL(conv2d_bp, 3, 2, false, 0, 9) {
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto gradW = OUTPUT_NULLIFIED(1); // [kH, kW, iC, oC] always
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC]
int kH = INT_ARG(0); // filter(kernel) height int kH = INT_ARG(0); // filter(kernel) height
int kW = INT_ARG(1); // filter(kernel) width int kW = INT_ARG(1); // filter(kernel) width
@ -267,7 +267,7 @@ CUSTOM_OP_IMPL(conv2d_input_bp, 3, 1, false, 0, 9) {
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
auto gradO = INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next auto gradO = INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
int kH = INT_ARG(0); // filter(kernel) height int kH = INT_ARG(0); // filter(kernel) height
int kW = INT_ARG(1); // filter(kernel) width int kW = INT_ARG(1); // filter(kernel) width

View File

@ -38,7 +38,7 @@ CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) {
auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC] always auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC] always
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) auto output = OUTPUT_NULLIFIED(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DECONV2D OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DECONV2D OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf());
REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM DECONV2D OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM DECONV2D OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf());

View File

@ -35,7 +35,7 @@ CUSTOM_OP_IMPL(deconv2d_tf, 3, 1, false, 0, 9) {
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
auto gradIShape = INPUT_VARIABLE(0); // [4] - shape of input of conv2d (that is shape of gradI) auto gradIShape = INPUT_VARIABLE(0); // [4] - shape of input of conv2d (that is shape of gradI)
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) height int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) height
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) width int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) width

View File

@ -35,7 +35,7 @@ CUSTOM_OP_IMPL(depthwise_conv2d, 2, 1, false, 0, 9) {
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] = iC*mC auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] = iC*mC
auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) auto output = OUTPUT_NULLIFIED(0); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW)
REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf());
REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf());
@ -152,9 +152,9 @@ CUSTOM_OP_IMPL(depthwise_conv2d_bp, 3, 2, false, 0, 9) {
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon
auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, mC] always auto gradW = OUTPUT_NULLIFIED(1); // [kH, kW, iC, mC] always
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC]
REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf());
REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf());

View File

@ -30,8 +30,7 @@ namespace sd {
namespace ops { namespace ops {
CUSTOM_OP_IMPL(im2col, 1, 1, false, 0, 9) { CUSTOM_OP_IMPL(im2col, 1, 1, false, 0, 9) {
auto x = INPUT_VARIABLE(0); auto x = INPUT_VARIABLE(0);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_NULLIFIED(0);
REQUIRE_TRUE(x->rankOf() == 4, 0, "im2col input should be 4D, but got %i instead", x->rankOf()); REQUIRE_TRUE(x->rankOf() == 4, 0, "im2col input should be 4D, but got %i instead", x->rankOf());
REQUIRE_TRUE(z->rankOf() == 6, 0, "im2col output should be 6D, but got %i instead", z->rankOf()); REQUIRE_TRUE(z->rankOf() == 6, 0, "im2col output should be 6D, but got %i instead", z->rankOf());
@ -53,8 +52,6 @@ namespace sd {
LaunchContext* ctx = block.launchContext(); LaunchContext* ctx = block.launchContext();
sd::ops::helpers::im2col(*ctx, *x, *z, kernelHeight, kernelWidth, strideY, strideX, padHeight, padWidth, dY, dX, NDArrayFactory::create(zeroPadVal, block.launchContext())); sd::ops::helpers::im2col(*ctx, *x, *z, kernelHeight, kernelWidth, strideY, strideX, padHeight, padWidth, dY, dX, NDArrayFactory::create(zeroPadVal, block.launchContext()));
STORE_RESULT(*z);
return Status::OK(); return Status::OK();
} }
@ -107,7 +104,7 @@ namespace sd {
CUSTOM_OP_IMPL(im2col_bp, 2, 1, false, 0, 9) { CUSTOM_OP_IMPL(im2col_bp, 2, 1, false, 0, 9) {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto gradAtOutput = INPUT_VARIABLE(1); auto gradAtOutput = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_NULLIFIED(0);
REQUIRE_TRUE(input->rankOf() == 4, 0, "im2col_bp input should be 4D, but got %i instead", input->rankOf()); REQUIRE_TRUE(input->rankOf() == 4, 0, "im2col_bp input should be 4D, but got %i instead", input->rankOf());
REQUIRE_TRUE(gradAtOutput->rankOf() == 6, 0, "im2col_bp gradient at output (input idx 1) should be 6D, but got %i instead", gradAtOutput->rankOf()); REQUIRE_TRUE(gradAtOutput->rankOf() == 6, 0, "im2col_bp gradient at output (input idx 1) should be 6D, but got %i instead", gradAtOutput->rankOf());

View File

@ -37,7 +37,7 @@ CUSTOM_OP_IMPL(sconv2d, 2, 1, false, 0, 9) {
NDArray *weightsPoint = nullptr; // [1, 1, iC*mC, oC] always NDArray *weightsPoint = nullptr; // [1, 1, iC*mC, oC] always
NDArray *bias = nullptr; // [oC], if weightsPoint=nullptr then oC = iC*mC NDArray *bias = nullptr; // [oC], if weightsPoint=nullptr then oC = iC*mC
NDArray *output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) NDArray *output = OUTPUT_NULLIFIED(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
if(block.width() == 3) { if(block.width() == 3) {
if((INPUT_VARIABLE(2))->rankOf() == 4) if((INPUT_VARIABLE(2))->rankOf() == 4)
@ -199,26 +199,26 @@ CUSTOM_OP_IMPL(sconv2d_bp, 3, 2, false, 0, 9) {
NDArray *weightsPoint = nullptr; // [1, 1, iC*mC, oC] always NDArray *weightsPoint = nullptr; // [1, 1, iC*mC, oC] always
NDArray *bias = nullptr; // [oC], oC = iC*mC if weightsPoint=nullptr NDArray *bias = nullptr; // [oC], oC = iC*mC if weightsPoint=nullptr
NDArray *gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon NDArray *gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
NDArray *gradWD = OUTPUT_VARIABLE(1); // [kH, kW, iC, mC] always NDArray *gradWD = OUTPUT_NULLIFIED(1); // [kH, kW, iC, mC] always
NDArray *gradWP = nullptr; // [1, 1, iC*mC, oC] always NDArray *gradWP = nullptr; // [1, 1, iC*mC, oC] always
NDArray *gradB = nullptr; // [oC] NDArray *gradB = nullptr; // [oC]
if(block.width() == 4) { if(block.width() == 4) {
if((INPUT_VARIABLE(3))->rankOf() == 4) { if((INPUT_VARIABLE(3))->rankOf() == 4) {
weightsPoint = INPUT_VARIABLE(3); weightsPoint = INPUT_VARIABLE(3);
gradWP = OUTPUT_VARIABLE(2); gradWP = OUTPUT_NULLIFIED(2);
} }
else { else {
bias = INPUT_VARIABLE(3); bias = INPUT_VARIABLE(3);
gradB = OUTPUT_VARIABLE(2); gradB = OUTPUT_NULLIFIED(2);
} }
} }
else if(block.width() == 5) { else if(block.width() == 5) {
weightsPoint = INPUT_VARIABLE(3); weightsPoint = INPUT_VARIABLE(3);
bias = INPUT_VARIABLE(4); bias = INPUT_VARIABLE(4);
gradWP = OUTPUT_VARIABLE(2); gradWP = OUTPUT_NULLIFIED(2);
gradB = OUTPUT_VARIABLE(3); gradB = OUTPUT_NULLIFIED(3);
} }

View File

@ -32,7 +32,7 @@ namespace ops {
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(upsampling2d, 1, 1, false, 0, 2) { CUSTOM_OP_IMPL(upsampling2d, 1, 1, false, 0, 2) {
auto input = INPUT_VARIABLE(0); // [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC) auto input = INPUT_VARIABLE(0); // [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC)
auto output = OUTPUT_VARIABLE(0); // [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, factorH*iH, factorW*iW, iC] (NHWC) auto output = OUTPUT_NULLIFIED(0); // [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, factorH*iH, factorW*iW, iC] (NHWC)
const int factorH = INT_ARG(0); const int factorH = INT_ARG(0);
const int factorW = INT_ARG(1); const int factorW = INT_ARG(1);
@ -97,7 +97,7 @@ CUSTOM_OP_IMPL(upsampling2d_bp, 2, 1, false, 0, 0) {
// NDArray<T>* input = INPUT_VARIABLE(0); // [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC) // NDArray<T>* input = INPUT_VARIABLE(0); // [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC)
auto gradO = INPUT_VARIABLE(1); // [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, factorH*iH, factorW*iW, iC] (NHWC) auto gradO = INPUT_VARIABLE(1); // [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, factorH*iH, factorW*iW, iC] (NHWC)
auto gradI = OUTPUT_VARIABLE(0); // [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC) auto gradI = OUTPUT_NULLIFIED(0); // [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC)
const int isNCHW = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; // INT_ARG(0): 0-NCHW, 1-NHWC const int isNCHW = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; // INT_ARG(0): 0-NCHW, 1-NHWC

View File

@ -31,7 +31,7 @@ namespace ops {
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(upsampling3d, 1, 1, false, 0, 3) { CUSTOM_OP_IMPL(upsampling3d, 1, 1, false, 0, 3) {
auto input = INPUT_VARIABLE(0); // [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC) auto input = INPUT_VARIABLE(0); // [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC)
auto output = OUTPUT_VARIABLE(0); // [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC) auto output = OUTPUT_NULLIFIED(0); // [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC)
const int factorD = INT_ARG(0); const int factorD = INT_ARG(0);
const int factorH = INT_ARG(1); const int factorH = INT_ARG(1);
@ -97,7 +97,7 @@ DECLARE_SHAPE_FN(upsampling3d) {
CUSTOM_OP_IMPL(upsampling3d_bp, 2, 1, false, 0, 0) { CUSTOM_OP_IMPL(upsampling3d_bp, 2, 1, false, 0, 0) {
// NDArray<T>* input = INPUT_VARIABLE(0); // [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC) // NDArray<T>* input = INPUT_VARIABLE(0); // [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC)
auto gradO = INPUT_VARIABLE(1); // [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC) auto gradO = INPUT_VARIABLE(1); // [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC)
auto gradI = OUTPUT_VARIABLE(0); // [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC) auto gradI = OUTPUT_NULLIFIED(0); // [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC)
const int isNCDHW = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; // INT_ARG(0): 0-NCHW, 1-NHWC const int isNCDHW = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; // INT_ARG(0): 0-NCHW, 1-NHWC

View File

@ -31,7 +31,7 @@ namespace ops {
CUSTOM_OP_IMPL(avgpool2d, 1, 1, false, 0, 10) { CUSTOM_OP_IMPL(avgpool2d, 1, 1, false, 0, 10) {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_NULLIFIED(0);
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
@ -147,7 +147,7 @@ CUSTOM_OP_IMPL(avgpool2d_bp, 2, 1, false, 0, 10) {
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
int kH = INT_ARG(0); // filter(kernel) height int kH = INT_ARG(0); // filter(kernel) height
int kW = INT_ARG(1); // filter(kernel) width int kW = INT_ARG(1); // filter(kernel) width

View File

@ -32,7 +32,7 @@ namespace ops {
CUSTOM_OP_IMPL(avgpool3dnew, 1, 1, false, 0, 14) { CUSTOM_OP_IMPL(avgpool3dnew, 1, 1, false, 0, 14) {
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) auto output = OUTPUT_NULLIFIED(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW)
int kD = INT_ARG(0); // filter(kernel) depth int kD = INT_ARG(0); // filter(kernel) depth
int kH = INT_ARG(1); // filter(kernel) height int kH = INT_ARG(1); // filter(kernel) height
@ -149,7 +149,7 @@ CUSTOM_OP_IMPL(avgpool3dnew_bp, 2, 1, false, 0, 14) {
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon auto gradI = OUTPUT_NULLIFIED(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
const int kD = INT_ARG(0); // filter(kernel) depth const int kD = INT_ARG(0); // filter(kernel) depth
const int kH = INT_ARG(1); // filter(kernel) height const int kH = INT_ARG(1); // filter(kernel) height

View File

@ -38,7 +38,7 @@ CUSTOM_OP_IMPL(maxpool2d, 1, 1, false, 0, 9) {
REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D OP: input array should have rank of 4, but got %i instead", input->rankOf()); REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D OP: input array should have rank of 4, but got %i instead", input->rankOf());
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_NULLIFIED(0);
const int kH = INT_ARG(0); const int kH = INT_ARG(0);
const int kW = INT_ARG(1); const int kW = INT_ARG(1);
@ -150,7 +150,7 @@ CUSTOM_OP_IMPL(maxpool2d_bp, 2, 1, false, 0, 10) {
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
int kH = INT_ARG(0); // filter(kernel) height int kH = INT_ARG(0); // filter(kernel) height
int kW = INT_ARG(1); // filter(kernel) width int kW = INT_ARG(1); // filter(kernel) width

View File

@ -32,7 +32,7 @@ namespace ops {
CUSTOM_OP_IMPL(maxpool3dnew, 1, 1, false, 0, 14) { CUSTOM_OP_IMPL(maxpool3dnew, 1, 1, false, 0, 14) {
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) auto output = OUTPUT_NULLIFIED(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW)
int kD = INT_ARG(0); // filter(kernel) depth int kD = INT_ARG(0); // filter(kernel) depth
int kH = INT_ARG(1); // filter(kernel) height int kH = INT_ARG(1); // filter(kernel) height
@ -151,7 +151,7 @@ CUSTOM_OP_IMPL(maxpool3dnew_bp, 2, 1, false, 0, 14) {
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon auto gradI = OUTPUT_NULLIFIED(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
const int kD = INT_ARG(0); // filter(kernel) depth const int kD = INT_ARG(0); // filter(kernel) depth
const int kH = INT_ARG(1); // filter(kernel) height const int kH = INT_ARG(1); // filter(kernel) height

View File

@ -30,14 +30,14 @@ namespace sd {
CUSTOM_OP_IMPL(max_pool_with_argmax, 1, 2, false, 0, 9) { CUSTOM_OP_IMPL(max_pool_with_argmax, 1, 2, false, 0, 9) {
auto x = INPUT_VARIABLE(0); auto x = INPUT_VARIABLE(0);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_NULLIFIED(0);
auto indeces = OUTPUT_VARIABLE(1); auto indices = OUTPUT_NULLIFIED(1);
REQUIRE_TRUE(x->rankOf() == 4, 0, "max_pool_with_argmax: Input should have rank of 4, but got %i instead", x->rankOf()); REQUIRE_TRUE(x->rankOf() == 4, 0, "max_pool_with_argmax: Input should have rank of 4, but got %i instead", x->rankOf());
auto argI = *(block.getIArguments()); auto argI = *(block.getIArguments());
helpers::maxPoolingFunctor(block.launchContext(), block, x, z, argI, indeces); helpers::maxPoolingFunctor(block.launchContext(), block, x, z, argI, indices);
return Status::OK(); return Status::OK();
} }

View File

@ -32,7 +32,7 @@ namespace sd {
REQUIRE_OK(this->validateInputLengthMatch(block)); REQUIRE_OK(this->validateInputLengthMatch(block));
REQUIRE_OK(this->validateInputDimensionsMatch(block)); REQUIRE_OK(this->validateInputDimensionsMatch(block));
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_NULLIFIED(0);
REQUIRE_TRUE(input->rankOf() == 4, 0, "PNORMPOOL2D op: input should have rank of 4, but got %i instead", input->rankOf()); REQUIRE_TRUE(input->rankOf() == 4, 0, "PNORMPOOL2D op: input should have rank of 4, but got %i instead", input->rankOf());
@ -145,7 +145,7 @@ CUSTOM_OP_IMPL(pnormpool2d_bp, 2, 1, false, 1, 10) {
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
int kH = INT_ARG(0); // filter(kernel) height int kH = INT_ARG(0); // filter(kernel) height
int kW = INT_ARG(1); // filter(kernel) width int kW = INT_ARG(1); // filter(kernel) width

View File

@ -33,7 +33,7 @@ CONFIGURABLE_OP_IMPL(dropout, 1, 1, true, 1, 1) {
auto input = INPUT_VARIABLE(0); // lookup param auto input = INPUT_VARIABLE(0); // lookup param
NDArray *reduceShape = nullptr; // this param is optional NDArray *reduceShape = nullptr; // this param is optional
auto output = OUTPUT_VARIABLE(0); // auto output = OUTPUT_NULLIFIED(0); //
int seed = INT_ARG(0); int seed = INT_ARG(0);
@ -66,7 +66,7 @@ CONFIGURABLE_OP_IMPL(dropout_bp, 2, 1, false, 1, 1) {
NDArray* gradOut = INPUT_VARIABLE(1); // lookup param NDArray* gradOut = INPUT_VARIABLE(1); // lookup param
NDArray* reduceShape = nullptr; // this param is optional NDArray* reduceShape = nullptr; // this param is optional
NDArray* output = OUTPUT_VARIABLE(0); // NDArray* output = OUTPUT_NULLIFIED(0); //
int seed = INT_ARG(0); int seed = INT_ARG(0);

View File

@ -30,7 +30,7 @@ namespace sd {
CUSTOM_OP_IMPL(lstsq, 2, 1, false, 0, 0) { CUSTOM_OP_IMPL(lstsq, 2, 1, false, 0, 0) {
auto a = INPUT_VARIABLE(0); auto a = INPUT_VARIABLE(0);
auto b = INPUT_VARIABLE(1); auto b = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_NULLIFIED(0);
bool fastFlag = true; bool fastFlag = true;
double l2_factor = 0.; double l2_factor = 0.;
if (block.numB() > 0) { if (block.numB() > 0) {
@ -56,7 +56,7 @@ namespace sd {
CUSTOM_OP_IMPL(solve_ls, 2, 1, false, 0, 0) { CUSTOM_OP_IMPL(solve_ls, 2, 1, false, 0, 0) {
auto a = INPUT_VARIABLE(0); auto a = INPUT_VARIABLE(0);
auto b = INPUT_VARIABLE(1); auto b = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_NULLIFIED(0);
bool fastFlag = true; bool fastFlag = true;
double l2_factor = 0.; double l2_factor = 0.;
if (block.numB() > 0) { if (block.numB() > 0) {

View File

@ -114,7 +114,7 @@ namespace sd {
CUSTOM_OP_IMPL(logdet, 1, 1, false, 0, 0) { CUSTOM_OP_IMPL(logdet, 1, 1, false, 0, 0) {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_NULLIFIED(0);
REQUIRE_TRUE(input->rankOf() >=2, 0, "logdet: The rank of input array should not less than 2, but %i is given", input->rankOf()); REQUIRE_TRUE(input->rankOf() >=2, 0, "logdet: The rank of input array should not less than 2, but %i is given", input->rankOf());
REQUIRE_TRUE(input->sizeAt(-1) == input->sizeAt(-2), 0, "logdet: The last two dimmensions should be equal, but %i and %i are given", input->sizeAt(-1), input->sizeAt(-2)); REQUIRE_TRUE(input->sizeAt(-1) == input->sizeAt(-2), 0, "logdet: The last two dimmensions should be equal, but %i and %i are given", input->sizeAt(-1), input->sizeAt(-2));

View File

@ -76,8 +76,8 @@ namespace sd {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto indices = INPUT_VARIABLE(1); auto indices = INPUT_VARIABLE(1);
auto gradOut = INPUT_VARIABLE(2); auto gradOut = INPUT_VARIABLE(2);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_NULLIFIED(0);
auto outIndices = OUTPUT_VARIABLE(1); auto outIndices = OUTPUT_NULLIFIED(1);
outIndices->assign(indices); outIndices->assign(indices);
return helpers::segmentMaxFunctorBP(block.launchContext(), input, indices, gradOut, output); return helpers::segmentMaxFunctorBP(block.launchContext(), input, indices, gradOut, output);
} }

View File

@ -76,8 +76,8 @@ namespace sd {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto indices = INPUT_VARIABLE(1); auto indices = INPUT_VARIABLE(1);
auto gradOut = INPUT_VARIABLE(2); auto gradOut = INPUT_VARIABLE(2);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_NULLIFIED(0);
auto outIndices = OUTPUT_VARIABLE(1); auto outIndices = OUTPUT_NULLIFIED(1);
outIndices->assign(indices); outIndices->assign(indices);
return helpers::segmentMeanFunctorBP(block.launchContext(), input, indices, gradOut, output); return helpers::segmentMeanFunctorBP(block.launchContext(), input, indices, gradOut, output);
} }

View File

@ -66,8 +66,8 @@ namespace sd {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto indices = INPUT_VARIABLE(1); auto indices = INPUT_VARIABLE(1);
auto gradOut = INPUT_VARIABLE(2); auto gradOut = INPUT_VARIABLE(2);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_NULLIFIED(0);
auto outIndices = OUTPUT_VARIABLE(1); auto outIndices = OUTPUT_NULLIFIED(1);
outIndices->assign(indices); outIndices->assign(indices);
return helpers::segmentMinFunctorBP(block.launchContext(), input, indices, gradOut, output); return helpers::segmentMinFunctorBP(block.launchContext(), input, indices, gradOut, output);
} }

View File

@ -67,8 +67,8 @@ namespace sd {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto indices = INPUT_VARIABLE(1); auto indices = INPUT_VARIABLE(1);
auto gradOut = INPUT_VARIABLE(2); auto gradOut = INPUT_VARIABLE(2);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_NULLIFIED(0);
auto outIndices = OUTPUT_VARIABLE(1); auto outIndices = OUTPUT_NULLIFIED(1);
outIndices->assign(indices); outIndices->assign(indices);
helpers::segmentProdFunctorBP(block.launchContext(), input, indices, gradOut, output); helpers::segmentProdFunctorBP(block.launchContext(), input, indices, gradOut, output);

View File

@ -65,7 +65,7 @@ namespace sd {
CUSTOM_OP_IMPL(segment_sum_bp, 3, 2, false, 0, 0) { CUSTOM_OP_IMPL(segment_sum_bp, 3, 2, false, 0, 0) {
return helpers::segmentSumFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), OUTPUT_VARIABLE(0)); return helpers::segmentSumFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), OUTPUT_NULLIFIED(0));
} }
DECLARE_SHAPE_FN(segment_sum_bp){ DECLARE_SHAPE_FN(segment_sum_bp){
Nd4jLong* in = inputShape->at(0); Nd4jLong* in = inputShape->at(0);

View File

@ -25,7 +25,7 @@ namespace sd {
namespace ops { namespace ops {
CUSTOM_OP_IMPL(sequence_mask, 1, 1, false, 0, 0) { CUSTOM_OP_IMPL(sequence_mask, 1, 1, false, 0, 0) {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_NULLIFIED(0);
const int inRank = input->rankOf(); const int inRank = input->rankOf();
//REQUIRE_TRUE(inRank >= 1, 0, "sequence_mask: input array must have rank >= 1, but %i given!", inRank); //REQUIRE_TRUE(inRank >= 1, 0, "sequence_mask: input array must have rank >= 1, but %i given!", inRank);

View File

@ -26,7 +26,7 @@ namespace sd {
CUSTOM_OP_IMPL(unsorted_segment_max, 2, 1, false, 0, 0) { CUSTOM_OP_IMPL(unsorted_segment_max, 2, 1, false, 0, 0) {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto idxSegments = INPUT_VARIABLE(1); auto idxSegments = INPUT_VARIABLE(1);
auto segmentedOutput = OUTPUT_VARIABLE(0); auto segmentedOutput = OUTPUT_NULLIFIED(0);
Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : INT_ARG(0); Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : INT_ARG(0);
REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_max: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_max: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf());
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_max: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0)); REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_max: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0));
@ -67,7 +67,7 @@ namespace sd {
} }
CUSTOM_OP_IMPL(unsorted_segment_max_bp, 3, 2, false, 0, 1) { CUSTOM_OP_IMPL(unsorted_segment_max_bp, 3, 2, false, 0, 1) {
return helpers::unsortedSegmentMaxFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_VARIABLE(0)); return helpers::unsortedSegmentMaxFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_NULLIFIED(0));
} }
DECLARE_TYPES(unsorted_segment_max_bp) { DECLARE_TYPES(unsorted_segment_max_bp) {

View File

@ -26,7 +26,7 @@ namespace sd {
CUSTOM_OP_IMPL(unsorted_segment_mean, 2, 1, false, 0, 0) { CUSTOM_OP_IMPL(unsorted_segment_mean, 2, 1, false, 0, 0) {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto idxSegments = INPUT_VARIABLE(1); auto idxSegments = INPUT_VARIABLE(1);
auto segmentedOutput = OUTPUT_VARIABLE(0); auto segmentedOutput = OUTPUT_NULLIFIED(0);
Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : INT_ARG(0); Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : INT_ARG(0);
REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_mean: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_mean: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf());
@ -69,7 +69,7 @@ namespace sd {
} }
CUSTOM_OP_IMPL(unsorted_segment_mean_bp, 3, 2, false, 0, 1) { CUSTOM_OP_IMPL(unsorted_segment_mean_bp, 3, 2, false, 0, 1) {
return helpers::unsortedSegmentMeanFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_VARIABLE(0)); return helpers::unsortedSegmentMeanFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_NULLIFIED(0));
} }
DECLARE_TYPES(unsorted_segment_mean_bp) { DECLARE_TYPES(unsorted_segment_mean_bp) {

View File

@ -26,7 +26,7 @@ namespace sd {
CUSTOM_OP_IMPL(unsorted_segment_min, 2, 1, false, 0, 0) { CUSTOM_OP_IMPL(unsorted_segment_min, 2, 1, false, 0, 0) {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto idxSegments = INPUT_VARIABLE(1); auto idxSegments = INPUT_VARIABLE(1);
auto segmentedOutput = OUTPUT_VARIABLE(0); auto segmentedOutput = OUTPUT_NULLIFIED(0);
Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : INT_ARG(0); Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : INT_ARG(0);
REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_min: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_min: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf());
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_min: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0)); REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_min: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0));
@ -69,7 +69,7 @@ namespace sd {
} }
CUSTOM_OP_IMPL(unsorted_segment_min_bp, 3, 2, false, 0, 1) { CUSTOM_OP_IMPL(unsorted_segment_min_bp, 3, 2, false, 0, 1) {
return helpers::unsortedSegmentMinFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_VARIABLE(0)); return helpers::unsortedSegmentMinFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_NULLIFIED(0));
} }
DECLARE_TYPES(unsorted_segment_min_bp) { DECLARE_TYPES(unsorted_segment_min_bp) {

View File

@ -26,7 +26,7 @@ namespace sd {
CUSTOM_OP_IMPL(unsorted_segment_prod, 2, 1, false, 0, 0) { CUSTOM_OP_IMPL(unsorted_segment_prod, 2, 1, false, 0, 0) {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto idxSegments = INPUT_VARIABLE(1); auto idxSegments = INPUT_VARIABLE(1);
auto segmentedOutput = OUTPUT_VARIABLE(0); auto segmentedOutput = OUTPUT_NULLIFIED(0);
Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : INT_ARG(0); Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : INT_ARG(0);
REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_prod: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_prod: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf());
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_prod: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0)); REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_prod: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0));
@ -72,7 +72,7 @@ namespace sd {
auto indices = INPUT_VARIABLE(1); auto indices = INPUT_VARIABLE(1);
auto eps = INPUT_VARIABLE(2); auto eps = INPUT_VARIABLE(2);
// auto numOfClasses = INT_ARG(0); // auto numOfClasses = INT_ARG(0);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_NULLIFIED(0);
Nd4jLong numOfClasses = block.width() == 4 ? INPUT_VARIABLE(3)->e<Nd4jLong>(0) : INT_ARG(0); Nd4jLong numOfClasses = block.width() == 4 ? INPUT_VARIABLE(3)->e<Nd4jLong>(0) : INT_ARG(0);
REQUIRE_TRUE(indices->isVector(), 0, "unsorted_segment_prod_bp: segment indexes array should be a vector, but it rank is %i.", indices->rankOf()); REQUIRE_TRUE(indices->isVector(), 0, "unsorted_segment_prod_bp: segment indexes array should be a vector, but it rank is %i.", indices->rankOf());

View File

@ -26,7 +26,7 @@ namespace sd {
CUSTOM_OP_IMPL(unsorted_segment_sqrt_n, 2, 1, false, 0, 0) { CUSTOM_OP_IMPL(unsorted_segment_sqrt_n, 2, 1, false, 0, 0) {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto idxSegments = INPUT_VARIABLE(1); auto idxSegments = INPUT_VARIABLE(1);
auto segmentedOutput = OUTPUT_VARIABLE(0); auto segmentedOutput = OUTPUT_NULLIFIED(0);
Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : INT_ARG(0); Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : INT_ARG(0);
REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_sqrt_n: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_sqrt_n: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf());
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_sqrt_n: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0)); REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_sqrt_n: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0));
@ -68,7 +68,7 @@ namespace sd {
} }
CUSTOM_OP_IMPL(unsorted_segment_sqrt_n_bp, 3, 2, false, 0, 1) { CUSTOM_OP_IMPL(unsorted_segment_sqrt_n_bp, 3, 2, false, 0, 1) {
return helpers::unsortedSegmentSqrtNFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_VARIABLE(0)); return helpers::unsortedSegmentSqrtNFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_NULLIFIED(0));
} }
DECLARE_TYPES(unsorted_segment_sqrt_n_bp) { DECLARE_TYPES(unsorted_segment_sqrt_n_bp) {
getOpDescriptor() getOpDescriptor()

View File

@ -26,7 +26,7 @@ namespace sd {
CUSTOM_OP_IMPL(unsorted_segment_sum, 2, 1, false, 0, 0) { CUSTOM_OP_IMPL(unsorted_segment_sum, 2, 1, false, 0, 0) {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto idxSegments = INPUT_VARIABLE(1); auto idxSegments = INPUT_VARIABLE(1);
auto segmentedOutput = OUTPUT_VARIABLE(0); auto segmentedOutput = OUTPUT_NULLIFIED(0);
Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : INT_ARG(0); Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : INT_ARG(0);
REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_sum: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_sum: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf());
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_sum: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0)); REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_sum: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0));
@ -67,7 +67,7 @@ namespace sd {
return SHAPELIST(CONSTANT(outputShape)); return SHAPELIST(CONSTANT(outputShape));
} }
CUSTOM_OP_IMPL(unsorted_segment_sum_bp, 3, 2, false, 0, 1) { CUSTOM_OP_IMPL(unsorted_segment_sum_bp, 3, 2, false, 0, 1) {
return helpers::unsortedSegmentSumFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_VARIABLE(0)); return helpers::unsortedSegmentSumFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_NULLIFIED(0));
} }
DECLARE_SHAPE_FN(unsorted_segment_sum_bp){ DECLARE_SHAPE_FN(unsorted_segment_sum_bp){

View File

@ -42,7 +42,7 @@ namespace sd {
CUSTOM_OP_IMPL(random_multinomial, 2, 1, false, 0, 0) { CUSTOM_OP_IMPL(random_multinomial, 2, 1, false, 0, 0) {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_NULLIFIED(0);
auto inputSamples = INPUT_VARIABLE(1); auto inputSamples = INPUT_VARIABLE(1);

View File

@ -48,7 +48,7 @@ namespace sd {
*/ */
CUSTOM_OP_IMPL(firas_sparse, 1, 1, false, 0, -1) { CUSTOM_OP_IMPL(firas_sparse, 1, 1, false, 0, -1) {
auto x = INPUT_VARIABLE(0); auto x = INPUT_VARIABLE(0);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_NULLIFIED(0);
int batchSize = x->sizeAt(0); int batchSize = x->sizeAt(0);
int numColumns = x->sizeAt(1); int numColumns = x->sizeAt(1);

View File

@ -34,7 +34,7 @@ namespace ops {
auto dataP = INPUT_VARIABLE(3); auto dataP = INPUT_VARIABLE(3);
auto N = INT_ARG(0); auto N = INT_ARG(0);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_NULLIFIED(0);
REQUIRE_TRUE(rowP->isVector(), 0, "barnes_edge_force: row input must be a vector, but its rank is %i instead !", rowP->rankOf()); REQUIRE_TRUE(rowP->isVector(), 0, "barnes_edge_force: row input must be a vector, but its rank is %i instead !", rowP->rankOf());
REQUIRE_TRUE(colP->isVector(), 0, "barnes_edge_force: col input must be a vector, but its rank is %i instead !", colP->rankOf()); REQUIRE_TRUE(colP->isVector(), 0, "barnes_edge_force: col input must be a vector, but its rank is %i instead !", colP->rankOf());

View File

@ -54,8 +54,6 @@ void col2im_(sd::LaunchContext & context, const NDArray& input, NDArray& output
const Nd4jLong imStride1 = imStride[1]; const Nd4jLong imStride1 = imStride[1];
const Nd4jLong imStride2 = imStride[2]; const Nd4jLong imStride2 = imStride[2];
const Nd4jLong imStride3 = imStride[3]; const Nd4jLong imStride3 = imStride[3];
memset(imBuff, 0, shape::length(imShapeBuffer) * sizeof(T));
// if (shape::order(colShapeBuffer) == 'c' && shape::order(imShapeBuffer) == 'c' && shape::strideDescendingCAscendingF(colShapeBuffer) && shape::strideDescendingCAscendingF(imShapeBuffer)) { // if (shape::order(colShapeBuffer) == 'c' && shape::order(imShapeBuffer) == 'c' && shape::strideDescendingCAscendingF(colShapeBuffer) && shape::strideDescendingCAscendingF(imShapeBuffer)) {

View File

@ -116,6 +116,8 @@ void bgemm(const std::vector<NDArray*>& vA, const std::vector<NDArray*>& vB, std
const auto bType = pB[0]->dataType(); const auto bType = pB[0]->dataType();
const auto cType = pC[0]->dataType(); const auto cType = pC[0]->dataType();
std::lock_guard<std::mutex> lock(*LaunchContext::deviceMutex());
auto handle = reinterpret_cast<cublasHandle_t*>(context->getCublasHandle()); auto handle = reinterpret_cast<cublasHandle_t*>(context->getCublasHandle());
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();

View File

@ -96,6 +96,14 @@ namespace sd {
return _descriptor->getHash(); return _descriptor->getHash();
} }
sd::NDArray* sd::ops::DeclarableOp::getNullifiedZ(Context& block, int inputId) {
auto result = getZ(block, inputId);
if (result != nullptr && !block.isInplace())
result->nullify();
return result;
}
sd::NDArray* sd::ops::DeclarableOp::getZ(Context& ctx, int inputId) { sd::NDArray* sd::ops::DeclarableOp::getZ(Context& ctx, int inputId) {
NDArray* z = nullptr; NDArray* z = nullptr;
@ -294,7 +302,8 @@ namespace sd {
if (Environment::getInstance()->isDebugAndVerbose()) if (Environment::getInstance()->isDebugAndVerbose())
shape::printShapeInfoLinear("Going to create variable with shape", out); shape::printShapeInfoLinear("Going to create variable with shape", out);
auto outArr = new NDArray(out, true, ctx.launchContext()); // we're creating non-initialized array here
auto outArr = new NDArray(out, true, ctx.launchContext(), false);
ctx.pushNDArrayToVariableSpace(pair, outArr); ctx.pushNDArrayToVariableSpace(pair, outArr);

View File

@ -31,7 +31,15 @@ namespace sd {
_engine = engine; _engine = engine;
} }
sd::NDArray *PlatformHelper::getZ(graph::Context &ctx, int inputId) { sd::NDArray* PlatformHelper::getNullifiedZ(graph::Context& block, int inputId) {
auto result = getZ(block, inputId);
if (result != nullptr && !block.isInplace())
result->nullify();
return result;
}
sd::NDArray* PlatformHelper::getZ(graph::Context &ctx, int inputId) {
NDArray *z = nullptr; NDArray *z = nullptr;
if (ctx.isFastPath()) { if (ctx.isFastPath()) {

View File

@ -540,9 +540,9 @@ PLATFORM_IMPL(conv2d_bp, ENGINE_CPU) {
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto gradW = OUTPUT_NULLIFIED(1); // [kH, kW, iC, oC] always
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC]
int kH = INT_ARG(0); // filter(kernel) height int kH = INT_ARG(0); // filter(kernel) height
int kW = INT_ARG(1); // filter(kernel) width int kW = INT_ARG(1); // filter(kernel) width

View File

@ -542,9 +542,9 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) {
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon auto gradI = OUTPUT_NULLIFIED(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always auto gradW = OUTPUT_NULLIFIED(1); // [kD, kH, kW, iC, oC] always
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC]
REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM CONV3D_BP MKLDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM CONV3D_BP MKLDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
REQUIRE_TRUE(weights->rankOf() == 5, 0, "CUSTOM CONV3D_BP MKLDNN OP: rank of weights array must be equal to 5, but got %i instead !", weights->rankOf()); REQUIRE_TRUE(weights->rankOf() == 5, 0, "CUSTOM CONV3D_BP MKLDNN OP: rank of weights array must be equal to 5, but got %i instead !", weights->rankOf());

View File

@ -398,9 +398,9 @@ PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CPU) {
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon
auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, mC] always auto gradW = OUTPUT_NULLIFIED(1); // [kH, kW, iC, mC] always
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC]
REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf()); REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf());
REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf());

View File

@ -1513,6 +1513,7 @@
#define INPUT_VARIABLE(INDEX) block.array(INDEX) #define INPUT_VARIABLE(INDEX) block.array(INDEX)
#define OUTPUT_VARIABLE(INDEX) reinterpret_cast<sd::NDArray *>(this->getZ(block, INDEX)) #define OUTPUT_VARIABLE(INDEX) reinterpret_cast<sd::NDArray *>(this->getZ(block, INDEX))
#define OUTPUT_NULLIFIED(INDEX) reinterpret_cast<sd::NDArray *>(this->getNullifiedZ(block, INDEX))
#define INPUT_LIST(INDEX) reinterpret_cast<sd::NDArrayList *>(block.getVariable(INDEX)->getNDArrayList()) #define INPUT_LIST(INDEX) reinterpret_cast<sd::NDArrayList *>(block.getVariable(INDEX)->getNDArrayList())

View File

@ -2128,8 +2128,10 @@ TEST_F(ConvolutionTests1, col2im_test1) {
auto imageExpected = NDArrayFactory::create<float>('c', {bS, iC, iH, iW}, {1.f, 7.f, 12.f, 34.f, 17.f, 39.f, 44.f, 98.f, 33.f, 71.f, 76.f, 162.f, 49.f, 103.f, 108.f, 226.f}); auto imageExpected = NDArrayFactory::create<float>('c', {bS, iC, iH, iW}, {1.f, 7.f, 12.f, 34.f, 17.f, 39.f, 44.f, 98.f, 33.f, 71.f, 76.f, 162.f, 49.f, 103.f, 108.f, 226.f});
LaunchContext ctx;
sd::ops::helpers::col2im(ctx, columns, image, sH, sW, pH, pW, iH, iW, dH, dW); sd::ops::col2im op;
auto status = op.execute({&columns}, {&image}, {sH, sW, pH, pW, iH, iW, dH, dW, 0});
ASSERT_EQ(Status::OK(), status);
ASSERT_TRUE(image.equalsTo(imageExpected)); ASSERT_TRUE(image.equalsTo(imageExpected));
} }

View File

@ -18,7 +18,7 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper {
public IntVectorVector(long n) { allocate(n); } public IntVectorVector(long n) { allocate(n); }
private native void allocate(); private native void allocate();
private native void allocate(@Cast("size_t") long n); private native void allocate(@Cast("size_t") long n);
public native @Name("operator=") @ByRef IntVectorVector put(@ByRef IntVectorVector x); public native @Name("operator =") @ByRef IntVectorVector put(@ByRef IntVectorVector x);
public boolean empty() { return size() == 0; } public boolean empty() { return size() == 0; }
public native long size(); public native long size();
@ -67,7 +67,7 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper {
public LongVectorVector(long n) { allocate(n); } public LongVectorVector(long n) { allocate(n); }
private native void allocate(); private native void allocate();
private native void allocate(@Cast("size_t") long n); private native void allocate(@Cast("size_t") long n);
public native @Name("operator=") @ByRef LongVectorVector put(@ByRef LongVectorVector x); public native @Name("operator =") @ByRef LongVectorVector put(@ByRef LongVectorVector x);
public boolean empty() { return size() == 0; } public boolean empty() { return size() == 0; }
public native long size(); public native long size();
@ -117,7 +117,7 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper {
public NDArrayVector(long n) { allocate(n); } public NDArrayVector(long n) { allocate(n); }
private native void allocate(); private native void allocate();
private native void allocate(@Cast("size_t") long n); private native void allocate(@Cast("size_t") long n);
public native @Name("operator=") @ByRef NDArrayVector put(@ByRef NDArrayVector x); public native @Name("operator =") @ByRef NDArrayVector put(@ByRef NDArrayVector x);
public boolean empty() { return size() == 0; } public boolean empty() { return size() == 0; }
public native long size(); public native long size();
@ -135,9 +135,9 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper {
public Iterator(Pointer p) { super(p); } public Iterator(Pointer p) { super(p); }
public Iterator() { } public Iterator() { }
public native @Name("operator++") @ByRef Iterator increment(); public native @Name("operator ++") @ByRef Iterator increment();
public native @Name("operator==") boolean equals(@ByRef Iterator it); public native @Name("operator ==") boolean equals(@ByRef Iterator it);
public native @Name("operator*") @Const NDArray get(); public native @Name("operator *") @Const NDArray get();
} }
public NDArray[] get() { public NDArray[] get() {
@ -185,7 +185,7 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper {
public ConstNDArrayVector(long n) { allocate(n); } public ConstNDArrayVector(long n) { allocate(n); }
private native void allocate(); private native void allocate();
private native void allocate(@Cast("size_t") long n); private native void allocate(@Cast("size_t") long n);
public native @Name("operator=") @ByRef ConstNDArrayVector put(@ByRef ConstNDArrayVector x); public native @Name("operator =") @ByRef ConstNDArrayVector put(@ByRef ConstNDArrayVector x);
public boolean empty() { return size() == 0; } public boolean empty() { return size() == 0; }
public native long size(); public native long size();
@ -203,9 +203,9 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper {
public Iterator(Pointer p) { super(p); } public Iterator(Pointer p) { super(p); }
public Iterator() { } public Iterator() { }
public native @Name("operator++") @ByRef Iterator increment(); public native @Name("operator ++") @ByRef Iterator increment();
public native @Name("operator==") boolean equals(@ByRef Iterator it); public native @Name("operator ==") boolean equals(@ByRef Iterator it);
public native @Name("operator*") @Const NDArray get(); public native @Name("operator *") @Const NDArray get();
} }
public NDArray[] get() { public NDArray[] get() {
@ -250,7 +250,7 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper {
public IntIntPair(int firstValue, int secondValue) { this(); put(firstValue, secondValue); } public IntIntPair(int firstValue, int secondValue) { this(); put(firstValue, secondValue); }
public IntIntPair() { allocate(); } public IntIntPair() { allocate(); }
private native void allocate(); private native void allocate();
public native @Name("operator=") @ByRef IntIntPair put(@ByRef IntIntPair x); public native @Name("operator =") @ByRef IntIntPair put(@ByRef IntIntPair x);
@MemberGetter public native int first(); public native IntIntPair first(int first); @MemberGetter public native int first(); public native IntIntPair first(int first);
@ -3733,16 +3733,16 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
/** /**
* constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently * constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently
*/ */
public NDArray(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context); } public NDArray(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context, nullify); }
private native void allocate(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/); private native void allocate(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/);
public NDArray(@Cast("Nd4jLong*") LongPointer shapeInfo) { super((Pointer)null); allocate(shapeInfo); } public NDArray(@Cast("Nd4jLong*") LongPointer shapeInfo) { super((Pointer)null); allocate(shapeInfo); }
private native void allocate(@Cast("Nd4jLong*") LongPointer shapeInfo); private native void allocate(@Cast("Nd4jLong*") LongPointer shapeInfo);
public NDArray(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context); } public NDArray(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context, nullify); }
private native void allocate(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/); private native void allocate(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/);
public NDArray(@Cast("Nd4jLong*") LongBuffer shapeInfo) { super((Pointer)null); allocate(shapeInfo); } public NDArray(@Cast("Nd4jLong*") LongBuffer shapeInfo) { super((Pointer)null); allocate(shapeInfo); }
private native void allocate(@Cast("Nd4jLong*") LongBuffer shapeInfo); private native void allocate(@Cast("Nd4jLong*") LongBuffer shapeInfo);
public NDArray(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context); } public NDArray(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context, nullify); }
private native void allocate(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/); private native void allocate(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/);
public NDArray(@Cast("Nd4jLong*") long[] shapeInfo) { super((Pointer)null); allocate(shapeInfo); } public NDArray(@Cast("Nd4jLong*") long[] shapeInfo) { super((Pointer)null); allocate(shapeInfo); }
private native void allocate(@Cast("Nd4jLong*") long[] shapeInfo); private native void allocate(@Cast("Nd4jLong*") long[] shapeInfo);
@ -3750,16 +3750,16 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
* constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to be zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently * constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to be zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently
* set dtype as array type * set dtype as array type
*/ */
public NDArray(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context); } public NDArray(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context, nullify); }
private native void allocate(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/); private native void allocate(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/);
public NDArray(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype) { super((Pointer)null); allocate(shapeInfo, dtype); } public NDArray(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype) { super((Pointer)null); allocate(shapeInfo, dtype); }
private native void allocate(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype); private native void allocate(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype);
public NDArray(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context); } public NDArray(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context, nullify); }
private native void allocate(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/); private native void allocate(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/);
public NDArray(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype) { super((Pointer)null); allocate(shapeInfo, dtype); } public NDArray(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype) { super((Pointer)null); allocate(shapeInfo, dtype); }
private native void allocate(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype); private native void allocate(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype);
public NDArray(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context); } public NDArray(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context, nullify); }
private native void allocate(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/); private native void allocate(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/);
public NDArray(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype) { super((Pointer)null); allocate(shapeInfo, dtype); } public NDArray(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype) { super((Pointer)null); allocate(shapeInfo, dtype); }
private native void allocate(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype); private native void allocate(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype);
@ -9492,6 +9492,14 @@ public static final int PREALLOC_SIZE = 33554432;
* @return * @return
*/ */
public native NDArray getZ(@ByRef Context ctx, int inputId); public native NDArray getZ(@ByRef Context ctx, int inputId);
/**
* Helper method, needed for compatibility with DeclarableOp macros
* @param ctx
* @param inputId
* @return
*/
public native NDArray getNullifiedZ(@ByRef Context ctx, int inputId);
} }
@ -10289,7 +10297,6 @@ public static final int PREALLOC_SIZE = 33554432;
// #ifndef __JAVACPP_HACK__ // #ifndef __JAVACPP_HACK__
// #endif // JCPP // #endif // JCPP
// #endif // CUDA // #endif // CUDA
@ -10308,6 +10315,10 @@ public static final int PREALLOC_SIZE = 33554432;
public native void setDeviceID(int deviceID); public native void setDeviceID(int deviceID);
public native ErrorReference errorReference(); public native ErrorReference errorReference();
// #ifndef __JAVACPP_HACK__
// #endif
public static native @Cast("bool") boolean isInitialized(); public static native @Cast("bool") boolean isInitialized();
public static native void releaseBuffers(); public static native void releaseBuffers();

View File

@ -21,7 +21,7 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper {
public IntVectorVector(long n) { allocate(n); } public IntVectorVector(long n) { allocate(n); }
private native void allocate(); private native void allocate();
private native void allocate(@Cast("size_t") long n); private native void allocate(@Cast("size_t") long n);
public native @Name("operator=") @ByRef IntVectorVector put(@ByRef IntVectorVector x); public native @Name("operator =") @ByRef IntVectorVector put(@ByRef IntVectorVector x);
public boolean empty() { return size() == 0; } public boolean empty() { return size() == 0; }
public native long size(); public native long size();
@ -70,7 +70,7 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper {
public LongVectorVector(long n) { allocate(n); } public LongVectorVector(long n) { allocate(n); }
private native void allocate(); private native void allocate();
private native void allocate(@Cast("size_t") long n); private native void allocate(@Cast("size_t") long n);
public native @Name("operator=") @ByRef LongVectorVector put(@ByRef LongVectorVector x); public native @Name("operator =") @ByRef LongVectorVector put(@ByRef LongVectorVector x);
public boolean empty() { return size() == 0; } public boolean empty() { return size() == 0; }
public native long size(); public native long size();
@ -120,7 +120,7 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper {
public ConstNDArrayVector(long n) { allocate(n); } public ConstNDArrayVector(long n) { allocate(n); }
private native void allocate(); private native void allocate();
private native void allocate(@Cast("size_t") long n); private native void allocate(@Cast("size_t") long n);
public native @Name("operator=") @ByRef ConstNDArrayVector put(@ByRef ConstNDArrayVector x); public native @Name("operator =") @ByRef ConstNDArrayVector put(@ByRef ConstNDArrayVector x);
public boolean empty() { return size() == 0; } public boolean empty() { return size() == 0; }
public native long size(); public native long size();
@ -138,9 +138,9 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper {
public Iterator(Pointer p) { super(p); } public Iterator(Pointer p) { super(p); }
public Iterator() { } public Iterator() { }
public native @Name("operator++") @ByRef Iterator increment(); public native @Name("operator ++") @ByRef Iterator increment();
public native @Name("operator==") boolean equals(@ByRef Iterator it); public native @Name("operator ==") boolean equals(@ByRef Iterator it);
public native @Name("operator*") @Const NDArray get(); public native @Name("operator *") @Const NDArray get();
} }
public NDArray[] get() { public NDArray[] get() {
@ -188,7 +188,7 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper {
public NDArrayVector(long n) { allocate(n); } public NDArrayVector(long n) { allocate(n); }
private native void allocate(); private native void allocate();
private native void allocate(@Cast("size_t") long n); private native void allocate(@Cast("size_t") long n);
public native @Name("operator=") @ByRef NDArrayVector put(@ByRef NDArrayVector x); public native @Name("operator =") @ByRef NDArrayVector put(@ByRef NDArrayVector x);
public boolean empty() { return size() == 0; } public boolean empty() { return size() == 0; }
public native long size(); public native long size();
@ -206,9 +206,9 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper {
public Iterator(Pointer p) { super(p); } public Iterator(Pointer p) { super(p); }
public Iterator() { } public Iterator() { }
public native @Name("operator++") @ByRef Iterator increment(); public native @Name("operator ++") @ByRef Iterator increment();
public native @Name("operator==") boolean equals(@ByRef Iterator it); public native @Name("operator ==") boolean equals(@ByRef Iterator it);
public native @Name("operator*") @Const NDArray get(); public native @Name("operator *") @Const NDArray get();
} }
public NDArray[] get() { public NDArray[] get() {
@ -253,7 +253,7 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper {
public IntIntPair(int firstValue, int secondValue) { this(); put(firstValue, secondValue); } public IntIntPair(int firstValue, int secondValue) { this(); put(firstValue, secondValue); }
public IntIntPair() { allocate(); } public IntIntPair() { allocate(); }
private native void allocate(); private native void allocate();
public native @Name("operator=") @ByRef IntIntPair put(@ByRef IntIntPair x); public native @Name("operator =") @ByRef IntIntPair put(@ByRef IntIntPair x);
@MemberGetter public native int first(); public native IntIntPair first(int first); @MemberGetter public native int first(); public native IntIntPair first(int first);
@ -3736,16 +3736,16 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
/** /**
* constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently * constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently
*/ */
public NDArray(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context); } public NDArray(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context, nullify); }
private native void allocate(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/); private native void allocate(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/);
public NDArray(@Cast("Nd4jLong*") LongPointer shapeInfo) { super((Pointer)null); allocate(shapeInfo); } public NDArray(@Cast("Nd4jLong*") LongPointer shapeInfo) { super((Pointer)null); allocate(shapeInfo); }
private native void allocate(@Cast("Nd4jLong*") LongPointer shapeInfo); private native void allocate(@Cast("Nd4jLong*") LongPointer shapeInfo);
public NDArray(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context); } public NDArray(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context, nullify); }
private native void allocate(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/); private native void allocate(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/);
public NDArray(@Cast("Nd4jLong*") LongBuffer shapeInfo) { super((Pointer)null); allocate(shapeInfo); } public NDArray(@Cast("Nd4jLong*") LongBuffer shapeInfo) { super((Pointer)null); allocate(shapeInfo); }
private native void allocate(@Cast("Nd4jLong*") LongBuffer shapeInfo); private native void allocate(@Cast("Nd4jLong*") LongBuffer shapeInfo);
public NDArray(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context); } public NDArray(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context, nullify); }
private native void allocate(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/); private native void allocate(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/);
public NDArray(@Cast("Nd4jLong*") long[] shapeInfo) { super((Pointer)null); allocate(shapeInfo); } public NDArray(@Cast("Nd4jLong*") long[] shapeInfo) { super((Pointer)null); allocate(shapeInfo); }
private native void allocate(@Cast("Nd4jLong*") long[] shapeInfo); private native void allocate(@Cast("Nd4jLong*") long[] shapeInfo);
@ -3753,16 +3753,16 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
* constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to be zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently * constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to be zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently
* set dtype as array type * set dtype as array type
*/ */
public NDArray(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context); } public NDArray(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context, nullify); }
private native void allocate(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/); private native void allocate(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/);
public NDArray(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype) { super((Pointer)null); allocate(shapeInfo, dtype); } public NDArray(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype) { super((Pointer)null); allocate(shapeInfo, dtype); }
private native void allocate(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype); private native void allocate(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype);
public NDArray(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context); } public NDArray(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context, nullify); }
private native void allocate(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/); private native void allocate(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/);
public NDArray(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype) { super((Pointer)null); allocate(shapeInfo, dtype); } public NDArray(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype) { super((Pointer)null); allocate(shapeInfo, dtype); }
private native void allocate(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype); private native void allocate(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype);
public NDArray(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context); } public NDArray(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context, nullify); }
private native void allocate(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/); private native void allocate(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/);
public NDArray(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype) { super((Pointer)null); allocate(shapeInfo, dtype); } public NDArray(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype) { super((Pointer)null); allocate(shapeInfo, dtype); }
private native void allocate(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype); private native void allocate(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype);
@ -11459,6 +11459,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
// #define INPUT_VARIABLE(INDEX) block.array(INDEX) // #define INPUT_VARIABLE(INDEX) block.array(INDEX)
// #define OUTPUT_VARIABLE(INDEX) reinterpret_cast<sd::NDArray *>(this->getZ(block, INDEX)) // #define OUTPUT_VARIABLE(INDEX) reinterpret_cast<sd::NDArray *>(this->getZ(block, INDEX))
// #define OUTPUT_NULLIFIED(INDEX) reinterpret_cast<sd::NDArray *>(this->getNullifiedZ(block, INDEX))
// #define INPUT_LIST(INDEX) reinterpret_cast<sd::NDArrayList *>(block.getVariable(INDEX)->getNDArrayList()) // #define INPUT_LIST(INDEX) reinterpret_cast<sd::NDArrayList *>(block.getVariable(INDEX)->getNDArrayList())
@ -11809,6 +11810,14 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
* @return * @return
*/ */
public native NDArray getZ(@ByRef Context ctx, int inputId); public native NDArray getZ(@ByRef Context ctx, int inputId);
/**
* Helper method, needed for compatibility with DeclarableOp macros
* @param ctx
* @param inputId
* @return
*/
public native NDArray getNullifiedZ(@ByRef Context ctx, int inputId);
} }
@ -24205,6 +24214,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
public native void setDeviceID(int deviceID); public native void setDeviceID(int deviceID);
public native ErrorReference errorReference(); public native ErrorReference errorReference();
// #ifndef __JAVACPP_HACK__
// #endif
public static native @Cast("bool") boolean isInitialized(); public static native @Cast("bool") boolean isInitialized();
public static native void releaseBuffers(); public static native void releaseBuffers();

View File

@ -28,6 +28,7 @@ import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.junit.runners.Parameterized; import org.junit.runners.Parameterized;
import org.nd4j.imports.TFGraphs.NodeReader; import org.nd4j.imports.TFGraphs.NodeReader;
import org.nd4j.linalg.api.blas.BlasBufferUtil;
import org.nd4j.linalg.api.blas.Level1; import org.nd4j.linalg.api.blas.Level1;
import org.nd4j.linalg.api.blas.params.GemmParams; import org.nd4j.linalg.api.blas.params.GemmParams;
import org.nd4j.linalg.api.blas.params.MMulTranspose; import org.nd4j.linalg.api.blas.params.MMulTranspose;
@ -106,6 +107,7 @@ import java.nio.ByteOrder;
import java.nio.file.Files; import java.nio.file.Files;
import java.nio.file.Paths; import java.nio.file.Paths;
import java.util.*; import java.util.*;
import java.util.concurrent.CountDownLatch;
import static org.junit.Assert.*; import static org.junit.Assert.*;
import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertArrayEquals;