Nullify (#304)
* initial commit Signed-off-by: raver119 <raver119@gmail.com> * bunch of tweaks Signed-off-by: raver119 <raver119@gmail.com> * hamming distance nullification Signed-off-by: raver119 <raver119@gmail.com> * Add output array value assignment for testing/debugging Signed-off-by: Alex Black <blacka101@gmail.com> * don't assign empty arrays Signed-off-by: raver119 <raver119@gmail.com> * conv2d/conv3d/depthwise2d nullified Signed-off-by: raver119 <raver119@gmail.com> * conv2d/conv3d/depthwise2d nullified Signed-off-by: raver119 <raver119@gmail.com> * conv2d/conv3d/depthwise2d nullified Signed-off-by: raver119 <raver119@gmail.com> * few more fixes Signed-off-by: raver119 <raver119@gmail.com> * im2col Signed-off-by: raver119 <raver119@gmail.com> * pooling? Signed-off-by: raver119 <raver119@gmail.com> * more nullified Signed-off-by: raver119 <raver119@gmail.com> * ismax nullified Signed-off-by: raver119 <raver119@gmail.com> * rollback ismax nullification Signed-off-by: raver119 <raver119@gmail.com> * synchronized cublas handle use on per-device basis Signed-off-by: raver119 <raver119@gmail.com> * hiding method from jcpp Signed-off-by: raver119 <raver119@gmail.com> * get rid of test assigns in DeclarableOp Signed-off-by: raver119 <raver119@gmail.com> * get rid of assigns Signed-off-by: raver119 <raver119@gmail.com> * proper deviceId is back Signed-off-by: raver119 <raver119@gmail.com> * include fixed Signed-off-by: raver119 <raver119@gmail.com> Co-authored-by: Alex Black <blacka101@gmail.com>master
parent
30a28fae45
commit
7a2ac800dd
|
@ -277,13 +277,13 @@ namespace sd {
|
||||||
/**
|
/**
|
||||||
* constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently
|
* constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently
|
||||||
*/
|
*/
|
||||||
NDArray(Nd4jLong* shapeInfo, const bool copyStrides = false, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
|
NDArray(Nd4jLong* shapeInfo, const bool copyStrides = false, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), const bool nullify = true);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to be zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently
|
* constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to be zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently
|
||||||
* set dtype as array type
|
* set dtype as array type
|
||||||
*/
|
*/
|
||||||
NDArray(Nd4jLong* shapeInfo, const sd::DataType dtype, const bool copyStrides = false, sd::LaunchContext* context = sd::LaunchContext::defaultContext());
|
NDArray(Nd4jLong* shapeInfo, const sd::DataType dtype, const bool copyStrides = false, sd::LaunchContext* context = sd::LaunchContext::defaultContext(), const bool nullify = true);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* this constructor creates new array using shape information contained in vector argument
|
* this constructor creates new array using shape information contained in vector argument
|
||||||
|
|
|
@ -143,7 +143,7 @@ NDArray::NDArray(void* buffer, const char order, const std::vector<Nd4jLong> &sh
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
// creates new NDArray using shape information from "shapeInfo" array, set all elements in new array to be zeros
|
// creates new NDArray using shape information from "shapeInfo" array, set all elements in new array to be zeros
|
||||||
NDArray::NDArray(Nd4jLong* shapeInfo, const sd::DataType dtype, const bool copyStrides, sd::LaunchContext * context) {
|
NDArray::NDArray(Nd4jLong* shapeInfo, const sd::DataType dtype, const bool copyStrides, sd::LaunchContext * context, const bool nullify) {
|
||||||
|
|
||||||
if (shapeInfo == nullptr)
|
if (shapeInfo == nullptr)
|
||||||
throw std::runtime_error("NDArray constructor: can't be initalized without shapeinfo");
|
throw std::runtime_error("NDArray constructor: can't be initalized without shapeinfo");
|
||||||
|
@ -161,7 +161,9 @@ NDArray::NDArray(Nd4jLong* shapeInfo, const sd::DataType dtype, const bool copyS
|
||||||
|
|
||||||
if (!isEmpty()) {
|
if (!isEmpty()) {
|
||||||
_buffer = std::make_shared<DataBuffer>(lengthOf() * sizeOfT(), dtype, getContext()->getWorkspace());
|
_buffer = std::make_shared<DataBuffer>(lengthOf() * sizeOfT(), dtype, getContext()->getWorkspace());
|
||||||
_buffer->setToZeroBuffers();
|
|
||||||
|
if (nullify)
|
||||||
|
_buffer->setToZeroBuffers();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -213,7 +215,7 @@ NDArray::NDArray(sd::LaunchContext * context) {
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
// creates new NDArray using shape information from "shapeInfo" array, set all elements in new array to be zeros, set dtype as array type
|
// creates new NDArray using shape information from "shapeInfo" array, set all elements in new array to be zeros, set dtype as array type
|
||||||
NDArray::NDArray(Nd4jLong* shapeInfo, const bool copyStrides, sd::LaunchContext * context):
|
NDArray::NDArray(Nd4jLong* shapeInfo, const bool copyStrides, sd::LaunchContext * context, const bool nullify):
|
||||||
NDArray(shapeInfo, ArrayOptions::dataType(shapeInfo), copyStrides, context) {
|
NDArray(shapeInfo, ArrayOptions::dataType(shapeInfo), copyStrides, context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3339,9 +3341,6 @@ void NDArray::nullify() {
|
||||||
if (isEmpty())
|
if (isEmpty())
|
||||||
return;
|
return;
|
||||||
|
|
||||||
if (isS())
|
|
||||||
throw std::runtime_error("NDArray::nullify: can't nullify string array");
|
|
||||||
|
|
||||||
if (isView() || ews() != 1)
|
if (isView() || ews() != 1)
|
||||||
assign(0);
|
assign(0);
|
||||||
else
|
else
|
||||||
|
|
|
@ -54,6 +54,8 @@ class ND4J_EXPORT LaunchContext {
|
||||||
static std::vector<std::shared_ptr<LaunchContext>> _contexts;
|
static std::vector<std::shared_ptr<LaunchContext>> _contexts;
|
||||||
static std::mutex _mutex;
|
static std::mutex _mutex;
|
||||||
|
|
||||||
|
static MAP_IMPL<int, std::mutex*> _deviceMutexes;
|
||||||
|
|
||||||
// used for MKLDNN
|
// used for MKLDNN
|
||||||
void *_engine = nullptr;
|
void *_engine = nullptr;
|
||||||
|
|
||||||
|
@ -93,7 +95,6 @@ class ND4J_EXPORT LaunchContext {
|
||||||
void setCudaSpecialStream(cudaStream_t* cudaStream);
|
void setCudaSpecialStream(cudaStream_t* cudaStream);
|
||||||
void setCublasHandle(void *handle);
|
void setCublasHandle(void *handle);
|
||||||
|
|
||||||
|
|
||||||
#endif // JCPP
|
#endif // JCPP
|
||||||
|
|
||||||
#endif // CUDA
|
#endif // CUDA
|
||||||
|
@ -111,6 +112,12 @@ class ND4J_EXPORT LaunchContext {
|
||||||
void setDeviceID(int deviceID) { _deviceID = deviceID; }
|
void setDeviceID(int deviceID) { _deviceID = deviceID; }
|
||||||
sd::ErrorReference* errorReference();
|
sd::ErrorReference* errorReference();
|
||||||
|
|
||||||
|
#ifndef __JAVACPP_HACK__
|
||||||
|
// this method returns mutex shared between all threads that use the same device
|
||||||
|
static std::mutex* deviceMutex();
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
static bool isInitialized();
|
static bool isInitialized();
|
||||||
static void releaseBuffers();
|
static void releaseBuffers();
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <execution/LaunchContext.h>
|
#include <execution/LaunchContext.h>
|
||||||
|
#include <execution/AffinityManager.h>
|
||||||
#include <helpers/logger.h>
|
#include <helpers/logger.h>
|
||||||
#include <exceptions/cuda_exception.h>
|
#include <exceptions/cuda_exception.h>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
|
@ -42,6 +43,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::shared_ptr<LaunchContext>> LaunchContext::_contexts = std::vector<std::shared_ptr<LaunchContext>>();
|
std::vector<std::shared_ptr<LaunchContext>> LaunchContext::_contexts = std::vector<std::shared_ptr<LaunchContext>>();
|
||||||
|
MAP_IMPL<int, std::mutex*> LaunchContext::_deviceMutexes;
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
LaunchContext::LaunchContext() {
|
LaunchContext::LaunchContext() {
|
||||||
|
@ -49,6 +51,8 @@ namespace sd {
|
||||||
_workspace = nullptr;
|
_workspace = nullptr;
|
||||||
_deviceID = 0;
|
_deviceID = 0;
|
||||||
|
|
||||||
|
_deviceMutexes[_deviceID] = new std::mutex();
|
||||||
|
|
||||||
#ifdef HAVE_MKLDNN
|
#ifdef HAVE_MKLDNN
|
||||||
_engine = new dnnl::engine(dnnl::engine::kind::cpu, 0);
|
_engine = new dnnl::engine(dnnl::engine::kind::cpu, 0);
|
||||||
#endif
|
#endif
|
||||||
|
@ -68,6 +72,11 @@ namespace sd {
|
||||||
return LaunchContext::_contexts[0].get();
|
return LaunchContext::_contexts[0].get();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::mutex* LaunchContext::deviceMutex() {
|
||||||
|
auto deviceId = AffinityManager::currentDeviceId();
|
||||||
|
return _deviceMutexes[deviceId];
|
||||||
|
}
|
||||||
|
|
||||||
void LaunchContext::swapContextBuffers(ContextBuffers &buffers) {
|
void LaunchContext::swapContextBuffers(ContextBuffers &buffers) {
|
||||||
//
|
//
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,6 +31,7 @@ namespace sd {
|
||||||
|
|
||||||
std::vector<std::shared_ptr<LaunchContext>> LaunchContext::_contexts = std::vector<std::shared_ptr<LaunchContext>>();
|
std::vector<std::shared_ptr<LaunchContext>> LaunchContext::_contexts = std::vector<std::shared_ptr<LaunchContext>>();
|
||||||
std::mutex LaunchContext::_mutex;
|
std::mutex LaunchContext::_mutex;
|
||||||
|
MAP_IMPL<int, std::mutex*> LaunchContext::_deviceMutexes;
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
LaunchContext::LaunchContext(cudaStream_t *cudaStream, cudaStream_t& specialCudaStream, void* reductionPointer, void* scalarPointer, int* allocationPointer) {
|
LaunchContext::LaunchContext(cudaStream_t *cudaStream, cudaStream_t& specialCudaStream, void* reductionPointer, void* scalarPointer, int* allocationPointer) {
|
||||||
|
@ -44,6 +45,11 @@ LaunchContext::LaunchContext(cudaStream_t *cudaStream, cudaStream_t& specialCuda
|
||||||
_isAllocated = false;
|
_isAllocated = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::mutex* LaunchContext::deviceMutex() {
|
||||||
|
auto deviceId = AffinityManager::currentDeviceId();
|
||||||
|
return _deviceMutexes[deviceId];
|
||||||
|
}
|
||||||
|
|
||||||
LaunchContext::~LaunchContext() {
|
LaunchContext::~LaunchContext() {
|
||||||
if (_isAllocated) {
|
if (_isAllocated) {
|
||||||
|
|
||||||
|
@ -85,6 +91,8 @@ LaunchContext::LaunchContext() {
|
||||||
|
|
||||||
_contexts.resize(numDevices);
|
_contexts.resize(numDevices);
|
||||||
for (int e = 0; e < numDevices; e++) {
|
for (int e = 0; e < numDevices; e++) {
|
||||||
|
_deviceMutexes[e] = new std::mutex();
|
||||||
|
|
||||||
AffinityManager::setCurrentNativeDevice(e);
|
AffinityManager::setCurrentNativeDevice(e);
|
||||||
|
|
||||||
LaunchContext::_contexts[e] = std::make_shared<LaunchContext>();
|
LaunchContext::_contexts[e] = std::make_shared<LaunchContext>();
|
||||||
|
|
|
@ -252,6 +252,8 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
|
||||||
const bool typeIntFloat = AB && aType == DataType::INT8 && cType == DataType::FLOAT32 && major >= 6;
|
const bool typeIntFloat = AB && aType == DataType::INT8 && cType == DataType::FLOAT32 && major >= 6;
|
||||||
const bool typeHalfFloat = AB && aType == DataType::HALF && cType == DataType::FLOAT32 && major >= 6;
|
const bool typeHalfFloat = AB && aType == DataType::HALF && cType == DataType::FLOAT32 && major >= 6;
|
||||||
|
|
||||||
|
std::lock_guard<std::mutex> lock(*LaunchContext::deviceMutex());
|
||||||
|
|
||||||
auto handle = reinterpret_cast<cublasHandle_t *>(A->getContext()->getCublasHandle());
|
auto handle = reinterpret_cast<cublasHandle_t *>(A->getContext()->getCublasHandle());
|
||||||
auto stream = A->getContext()->getCudaStream();
|
auto stream = A->getContext()->getCudaStream();
|
||||||
|
|
||||||
|
@ -394,6 +396,8 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, sd::NDArray* Y,
|
||||||
const bool typeDouble = AXY && aType == DataType::DOUBLE;
|
const bool typeDouble = AXY && aType == DataType::DOUBLE;
|
||||||
const bool typeFloat = AXY && aType == DataType::FLOAT32;
|
const bool typeFloat = AXY && aType == DataType::FLOAT32;
|
||||||
|
|
||||||
|
std::lock_guard<std::mutex> lock(*LaunchContext::deviceMutex());
|
||||||
|
|
||||||
auto handle = reinterpret_cast<cublasHandle_t *>(A->getContext()->getCublasHandle());
|
auto handle = reinterpret_cast<cublasHandle_t *>(A->getContext()->getCublasHandle());
|
||||||
auto stream = A->getContext()->getCudaStream();
|
auto stream = A->getContext()->getCudaStream();
|
||||||
|
|
||||||
|
|
|
@ -106,6 +106,7 @@ namespace sd {
|
||||||
void storeResult(Context &block, int outputNumber, NDArray& array);
|
void storeResult(Context &block, int outputNumber, NDArray& array);
|
||||||
void storeResult(Context &block, int outputNumber, NDArray* array);
|
void storeResult(Context &block, int outputNumber, NDArray* array);
|
||||||
sd::NDArray* getZ(Context& block, int inputId = 0);
|
sd::NDArray* getZ(Context& block, int inputId = 0);
|
||||||
|
sd::NDArray* getNullifiedZ(Context& block, int inputId = 0);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method pre-allocates NDArrays for Op output, in case they are not available at op execution time
|
* This method pre-allocates NDArrays for Op output, in case they are not available at op execution time
|
||||||
|
|
|
@ -77,7 +77,15 @@ namespace sd {
|
||||||
* @param inputId
|
* @param inputId
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
sd::NDArray *getZ(graph::Context &ctx, int inputId);
|
sd::NDArray* getZ(graph::Context &ctx, int inputId);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Helper method, needed for compatibility with DeclarableOp macros
|
||||||
|
* @param ctx
|
||||||
|
* @param inputId
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
sd::NDArray* getNullifiedZ(graph::Context &ctx, int inputId);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,7 +30,7 @@ namespace sd {
|
||||||
CUSTOM_OP_IMPL(bits_hamming_distance, 2, 1, true, 0, 0) {
|
CUSTOM_OP_IMPL(bits_hamming_distance, 2, 1, true, 0, 0) {
|
||||||
auto x = INPUT_VARIABLE(0);
|
auto x = INPUT_VARIABLE(0);
|
||||||
auto y = INPUT_VARIABLE(1);
|
auto y = INPUT_VARIABLE(1);
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_NULLIFIED(0);
|
||||||
|
|
||||||
REQUIRE_TRUE(x->lengthOf() == y->lengthOf(), 0, "bits_hamming_distance: both arguments must have the same length");
|
REQUIRE_TRUE(x->lengthOf() == y->lengthOf(), 0, "bits_hamming_distance: both arguments must have the same length");
|
||||||
REQUIRE_TRUE(x->dataType() == y->dataType(), 0, "bits_hamming_distance: both arguments must have the same data type");
|
REQUIRE_TRUE(x->dataType() == y->dataType(), 0, "bits_hamming_distance: both arguments must have the same data type");
|
||||||
|
|
|
@ -32,7 +32,7 @@ namespace sd {
|
||||||
auto values = INPUT_VARIABLE(2);
|
auto values = INPUT_VARIABLE(2);
|
||||||
NDArray *def = nullptr;
|
NDArray *def = nullptr;
|
||||||
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_NULLIFIED(0);
|
||||||
|
|
||||||
if (block.width() > 3)
|
if (block.width() > 3)
|
||||||
def = INPUT_VARIABLE(3);
|
def = INPUT_VARIABLE(3);
|
||||||
|
|
|
@ -30,7 +30,7 @@ namespace sd {
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto delim = INPUT_VARIABLE(1);
|
auto delim = INPUT_VARIABLE(1);
|
||||||
|
|
||||||
auto indices = OUTPUT_VARIABLE(0);
|
auto indices = OUTPUT_NULLIFIED(0);
|
||||||
auto values = OUTPUT_VARIABLE(1);
|
auto values = OUTPUT_VARIABLE(1);
|
||||||
|
|
||||||
auto d = delim->e<std::string>(0);
|
auto d = delim->e<std::string>(0);
|
||||||
|
|
|
@ -28,7 +28,7 @@ namespace sd {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
CUSTOM_OP_IMPL(col2im, 1, 1, false, 0, 9) {
|
CUSTOM_OP_IMPL(col2im, 1, 1, false, 0, 9) {
|
||||||
auto x = INPUT_VARIABLE(0);
|
auto x = INPUT_VARIABLE(0);
|
||||||
auto z = OUTPUT_VARIABLE(0);
|
auto z = OUTPUT_NULLIFIED(0);
|
||||||
|
|
||||||
REQUIRE_TRUE(x->rankOf() == 6, 0, "col2im input should be 6D, but got %i instead", x->rankOf());
|
REQUIRE_TRUE(x->rankOf() == 6, 0, "col2im input should be 6D, but got %i instead", x->rankOf());
|
||||||
REQUIRE_TRUE(z->rankOf() == 4, 0, "col2im output should be 4D, but got %i instead", z->rankOf());
|
REQUIRE_TRUE(z->rankOf() == 4, 0, "col2im output should be 4D, but got %i instead", z->rankOf());
|
||||||
|
@ -45,8 +45,6 @@ namespace sd {
|
||||||
LaunchContext* ctx = block.launchContext();
|
LaunchContext* ctx = block.launchContext();
|
||||||
helpers::col2im(*ctx, *x, *z, strideY, strideX, padHeight, padWidth, imgHeight, imgWidth, dY, dX);
|
helpers::col2im(*ctx, *x, *z, strideY, strideX, padHeight, padWidth, imgHeight, imgWidth, dY, dX);
|
||||||
|
|
||||||
STORE_RESULT(*z);
|
|
||||||
|
|
||||||
return ND4J_STATUS_OK;
|
return ND4J_STATUS_OK;
|
||||||
}
|
}
|
||||||
DECLARE_SHAPE_FN(col2im) {
|
DECLARE_SHAPE_FN(col2im) {
|
||||||
|
|
|
@ -37,7 +37,7 @@ CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 5) {
|
||||||
auto weights = INPUT_VARIABLE(1); // [kW, iC, oC] always
|
auto weights = INPUT_VARIABLE(1); // [kW, iC, oC] always
|
||||||
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
|
||||||
auto output = OUTPUT_VARIABLE(0); // [bS, oW, oC] (NWC) or [bS, oC, oW] (NCW)
|
auto output = OUTPUT_NULLIFIED(0); // [bS, oW, oC] (NWC) or [bS, oC, oW] (NCW)
|
||||||
|
|
||||||
int kW = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) width
|
int kW = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) width
|
||||||
int sW = INT_ARG(1); // strides width
|
int sW = INT_ARG(1); // strides width
|
||||||
|
@ -167,9 +167,9 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) {
|
||||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oW, oC] (NWC) or [bS, oC, oW] (NCW), epsilon_next
|
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oW, oC] (NWC) or [bS, oC, oW] (NCW), epsilon_next
|
||||||
|
|
||||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW), epsilon
|
auto gradI = OUTPUT_NULLIFIED(0); // [bS, iW, iC] (NWC) or [bS, iC, iW] (NCW), epsilon
|
||||||
auto gradW = OUTPUT_VARIABLE(1); // [kW, iC, oC] always
|
auto gradW = OUTPUT_NULLIFIED(1); // [kW, iC, oC] always
|
||||||
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC]
|
||||||
|
|
||||||
int kW = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) width
|
int kW = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) width
|
||||||
int sW = INT_ARG(1); // strides width
|
int sW = INT_ARG(1); // strides width
|
||||||
|
|
|
@ -40,7 +40,7 @@ CUSTOM_OP_IMPL(conv2d, 2, 1, false, 0, 9) {
|
||||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
|
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
|
||||||
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
|
||||||
auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
|
auto output = OUTPUT_NULLIFIED(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
|
||||||
|
|
||||||
int sH = INT_ARG(2); // strides height
|
int sH = INT_ARG(2); // strides height
|
||||||
int sW = INT_ARG(3); // strides width
|
int sW = INT_ARG(3); // strides width
|
||||||
|
@ -161,9 +161,9 @@ CUSTOM_OP_IMPL(conv2d_bp, 3, 2, false, 0, 9) {
|
||||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||||
|
|
||||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||||
auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, oC] always
|
auto gradW = OUTPUT_NULLIFIED(1); // [kH, kW, iC, oC] always
|
||||||
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC]
|
||||||
|
|
||||||
int kH = INT_ARG(0); // filter(kernel) height
|
int kH = INT_ARG(0); // filter(kernel) height
|
||||||
int kW = INT_ARG(1); // filter(kernel) width
|
int kW = INT_ARG(1); // filter(kernel) width
|
||||||
|
@ -267,7 +267,7 @@ CUSTOM_OP_IMPL(conv2d_input_bp, 3, 1, false, 0, 9) {
|
||||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
|
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
|
||||||
auto gradO = INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
auto gradO = INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||||
|
|
||||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||||
|
|
||||||
int kH = INT_ARG(0); // filter(kernel) height
|
int kH = INT_ARG(0); // filter(kernel) height
|
||||||
int kW = INT_ARG(1); // filter(kernel) width
|
int kW = INT_ARG(1); // filter(kernel) width
|
||||||
|
|
|
@ -38,7 +38,7 @@ CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) {
|
||||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC] always
|
auto weights = INPUT_VARIABLE(1); // [kH, kW, oC, iC] always
|
||||||
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
|
||||||
auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
|
auto output = OUTPUT_NULLIFIED(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
|
||||||
|
|
||||||
REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DECONV2D OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf());
|
REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DECONV2D OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf());
|
||||||
REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM DECONV2D OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf());
|
REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM DECONV2D OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf());
|
||||||
|
|
|
@ -35,7 +35,7 @@ CUSTOM_OP_IMPL(deconv2d_tf, 3, 1, false, 0, 9) {
|
||||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
|
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
|
||||||
auto gradIShape = INPUT_VARIABLE(0); // [4] - shape of input of conv2d (that is shape of gradI)
|
auto gradIShape = INPUT_VARIABLE(0); // [4] - shape of input of conv2d (that is shape of gradI)
|
||||||
|
|
||||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||||
|
|
||||||
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) height
|
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0));// filter(kernel) height
|
||||||
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) width
|
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1));// filter(kernel) width
|
||||||
|
|
|
@ -35,7 +35,7 @@ CUSTOM_OP_IMPL(depthwise_conv2d, 2, 1, false, 0, 9) {
|
||||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always
|
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always
|
||||||
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] = iC*mC
|
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] = iC*mC
|
||||||
|
|
||||||
auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW)
|
auto output = OUTPUT_NULLIFIED(0); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW)
|
||||||
|
|
||||||
REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf());
|
REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf());
|
||||||
REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf());
|
REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf());
|
||||||
|
@ -152,9 +152,9 @@ CUSTOM_OP_IMPL(depthwise_conv2d_bp, 3, 2, false, 0, 9) {
|
||||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC]
|
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC]
|
||||||
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
|
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
|
||||||
|
|
||||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon
|
auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon
|
||||||
auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, mC] always
|
auto gradW = OUTPUT_NULLIFIED(1); // [kH, kW, iC, mC] always
|
||||||
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC]
|
||||||
|
|
||||||
REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf());
|
REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf());
|
||||||
REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf());
|
REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf());
|
||||||
|
|
|
@ -30,8 +30,7 @@ namespace sd {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
CUSTOM_OP_IMPL(im2col, 1, 1, false, 0, 9) {
|
CUSTOM_OP_IMPL(im2col, 1, 1, false, 0, 9) {
|
||||||
auto x = INPUT_VARIABLE(0);
|
auto x = INPUT_VARIABLE(0);
|
||||||
auto z = OUTPUT_VARIABLE(0);
|
auto z = OUTPUT_NULLIFIED(0);
|
||||||
|
|
||||||
|
|
||||||
REQUIRE_TRUE(x->rankOf() == 4, 0, "im2col input should be 4D, but got %i instead", x->rankOf());
|
REQUIRE_TRUE(x->rankOf() == 4, 0, "im2col input should be 4D, but got %i instead", x->rankOf());
|
||||||
REQUIRE_TRUE(z->rankOf() == 6, 0, "im2col output should be 6D, but got %i instead", z->rankOf());
|
REQUIRE_TRUE(z->rankOf() == 6, 0, "im2col output should be 6D, but got %i instead", z->rankOf());
|
||||||
|
@ -53,8 +52,6 @@ namespace sd {
|
||||||
LaunchContext* ctx = block.launchContext();
|
LaunchContext* ctx = block.launchContext();
|
||||||
sd::ops::helpers::im2col(*ctx, *x, *z, kernelHeight, kernelWidth, strideY, strideX, padHeight, padWidth, dY, dX, NDArrayFactory::create(zeroPadVal, block.launchContext()));
|
sd::ops::helpers::im2col(*ctx, *x, *z, kernelHeight, kernelWidth, strideY, strideX, padHeight, padWidth, dY, dX, NDArrayFactory::create(zeroPadVal, block.launchContext()));
|
||||||
|
|
||||||
STORE_RESULT(*z);
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -107,7 +104,7 @@ namespace sd {
|
||||||
CUSTOM_OP_IMPL(im2col_bp, 2, 1, false, 0, 9) {
|
CUSTOM_OP_IMPL(im2col_bp, 2, 1, false, 0, 9) {
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto gradAtOutput = INPUT_VARIABLE(1);
|
auto gradAtOutput = INPUT_VARIABLE(1);
|
||||||
auto z = OUTPUT_VARIABLE(0);
|
auto z = OUTPUT_NULLIFIED(0);
|
||||||
|
|
||||||
REQUIRE_TRUE(input->rankOf() == 4, 0, "im2col_bp input should be 4D, but got %i instead", input->rankOf());
|
REQUIRE_TRUE(input->rankOf() == 4, 0, "im2col_bp input should be 4D, but got %i instead", input->rankOf());
|
||||||
REQUIRE_TRUE(gradAtOutput->rankOf() == 6, 0, "im2col_bp gradient at output (input idx 1) should be 6D, but got %i instead", gradAtOutput->rankOf());
|
REQUIRE_TRUE(gradAtOutput->rankOf() == 6, 0, "im2col_bp gradient at output (input idx 1) should be 6D, but got %i instead", gradAtOutput->rankOf());
|
||||||
|
|
|
@ -37,7 +37,7 @@ CUSTOM_OP_IMPL(sconv2d, 2, 1, false, 0, 9) {
|
||||||
NDArray *weightsPoint = nullptr; // [1, 1, iC*mC, oC] always
|
NDArray *weightsPoint = nullptr; // [1, 1, iC*mC, oC] always
|
||||||
NDArray *bias = nullptr; // [oC], if weightsPoint=nullptr then oC = iC*mC
|
NDArray *bias = nullptr; // [oC], if weightsPoint=nullptr then oC = iC*mC
|
||||||
|
|
||||||
NDArray *output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
|
NDArray *output = OUTPUT_NULLIFIED(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
|
||||||
|
|
||||||
if(block.width() == 3) {
|
if(block.width() == 3) {
|
||||||
if((INPUT_VARIABLE(2))->rankOf() == 4)
|
if((INPUT_VARIABLE(2))->rankOf() == 4)
|
||||||
|
@ -199,26 +199,26 @@ CUSTOM_OP_IMPL(sconv2d_bp, 3, 2, false, 0, 9) {
|
||||||
NDArray *weightsPoint = nullptr; // [1, 1, iC*mC, oC] always
|
NDArray *weightsPoint = nullptr; // [1, 1, iC*mC, oC] always
|
||||||
NDArray *bias = nullptr; // [oC], oC = iC*mC if weightsPoint=nullptr
|
NDArray *bias = nullptr; // [oC], oC = iC*mC if weightsPoint=nullptr
|
||||||
|
|
||||||
NDArray *gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
NDArray *gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||||
NDArray *gradWD = OUTPUT_VARIABLE(1); // [kH, kW, iC, mC] always
|
NDArray *gradWD = OUTPUT_NULLIFIED(1); // [kH, kW, iC, mC] always
|
||||||
NDArray *gradWP = nullptr; // [1, 1, iC*mC, oC] always
|
NDArray *gradWP = nullptr; // [1, 1, iC*mC, oC] always
|
||||||
NDArray *gradB = nullptr; // [oC]
|
NDArray *gradB = nullptr; // [oC]
|
||||||
|
|
||||||
if(block.width() == 4) {
|
if(block.width() == 4) {
|
||||||
if((INPUT_VARIABLE(3))->rankOf() == 4) {
|
if((INPUT_VARIABLE(3))->rankOf() == 4) {
|
||||||
weightsPoint = INPUT_VARIABLE(3);
|
weightsPoint = INPUT_VARIABLE(3);
|
||||||
gradWP = OUTPUT_VARIABLE(2);
|
gradWP = OUTPUT_NULLIFIED(2);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
bias = INPUT_VARIABLE(3);
|
bias = INPUT_VARIABLE(3);
|
||||||
gradB = OUTPUT_VARIABLE(2);
|
gradB = OUTPUT_NULLIFIED(2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if(block.width() == 5) {
|
else if(block.width() == 5) {
|
||||||
weightsPoint = INPUT_VARIABLE(3);
|
weightsPoint = INPUT_VARIABLE(3);
|
||||||
bias = INPUT_VARIABLE(4);
|
bias = INPUT_VARIABLE(4);
|
||||||
gradWP = OUTPUT_VARIABLE(2);
|
gradWP = OUTPUT_NULLIFIED(2);
|
||||||
gradB = OUTPUT_VARIABLE(3);
|
gradB = OUTPUT_NULLIFIED(3);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -32,7 +32,7 @@ namespace ops {
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
CUSTOM_OP_IMPL(upsampling2d, 1, 1, false, 0, 2) {
|
CUSTOM_OP_IMPL(upsampling2d, 1, 1, false, 0, 2) {
|
||||||
auto input = INPUT_VARIABLE(0); // [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC)
|
auto input = INPUT_VARIABLE(0); // [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC)
|
||||||
auto output = OUTPUT_VARIABLE(0); // [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, factorH*iH, factorW*iW, iC] (NHWC)
|
auto output = OUTPUT_NULLIFIED(0); // [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, factorH*iH, factorW*iW, iC] (NHWC)
|
||||||
|
|
||||||
const int factorH = INT_ARG(0);
|
const int factorH = INT_ARG(0);
|
||||||
const int factorW = INT_ARG(1);
|
const int factorW = INT_ARG(1);
|
||||||
|
@ -97,7 +97,7 @@ CUSTOM_OP_IMPL(upsampling2d_bp, 2, 1, false, 0, 0) {
|
||||||
|
|
||||||
// NDArray<T>* input = INPUT_VARIABLE(0); // [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC)
|
// NDArray<T>* input = INPUT_VARIABLE(0); // [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC)
|
||||||
auto gradO = INPUT_VARIABLE(1); // [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, factorH*iH, factorW*iW, iC] (NHWC)
|
auto gradO = INPUT_VARIABLE(1); // [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, factorH*iH, factorW*iW, iC] (NHWC)
|
||||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC)
|
auto gradI = OUTPUT_NULLIFIED(0); // [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC)
|
||||||
|
|
||||||
const int isNCHW = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; // INT_ARG(0): 0-NCHW, 1-NHWC
|
const int isNCHW = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; // INT_ARG(0): 0-NCHW, 1-NHWC
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,7 @@ namespace ops {
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
CUSTOM_OP_IMPL(upsampling3d, 1, 1, false, 0, 3) {
|
CUSTOM_OP_IMPL(upsampling3d, 1, 1, false, 0, 3) {
|
||||||
auto input = INPUT_VARIABLE(0); // [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC)
|
auto input = INPUT_VARIABLE(0); // [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC)
|
||||||
auto output = OUTPUT_VARIABLE(0); // [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC)
|
auto output = OUTPUT_NULLIFIED(0); // [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC)
|
||||||
|
|
||||||
const int factorD = INT_ARG(0);
|
const int factorD = INT_ARG(0);
|
||||||
const int factorH = INT_ARG(1);
|
const int factorH = INT_ARG(1);
|
||||||
|
@ -97,7 +97,7 @@ DECLARE_SHAPE_FN(upsampling3d) {
|
||||||
CUSTOM_OP_IMPL(upsampling3d_bp, 2, 1, false, 0, 0) {
|
CUSTOM_OP_IMPL(upsampling3d_bp, 2, 1, false, 0, 0) {
|
||||||
// NDArray<T>* input = INPUT_VARIABLE(0); // [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC)
|
// NDArray<T>* input = INPUT_VARIABLE(0); // [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC)
|
||||||
auto gradO = INPUT_VARIABLE(1); // [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC)
|
auto gradO = INPUT_VARIABLE(1); // [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC)
|
||||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC)
|
auto gradI = OUTPUT_NULLIFIED(0); // [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC)
|
||||||
|
|
||||||
const int isNCDHW = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; // INT_ARG(0): 0-NCHW, 1-NHWC
|
const int isNCDHW = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; // INT_ARG(0): 0-NCHW, 1-NHWC
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,7 @@ namespace ops {
|
||||||
CUSTOM_OP_IMPL(avgpool2d, 1, 1, false, 0, 10) {
|
CUSTOM_OP_IMPL(avgpool2d, 1, 1, false, 0, 10) {
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_NULLIFIED(0);
|
||||||
|
|
||||||
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
||||||
|
|
||||||
|
@ -147,7 +147,7 @@ CUSTOM_OP_IMPL(avgpool2d_bp, 2, 1, false, 0, 10) {
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||||
auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||||
|
|
||||||
int kH = INT_ARG(0); // filter(kernel) height
|
int kH = INT_ARG(0); // filter(kernel) height
|
||||||
int kW = INT_ARG(1); // filter(kernel) width
|
int kW = INT_ARG(1); // filter(kernel) width
|
||||||
|
|
|
@ -32,7 +32,7 @@ namespace ops {
|
||||||
CUSTOM_OP_IMPL(avgpool3dnew, 1, 1, false, 0, 14) {
|
CUSTOM_OP_IMPL(avgpool3dnew, 1, 1, false, 0, 14) {
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||||
auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW)
|
auto output = OUTPUT_NULLIFIED(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW)
|
||||||
|
|
||||||
int kD = INT_ARG(0); // filter(kernel) depth
|
int kD = INT_ARG(0); // filter(kernel) depth
|
||||||
int kH = INT_ARG(1); // filter(kernel) height
|
int kH = INT_ARG(1); // filter(kernel) height
|
||||||
|
@ -149,7 +149,7 @@ CUSTOM_OP_IMPL(avgpool3dnew_bp, 2, 1, false, 0, 14) {
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||||
auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
auto gradI = OUTPUT_NULLIFIED(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
||||||
|
|
||||||
const int kD = INT_ARG(0); // filter(kernel) depth
|
const int kD = INT_ARG(0); // filter(kernel) depth
|
||||||
const int kH = INT_ARG(1); // filter(kernel) height
|
const int kH = INT_ARG(1); // filter(kernel) height
|
||||||
|
|
|
@ -38,7 +38,7 @@ CUSTOM_OP_IMPL(maxpool2d, 1, 1, false, 0, 9) {
|
||||||
REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D OP: input array should have rank of 4, but got %i instead", input->rankOf());
|
REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D OP: input array should have rank of 4, but got %i instead", input->rankOf());
|
||||||
|
|
||||||
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_NULLIFIED(0);
|
||||||
|
|
||||||
const int kH = INT_ARG(0);
|
const int kH = INT_ARG(0);
|
||||||
const int kW = INT_ARG(1);
|
const int kW = INT_ARG(1);
|
||||||
|
@ -150,7 +150,7 @@ CUSTOM_OP_IMPL(maxpool2d_bp, 2, 1, false, 0, 10) {
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||||
auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||||
|
|
||||||
int kH = INT_ARG(0); // filter(kernel) height
|
int kH = INT_ARG(0); // filter(kernel) height
|
||||||
int kW = INT_ARG(1); // filter(kernel) width
|
int kW = INT_ARG(1); // filter(kernel) width
|
||||||
|
|
|
@ -32,7 +32,7 @@ namespace ops {
|
||||||
CUSTOM_OP_IMPL(maxpool3dnew, 1, 1, false, 0, 14) {
|
CUSTOM_OP_IMPL(maxpool3dnew, 1, 1, false, 0, 14) {
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||||
auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW)
|
auto output = OUTPUT_NULLIFIED(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW)
|
||||||
|
|
||||||
int kD = INT_ARG(0); // filter(kernel) depth
|
int kD = INT_ARG(0); // filter(kernel) depth
|
||||||
int kH = INT_ARG(1); // filter(kernel) height
|
int kH = INT_ARG(1); // filter(kernel) height
|
||||||
|
@ -151,7 +151,7 @@ CUSTOM_OP_IMPL(maxpool3dnew_bp, 2, 1, false, 0, 14) {
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||||
auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
auto gradI = OUTPUT_NULLIFIED(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
||||||
|
|
||||||
const int kD = INT_ARG(0); // filter(kernel) depth
|
const int kD = INT_ARG(0); // filter(kernel) depth
|
||||||
const int kH = INT_ARG(1); // filter(kernel) height
|
const int kH = INT_ARG(1); // filter(kernel) height
|
||||||
|
|
|
@ -30,14 +30,14 @@ namespace sd {
|
||||||
CUSTOM_OP_IMPL(max_pool_with_argmax, 1, 2, false, 0, 9) {
|
CUSTOM_OP_IMPL(max_pool_with_argmax, 1, 2, false, 0, 9) {
|
||||||
|
|
||||||
auto x = INPUT_VARIABLE(0);
|
auto x = INPUT_VARIABLE(0);
|
||||||
auto z = OUTPUT_VARIABLE(0);
|
auto z = OUTPUT_NULLIFIED(0);
|
||||||
auto indeces = OUTPUT_VARIABLE(1);
|
auto indices = OUTPUT_NULLIFIED(1);
|
||||||
|
|
||||||
REQUIRE_TRUE(x->rankOf() == 4, 0, "max_pool_with_argmax: Input should have rank of 4, but got %i instead", x->rankOf());
|
REQUIRE_TRUE(x->rankOf() == 4, 0, "max_pool_with_argmax: Input should have rank of 4, but got %i instead", x->rankOf());
|
||||||
|
|
||||||
auto argI = *(block.getIArguments());
|
auto argI = *(block.getIArguments());
|
||||||
|
|
||||||
helpers::maxPoolingFunctor(block.launchContext(), block, x, z, argI, indeces);
|
helpers::maxPoolingFunctor(block.launchContext(), block, x, z, argI, indices);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,7 +32,7 @@ namespace sd {
|
||||||
REQUIRE_OK(this->validateInputLengthMatch(block));
|
REQUIRE_OK(this->validateInputLengthMatch(block));
|
||||||
REQUIRE_OK(this->validateInputDimensionsMatch(block));
|
REQUIRE_OK(this->validateInputDimensionsMatch(block));
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_NULLIFIED(0);
|
||||||
|
|
||||||
REQUIRE_TRUE(input->rankOf() == 4, 0, "PNORMPOOL2D op: input should have rank of 4, but got %i instead", input->rankOf());
|
REQUIRE_TRUE(input->rankOf() == 4, 0, "PNORMPOOL2D op: input should have rank of 4, but got %i instead", input->rankOf());
|
||||||
|
|
||||||
|
@ -145,7 +145,7 @@ CUSTOM_OP_IMPL(pnormpool2d_bp, 2, 1, false, 1, 10) {
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||||
auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||||
|
|
||||||
int kH = INT_ARG(0); // filter(kernel) height
|
int kH = INT_ARG(0); // filter(kernel) height
|
||||||
int kW = INT_ARG(1); // filter(kernel) width
|
int kW = INT_ARG(1); // filter(kernel) width
|
||||||
|
|
|
@ -33,7 +33,7 @@ CONFIGURABLE_OP_IMPL(dropout, 1, 1, true, 1, 1) {
|
||||||
auto input = INPUT_VARIABLE(0); // lookup param
|
auto input = INPUT_VARIABLE(0); // lookup param
|
||||||
|
|
||||||
NDArray *reduceShape = nullptr; // this param is optional
|
NDArray *reduceShape = nullptr; // this param is optional
|
||||||
auto output = OUTPUT_VARIABLE(0); //
|
auto output = OUTPUT_NULLIFIED(0); //
|
||||||
|
|
||||||
int seed = INT_ARG(0);
|
int seed = INT_ARG(0);
|
||||||
|
|
||||||
|
@ -66,7 +66,7 @@ CONFIGURABLE_OP_IMPL(dropout_bp, 2, 1, false, 1, 1) {
|
||||||
NDArray* gradOut = INPUT_VARIABLE(1); // lookup param
|
NDArray* gradOut = INPUT_VARIABLE(1); // lookup param
|
||||||
|
|
||||||
NDArray* reduceShape = nullptr; // this param is optional
|
NDArray* reduceShape = nullptr; // this param is optional
|
||||||
NDArray* output = OUTPUT_VARIABLE(0); //
|
NDArray* output = OUTPUT_NULLIFIED(0); //
|
||||||
|
|
||||||
int seed = INT_ARG(0);
|
int seed = INT_ARG(0);
|
||||||
|
|
||||||
|
|
|
@ -30,7 +30,7 @@ namespace sd {
|
||||||
CUSTOM_OP_IMPL(lstsq, 2, 1, false, 0, 0) {
|
CUSTOM_OP_IMPL(lstsq, 2, 1, false, 0, 0) {
|
||||||
auto a = INPUT_VARIABLE(0);
|
auto a = INPUT_VARIABLE(0);
|
||||||
auto b = INPUT_VARIABLE(1);
|
auto b = INPUT_VARIABLE(1);
|
||||||
auto z = OUTPUT_VARIABLE(0);
|
auto z = OUTPUT_NULLIFIED(0);
|
||||||
bool fastFlag = true;
|
bool fastFlag = true;
|
||||||
double l2_factor = 0.;
|
double l2_factor = 0.;
|
||||||
if (block.numB() > 0) {
|
if (block.numB() > 0) {
|
||||||
|
@ -56,7 +56,7 @@ namespace sd {
|
||||||
CUSTOM_OP_IMPL(solve_ls, 2, 1, false, 0, 0) {
|
CUSTOM_OP_IMPL(solve_ls, 2, 1, false, 0, 0) {
|
||||||
auto a = INPUT_VARIABLE(0);
|
auto a = INPUT_VARIABLE(0);
|
||||||
auto b = INPUT_VARIABLE(1);
|
auto b = INPUT_VARIABLE(1);
|
||||||
auto z = OUTPUT_VARIABLE(0);
|
auto z = OUTPUT_NULLIFIED(0);
|
||||||
bool fastFlag = true;
|
bool fastFlag = true;
|
||||||
double l2_factor = 0.;
|
double l2_factor = 0.;
|
||||||
if (block.numB() > 0) {
|
if (block.numB() > 0) {
|
||||||
|
|
|
@ -114,7 +114,7 @@ namespace sd {
|
||||||
|
|
||||||
CUSTOM_OP_IMPL(logdet, 1, 1, false, 0, 0) {
|
CUSTOM_OP_IMPL(logdet, 1, 1, false, 0, 0) {
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_NULLIFIED(0);
|
||||||
|
|
||||||
REQUIRE_TRUE(input->rankOf() >=2, 0, "logdet: The rank of input array should not less than 2, but %i is given", input->rankOf());
|
REQUIRE_TRUE(input->rankOf() >=2, 0, "logdet: The rank of input array should not less than 2, but %i is given", input->rankOf());
|
||||||
REQUIRE_TRUE(input->sizeAt(-1) == input->sizeAt(-2), 0, "logdet: The last two dimmensions should be equal, but %i and %i are given", input->sizeAt(-1), input->sizeAt(-2));
|
REQUIRE_TRUE(input->sizeAt(-1) == input->sizeAt(-2), 0, "logdet: The last two dimmensions should be equal, but %i and %i are given", input->sizeAt(-1), input->sizeAt(-2));
|
||||||
|
|
|
@ -76,8 +76,8 @@ namespace sd {
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto indices = INPUT_VARIABLE(1);
|
auto indices = INPUT_VARIABLE(1);
|
||||||
auto gradOut = INPUT_VARIABLE(2);
|
auto gradOut = INPUT_VARIABLE(2);
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_NULLIFIED(0);
|
||||||
auto outIndices = OUTPUT_VARIABLE(1);
|
auto outIndices = OUTPUT_NULLIFIED(1);
|
||||||
outIndices->assign(indices);
|
outIndices->assign(indices);
|
||||||
return helpers::segmentMaxFunctorBP(block.launchContext(), input, indices, gradOut, output);
|
return helpers::segmentMaxFunctorBP(block.launchContext(), input, indices, gradOut, output);
|
||||||
}
|
}
|
||||||
|
|
|
@ -76,8 +76,8 @@ namespace sd {
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto indices = INPUT_VARIABLE(1);
|
auto indices = INPUT_VARIABLE(1);
|
||||||
auto gradOut = INPUT_VARIABLE(2);
|
auto gradOut = INPUT_VARIABLE(2);
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_NULLIFIED(0);
|
||||||
auto outIndices = OUTPUT_VARIABLE(1);
|
auto outIndices = OUTPUT_NULLIFIED(1);
|
||||||
outIndices->assign(indices);
|
outIndices->assign(indices);
|
||||||
return helpers::segmentMeanFunctorBP(block.launchContext(), input, indices, gradOut, output);
|
return helpers::segmentMeanFunctorBP(block.launchContext(), input, indices, gradOut, output);
|
||||||
}
|
}
|
||||||
|
|
|
@ -66,8 +66,8 @@ namespace sd {
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto indices = INPUT_VARIABLE(1);
|
auto indices = INPUT_VARIABLE(1);
|
||||||
auto gradOut = INPUT_VARIABLE(2);
|
auto gradOut = INPUT_VARIABLE(2);
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_NULLIFIED(0);
|
||||||
auto outIndices = OUTPUT_VARIABLE(1);
|
auto outIndices = OUTPUT_NULLIFIED(1);
|
||||||
outIndices->assign(indices);
|
outIndices->assign(indices);
|
||||||
return helpers::segmentMinFunctorBP(block.launchContext(), input, indices, gradOut, output);
|
return helpers::segmentMinFunctorBP(block.launchContext(), input, indices, gradOut, output);
|
||||||
}
|
}
|
||||||
|
|
|
@ -67,8 +67,8 @@ namespace sd {
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto indices = INPUT_VARIABLE(1);
|
auto indices = INPUT_VARIABLE(1);
|
||||||
auto gradOut = INPUT_VARIABLE(2);
|
auto gradOut = INPUT_VARIABLE(2);
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_NULLIFIED(0);
|
||||||
auto outIndices = OUTPUT_VARIABLE(1);
|
auto outIndices = OUTPUT_NULLIFIED(1);
|
||||||
outIndices->assign(indices);
|
outIndices->assign(indices);
|
||||||
helpers::segmentProdFunctorBP(block.launchContext(), input, indices, gradOut, output);
|
helpers::segmentProdFunctorBP(block.launchContext(), input, indices, gradOut, output);
|
||||||
|
|
||||||
|
|
|
@ -65,7 +65,7 @@ namespace sd {
|
||||||
|
|
||||||
CUSTOM_OP_IMPL(segment_sum_bp, 3, 2, false, 0, 0) {
|
CUSTOM_OP_IMPL(segment_sum_bp, 3, 2, false, 0, 0) {
|
||||||
|
|
||||||
return helpers::segmentSumFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), OUTPUT_VARIABLE(0));
|
return helpers::segmentSumFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), OUTPUT_NULLIFIED(0));
|
||||||
}
|
}
|
||||||
DECLARE_SHAPE_FN(segment_sum_bp){
|
DECLARE_SHAPE_FN(segment_sum_bp){
|
||||||
Nd4jLong* in = inputShape->at(0);
|
Nd4jLong* in = inputShape->at(0);
|
||||||
|
|
|
@ -25,7 +25,7 @@ namespace sd {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
CUSTOM_OP_IMPL(sequence_mask, 1, 1, false, 0, 0) {
|
CUSTOM_OP_IMPL(sequence_mask, 1, 1, false, 0, 0) {
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_NULLIFIED(0);
|
||||||
const int inRank = input->rankOf();
|
const int inRank = input->rankOf();
|
||||||
|
|
||||||
//REQUIRE_TRUE(inRank >= 1, 0, "sequence_mask: input array must have rank >= 1, but %i given!", inRank);
|
//REQUIRE_TRUE(inRank >= 1, 0, "sequence_mask: input array must have rank >= 1, but %i given!", inRank);
|
||||||
|
|
|
@ -26,7 +26,7 @@ namespace sd {
|
||||||
CUSTOM_OP_IMPL(unsorted_segment_max, 2, 1, false, 0, 0) {
|
CUSTOM_OP_IMPL(unsorted_segment_max, 2, 1, false, 0, 0) {
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto idxSegments = INPUT_VARIABLE(1);
|
auto idxSegments = INPUT_VARIABLE(1);
|
||||||
auto segmentedOutput = OUTPUT_VARIABLE(0);
|
auto segmentedOutput = OUTPUT_NULLIFIED(0);
|
||||||
Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : INT_ARG(0);
|
Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : INT_ARG(0);
|
||||||
REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_max: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf());
|
REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_max: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf());
|
||||||
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_max: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0));
|
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_max: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0));
|
||||||
|
@ -67,7 +67,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
|
|
||||||
CUSTOM_OP_IMPL(unsorted_segment_max_bp, 3, 2, false, 0, 1) {
|
CUSTOM_OP_IMPL(unsorted_segment_max_bp, 3, 2, false, 0, 1) {
|
||||||
return helpers::unsortedSegmentMaxFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_VARIABLE(0));
|
return helpers::unsortedSegmentMaxFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_NULLIFIED(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_TYPES(unsorted_segment_max_bp) {
|
DECLARE_TYPES(unsorted_segment_max_bp) {
|
||||||
|
|
|
@ -26,7 +26,7 @@ namespace sd {
|
||||||
CUSTOM_OP_IMPL(unsorted_segment_mean, 2, 1, false, 0, 0) {
|
CUSTOM_OP_IMPL(unsorted_segment_mean, 2, 1, false, 0, 0) {
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto idxSegments = INPUT_VARIABLE(1);
|
auto idxSegments = INPUT_VARIABLE(1);
|
||||||
auto segmentedOutput = OUTPUT_VARIABLE(0);
|
auto segmentedOutput = OUTPUT_NULLIFIED(0);
|
||||||
Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : INT_ARG(0);
|
Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : INT_ARG(0);
|
||||||
|
|
||||||
REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_mean: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf());
|
REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_mean: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf());
|
||||||
|
@ -69,7 +69,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
|
|
||||||
CUSTOM_OP_IMPL(unsorted_segment_mean_bp, 3, 2, false, 0, 1) {
|
CUSTOM_OP_IMPL(unsorted_segment_mean_bp, 3, 2, false, 0, 1) {
|
||||||
return helpers::unsortedSegmentMeanFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_VARIABLE(0));
|
return helpers::unsortedSegmentMeanFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_NULLIFIED(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_TYPES(unsorted_segment_mean_bp) {
|
DECLARE_TYPES(unsorted_segment_mean_bp) {
|
||||||
|
|
|
@ -26,7 +26,7 @@ namespace sd {
|
||||||
CUSTOM_OP_IMPL(unsorted_segment_min, 2, 1, false, 0, 0) {
|
CUSTOM_OP_IMPL(unsorted_segment_min, 2, 1, false, 0, 0) {
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto idxSegments = INPUT_VARIABLE(1);
|
auto idxSegments = INPUT_VARIABLE(1);
|
||||||
auto segmentedOutput = OUTPUT_VARIABLE(0);
|
auto segmentedOutput = OUTPUT_NULLIFIED(0);
|
||||||
Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : INT_ARG(0);
|
Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : INT_ARG(0);
|
||||||
REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_min: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf());
|
REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_min: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf());
|
||||||
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_min: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0));
|
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_min: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0));
|
||||||
|
@ -69,7 +69,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
|
|
||||||
CUSTOM_OP_IMPL(unsorted_segment_min_bp, 3, 2, false, 0, 1) {
|
CUSTOM_OP_IMPL(unsorted_segment_min_bp, 3, 2, false, 0, 1) {
|
||||||
return helpers::unsortedSegmentMinFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_VARIABLE(0));
|
return helpers::unsortedSegmentMinFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_NULLIFIED(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_TYPES(unsorted_segment_min_bp) {
|
DECLARE_TYPES(unsorted_segment_min_bp) {
|
||||||
|
|
|
@ -26,7 +26,7 @@ namespace sd {
|
||||||
CUSTOM_OP_IMPL(unsorted_segment_prod, 2, 1, false, 0, 0) {
|
CUSTOM_OP_IMPL(unsorted_segment_prod, 2, 1, false, 0, 0) {
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto idxSegments = INPUT_VARIABLE(1);
|
auto idxSegments = INPUT_VARIABLE(1);
|
||||||
auto segmentedOutput = OUTPUT_VARIABLE(0);
|
auto segmentedOutput = OUTPUT_NULLIFIED(0);
|
||||||
Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : INT_ARG(0);
|
Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : INT_ARG(0);
|
||||||
REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_prod: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf());
|
REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_prod: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf());
|
||||||
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_prod: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0));
|
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_prod: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0));
|
||||||
|
@ -72,7 +72,7 @@ namespace sd {
|
||||||
auto indices = INPUT_VARIABLE(1);
|
auto indices = INPUT_VARIABLE(1);
|
||||||
auto eps = INPUT_VARIABLE(2);
|
auto eps = INPUT_VARIABLE(2);
|
||||||
// auto numOfClasses = INT_ARG(0);
|
// auto numOfClasses = INT_ARG(0);
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_NULLIFIED(0);
|
||||||
|
|
||||||
Nd4jLong numOfClasses = block.width() == 4 ? INPUT_VARIABLE(3)->e<Nd4jLong>(0) : INT_ARG(0);
|
Nd4jLong numOfClasses = block.width() == 4 ? INPUT_VARIABLE(3)->e<Nd4jLong>(0) : INT_ARG(0);
|
||||||
REQUIRE_TRUE(indices->isVector(), 0, "unsorted_segment_prod_bp: segment indexes array should be a vector, but it rank is %i.", indices->rankOf());
|
REQUIRE_TRUE(indices->isVector(), 0, "unsorted_segment_prod_bp: segment indexes array should be a vector, but it rank is %i.", indices->rankOf());
|
||||||
|
|
|
@ -26,7 +26,7 @@ namespace sd {
|
||||||
CUSTOM_OP_IMPL(unsorted_segment_sqrt_n, 2, 1, false, 0, 0) {
|
CUSTOM_OP_IMPL(unsorted_segment_sqrt_n, 2, 1, false, 0, 0) {
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto idxSegments = INPUT_VARIABLE(1);
|
auto idxSegments = INPUT_VARIABLE(1);
|
||||||
auto segmentedOutput = OUTPUT_VARIABLE(0);
|
auto segmentedOutput = OUTPUT_NULLIFIED(0);
|
||||||
Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : INT_ARG(0);
|
Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : INT_ARG(0);
|
||||||
REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_sqrt_n: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf());
|
REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_sqrt_n: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf());
|
||||||
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_sqrt_n: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0));
|
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_sqrt_n: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0));
|
||||||
|
@ -68,7 +68,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
|
|
||||||
CUSTOM_OP_IMPL(unsorted_segment_sqrt_n_bp, 3, 2, false, 0, 1) {
|
CUSTOM_OP_IMPL(unsorted_segment_sqrt_n_bp, 3, 2, false, 0, 1) {
|
||||||
return helpers::unsortedSegmentSqrtNFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_VARIABLE(0));
|
return helpers::unsortedSegmentSqrtNFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_NULLIFIED(0));
|
||||||
}
|
}
|
||||||
DECLARE_TYPES(unsorted_segment_sqrt_n_bp) {
|
DECLARE_TYPES(unsorted_segment_sqrt_n_bp) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
|
|
|
@ -26,7 +26,7 @@ namespace sd {
|
||||||
CUSTOM_OP_IMPL(unsorted_segment_sum, 2, 1, false, 0, 0) {
|
CUSTOM_OP_IMPL(unsorted_segment_sum, 2, 1, false, 0, 0) {
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto idxSegments = INPUT_VARIABLE(1);
|
auto idxSegments = INPUT_VARIABLE(1);
|
||||||
auto segmentedOutput = OUTPUT_VARIABLE(0);
|
auto segmentedOutput = OUTPUT_NULLIFIED(0);
|
||||||
Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : INT_ARG(0);
|
Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e<Nd4jLong>(0) : INT_ARG(0);
|
||||||
REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_sum: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf());
|
REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_sum: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf());
|
||||||
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_sum: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0));
|
REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_sum: segment indexes array length should be equal to the input first dimension, but %i != %i.", idxSegments->lengthOf(), input->sizeAt(0));
|
||||||
|
@ -67,7 +67,7 @@ namespace sd {
|
||||||
return SHAPELIST(CONSTANT(outputShape));
|
return SHAPELIST(CONSTANT(outputShape));
|
||||||
}
|
}
|
||||||
CUSTOM_OP_IMPL(unsorted_segment_sum_bp, 3, 2, false, 0, 1) {
|
CUSTOM_OP_IMPL(unsorted_segment_sum_bp, 3, 2, false, 0, 1) {
|
||||||
return helpers::unsortedSegmentSumFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_VARIABLE(0));
|
return helpers::unsortedSegmentSumFunctorBP(block.launchContext(), INPUT_VARIABLE(0), INPUT_VARIABLE(1), INPUT_VARIABLE(2), INT_ARG(0), OUTPUT_NULLIFIED(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(unsorted_segment_sum_bp){
|
DECLARE_SHAPE_FN(unsorted_segment_sum_bp){
|
||||||
|
|
|
@ -42,7 +42,7 @@ namespace sd {
|
||||||
CUSTOM_OP_IMPL(random_multinomial, 2, 1, false, 0, 0) {
|
CUSTOM_OP_IMPL(random_multinomial, 2, 1, false, 0, 0) {
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_NULLIFIED(0);
|
||||||
auto inputSamples = INPUT_VARIABLE(1);
|
auto inputSamples = INPUT_VARIABLE(1);
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -48,7 +48,7 @@ namespace sd {
|
||||||
*/
|
*/
|
||||||
CUSTOM_OP_IMPL(firas_sparse, 1, 1, false, 0, -1) {
|
CUSTOM_OP_IMPL(firas_sparse, 1, 1, false, 0, -1) {
|
||||||
auto x = INPUT_VARIABLE(0);
|
auto x = INPUT_VARIABLE(0);
|
||||||
auto z = OUTPUT_VARIABLE(0);
|
auto z = OUTPUT_NULLIFIED(0);
|
||||||
|
|
||||||
int batchSize = x->sizeAt(0);
|
int batchSize = x->sizeAt(0);
|
||||||
int numColumns = x->sizeAt(1);
|
int numColumns = x->sizeAt(1);
|
||||||
|
|
|
@ -34,7 +34,7 @@ namespace ops {
|
||||||
auto dataP = INPUT_VARIABLE(3);
|
auto dataP = INPUT_VARIABLE(3);
|
||||||
auto N = INT_ARG(0);
|
auto N = INT_ARG(0);
|
||||||
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_NULLIFIED(0);
|
||||||
|
|
||||||
REQUIRE_TRUE(rowP->isVector(), 0, "barnes_edge_force: row input must be a vector, but its rank is %i instead !", rowP->rankOf());
|
REQUIRE_TRUE(rowP->isVector(), 0, "barnes_edge_force: row input must be a vector, but its rank is %i instead !", rowP->rankOf());
|
||||||
REQUIRE_TRUE(colP->isVector(), 0, "barnes_edge_force: col input must be a vector, but its rank is %i instead !", colP->rankOf());
|
REQUIRE_TRUE(colP->isVector(), 0, "barnes_edge_force: col input must be a vector, but its rank is %i instead !", colP->rankOf());
|
||||||
|
|
|
@ -54,8 +54,6 @@ void col2im_(sd::LaunchContext & context, const NDArray& input, NDArray& output
|
||||||
const Nd4jLong imStride1 = imStride[1];
|
const Nd4jLong imStride1 = imStride[1];
|
||||||
const Nd4jLong imStride2 = imStride[2];
|
const Nd4jLong imStride2 = imStride[2];
|
||||||
const Nd4jLong imStride3 = imStride[3];
|
const Nd4jLong imStride3 = imStride[3];
|
||||||
|
|
||||||
memset(imBuff, 0, shape::length(imShapeBuffer) * sizeof(T));
|
|
||||||
|
|
||||||
|
|
||||||
// if (shape::order(colShapeBuffer) == 'c' && shape::order(imShapeBuffer) == 'c' && shape::strideDescendingCAscendingF(colShapeBuffer) && shape::strideDescendingCAscendingF(imShapeBuffer)) {
|
// if (shape::order(colShapeBuffer) == 'c' && shape::order(imShapeBuffer) == 'c' && shape::strideDescendingCAscendingF(colShapeBuffer) && shape::strideDescendingCAscendingF(imShapeBuffer)) {
|
||||||
|
|
|
@ -116,6 +116,8 @@ void bgemm(const std::vector<NDArray*>& vA, const std::vector<NDArray*>& vB, std
|
||||||
const auto bType = pB[0]->dataType();
|
const auto bType = pB[0]->dataType();
|
||||||
const auto cType = pC[0]->dataType();
|
const auto cType = pC[0]->dataType();
|
||||||
|
|
||||||
|
std::lock_guard<std::mutex> lock(*LaunchContext::deviceMutex());
|
||||||
|
|
||||||
auto handle = reinterpret_cast<cublasHandle_t*>(context->getCublasHandle());
|
auto handle = reinterpret_cast<cublasHandle_t*>(context->getCublasHandle());
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
|
|
||||||
|
|
|
@ -96,6 +96,14 @@ namespace sd {
|
||||||
return _descriptor->getHash();
|
return _descriptor->getHash();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sd::NDArray* sd::ops::DeclarableOp::getNullifiedZ(Context& block, int inputId) {
|
||||||
|
auto result = getZ(block, inputId);
|
||||||
|
if (result != nullptr && !block.isInplace())
|
||||||
|
result->nullify();
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
sd::NDArray* sd::ops::DeclarableOp::getZ(Context& ctx, int inputId) {
|
sd::NDArray* sd::ops::DeclarableOp::getZ(Context& ctx, int inputId) {
|
||||||
NDArray* z = nullptr;
|
NDArray* z = nullptr;
|
||||||
|
@ -294,7 +302,8 @@ namespace sd {
|
||||||
if (Environment::getInstance()->isDebugAndVerbose())
|
if (Environment::getInstance()->isDebugAndVerbose())
|
||||||
shape::printShapeInfoLinear("Going to create variable with shape", out);
|
shape::printShapeInfoLinear("Going to create variable with shape", out);
|
||||||
|
|
||||||
auto outArr = new NDArray(out, true, ctx.launchContext());
|
// we're creating non-initialized array here
|
||||||
|
auto outArr = new NDArray(out, true, ctx.launchContext(), false);
|
||||||
|
|
||||||
ctx.pushNDArrayToVariableSpace(pair, outArr);
|
ctx.pushNDArrayToVariableSpace(pair, outArr);
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,15 @@ namespace sd {
|
||||||
_engine = engine;
|
_engine = engine;
|
||||||
}
|
}
|
||||||
|
|
||||||
sd::NDArray *PlatformHelper::getZ(graph::Context &ctx, int inputId) {
|
sd::NDArray* PlatformHelper::getNullifiedZ(graph::Context& block, int inputId) {
|
||||||
|
auto result = getZ(block, inputId);
|
||||||
|
if (result != nullptr && !block.isInplace())
|
||||||
|
result->nullify();
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
sd::NDArray* PlatformHelper::getZ(graph::Context &ctx, int inputId) {
|
||||||
NDArray *z = nullptr;
|
NDArray *z = nullptr;
|
||||||
|
|
||||||
if (ctx.isFastPath()) {
|
if (ctx.isFastPath()) {
|
||||||
|
|
|
@ -540,9 +540,9 @@ PLATFORM_IMPL(conv2d_bp, ENGINE_CPU) {
|
||||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||||
|
|
||||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||||
auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, oC] always
|
auto gradW = OUTPUT_NULLIFIED(1); // [kH, kW, iC, oC] always
|
||||||
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC]
|
||||||
|
|
||||||
int kH = INT_ARG(0); // filter(kernel) height
|
int kH = INT_ARG(0); // filter(kernel) height
|
||||||
int kW = INT_ARG(1); // filter(kernel) width
|
int kW = INT_ARG(1); // filter(kernel) width
|
||||||
|
|
|
@ -542,9 +542,9 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) {
|
||||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||||
|
|
||||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
auto gradI = OUTPUT_NULLIFIED(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
||||||
auto gradW = OUTPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
|
auto gradW = OUTPUT_NULLIFIED(1); // [kD, kH, kW, iC, oC] always
|
||||||
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC]
|
||||||
|
|
||||||
REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM CONV3D_BP MKLDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
|
REQUIRE_TRUE(input->rankOf() == 5, 0, "CUSTOM CONV3D_BP MKLDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
|
||||||
REQUIRE_TRUE(weights->rankOf() == 5, 0, "CUSTOM CONV3D_BP MKLDNN OP: rank of weights array must be equal to 5, but got %i instead !", weights->rankOf());
|
REQUIRE_TRUE(weights->rankOf() == 5, 0, "CUSTOM CONV3D_BP MKLDNN OP: rank of weights array must be equal to 5, but got %i instead !", weights->rankOf());
|
||||||
|
|
|
@ -398,9 +398,9 @@ PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CPU) {
|
||||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC]
|
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC]
|
||||||
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
|
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next
|
||||||
|
|
||||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon
|
auto gradI = OUTPUT_NULLIFIED(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon
|
||||||
auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, mC] always
|
auto gradW = OUTPUT_NULLIFIED(1); // [kH, kW, iC, mC] always
|
||||||
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
auto gradB = block.width() > 3 ? OUTPUT_NULLIFIED(2) : nullptr; // [oC]
|
||||||
|
|
||||||
REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf());
|
REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf());
|
||||||
REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf());
|
REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf());
|
||||||
|
|
|
@ -1513,6 +1513,7 @@
|
||||||
|
|
||||||
#define INPUT_VARIABLE(INDEX) block.array(INDEX)
|
#define INPUT_VARIABLE(INDEX) block.array(INDEX)
|
||||||
#define OUTPUT_VARIABLE(INDEX) reinterpret_cast<sd::NDArray *>(this->getZ(block, INDEX))
|
#define OUTPUT_VARIABLE(INDEX) reinterpret_cast<sd::NDArray *>(this->getZ(block, INDEX))
|
||||||
|
#define OUTPUT_NULLIFIED(INDEX) reinterpret_cast<sd::NDArray *>(this->getNullifiedZ(block, INDEX))
|
||||||
|
|
||||||
#define INPUT_LIST(INDEX) reinterpret_cast<sd::NDArrayList *>(block.getVariable(INDEX)->getNDArrayList())
|
#define INPUT_LIST(INDEX) reinterpret_cast<sd::NDArrayList *>(block.getVariable(INDEX)->getNDArrayList())
|
||||||
|
|
||||||
|
|
|
@ -2128,8 +2128,10 @@ TEST_F(ConvolutionTests1, col2im_test1) {
|
||||||
|
|
||||||
auto imageExpected = NDArrayFactory::create<float>('c', {bS, iC, iH, iW}, {1.f, 7.f, 12.f, 34.f, 17.f, 39.f, 44.f, 98.f, 33.f, 71.f, 76.f, 162.f, 49.f, 103.f, 108.f, 226.f});
|
auto imageExpected = NDArrayFactory::create<float>('c', {bS, iC, iH, iW}, {1.f, 7.f, 12.f, 34.f, 17.f, 39.f, 44.f, 98.f, 33.f, 71.f, 76.f, 162.f, 49.f, 103.f, 108.f, 226.f});
|
||||||
|
|
||||||
LaunchContext ctx;
|
|
||||||
sd::ops::helpers::col2im(ctx, columns, image, sH, sW, pH, pW, iH, iW, dH, dW);
|
sd::ops::col2im op;
|
||||||
|
auto status = op.execute({&columns}, {&image}, {sH, sW, pH, pW, iH, iW, dH, dW, 0});
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
ASSERT_TRUE(image.equalsTo(imageExpected));
|
ASSERT_TRUE(image.equalsTo(imageExpected));
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,7 +18,7 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper {
|
||||||
public IntVectorVector(long n) { allocate(n); }
|
public IntVectorVector(long n) { allocate(n); }
|
||||||
private native void allocate();
|
private native void allocate();
|
||||||
private native void allocate(@Cast("size_t") long n);
|
private native void allocate(@Cast("size_t") long n);
|
||||||
public native @Name("operator=") @ByRef IntVectorVector put(@ByRef IntVectorVector x);
|
public native @Name("operator =") @ByRef IntVectorVector put(@ByRef IntVectorVector x);
|
||||||
|
|
||||||
public boolean empty() { return size() == 0; }
|
public boolean empty() { return size() == 0; }
|
||||||
public native long size();
|
public native long size();
|
||||||
|
@ -67,7 +67,7 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper {
|
||||||
public LongVectorVector(long n) { allocate(n); }
|
public LongVectorVector(long n) { allocate(n); }
|
||||||
private native void allocate();
|
private native void allocate();
|
||||||
private native void allocate(@Cast("size_t") long n);
|
private native void allocate(@Cast("size_t") long n);
|
||||||
public native @Name("operator=") @ByRef LongVectorVector put(@ByRef LongVectorVector x);
|
public native @Name("operator =") @ByRef LongVectorVector put(@ByRef LongVectorVector x);
|
||||||
|
|
||||||
public boolean empty() { return size() == 0; }
|
public boolean empty() { return size() == 0; }
|
||||||
public native long size();
|
public native long size();
|
||||||
|
@ -117,7 +117,7 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper {
|
||||||
public NDArrayVector(long n) { allocate(n); }
|
public NDArrayVector(long n) { allocate(n); }
|
||||||
private native void allocate();
|
private native void allocate();
|
||||||
private native void allocate(@Cast("size_t") long n);
|
private native void allocate(@Cast("size_t") long n);
|
||||||
public native @Name("operator=") @ByRef NDArrayVector put(@ByRef NDArrayVector x);
|
public native @Name("operator =") @ByRef NDArrayVector put(@ByRef NDArrayVector x);
|
||||||
|
|
||||||
public boolean empty() { return size() == 0; }
|
public boolean empty() { return size() == 0; }
|
||||||
public native long size();
|
public native long size();
|
||||||
|
@ -135,9 +135,9 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper {
|
||||||
public Iterator(Pointer p) { super(p); }
|
public Iterator(Pointer p) { super(p); }
|
||||||
public Iterator() { }
|
public Iterator() { }
|
||||||
|
|
||||||
public native @Name("operator++") @ByRef Iterator increment();
|
public native @Name("operator ++") @ByRef Iterator increment();
|
||||||
public native @Name("operator==") boolean equals(@ByRef Iterator it);
|
public native @Name("operator ==") boolean equals(@ByRef Iterator it);
|
||||||
public native @Name("operator*") @Const NDArray get();
|
public native @Name("operator *") @Const NDArray get();
|
||||||
}
|
}
|
||||||
|
|
||||||
public NDArray[] get() {
|
public NDArray[] get() {
|
||||||
|
@ -185,7 +185,7 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper {
|
||||||
public ConstNDArrayVector(long n) { allocate(n); }
|
public ConstNDArrayVector(long n) { allocate(n); }
|
||||||
private native void allocate();
|
private native void allocate();
|
||||||
private native void allocate(@Cast("size_t") long n);
|
private native void allocate(@Cast("size_t") long n);
|
||||||
public native @Name("operator=") @ByRef ConstNDArrayVector put(@ByRef ConstNDArrayVector x);
|
public native @Name("operator =") @ByRef ConstNDArrayVector put(@ByRef ConstNDArrayVector x);
|
||||||
|
|
||||||
public boolean empty() { return size() == 0; }
|
public boolean empty() { return size() == 0; }
|
||||||
public native long size();
|
public native long size();
|
||||||
|
@ -203,9 +203,9 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper {
|
||||||
public Iterator(Pointer p) { super(p); }
|
public Iterator(Pointer p) { super(p); }
|
||||||
public Iterator() { }
|
public Iterator() { }
|
||||||
|
|
||||||
public native @Name("operator++") @ByRef Iterator increment();
|
public native @Name("operator ++") @ByRef Iterator increment();
|
||||||
public native @Name("operator==") boolean equals(@ByRef Iterator it);
|
public native @Name("operator ==") boolean equals(@ByRef Iterator it);
|
||||||
public native @Name("operator*") @Const NDArray get();
|
public native @Name("operator *") @Const NDArray get();
|
||||||
}
|
}
|
||||||
|
|
||||||
public NDArray[] get() {
|
public NDArray[] get() {
|
||||||
|
@ -250,7 +250,7 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper {
|
||||||
public IntIntPair(int firstValue, int secondValue) { this(); put(firstValue, secondValue); }
|
public IntIntPair(int firstValue, int secondValue) { this(); put(firstValue, secondValue); }
|
||||||
public IntIntPair() { allocate(); }
|
public IntIntPair() { allocate(); }
|
||||||
private native void allocate();
|
private native void allocate();
|
||||||
public native @Name("operator=") @ByRef IntIntPair put(@ByRef IntIntPair x);
|
public native @Name("operator =") @ByRef IntIntPair put(@ByRef IntIntPair x);
|
||||||
|
|
||||||
|
|
||||||
@MemberGetter public native int first(); public native IntIntPair first(int first);
|
@MemberGetter public native int first(); public native IntIntPair first(int first);
|
||||||
|
@ -3733,16 +3733,16 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
|
||||||
/**
|
/**
|
||||||
* constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently
|
* constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently
|
||||||
*/
|
*/
|
||||||
public NDArray(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context); }
|
public NDArray(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context, nullify); }
|
||||||
private native void allocate(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/);
|
private native void allocate(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/);
|
||||||
public NDArray(@Cast("Nd4jLong*") LongPointer shapeInfo) { super((Pointer)null); allocate(shapeInfo); }
|
public NDArray(@Cast("Nd4jLong*") LongPointer shapeInfo) { super((Pointer)null); allocate(shapeInfo); }
|
||||||
private native void allocate(@Cast("Nd4jLong*") LongPointer shapeInfo);
|
private native void allocate(@Cast("Nd4jLong*") LongPointer shapeInfo);
|
||||||
public NDArray(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context); }
|
public NDArray(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context, nullify); }
|
||||||
private native void allocate(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/);
|
private native void allocate(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/);
|
||||||
public NDArray(@Cast("Nd4jLong*") LongBuffer shapeInfo) { super((Pointer)null); allocate(shapeInfo); }
|
public NDArray(@Cast("Nd4jLong*") LongBuffer shapeInfo) { super((Pointer)null); allocate(shapeInfo); }
|
||||||
private native void allocate(@Cast("Nd4jLong*") LongBuffer shapeInfo);
|
private native void allocate(@Cast("Nd4jLong*") LongBuffer shapeInfo);
|
||||||
public NDArray(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context); }
|
public NDArray(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context, nullify); }
|
||||||
private native void allocate(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/);
|
private native void allocate(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/);
|
||||||
public NDArray(@Cast("Nd4jLong*") long[] shapeInfo) { super((Pointer)null); allocate(shapeInfo); }
|
public NDArray(@Cast("Nd4jLong*") long[] shapeInfo) { super((Pointer)null); allocate(shapeInfo); }
|
||||||
private native void allocate(@Cast("Nd4jLong*") long[] shapeInfo);
|
private native void allocate(@Cast("Nd4jLong*") long[] shapeInfo);
|
||||||
|
|
||||||
|
@ -3750,16 +3750,16 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
|
||||||
* constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to be zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently
|
* constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to be zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently
|
||||||
* set dtype as array type
|
* set dtype as array type
|
||||||
*/
|
*/
|
||||||
public NDArray(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context); }
|
public NDArray(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context, nullify); }
|
||||||
private native void allocate(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/);
|
private native void allocate(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/);
|
||||||
public NDArray(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype) { super((Pointer)null); allocate(shapeInfo, dtype); }
|
public NDArray(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype) { super((Pointer)null); allocate(shapeInfo, dtype); }
|
||||||
private native void allocate(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype);
|
private native void allocate(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype);
|
||||||
public NDArray(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context); }
|
public NDArray(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context, nullify); }
|
||||||
private native void allocate(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/);
|
private native void allocate(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/);
|
||||||
public NDArray(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype) { super((Pointer)null); allocate(shapeInfo, dtype); }
|
public NDArray(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype) { super((Pointer)null); allocate(shapeInfo, dtype); }
|
||||||
private native void allocate(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype);
|
private native void allocate(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype);
|
||||||
public NDArray(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context); }
|
public NDArray(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context, nullify); }
|
||||||
private native void allocate(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/);
|
private native void allocate(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/);
|
||||||
public NDArray(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype) { super((Pointer)null); allocate(shapeInfo, dtype); }
|
public NDArray(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype) { super((Pointer)null); allocate(shapeInfo, dtype); }
|
||||||
private native void allocate(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype);
|
private native void allocate(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype);
|
||||||
|
|
||||||
|
@ -9492,6 +9492,14 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public native NDArray getZ(@ByRef Context ctx, int inputId);
|
public native NDArray getZ(@ByRef Context ctx, int inputId);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Helper method, needed for compatibility with DeclarableOp macros
|
||||||
|
* @param ctx
|
||||||
|
* @param inputId
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public native NDArray getNullifiedZ(@ByRef Context ctx, int inputId);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -10289,7 +10297,6 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
|
|
||||||
// #ifndef __JAVACPP_HACK__
|
// #ifndef __JAVACPP_HACK__
|
||||||
|
|
||||||
|
|
||||||
// #endif // JCPP
|
// #endif // JCPP
|
||||||
|
|
||||||
// #endif // CUDA
|
// #endif // CUDA
|
||||||
|
@ -10308,6 +10315,10 @@ public static final int PREALLOC_SIZE = 33554432;
|
||||||
public native void setDeviceID(int deviceID);
|
public native void setDeviceID(int deviceID);
|
||||||
public native ErrorReference errorReference();
|
public native ErrorReference errorReference();
|
||||||
|
|
||||||
|
// #ifndef __JAVACPP_HACK__
|
||||||
|
|
||||||
|
// #endif
|
||||||
|
|
||||||
public static native @Cast("bool") boolean isInitialized();
|
public static native @Cast("bool") boolean isInitialized();
|
||||||
public static native void releaseBuffers();
|
public static native void releaseBuffers();
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,7 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper {
|
||||||
public IntVectorVector(long n) { allocate(n); }
|
public IntVectorVector(long n) { allocate(n); }
|
||||||
private native void allocate();
|
private native void allocate();
|
||||||
private native void allocate(@Cast("size_t") long n);
|
private native void allocate(@Cast("size_t") long n);
|
||||||
public native @Name("operator=") @ByRef IntVectorVector put(@ByRef IntVectorVector x);
|
public native @Name("operator =") @ByRef IntVectorVector put(@ByRef IntVectorVector x);
|
||||||
|
|
||||||
public boolean empty() { return size() == 0; }
|
public boolean empty() { return size() == 0; }
|
||||||
public native long size();
|
public native long size();
|
||||||
|
@ -70,7 +70,7 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper {
|
||||||
public LongVectorVector(long n) { allocate(n); }
|
public LongVectorVector(long n) { allocate(n); }
|
||||||
private native void allocate();
|
private native void allocate();
|
||||||
private native void allocate(@Cast("size_t") long n);
|
private native void allocate(@Cast("size_t") long n);
|
||||||
public native @Name("operator=") @ByRef LongVectorVector put(@ByRef LongVectorVector x);
|
public native @Name("operator =") @ByRef LongVectorVector put(@ByRef LongVectorVector x);
|
||||||
|
|
||||||
public boolean empty() { return size() == 0; }
|
public boolean empty() { return size() == 0; }
|
||||||
public native long size();
|
public native long size();
|
||||||
|
@ -120,7 +120,7 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper {
|
||||||
public ConstNDArrayVector(long n) { allocate(n); }
|
public ConstNDArrayVector(long n) { allocate(n); }
|
||||||
private native void allocate();
|
private native void allocate();
|
||||||
private native void allocate(@Cast("size_t") long n);
|
private native void allocate(@Cast("size_t") long n);
|
||||||
public native @Name("operator=") @ByRef ConstNDArrayVector put(@ByRef ConstNDArrayVector x);
|
public native @Name("operator =") @ByRef ConstNDArrayVector put(@ByRef ConstNDArrayVector x);
|
||||||
|
|
||||||
public boolean empty() { return size() == 0; }
|
public boolean empty() { return size() == 0; }
|
||||||
public native long size();
|
public native long size();
|
||||||
|
@ -138,9 +138,9 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper {
|
||||||
public Iterator(Pointer p) { super(p); }
|
public Iterator(Pointer p) { super(p); }
|
||||||
public Iterator() { }
|
public Iterator() { }
|
||||||
|
|
||||||
public native @Name("operator++") @ByRef Iterator increment();
|
public native @Name("operator ++") @ByRef Iterator increment();
|
||||||
public native @Name("operator==") boolean equals(@ByRef Iterator it);
|
public native @Name("operator ==") boolean equals(@ByRef Iterator it);
|
||||||
public native @Name("operator*") @Const NDArray get();
|
public native @Name("operator *") @Const NDArray get();
|
||||||
}
|
}
|
||||||
|
|
||||||
public NDArray[] get() {
|
public NDArray[] get() {
|
||||||
|
@ -188,7 +188,7 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper {
|
||||||
public NDArrayVector(long n) { allocate(n); }
|
public NDArrayVector(long n) { allocate(n); }
|
||||||
private native void allocate();
|
private native void allocate();
|
||||||
private native void allocate(@Cast("size_t") long n);
|
private native void allocate(@Cast("size_t") long n);
|
||||||
public native @Name("operator=") @ByRef NDArrayVector put(@ByRef NDArrayVector x);
|
public native @Name("operator =") @ByRef NDArrayVector put(@ByRef NDArrayVector x);
|
||||||
|
|
||||||
public boolean empty() { return size() == 0; }
|
public boolean empty() { return size() == 0; }
|
||||||
public native long size();
|
public native long size();
|
||||||
|
@ -206,9 +206,9 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper {
|
||||||
public Iterator(Pointer p) { super(p); }
|
public Iterator(Pointer p) { super(p); }
|
||||||
public Iterator() { }
|
public Iterator() { }
|
||||||
|
|
||||||
public native @Name("operator++") @ByRef Iterator increment();
|
public native @Name("operator ++") @ByRef Iterator increment();
|
||||||
public native @Name("operator==") boolean equals(@ByRef Iterator it);
|
public native @Name("operator ==") boolean equals(@ByRef Iterator it);
|
||||||
public native @Name("operator*") @Const NDArray get();
|
public native @Name("operator *") @Const NDArray get();
|
||||||
}
|
}
|
||||||
|
|
||||||
public NDArray[] get() {
|
public NDArray[] get() {
|
||||||
|
@ -253,7 +253,7 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper {
|
||||||
public IntIntPair(int firstValue, int secondValue) { this(); put(firstValue, secondValue); }
|
public IntIntPair(int firstValue, int secondValue) { this(); put(firstValue, secondValue); }
|
||||||
public IntIntPair() { allocate(); }
|
public IntIntPair() { allocate(); }
|
||||||
private native void allocate();
|
private native void allocate();
|
||||||
public native @Name("operator=") @ByRef IntIntPair put(@ByRef IntIntPair x);
|
public native @Name("operator =") @ByRef IntIntPair put(@ByRef IntIntPair x);
|
||||||
|
|
||||||
|
|
||||||
@MemberGetter public native int first(); public native IntIntPair first(int first);
|
@MemberGetter public native int first(); public native IntIntPair first(int first);
|
||||||
|
@ -3736,16 +3736,16 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
|
||||||
/**
|
/**
|
||||||
* constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently
|
* constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently
|
||||||
*/
|
*/
|
||||||
public NDArray(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context); }
|
public NDArray(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context, nullify); }
|
||||||
private native void allocate(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/);
|
private native void allocate(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/);
|
||||||
public NDArray(@Cast("Nd4jLong*") LongPointer shapeInfo) { super((Pointer)null); allocate(shapeInfo); }
|
public NDArray(@Cast("Nd4jLong*") LongPointer shapeInfo) { super((Pointer)null); allocate(shapeInfo); }
|
||||||
private native void allocate(@Cast("Nd4jLong*") LongPointer shapeInfo);
|
private native void allocate(@Cast("Nd4jLong*") LongPointer shapeInfo);
|
||||||
public NDArray(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context); }
|
public NDArray(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context, nullify); }
|
||||||
private native void allocate(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/);
|
private native void allocate(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/);
|
||||||
public NDArray(@Cast("Nd4jLong*") LongBuffer shapeInfo) { super((Pointer)null); allocate(shapeInfo); }
|
public NDArray(@Cast("Nd4jLong*") LongBuffer shapeInfo) { super((Pointer)null); allocate(shapeInfo); }
|
||||||
private native void allocate(@Cast("Nd4jLong*") LongBuffer shapeInfo);
|
private native void allocate(@Cast("Nd4jLong*") LongBuffer shapeInfo);
|
||||||
public NDArray(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context); }
|
public NDArray(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, copyStrides, context, nullify); }
|
||||||
private native void allocate(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/);
|
private native void allocate(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/);
|
||||||
public NDArray(@Cast("Nd4jLong*") long[] shapeInfo) { super((Pointer)null); allocate(shapeInfo); }
|
public NDArray(@Cast("Nd4jLong*") long[] shapeInfo) { super((Pointer)null); allocate(shapeInfo); }
|
||||||
private native void allocate(@Cast("Nd4jLong*") long[] shapeInfo);
|
private native void allocate(@Cast("Nd4jLong*") long[] shapeInfo);
|
||||||
|
|
||||||
|
@ -3753,16 +3753,16 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
|
||||||
* constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to be zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently
|
* constructor creates new NDArray using shape information from "shapeInfo", set all elements in new array to be zeros, if copyStrides is true then use stride values from "shapeInfo", else calculate strides independently
|
||||||
* set dtype as array type
|
* set dtype as array type
|
||||||
*/
|
*/
|
||||||
public NDArray(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context); }
|
public NDArray(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context, nullify); }
|
||||||
private native void allocate(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/);
|
private native void allocate(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/);
|
||||||
public NDArray(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype) { super((Pointer)null); allocate(shapeInfo, dtype); }
|
public NDArray(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype) { super((Pointer)null); allocate(shapeInfo, dtype); }
|
||||||
private native void allocate(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype);
|
private native void allocate(@Cast("Nd4jLong*") LongPointer shapeInfo, @Cast("const sd::DataType") int dtype);
|
||||||
public NDArray(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context); }
|
public NDArray(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context, nullify); }
|
||||||
private native void allocate(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/);
|
private native void allocate(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/);
|
||||||
public NDArray(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype) { super((Pointer)null); allocate(shapeInfo, dtype); }
|
public NDArray(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype) { super((Pointer)null); allocate(shapeInfo, dtype); }
|
||||||
private native void allocate(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype);
|
private native void allocate(@Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("const sd::DataType") int dtype);
|
||||||
public NDArray(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context); }
|
public NDArray(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/) { super((Pointer)null); allocate(shapeInfo, dtype, copyStrides, context, nullify); }
|
||||||
private native void allocate(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/);
|
private native void allocate(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype, @Cast("const bool") boolean copyStrides/*=false*/, LaunchContext context/*=sd::LaunchContext::defaultContext()*/, @Cast("const bool") boolean nullify/*=true*/);
|
||||||
public NDArray(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype) { super((Pointer)null); allocate(shapeInfo, dtype); }
|
public NDArray(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype) { super((Pointer)null); allocate(shapeInfo, dtype); }
|
||||||
private native void allocate(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype);
|
private native void allocate(@Cast("Nd4jLong*") long[] shapeInfo, @Cast("const sd::DataType") int dtype);
|
||||||
|
|
||||||
|
@ -11459,6 +11459,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
|
|
||||||
// #define INPUT_VARIABLE(INDEX) block.array(INDEX)
|
// #define INPUT_VARIABLE(INDEX) block.array(INDEX)
|
||||||
// #define OUTPUT_VARIABLE(INDEX) reinterpret_cast<sd::NDArray *>(this->getZ(block, INDEX))
|
// #define OUTPUT_VARIABLE(INDEX) reinterpret_cast<sd::NDArray *>(this->getZ(block, INDEX))
|
||||||
|
// #define OUTPUT_NULLIFIED(INDEX) reinterpret_cast<sd::NDArray *>(this->getNullifiedZ(block, INDEX))
|
||||||
|
|
||||||
// #define INPUT_LIST(INDEX) reinterpret_cast<sd::NDArrayList *>(block.getVariable(INDEX)->getNDArrayList())
|
// #define INPUT_LIST(INDEX) reinterpret_cast<sd::NDArrayList *>(block.getVariable(INDEX)->getNDArrayList())
|
||||||
|
|
||||||
|
@ -11809,6 +11810,14 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public native NDArray getZ(@ByRef Context ctx, int inputId);
|
public native NDArray getZ(@ByRef Context ctx, int inputId);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Helper method, needed for compatibility with DeclarableOp macros
|
||||||
|
* @param ctx
|
||||||
|
* @param inputId
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public native NDArray getNullifiedZ(@ByRef Context ctx, int inputId);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -24205,6 +24214,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
public native void setDeviceID(int deviceID);
|
public native void setDeviceID(int deviceID);
|
||||||
public native ErrorReference errorReference();
|
public native ErrorReference errorReference();
|
||||||
|
|
||||||
|
// #ifndef __JAVACPP_HACK__
|
||||||
|
|
||||||
|
// #endif
|
||||||
|
|
||||||
public static native @Cast("bool") boolean isInitialized();
|
public static native @Cast("bool") boolean isInitialized();
|
||||||
public static native void releaseBuffers();
|
public static native void releaseBuffers();
|
||||||
|
|
||||||
|
|
|
@ -28,6 +28,7 @@ import org.junit.rules.TemporaryFolder;
|
||||||
import org.junit.runner.RunWith;
|
import org.junit.runner.RunWith;
|
||||||
import org.junit.runners.Parameterized;
|
import org.junit.runners.Parameterized;
|
||||||
import org.nd4j.imports.TFGraphs.NodeReader;
|
import org.nd4j.imports.TFGraphs.NodeReader;
|
||||||
|
import org.nd4j.linalg.api.blas.BlasBufferUtil;
|
||||||
import org.nd4j.linalg.api.blas.Level1;
|
import org.nd4j.linalg.api.blas.Level1;
|
||||||
import org.nd4j.linalg.api.blas.params.GemmParams;
|
import org.nd4j.linalg.api.blas.params.GemmParams;
|
||||||
import org.nd4j.linalg.api.blas.params.MMulTranspose;
|
import org.nd4j.linalg.api.blas.params.MMulTranspose;
|
||||||
|
@ -106,6 +107,7 @@ import java.nio.ByteOrder;
|
||||||
import java.nio.file.Files;
|
import java.nio.file.Files;
|
||||||
import java.nio.file.Paths;
|
import java.nio.file.Paths;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
import java.util.concurrent.CountDownLatch;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.Assert.*;
|
||||||
import static org.junit.Assert.assertArrayEquals;
|
import static org.junit.Assert.assertArrayEquals;
|
||||||
|
|
Loading…
Reference in New Issue