[WIP] Error handling (#169)

* CUDA reverse rewrite + couple of tests

Signed-off-by: raver119 <raver119@gmail.com>

* don't throw exception on invalid pointer

Signed-off-by: raver119 <raver119@gmail.com>

* data types validation for fastpath exec mode + 2 tests

Signed-off-by: raver119 <raver119@gmail.com>

* data types validation for fastpath exec mode + 2 tests

Signed-off-by: raver119 <raver119@gmail.com>

* ismax allowed dtypes tweak

Signed-off-by: raver119 <raver119@gmail.com>

* lastErrorCode + lastErrorMessage for native exceptions handling

Signed-off-by: raver119 <raver119@gmail.com>

* exportable ErrorReference

Signed-off-by: raver119 <raver119@gmail.com>

* check error codes in java

Signed-off-by: raver119 <raver119@gmail.com>

* - consume lastErrorCode
- fast_in dtype validation fix

Signed-off-by: raver119 <raver119@gmail.com>

* - sg/cb allowed output type change
- minor logging fix for data type validation

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-26 19:57:51 +03:00 committed by GitHub
parent bb5fc36e5e
commit 25e5c23eae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 3002 additions and 3072 deletions

View File

@ -79,6 +79,18 @@ bool verbose = false;
extern "C" { extern "C" {
/**
* This function returns last error code stored,
* @return non-zero if something bad happened
*/
ND4J_EXPORT int lastErrorCode();
/**
* This function returns last error message, if last error code > 0
* @return
*/
ND4J_EXPORT const char* lastErrorMessage();
/** /**
* *
* @param p * @param p
@ -557,38 +569,6 @@ ND4J_EXPORT void execScalarBoolTad(Nd4jPointer *extraPointers,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ); Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ);
/**
* Append an input array
* to the end of a flat array
* in a particular order
* @param offset the offset of the array to start at
* @param order the order
* @param result the result array
* @param resultShapeInfo the shape info for te array
* @param input the input for the array
* @param inputShapeInfo the shape information for that array
*/
ND4J_EXPORT void flatten(
Nd4jPointer *extraPointers,
int offset,
char order,
void *result, Nd4jLong *resultShapeInfo,
void *dresult, Nd4jLong *dresultShapeInfo,
void *input, Nd4jLong *inputShapeInfo,
void *dinput, Nd4jLong *dinputShapeInfo);
ND4J_EXPORT void concat(
Nd4jPointer *extraPointers,
int dimension,
int numArrays,
Nd4jPointer *data, Nd4jPointer *inputShapeInfo,
Nd4jPointer *ddata, Nd4jPointer *dinputShapeInfo,
void *result, Nd4jLong *resultShapeInfo,
void *dresult, Nd4jLong *dresultShapeInfo,
Nd4jPointer *tadPointers, Nd4jPointer *offsetPointers);
ND4J_EXPORT void specialConcat ( ND4J_EXPORT void specialConcat (
Nd4jPointer *extraPointers, Nd4jPointer *extraPointers,
int dimension, int dimension,

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -23,6 +23,7 @@
#include <dll.h> #include <dll.h>
#include <pointercast.h> #include <pointercast.h>
#include <execution/ErrorReference.h>
namespace nd4j { namespace nd4j {
class ND4J_EXPORT ContextBuffers { class ND4J_EXPORT ContextBuffers {
@ -32,6 +33,7 @@ namespace nd4j {
void* _allocationPointer = nullptr; void* _allocationPointer = nullptr;
void* _execStream = nullptr; void* _execStream = nullptr;
void* _specialStream = nullptr; void* _specialStream = nullptr;
sd::ErrorReference _errorReference;
bool _allocated = false; bool _allocated = false;
bool _initialized = false; bool _initialized = false;
@ -60,6 +62,8 @@ namespace nd4j {
void setScalarBuffer(void* pointer); void setScalarBuffer(void* pointer);
void setAllocationBuffer(void* pointer); void setAllocationBuffer(void* pointer);
sd::ErrorReference* errorReference();
void triggerOwnership(bool isOwner); void triggerOwnership(bool isOwner);
int deviceId(); int deviceId();

View File

@ -15,32 +15,32 @@
******************************************************************************/ ******************************************************************************/
// //
// Created by Yurii Shyrma on 27.01.2018 // @author raver119@gmail.com
// //
#ifndef LIBND4J_PROVIDERRNG_H #ifndef DEV_TESTS_ERRORREFERENCE_H
#define LIBND4J_PROVIDERRNG_H #define DEV_TESTS_ERRORREFERENCE_H
#include <helpers/helper_random.h> #include <string>
#include <mutex> #include <dll.h>
namespace nd4j {
class ProviderRNG {
protected:
random::RandomBuffer* _rng;
static std::mutex _mutex;
ProviderRNG();
namespace sd {
class ND4J_EXPORT ErrorReference {
private:
int _errorCode = 0;
std::string _errorMessage;
public: public:
ProviderRNG(const ProviderRNG&) = delete; ErrorReference() = default;
void operator=(const ProviderRNG&) = delete; ~ErrorReference() = default;
random::RandomBuffer* getRNG() const;
static ProviderRNG& getInstance();
};
int errorCode();
const char* errorMessage();
void setErrorCode(int errorCode);
void setErrorMessage(std::string message);
void setErrorMessage(const char* message);
};
} }
#endif //LIBND4J_PROVIDERRNG_H
#endif //DEV_TESTS_ERRORREFERENCE_H

View File

@ -37,6 +37,7 @@
#include <vector> #include <vector>
#include <mutex> #include <mutex>
#include <execution/ContextBuffers.h> #include <execution/ContextBuffers.h>
#include <execution/ErrorReference.h>
@ -97,9 +98,12 @@ class ND4J_EXPORT LaunchContext {
int getDeviceID() const {return _deviceID;} int getDeviceID() const {return _deviceID;}
void setDeviceID(int deviceID) { _deviceID = deviceID; } void setDeviceID(int deviceID) { _deviceID = deviceID; }
sd::ErrorReference* errorReference();
static bool isInitialized(); static bool isInitialized();
static void releaseBuffers(); static void releaseBuffers();
static LaunchContext* defaultContext(); static LaunchContext* defaultContext();

View File

@ -99,4 +99,8 @@ namespace nd4j {
ContextBuffers& ContextBuffers::operator=(ContextBuffers&& other) { ContextBuffers& ContextBuffers::operator=(ContextBuffers&& other) {
return *this; return *this;
} }
sd::ErrorReference* ContextBuffers::errorReference() {
return &_errorReference;
}
} }

View File

@ -23,7 +23,11 @@
#include <exceptions/cuda_exception.h> #include <exceptions/cuda_exception.h>
#include <thread> #include <thread>
#ifdef IOS_BUILD
nd4j::ContextBuffers contextBuffers = nd4j::ContextBuffers(); nd4j::ContextBuffers contextBuffers = nd4j::ContextBuffers();
#else
thread_local nd4j::ContextBuffers contextBuffers = nd4j::ContextBuffers();
#endif
namespace nd4j { namespace nd4j {
@ -65,4 +69,8 @@ namespace nd4j {
void LaunchContext::releaseBuffers() { void LaunchContext::releaseBuffers() {
// //
} }
sd::ErrorReference* LaunchContext::errorReference() {
return contextBuffers.errorReference();
}
} }

View File

@ -220,5 +220,9 @@ namespace nd4j {
bool ContextBuffers::isInitialized() { bool ContextBuffers::isInitialized() {
return _initialized; return _initialized;
} }
sd::ErrorReference* ContextBuffers::errorReference() {
return &_errorReference;
}
} }

View File

@ -168,4 +168,8 @@ LaunchContext::LaunchContext() {
bool LaunchContext::isInitialized() { bool LaunchContext::isInitialized() {
return contextBuffers.isInitialized(); return contextBuffers.isInitialized();
} }
sd::ErrorReference* LaunchContext::errorReference() {
return contextBuffers.errorReference();
}
} }

View File

@ -15,37 +15,32 @@
******************************************************************************/ ******************************************************************************/
// //
// Created by Yurii Shyrma on 27.01.2018 // @author raver119@gmail.com
// //
#include <helpers/ProviderRNG.h>
#include <NativeOps.h>
#include <execution/ErrorReference.h>
namespace nd4j { namespace sd {
int ErrorReference::errorCode() {
ProviderRNG::ProviderRNG() { return _errorCode;
}
Nd4jLong *buffer = new Nd4jLong[100000]; const char* ErrorReference::errorMessage() {
std::lock_guard<std::mutex> lock(_mutex); // since we're fetching error message - error code will be assumed consumed & nullified
#ifndef __CUDABLAS__ _errorCode = 0;
// at this moment we don't have streams etc, so let's just skip this for now return _errorMessage.c_str();
_rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, 123, 100000, (Nd4jPointer) buffer); }
#endif
// if(_rng != nullptr) void ErrorReference::setErrorCode(int errorCode) {
} _errorCode = errorCode;
}
ProviderRNG& ProviderRNG::getInstance() {
void ErrorReference::setErrorMessage(std::string message) {
static ProviderRNG instance; _errorMessage = message;
return instance; }
}
void ErrorReference::setErrorMessage(const char* message) {
random::RandomBuffer* ProviderRNG::getRNG() const { _errorMessage = std::string(message);
}
return _rng;
}
std::mutex ProviderRNG::_mutex;
} }

View File

@ -45,7 +45,7 @@ DECLARE_SYN(IsMax, ismax);
DECLARE_TYPES(ismax) { DECLARE_TYPES(ismax) {
getOpDescriptor() getOpDescriptor()
->setAllowedInputTypes(0, DataType::ANY) ->setAllowedInputTypes(0, DataType::ANY)
->setAllowedOutputTypes(0, DataType::BOOL); ->setAllowedOutputTypes(0, DataType::ANY);
} }

View File

@ -84,7 +84,8 @@ namespace nd4j {
->setAllowedInputTypes(11, nd4j::DataType::INT64) ->setAllowedInputTypes(11, nd4j::DataType::INT64)
->setAllowedInputTypes(12, nd4j::DataType::INT32) ->setAllowedInputTypes(12, nd4j::DataType::INT32)
->setAllowedInputTypes(13, nd4j::DataType::INT32) ->setAllowedInputTypes(13, nd4j::DataType::INT32)
->setAllowedInputTypes(14, {ALL_FLOATS}); ->setAllowedInputTypes(14, {ALL_FLOATS})
->setAllowedOutputTypes(nd4j::DataType::ANY);
} }
} }
} }

View File

@ -79,7 +79,7 @@ namespace nd4j {
->setAllowedInputTypes(9, {ALL_FLOATS}) ->setAllowedInputTypes(9, {ALL_FLOATS})
->setAllowedInputTypes(10, nd4j::DataType::INT64) ->setAllowedInputTypes(10, nd4j::DataType::INT64)
->setAllowedInputTypes(11, {ALL_FLOATS}) ->setAllowedInputTypes(11, {ALL_FLOATS})
->setAllowedOutputTypes(nd4j::DataType::INT8); ->setAllowedOutputTypes(nd4j::DataType::ANY);
} }
/* /*

View File

@ -70,7 +70,7 @@ CONFIGURABLE_OP_IMPL(softmax_bp, 2, 1, true, 0, 0) {
DECLARE_TYPES(softmax_bp) { DECLARE_TYPES(softmax_bp) {
getOpDescriptor() getOpDescriptor()
->setAllowedInputTypes(DataType::ANY) ->setAllowedInputTypes({ALL_FLOATS})
->setAllowedOutputTypes({ALL_FLOATS}); ->setAllowedOutputTypes({ALL_FLOATS});
} }

View File

@ -30,51 +30,9 @@ namespace nd4j {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
template <typename T>
inline void __device__ indexSwap(T* arr, Nd4jLong idx1, Nd4jLong idx2) {
T tmp = arr[idx1];
arr[idx1] = arr[idx2];
arr[idx2] = tmp;
}
// template <typename T>
// void reverseArray(nd4j::LaunchContext * context, void* inArr, Nd4jLong *inShapeBuffer, void *result, Nd4jLong *zShapeBuffer, int numOfElemsToReverse = 0);
/////////////////////////////////////////////////////////////////////////////////////
template <typename T>
static __global__ void reverseArrayInplaceKernel(void *input, Nd4jLong *inputShape, Nd4jLong numOfElemsToReverse) {
const auto tid = blockIdx.x * gridDim.x + threadIdx.x;
const auto step = gridDim.x * blockDim.x;
__shared__ Nd4jLong length;
__shared__ int linearStatus;
__shared__ T* inputArr;
if (threadIdx.x == 0) {
length = shape::length(inputShape);
linearStatus = shape::elementWiseStride(inputShape);
inputArr = reinterpret_cast<T*>(input);
}
__syncthreads();
for (Nd4jLong e = tid; e < numOfElemsToReverse / 2; e += step) {
if (linearStatus == 1) {
auto idx = numOfElemsToReverse - e - 1;
indexSwap(inputArr, e, idx);
}
else if (linearStatus > 1) {
auto idx1 = (numOfElemsToReverse - e - 1) * linearStatus;
Nd4jLong idx2 = e * linearStatus;
indexSwap(inputArr, idx1, idx2);
}
else {
auto inOffset = shape::getIndexOffset(e, inputShape, length);
auto outOffset = shape::getIndexOffset(numOfElemsToReverse - e - 1, inputShape, length);
indexSwap(inputArr, inOffset, outOffset);
}
}
}
template <typename T> template <typename T>
static __global__ void reverseArrayKernel(void* input, Nd4jLong *inputShape, void* output, Nd4jLong *outputShape, Nd4jLong numOfElemsToReverse) { static __global__ void reverseArrayKernel(void* input, Nd4jLong *inputShape, void* output, Nd4jLong *outputShape, Nd4jLong numOfElemsToReverse) {
const auto tid = blockIdx.x * gridDim.x + threadIdx.x; const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
const auto step = gridDim.x * blockDim.x; const auto step = gridDim.x * blockDim.x;
__shared__ Nd4jLong length; __shared__ Nd4jLong length;
__shared__ int linearStatus; __shared__ int linearStatus;
@ -93,51 +51,47 @@ namespace helpers {
} }
__syncthreads(); __syncthreads();
for (Nd4jLong e = tid; e < length; e += step) { auto odd = length % 2 != 0;
if (e < numOfElemsToReverse ) { auto limit = length / 2;
if (linearStatus == 1) {
auto idx = numOfElemsToReverse - e - 1; for (Nd4jLong e = tid; e < limit; e += step) {
outputArr[idx] = inputArr[e]; // we're calculating offsets within input array
} else if (linearStatus > 1) { auto fOffset = shape::getIndexOffset(e, inputShape, length);
auto idx1 = (numOfElemsToReverse - e - 1) * linearStatus; auto lOffset = shape::getIndexOffset(numOfElemsToReverse - e - 1, inputShape, length);
Nd4jLong idx2 = e * linearStatus;
outputArr[idx1] = inputArr[idx2]; // now we're storing input values
} else { auto v1 = inputArr[fOffset];
auto inOffset = shape::getIndexOffset(e, inputShape, length); auto v2 = inputArr[lOffset];
auto outOffset = shape::getIndexOffset(numOfElemsToReverse - e - 1, outputShape, length);
outputArr[outOffset] = inputArr[inOffset]; // now we're calculating offsets within output array
} auto zfOffset = shape::getIndexOffset(e, outputShape, length);
} auto zlOffset = shape::getIndexOffset(numOfElemsToReverse - e - 1, outputShape, length);
else {
if (linearStatus == 1) { // and saving values to output arrays
outputArr[e] = inputArr[e]; outputArr[zfOffset] = v2;
} else if (linearStatus > 1) { outputArr[zlOffset] = v1;
auto idx1 = e * linearStatus;
Nd4jLong idx2 = e * linearStatus; //printf("TID: %i; E: %lld; z[%lld], z[%lld] = x[%lld], x[%lld];\n", tid, e, zfOffset, zlOffset, lOffset, fOffset);
outputArr[idx1] = inputArr[idx2];
} else {
auto inOffset = shape::getIndexOffset(e, inputShape, length);
auto outOffset = shape::getIndexOffset(e, outputShape, length);
outputArr[outOffset] = inputArr[inOffset];
}
}
} }
//printf("\n"); // in case of odd array we'll have to move middle value
if (odd && tid == 0) {
auto xOffset = shape::getIndexOffset(limit, inputShape, length);
auto zOffset = shape::getIndexOffset(limit, outputShape, length);
outputArr[zOffset] = inputArr[xOffset];
//printf("TID: %i; E: %lld; z[%lld] = x[%lld];\n", tid, limit, zOffset, xOffset);
}
} }
template<typename T> template<typename T>
static void reverseArray(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int numOfElemsToReverse) { static void reverseArray(nd4j::LaunchContext * context, NDArray* input, NDArray* output, Nd4jLong numOfElemsToReverse) {
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
Nd4jLong numOfReverse = numOfElemsToReverse; Nd4jLong numOfReverse = numOfElemsToReverse;
if (numOfElemsToReverse == 0) if (numOfElemsToReverse == 0)
numOfReverse = input->lengthOf(); numOfReverse = input->lengthOf();
if (input == output) {
reverseArrayInplaceKernel<T><<<256, 512, 8192, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), numOfReverse); reverseArrayKernel<T><<<256, 512, 8192, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), numOfReverse);
}
else {
reverseArrayKernel<T><<<256, 512, 8192, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), numOfReverse);
}
} }
@ -221,7 +175,7 @@ namespace helpers {
delete listIn; delete listIn;
} }
BUILD_SINGLE_TEMPLATE(template void reverseArray, (nd4j::LaunchContext * context, NDArray *inArr, NDArray *outArr, int numOfElemsToReverse), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void reverseArray, (nd4j::LaunchContext * context, NDArray *inArr, NDArray *outArr, Nd4jLong numOfElemsToReverse), LIBND4J_TYPES);
} }
} }

View File

@ -19,7 +19,6 @@
// //
#include <ops/declarable/DeclarableOp.h> #include <ops/declarable/DeclarableOp.h>
#include <helpers/ProviderRNG.h>
#include <Status.h> #include <Status.h>
#include <helpers/ShapeUtils.h> #include <helpers/ShapeUtils.h>
#include <NDArrayFactory.h> #include <NDArrayFactory.h>
@ -190,32 +189,6 @@ namespace nd4j {
auto outSha = this->calculateOutputShape(&inSha, ctx); auto outSha = this->calculateOutputShape(&inSha, ctx);
results = outSha->size(); results = outSha->size();
// we must "validate" our output shapes
/*
for (int e = 0; e < results; e++) {
auto ptr = outSha->at(e);
// checking for the same pointer used twice
for (int i = 0; i < results; i++){
if (i == e)
continue;
auto com = outSha->at(i);
if (ptr == com)
throw std::runtime_error("ShapeFunction returned same shape instance twice [" + *_descriptor->getOpName() + "]");
}
// checking for input pointer returned back
for (int i = 0; i < inSha.size(); i++){
auto com = inSha.at(i);
if (ptr == com)
throw std::runtime_error("ShapeFunction returned input shape instance as output [" + *_descriptor->getOpName() + "]");
}
}
*/
// optionally saving shapeTime // optionally saving shapeTime
if (Environment::getInstance()->isProfiling() && node != nullptr) { if (Environment::getInstance()->isProfiling() && node != nullptr) {
shapeEnd = std::chrono::system_clock::now(); shapeEnd = std::chrono::system_clock::now();
@ -355,75 +328,139 @@ namespace nd4j {
// rolling over inputs first // rolling over inputs first
int cnt = 0, inT = 0; int cnt = 0, inT = 0;
std::vector<nd4j::DataType> inputTypes(block.width()); std::vector<nd4j::DataType> inputTypes(block.width());
for (auto &p: *(block.inputs())) { if (block.isFastPath()) {
auto var = block.variable(p); for (auto array: block.fastpath_in()) {
// we're not checking validity, if ANY types were explicitly allowed
//if (block.dataType(cnt) == nd4j::DataType::ANY)
// continue;
// only validating non-null variables
if (var != nullptr && var->hasNDArray()) {
auto array = var->getNDArray();
inputTypes[inT++] = array->dataType(); inputTypes[inT++] = array->dataType();
if (!_descriptor->checkInputMatch(cnt, array->dataType())) { if (!_descriptor->checkInputMatch(cnt, array->dataType())) {
auto ctype = DataTypeUtils::asString(array->dataType()); auto ctype = DataTypeUtils::asString(array->dataType());
nd4j_printf("Op [%s] failed check for input [%i], DataType: [%s]\n", _descriptor->getOpName()->data(), cnt, ctype.c_str()); nd4j_printf("Op [%s] failed check for input [%i], DataType: [%s]\n",
_descriptor->getOpName()->data(), cnt, ctype.c_str());
return ND4J_STATUS_BAD_ARGUMENTS; return ND4J_STATUS_BAD_ARGUMENTS;
} }
cnt++;
} }
} else {
for (auto &p: *(block.inputs())) {
auto var = block.variable(p);
cnt++; // we're not checking validity, if ANY types were explicitly allowed
} //if (block.dataType(cnt) == nd4j::DataType::ANY)
// continue;
// checking optionally available outputs
auto varSpace = block.getVariableSpace();
for (int index = 0; index < DataTypeUtils::max<int>(); index++) {
if (varSpace != nullptr && varSpace->hasVariable(block.nodeId(), index)) {
auto var = block.variable(block.nodeId(), index);
// only validating non-null variables // only validating non-null variables
if (var != nullptr && var->hasNDArray()) { if (var != nullptr && var->hasNDArray()) {
auto array = var->getNDArray(); auto array = var->getNDArray();
auto cType = array->dataType();
if (_descriptor->isSameMode()) { inputTypes[inT++] = array->dataType();
if (!_descriptor->checkInputMatch(cnt, array->dataType())) {
if (index >= block.width()) { auto ctype = DataTypeUtils::asString(array->dataType());
auto iv = block.variable(0); nd4j_printf("Op [%s] failed check for input [%i], DataType: [%s]\n",
_descriptor->getOpName()->data(), cnt, ctype.c_str());
if (iv->getNDArray()->dataType() != cType) {
auto t = DataTypeUtils::asString(cType);
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s]\n", _descriptor->getOpName()->data(), index, t.c_str());
return ND4J_STATUS_BAD_ARGUMENTS;
}
} else {
// for same mode, output type must be the same as input type
auto iv = block.variable(index);
if (iv->getNDArray()->dataType() != cType) {
auto t = DataTypeUtils::asString(cType);
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s]\n", _descriptor->getOpName()->data(), index, t.c_str());
return ND4J_STATUS_BAD_ARGUMENTS;
}
}
} else if (_descriptor->isInherit(index)) {
// in inherit mode, output type must be the same as one of input types
if (std::find(inputTypes.begin(), inputTypes.end(), cType) == inputTypes.end()) {
auto t = DataTypeUtils::asString(cType);
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s].\n", _descriptor->getOpName()->data(), index, t.c_str());
return ND4J_STATUS_BAD_ARGUMENTS;
}
} else if (!_descriptor->checkOutputMatch(index, cType)) {
auto t = DataTypeUtils::asString(cType);
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%i];\n", _descriptor->getOpName()->data(), index, t.c_str());
return ND4J_STATUS_BAD_ARGUMENTS; return ND4J_STATUS_BAD_ARGUMENTS;
} }
} }
} else
break; cnt++;
}
}
if (block.isFastPath()) {
int index = 0;
for (auto array: block.fastpath_out()) {
auto cType = array->dataType();
if (_descriptor->isSameMode()) {
if (index >= block.width()) {
auto ia = block.fastpath_in()[0];
if (ia->dataType() != cType) {
auto t = DataTypeUtils::asString(cType);
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s]\n",
_descriptor->getOpName()->data(), index, t.c_str());
return ND4J_STATUS_BAD_ARGUMENTS;
}
} else {
// for same mode, output type must be the same as input type
auto ia = block.fastpath_in()[index];
if (ia->dataType() != cType) {
auto t = DataTypeUtils::asString(cType);
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s]\n",
_descriptor->getOpName()->data(), index, t.c_str());
return ND4J_STATUS_BAD_ARGUMENTS;
}
}
} else if (_descriptor->isInherit(index)) {
// in inherit mode, output type must be the same as one of input types
if (std::find(inputTypes.begin(), inputTypes.end(), cType) == inputTypes.end()) {
auto t = DataTypeUtils::asString(cType);
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s].\n",
_descriptor->getOpName()->data(), index, t.c_str());
return ND4J_STATUS_BAD_ARGUMENTS;
}
} else if (!_descriptor->checkOutputMatch(index, cType)) {
auto t = DataTypeUtils::asString(cType);
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s];\n",
_descriptor->getOpName()->data(), index, t.c_str());
return ND4J_STATUS_BAD_ARGUMENTS;
}
index++;
}
} else {
// checking optionally available outputs
auto varSpace = block.getVariableSpace();
for (int index = 0; index < DataTypeUtils::max<int>(); index++) {
if (varSpace != nullptr && varSpace->hasVariable(block.nodeId(), index)) {
auto var = block.variable(block.nodeId(), index);
// only validating non-null variables
if (var != nullptr && var->hasNDArray()) {
auto array = var->getNDArray();
auto cType = array->dataType();
if (_descriptor->isSameMode()) {
if (index >= block.width()) {
auto iv = block.variable(0);
if (iv->getNDArray()->dataType() != cType) {
auto t = DataTypeUtils::asString(cType);
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s]\n",
_descriptor->getOpName()->data(), index, t.c_str());
return ND4J_STATUS_BAD_ARGUMENTS;
}
} else {
// for same mode, output type must be the same as input type
auto iv = block.variable(index);
if (iv->getNDArray()->dataType() != cType) {
auto t = DataTypeUtils::asString(cType);
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s]\n",
_descriptor->getOpName()->data(), index, t.c_str());
return ND4J_STATUS_BAD_ARGUMENTS;
}
}
} else if (_descriptor->isInherit(index)) {
// in inherit mode, output type must be the same as one of input types
if (std::find(inputTypes.begin(), inputTypes.end(), cType) == inputTypes.end()) {
auto t = DataTypeUtils::asString(cType);
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s].\n",
_descriptor->getOpName()->data(), index, t.c_str());
return ND4J_STATUS_BAD_ARGUMENTS;
}
} else if (!_descriptor->checkOutputMatch(index, cType)) {
auto t = DataTypeUtils::asString(cType);
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s];\n",
_descriptor->getOpName()->data(), index, t.c_str());
return ND4J_STATUS_BAD_ARGUMENTS;
}
}
} else
break;
}
} }

View File

@ -400,6 +400,32 @@ TEST_F(JavaInteropTests, Test_Synonyms_3) {
ASSERT_EQ(nameRef, name); ASSERT_EQ(nameRef, name);
} }
TEST_F(JavaInteropTests, Test_FastPath_Validation_1) {
auto x = NDArrayFactory::create<int>('c', {4}, {1, 2, 3, 4});
auto z = NDArrayFactory::create<int>('c', {4}, {1, 2, 3, 4});
Context ctx(1);
ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo());
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
nd4j::ops::softmax op;
auto status = op.execute(&ctx);
ASSERT_NE(Status::OK(), status);
}
TEST_F(JavaInteropTests, Test_FastPath_Validation_2) {
auto x = NDArrayFactory::create<float>('c', {4}, {1.f, 2.f, 3.f, 4.f});
auto z = NDArrayFactory::create<int>('c', {4}, {1, 2, 3, 4});
Context ctx(1);
ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo());
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
nd4j::ops::softmax op;
auto status = op.execute(&ctx);
ASSERT_NE(Status::OK(), status);
}
/* /*
TEST_F(JavaInteropTests, test_avgpooling_edge_1) { TEST_F(JavaInteropTests, test_avgpooling_edge_1) {
int inOutH = 35; int inOutH = 35;

View File

@ -992,81 +992,6 @@ TEST_F(NativeOpsTests, ScalarTadTest_2) {
ASSERT_TRUE(exp.e<bool>(5) == z.e<bool>(5) && exp.e<bool>(15)); ASSERT_TRUE(exp.e<bool>(5) == z.e<bool>(5) && exp.e<bool>(15));
} }
TEST_F(NativeOpsTests, FlattenTest_1) {
auto x = NDArrayFactory::create<float>('c', {5, 5});
auto y = NDArrayFactory::create<float>('c', {5, 5});
auto exp = NDArrayFactory::create<float>('c', {2, 5,5});
auto z = NDArrayFactory::create<float>('c', {2, 5,5});
Nd4jPointer extra[6];
#ifdef __CUDABLAS__
extra[1] = x.getContext()->getCudaStream();
extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr;
x.syncToHost();
y.syncToHost();
printf("Unsupported for CUDA platform yet.\n");
return;
#endif
x.linspace(1.0,2);
y.linspace(2,2);
//y.assign(2.);
x.syncToDevice();
z.syncToDevice();
auto dimension = NDArrayFactory::create<int>({0, 1});
auto dimensions = reinterpret_cast<int*>(dimension.buffer());
auto tadPackX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), dimensions, dimension.lengthOf());
auto tadPackZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(z.shapeInfo(), dimensions, dimension.lengthOf());
exp(1, {0}).linspace(1,2);
::flatten(extra,
25, 'c', z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
x.buffer(), x.shapeInfo(),
x.specialBuffer(), x.specialShapeInfo());
// exp.printIndexedBuffer("Exp");
// z.printIndexedBuffer("Flatten");
ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(NativeOpsTests, ConcatTest_1) {
auto x = NDArrayFactory::create<float>('c', {5, 5});
auto y = NDArrayFactory::create<float>('c', {5, 5});
auto exp = NDArrayFactory::create<float>('c', {10,5});
auto z = NDArrayFactory::create<float>('c', {10,5});
Nd4jPointer extra[6];
#ifdef __CUDABLAS__
extra[1] = x.getContext()->getCudaStream();
extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr;
x.syncToHost();
y.syncToHost();
printf("Unsupported for CUDA platform yet.\n");
return;
#endif
x.linspace(1.0);
y.linspace(26);
//y.assign(2.);
x.syncToDevice();
z.syncToDevice();
int d = 0;
auto dimension = NDArrayFactory::create<int>('c', {1}, {d});
auto dimensions = reinterpret_cast<int*>(dimension.buffer());
//auto tadPackX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), dimensions, dimension.lengthOf());
auto tadPackZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(z.shapeInfo(), dimensions, dimension.lengthOf());
exp.linspace(1);
Nd4jPointer datas[] = {x.buffer(), y.buffer()};
Nd4jPointer shapes[] = {x.shapeInfo(), y.shapeInfo()};
::concat(extra,
0, 2, datas, shapes, nullptr, nullptr, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
nullptr, nullptr);
// exp.printIndexedBuffer("Exp");
// z.printIndexedBuffer("Concat");
ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(NativeOpsTests, ConcatTest_2) { TEST_F(NativeOpsTests, ConcatTest_2) {
auto x = NDArrayFactory::create<float>('c', {5, 5}); auto x = NDArrayFactory::create<float>('c', {5, 5});
auto y = NDArrayFactory::create<float>('c', {5, 5}); auto y = NDArrayFactory::create<float>('c', {5, 5});

View File

@ -557,43 +557,6 @@ public interface NativeOps {
@Cast("Nd4jLong *") LongPointer tadShapeInfo, @Cast("Nd4jLong *") LongPointer tadOffsets, @Cast("Nd4jLong *") LongPointer tadShapeInfo, @Cast("Nd4jLong *") LongPointer tadOffsets,
@Cast("Nd4jLong *") LongPointer tadShapeInfoZ, @Cast("Nd4jLong *") LongPointer tadOffsetsZ); @Cast("Nd4jLong *") LongPointer tadShapeInfoZ, @Cast("Nd4jLong *") LongPointer tadOffsetsZ);
/**
* @param extraPointers
* @param offset
* @param order
* @param results
* @param resultShapeInfo
* @param input
* @param inputShapeInfo
*/
void flatten(PointerPointer extraPointers,
int offset,
char order,
Pointer results, @Cast("Nd4jLong *") LongPointer resultShapeInfo,
Pointer dresults, @Cast("Nd4jLong *") LongPointer dresultShapeInfo,
Pointer input, @Cast("Nd4jLong *") LongPointer inputShapeInfo,
Pointer dinput, @Cast("Nd4jLong *") LongPointer dinputShapeInfo);
/**
* @param extraPointers
* @param dimension
* @param numArrays
* @param data
* @param inputShapeInfo
* @param results
* @param resultShapeInfo
* @param tadPointers
* @param tadOffsets
*/
void concat(PointerPointer extraPointers,
int dimension,
int numArrays,
PointerPointer data, PointerPointer inputShapeInfo,
PointerPointer ddata, PointerPointer dinputShapeInfo,
Pointer results, @Cast("Nd4jLong *") LongPointer resultShapeInfo,
Pointer dresults, @Cast("Nd4jLong *") LongPointer dresultShapeInfo,
PointerPointer tadPointers,
PointerPointer tadOffsets);
void specialConcat(PointerPointer extraPointers, void specialConcat(PointerPointer extraPointers,
int dimension, int dimension,
@ -1185,4 +1148,7 @@ public interface NativeOps {
Pointer lcCopyStream(OpaqueLaunchContext lc); Pointer lcCopyStream(OpaqueLaunchContext lc);
Pointer lcBlasHandle(OpaqueLaunchContext lc); Pointer lcBlasHandle(OpaqueLaunchContext lc);
Pointer lcSolverHandle(OpaqueLaunchContext lc); Pointer lcSolverHandle(OpaqueLaunchContext lc);
int lastErrorCode();
String lastErrorMessage();
} }

View File

@ -22,6 +22,7 @@ import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.pointers.CudaPointer; import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.linalg.exception.ND4JException; import org.nd4j.linalg.exception.ND4JException;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder; import org.nd4j.nativeblas.NativeOpsHolder;
/** /**
@ -67,14 +68,18 @@ public class cudaEvent_t extends CudaPointer {
int res = NativeOpsHolder.getInstance().getDeviceNativeOps().eventSynchronize(this); int res = NativeOpsHolder.getInstance().getDeviceNativeOps().eventSynchronize(this);
if (res == 0) if (res == 0)
throw new ND4JException("CUDA exception happened. Terminating. Last op: [" + Nd4j.getExecutioner().getLastOp() +"]"); throw new ND4JException("CUDA exception happened. Terminating. Last op: [" + Nd4j.getExecutioner().getLastOp() +"]");
if (NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode() != 0)
throw new RuntimeException(NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage());
} }
} }
public void register(cudaStream_t stream) { public void register(cudaStream_t stream) {
if (!isDestroyed()) { if (!isDestroyed()) {
int res = NativeOpsHolder.getInstance().getDeviceNativeOps().registerEvent(this, stream); int res = NativeOpsHolder.getInstance().getDeviceNativeOps().registerEvent(this, stream);
if (res == 0)
throw new ND4JException("CUDA exception happened. Terminating. Last op: [" + Nd4j.getExecutioner().getLastOp() +"]"); if (NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode() != 0)
throw new RuntimeException(NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage());
} }
} }
} }

View File

@ -36,8 +36,9 @@ public class cudaStream_t extends CudaPointer {
public int synchronize() { public int synchronize() {
NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
int res = nativeOps.streamSynchronize(this); int res = nativeOps.streamSynchronize(this);
if (res == 0)
throw new ND4JException("CUDA exception happened. Terminating. Last op: [" + Nd4j.getExecutioner().getLastOp() +"]"); if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
return res; return res;
} }

View File

@ -24,6 +24,7 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.DataTypeEx; import org.nd4j.linalg.api.buffer.DataTypeEx;
import org.nd4j.linalg.api.buffer.Utf8Buffer; import org.nd4j.linalg.api.buffer.Utf8Buffer;
import org.nd4j.linalg.api.memory.enums.MemoryKind; import org.nd4j.linalg.api.memory.enums.MemoryKind;
import org.nd4j.linalg.api.ops.custom.Flatten;
import org.nd4j.linalg.api.ops.impl.shape.Concat; import org.nd4j.linalg.api.ops.impl.shape.Concat;
import org.nd4j.linalg.api.ops.performance.PerformanceTracker; import org.nd4j.linalg.api.ops.performance.PerformanceTracker;
import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper; import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper;
@ -104,6 +105,9 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
functions.put(11, Loader.addressof("cusolverDnSgesvd")); functions.put(11, Loader.addressof("cusolverDnSgesvd"));
functions.put(12, Loader.addressof("cusolverDnDgesvd")); functions.put(12, Loader.addressof("cusolverDnDgesvd"));
nativeOps.initializeFunctions(functions); nativeOps.initializeFunctions(functions);
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
} }
@Override @Override
@ -335,75 +339,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
if (Nd4j.getExecutioner() instanceof GridExecutioner) if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); ((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
int length = 0; return Nd4j.exec(new Flatten(order, matrices.toArray(new INDArray[0])))[0];
DataType t = null;
for (INDArray m : matrices) {
length += m.length();
if (t == null)
t = m.dataType();
Preconditions.checkArgument(t == m.dataType(), "Arrays must have same data type");
}
INDArray ret = Nd4j.create(t, new long[] {length}, order);
int linearIndex = 0;
AtomicAllocator allocator = AtomicAllocator.getInstance();
for (INDArray m : matrices) {
if (m.isEmpty())
continue;
CudaContext context = allocator.getFlowController().prepareAction(ret, m);
if (m.ordering() == order && ret.elementWiseStride() == m.elementWiseStride()
&& ret.elementWiseStride() == 1) {
// do memcpy in proper direction and forget about that
// FIXME: get rid of this
((BaseCudaDataBuffer) m.data()).lazyAllocateHostPointer();
allocator.memcpyAsync(ret.data(), new CudaPointer(allocator.getHostPointer(m).address()),
AllocationUtils.getRequiredMemory(AllocationUtils.buildAllocationShape(m)),
linearIndex * (m.data().dataType() == DataType.DOUBLE ? 8
: m.data().dataType() == DataType.FLOAT ? 4 : 2));
linearIndex += m.length();
} else {
Pointer hostYShapeInfo = AddressRetriever.retrieveHostPointer(m.shapeInfoDataBuffer());
PointerPointer extras = new PointerPointer(
AddressRetriever.retrieveHostPointer(ret.shapeInfoDataBuffer()), context.getOldStream(),
allocator.getDeviceIdPointer(), null,
context.getBufferReduction(), context.getBufferScalar(), null,
hostYShapeInfo, AddressRetriever.retrieveHostPointer(ret.shapeInfoDataBuffer()));
nativeOps.flatten(extras, linearIndex, order,
null,
(LongPointer) allocator.getHostPointer(ret.shapeInfoDataBuffer()),
allocator.getPointer(ret, context),
(LongPointer) allocator.getPointer(ret.shapeInfoDataBuffer(), context),
null,
(LongPointer) allocator.getHostPointer(m.shapeInfoDataBuffer()),
allocator.getPointer(m, context),
(LongPointer) allocator.getPointer(m.shapeInfoDataBuffer(), context));
//Works for all cases...
/* NdIndexIterator iter = new NdIndexIterator(order, m.shape());
while (iter.hasNext()) {
ret.putScalar(linearIndex++, m.getDouble(iter.next()));
}*/
linearIndex += m.length();
}
if (ret != null)
allocator.registerAction(context, ret, m);
}
return ret;
} }
@Override @Override
@ -412,131 +348,6 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); ((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
return Nd4j.exec(new Concat(dimension, toConcat))[0]; return Nd4j.exec(new Concat(dimension, toConcat))[0];
// legacy implementation
/*
boolean allScalars = true;
var outputShape = ArrayUtil.copy(toConcat[0].shape());
if (toConcat.length == 1)
return toConcat[0];
int sumAlongDim = 0;
for (int i = 0; i < toConcat.length; i++) {
if (toConcat[i].isCompressed())
Nd4j.getCompressor().decompressi(toConcat[i]);
allScalars &= toConcat[i].rank() == 0;
sumAlongDim += toConcat[i].size(dimension);
}
if (allScalars) {
outputShape = new long[]{sumAlongDim};
} else {
outputShape[dimension] = sumAlongDim;
}
INDArray ret = Nd4j.createUninitialized(toConcat[0].dataType(), outputShape, Nd4j.order());
AtomicAllocator allocator = AtomicAllocator.getInstance();
CudaContext context = allocator.getFlowController().prepareAction(ret, toConcat);
val shapeInfoPointers = new long[toConcat.length];
val dataPointers = new long[toConcat.length];
val tadPointers = new long[toConcat.length];
val offsetsPointers = new long[toConcat.length];
val hostShapeInfoPointers = new long[toConcat.length];
TADManager tadManager = Nd4j.getExecutioner().getTADManager();
for (int i = 0; i < toConcat.length; i++) {
shapeInfoPointers[i] = AddressRetriever.retrieveDeviceAddress(toConcat[i].shapeInfoDataBuffer(), context);
dataPointers[i] = AtomicAllocator.getInstance().getPointer(toConcat[i], context).address();
hostShapeInfoPointers[i] = AtomicAllocator.getInstance().getHostPointer(toConcat[i].shapeInfoDataBuffer()).address();
sumAlongDim += toConcat[i].size(dimension);
for (int j = 0; j < toConcat[i].rank(); j++)
if (j != dimension && toConcat[i].size(j) != outputShape[j]) {
throw new IllegalArgumentException(
"Illegal concatenation at array " + i + " and shape element " + j);
}
if (!allScalars) {
val tadBuffers = tadManager.getTADOnlyShapeInfo(toConcat[i], new int[]{dimension});
long devTadShapeInfo = AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context).address();
val offsets = tadBuffers.getSecond();
long devTadOffsets = AtomicAllocator.getInstance().getPointer(offsets, context).address();
tadPointers[i] = devTadShapeInfo;
offsetsPointers[i] = devTadOffsets;
}
}
// getting tadOnlyShape for result
val zBuffers = tadManager.getTADOnlyShapeInfo(ret, new int[] {dimension});
val hostPointers = new LongPointer(hostShapeInfoPointers);
val hosthost = new PointerPointerWrapper(hostPointers);
//System.out.println("shapePointers: " + Arrays.toString(shapeInfoPointers));
val dZ = AtomicAllocator.getInstance().getPointer(ret, context);
val dZShapeInfo = AddressRetriever.retrieveDevicePointer(ret.shapeInfoDataBuffer(), context);
//val tempData = new CudaDoubleDataBuffer(toConcat.length);
//val tempShapes = new CudaDoubleDataBuffer(toConcat.length);
//val tempTAD = new CudaDoubleDataBuffer(toConcat.length);
//val tempOffsets = new CudaDoubleDataBuffer(toConcat.length);
//AtomicAllocator.getInstance().memcpyBlocking(tempData, new LongPointer(dataPointers), dataPointers.length * 8,0);
//AtomicAllocator.getInstance().memcpyBlocking(tempShapes, new LongPointer(shapeInfoPointers), shapeInfoPointers.length * 8, 0);
//AtomicAllocator.getInstance().memcpyBlocking(tempTAD, new LongPointer(tadPointers), tadPointers.length * 8, 0);
//AtomicAllocator.getInstance().memcpyBlocking(tempOffsets, new LongPointer(offsetsPointers), offsetsPointers.length * 8, 0);
val dataPointer = new PointerPointerWrapper(new LongPointer(dataPointers)); //AtomicAllocator.getInstance().getPointer(tempData, context);
val shapesPointer = new PointerPointerWrapper(new LongPointer(shapeInfoPointers));//AtomicAllocator.getInstance().getPointer(tempShapes, context);
//val tadPointer = AtomicAllocator.getInstance().getPointer(tempTAD, context);
//val offsetPointer = AtomicAllocator.getInstance().getPointer(tempOffsets, context);
// System.out.println("ShapesPointer after conversion: " + shapesPointer);
val extras = new PointerPointer(AddressRetriever.retrieveHostPointer(ret.shapeInfoDataBuffer()),
context.getOldStream(), allocator.getDeviceIdPointer(), null,
context.getBufferReduction(), context.getBufferScalar(), null,
AddressRetriever.retrieveHostPointer(toConcat[0].shapeInfoDataBuffer()),
AddressRetriever.retrieveHostPointer(ret.shapeInfoDataBuffer()),
new LongPointer(hostShapeInfoPointers),
AtomicAllocator.getInstance().getPointer(zBuffers.getFirst(), context), // getting zTADShape
AtomicAllocator.getInstance().getPointer(zBuffers.getSecond(), context) // getting zOffset
);
nativeOps.concat(extras,
dimension,
toConcat.length,
null,
hosthost,
dataPointer,
shapesPointer,
null,
(LongPointer) ret.shapeInfoDataBuffer().addressPointer(),
dZ,
(LongPointer) dZShapeInfo,
null,
null);
allocator.registerAction(context, ret, toConcat);
return ret;
//return super.concat(dimension, toConcat);
*/
} }
@ -590,6 +401,8 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
(LongPointer) ret.shapeInfoDataBuffer().addressPointer(), (LongPointer) ret.shapeInfoDataBuffer().addressPointer(),
null, null); null, null);
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
AllocationPoint point = allocator.getAllocationPoint(ret); AllocationPoint point = allocator.getAllocationPoint(ret);
@ -598,6 +411,9 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
nativeOps.memcpyAsync(point.getDevicePointer(), point.getHostPointer(), ret.lengthLong() * Nd4j.sizeOfDataType(ret.data().dataType()), CudaConstants.cudaMemcpyHostToDevice, context.getSpecialStream()); nativeOps.memcpyAsync(point.getDevicePointer(), point.getHostPointer(), ret.lengthLong() * Nd4j.sizeOfDataType(ret.data().dataType()), CudaConstants.cudaMemcpyHostToDevice, context.getSpecialStream());
context.getSpecialStream().synchronize(); context.getSpecialStream().synchronize();
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), perfD, point.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE); PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), perfD, point.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE);
point.tickHostRead(); point.tickHostRead();
@ -729,6 +545,8 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
(LongPointer) zTadShapeInfo, (LongPointer) zTadShapeInfo,
new LongPointerWrapper(zTadOffsets)); new LongPointerWrapper(zTadOffsets));
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
allocator.registerAction(context, ret, source); allocator.registerAction(context, ret, source);
@ -743,7 +561,6 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
return target.assign(arrays[0]); return target.assign(arrays[0]);
// we do averaging on GPU only if ALL devices have p2p links // we do averaging on GPU only if ALL devices have p2p links
//if (CudaEnvironment.getInstance().getConfiguration().isCrossDeviceAccessAllowed() && nativeOps.isP2PAvailable()) {
if (true) { if (true) {
Nd4j.getExecutioner().push(); Nd4j.getExecutioner().push();
@ -781,6 +598,9 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
nativeOps.accumulate(extras, null, (LongPointer) arrays[0].shapeInfoDataBuffer().addressPointer(), x, null, null, (LongPointer) allocator.getHostPointer(target.shapeInfoDataBuffer()) , z, (LongPointer) allocator.getPointer(target.shapeInfoDataBuffer()), arrays.length, len); nativeOps.accumulate(extras, null, (LongPointer) arrays[0].shapeInfoDataBuffer().addressPointer(), x, null, null, (LongPointer) allocator.getHostPointer(target.shapeInfoDataBuffer()) , z, (LongPointer) allocator.getPointer(target.shapeInfoDataBuffer()), arrays.length, len);
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
allocator.getFlowController().registerAction(context, target, arrays); allocator.getFlowController().registerAction(context, target, arrays);
return target; return target;
@ -824,6 +644,8 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
arrays.length, arrays.length,
len); len);
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
AtomicAllocator.getInstance().getAllocationPoint(target).tickHostWrite(); AtomicAllocator.getInstance().getAllocationPoint(target).tickHostWrite();
@ -895,6 +717,9 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
arrays.length, arrays.length,
len, true); len, true);
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
allocator.getFlowController().registerAction(context, target, arrays); allocator.getFlowController().registerAction(context, target, arrays);
return target; return target;
@ -940,6 +765,9 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
arrays.length, arrays.length,
len, true); len, true);
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
if (target != null) if (target != null)
AtomicAllocator.getInstance().getAllocationPoint(target).tickHostWrite(); AtomicAllocator.getInstance().getAllocationPoint(target).tickHostWrite();
@ -1115,6 +943,9 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
(IntPointer) shuffleMap, new PointerPointer(allocator.getPointer(tempTAD, context)), (IntPointer) shuffleMap, new PointerPointer(allocator.getPointer(tempTAD, context)),
new PointerPointer(allocator.getPointer(tempOffsets, context))); new PointerPointer(allocator.getPointer(tempOffsets, context)));
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
for (int f = 0; f < arrays.size(); f++) { for (int f = 0; f < arrays.size(); f++) {
allocator.getFlowController().registerAction(context, arrays.get(f)); allocator.getFlowController().registerAction(context, arrays.get(f));
} }
@ -1260,6 +1091,9 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
val p = new PointerPointer<>(new Pointer[]{null, stream}); val p = new PointerPointer<>(new Pointer[]{null, stream});
nativeOps.convertTypes(p, typeSrc.ordinal(), source, length, typeDst.ordinal(), target); nativeOps.convertTypes(p, typeSrc.ordinal(), source, length, typeDst.ordinal(), target);
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
} }
@Override @Override
@ -1277,7 +1111,13 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
srcPtr = nativeOps.mallocDevice(ssize, 0, 0); srcPtr = nativeOps.mallocDevice(ssize, 0, 0);
dstPtr = nativeOps.mallocDevice(size, 0, 0); dstPtr = nativeOps.mallocDevice(size, 0, 0);
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
nativeOps.memcpyAsync(srcPtr, source, ssize, CudaConstants.cudaMemcpyHostToDevice, stream); nativeOps.memcpyAsync(srcPtr, source, ssize, CudaConstants.cudaMemcpyHostToDevice, stream);
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
} else { } else {
// decompressing // decompressing
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
@ -1288,9 +1128,15 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
stream.synchronize(); stream.synchronize();
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
if (buffer instanceof CompressedDataBuffer) { if (buffer instanceof CompressedDataBuffer) {
nativeOps.freeDevice(srcPtr, 0); nativeOps.freeDevice(srcPtr, 0);
nativeOps.freeDevice(dstPtr, 0); nativeOps.freeDevice(dstPtr, 0);
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
} }
} }
@ -1309,13 +1155,15 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
val size = ((CompressedDataBuffer) source).getCompressionDescriptor().getCompressedLength(); val size = ((CompressedDataBuffer) source).getCompressionDescriptor().getCompressedLength();
srcPtr = ws.alloc(size, MemoryKind.DEVICE, DataType.HALF, false); srcPtr = ws.alloc(size, MemoryKind.DEVICE, DataType.HALF, false);
nativeOps.memcpyAsync(srcPtr, source.addressPointer(), size, CudaConstants.cudaMemcpyHostToHost, stream); nativeOps.memcpyAsync(srcPtr, source.addressPointer(), size, CudaConstants.cudaMemcpyHostToHost, stream);
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
} }
// if true - we're compressing into host memory // if true - we're compressing into host memory
if (target instanceof CompressedDataBuffer) { if (target instanceof CompressedDataBuffer) {
val size = ((CompressedDataBuffer) target).getCompressionDescriptor().getCompressedLength(); val size = ((CompressedDataBuffer) target).getCompressionDescriptor().getCompressedLength();
dstPtr = ws.alloc(size, MemoryKind.DEVICE, DataType.HALF, false); dstPtr = ws.alloc(size, MemoryKind.DEVICE, DataType.HALF, false);
//nativeOps.memcpyAsync(dstPtr, target.addressPointer(), size, CudaConstants.cudaMemcpyHostToHost, stream);
} }
} else { } else {
// if true - we're decompressing from host memory // if true - we're decompressing from host memory
@ -1325,6 +1173,9 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
srcPtr = nativeOps.mallocDevice(size, 0, 0); srcPtr = nativeOps.mallocDevice(size, 0, 0);
nativeOps.memcpyAsync(srcPtr, source.addressPointer(), size, CudaConstants.cudaMemcpyHostToHost, stream); nativeOps.memcpyAsync(srcPtr, source.addressPointer(), size, CudaConstants.cudaMemcpyHostToHost, stream);
stream.synchronize(); stream.synchronize();
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
} else } else
srcPtr = AtomicAllocator.getInstance().getPointer(source); srcPtr = AtomicAllocator.getInstance().getPointer(source);
@ -1333,8 +1184,9 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
log.info("Replacing target ptr"); log.info("Replacing target ptr");
val size = ((CompressedDataBuffer) target).getCompressionDescriptor().getCompressedLength(); val size = ((CompressedDataBuffer) target).getCompressionDescriptor().getCompressedLength();
dstPtr = nativeOps.mallocDevice(size, 0, 0); dstPtr = nativeOps.mallocDevice(size, 0, 0);
//nativeOps.memcpyAsync(dstPtr, source.addressPointer(), size, CudaConstants.cudaMemcpyHostToHost, stream);
//stream.synchronize(); if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
} else } else
dstPtr = AtomicAllocator.getInstance().getPointer(target); dstPtr = AtomicAllocator.getInstance().getPointer(target);
} }
@ -1342,6 +1194,9 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
convertDataEx(typeSrc, srcPtr, typeDst, dstPtr, target.length()); convertDataEx(typeSrc, srcPtr, typeDst, dstPtr, target.length());
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
Nd4j.getExecutioner().commit(); Nd4j.getExecutioner().commit();
@ -1364,6 +1219,9 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
} }
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
Nd4j.getExecutioner().commit(); Nd4j.getExecutioner().commit();
} }
@ -1462,6 +1320,9 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
new LongPointerWrapper(AtomicAllocator.getInstance().getPointer(tadBuffers.getSecond(), context)) new LongPointerWrapper(AtomicAllocator.getInstance().getPointer(tadBuffers.getSecond(), context))
); );
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
AtomicAllocator.getInstance().getFlowController().registerActionAllWrite(context, result); AtomicAllocator.getInstance().getFlowController().registerActionAllWrite(context, result);
AtomicAllocator.getInstance().getFlowController().registerAction(context,null, result); AtomicAllocator.getInstance().getFlowController().registerAction(context,null, result);
@ -1517,6 +1378,8 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
descending descending
); );
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
AtomicAllocator.getInstance().getFlowController().registerAction(context, x); AtomicAllocator.getInstance().getFlowController().registerAction(context, x);
@ -1565,6 +1428,8 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
descending descending
); );
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
AtomicAllocator.getInstance().getFlowController().registerAction(context, x); AtomicAllocator.getInstance().getFlowController().registerAction(context, x);

View File

@ -207,6 +207,10 @@ public class CudaExecutioner extends DefaultOpExecutioner {
throw new UnsupportedOperationException("Unknown op type: " + op.getOpType()); throw new UnsupportedOperationException("Unknown op type: " + op.getOpType());
} }
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y()); AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
profilingConfigurableHookOut(op, st); profilingConfigurableHookOut(op, st);
@ -461,6 +465,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
} }
} }
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
profilingConfigurableHookOut(op, st); profilingConfigurableHookOut(op, st);
return op.z(); return op.z();
@ -619,7 +626,8 @@ public class CudaExecutioner extends DefaultOpExecutioner {
AtomicAllocator.getInstance().getPointer(op.dimensions(), context), AtomicAllocator.getInstance().getPointer(op.dimensions(), context),
null); null);
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y()); AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
@ -777,6 +785,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
throw new UnsupportedOperationException("Unknown opType: " + op.getOpType()); throw new UnsupportedOperationException("Unknown opType: " + op.getOpType());
} }
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y()); AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
profilingConfigurableHookOut(op, st); profilingConfigurableHookOut(op, st);
@ -868,6 +879,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
AtomicAllocator.getInstance().registerAction(context, null, op.x(), op.y()); AtomicAllocator.getInstance().registerAction(context, null, op.x(), op.y());
} }
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
profilingConfigurableHookOut(op, st); profilingConfigurableHookOut(op, st);
return null; return null;
@ -1105,6 +1119,8 @@ public class CudaExecutioner extends DefaultOpExecutioner {
AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y()); AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
} }
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
profilingConfigurableHookOut(op, st); profilingConfigurableHookOut(op, st);
@ -1194,6 +1210,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
AtomicAllocator.getInstance().getFlowController().registerAction(context, op.z(), op.x(), op.y()); AtomicAllocator.getInstance().getFlowController().registerAction(context, op.z(), op.x(), op.y());
profilingConfigurableHookOut(op, st); profilingConfigurableHookOut(op, st);
@ -1268,6 +1287,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
throw new UnsupportedOperationException("Unknown op type: " + op.getOpType()); throw new UnsupportedOperationException("Unknown op type: " + op.getOpType());
} }
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.scalar()); AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.scalar());
profilingConfigurableHookOut(op, st); profilingConfigurableHookOut(op, st);
@ -1423,6 +1445,8 @@ public class CudaExecutioner extends DefaultOpExecutioner {
} }
} }
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y()); AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
@ -1582,6 +1606,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
batch.getSample().maxIndexArguments(), batch.getSample().maxRealArguments(), batch.getSample().maxIndexArguments(), batch.getSample().maxRealArguments(),
AtomicAllocator.getInstance().getPointer(surfaceBuffer, context), FlatBuffersMapper.getDataTypeAsByte(dataType)); AtomicAllocator.getInstance().getPointer(surfaceBuffer, context), FlatBuffersMapper.getDataTypeAsByte(dataType));
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
surfacePoint.tickHostWrite(); surfacePoint.tickHostWrite();
} }
@ -1676,6 +1703,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
numIndexArguments, iPtr, numIntArrays, numIndexArguments, iPtr, numIntArrays,
AtomicAllocator.getInstance().getPointer(realsBuffer.data(), context), AtomicAllocator.getInstance().getPointer(realsBuffer.data(), context),
numRealArguments, FlatBuffersMapper.getDataTypeAsByte(dataType)); numRealArguments, FlatBuffersMapper.getDataTypeAsByte(dataType));
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
} }
/** /**
@ -1739,6 +1769,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.z().dataType()), context)); AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.z().dataType()), context));
} }
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
AtomicAllocator.getInstance().getFlowController().registerAction(context, op.z(), op.x(), op.y()); AtomicAllocator.getInstance().getFlowController().registerAction(context, op.z(), op.x(), op.y());
profilingConfigurableHookOut(op, st); profilingConfigurableHookOut(op, st);
@ -1969,6 +2002,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
nativeOps.decodeThreshold(extras, AtomicAllocator.getInstance().getPointer(buffer), compressedLength, AtomicAllocator.getInstance().getPointer(result), (LongPointer) AtomicAllocator.getInstance().getHostPointer(target.shapeInfoDataBuffer())); nativeOps.decodeThreshold(extras, AtomicAllocator.getInstance().getPointer(buffer), compressedLength, AtomicAllocator.getInstance().getPointer(result), (LongPointer) AtomicAllocator.getInstance().getHostPointer(target.shapeInfoDataBuffer()));
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
AtomicAllocator.getInstance().getAllocationPoint(result).tickDeviceWrite(); AtomicAllocator.getInstance().getAllocationPoint(result).tickDeviceWrite();
return target; return target;
@ -2013,7 +2049,8 @@ public class CudaExecutioner extends DefaultOpExecutioner {
(IntPointer) AtomicAllocator.getInstance().getPointer(buffer, context), (IntPointer) AtomicAllocator.getInstance().getPointer(buffer, context),
(float) threshold); (float) threshold);
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
AtomicAllocator.getInstance().getFlowController().registerAction(context, indArray); AtomicAllocator.getInstance().getFlowController().registerAction(context, indArray);
@ -2039,6 +2076,8 @@ public class CudaExecutioner extends DefaultOpExecutioner {
nativeOps.decodeBitmap(extras, AtomicAllocator.getInstance().getPointer(encoded.data(), context), target.lengthLong(), AtomicAllocator.getInstance().getPointer(target, context), (LongPointer) AtomicAllocator.getInstance().getHostPointer(target.shapeInfoDataBuffer())); nativeOps.decodeBitmap(extras, AtomicAllocator.getInstance().getPointer(encoded.data(), context), target.lengthLong(), AtomicAllocator.getInstance().getPointer(target, context), (LongPointer) AtomicAllocator.getInstance().getHostPointer(target.shapeInfoDataBuffer()));
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
AtomicAllocator.getInstance().getFlowController().registerAction(context, target); AtomicAllocator.getInstance().getFlowController().registerAction(context, target);
@ -2151,6 +2190,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
OpaqueShapeList ptrptr = nativeOps.calculateOutputShapes2(null, hash, inputBuffers, inputShapes, op.inputArguments().length, tArgs, op.tArgs().length, iArgs, op.iArgs().length, bArgs, op.numBArguments()); OpaqueShapeList ptrptr = nativeOps.calculateOutputShapes2(null, hash, inputBuffers, inputShapes, op.inputArguments().length, tArgs, op.tArgs().length, iArgs, op.iArgs().length, bArgs, op.numBArguments());
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
if (ptrptr == null) if (ptrptr == null)
throw new RuntimeException(); throw new RuntimeException();
@ -2221,109 +2263,6 @@ public class CudaExecutioner extends DefaultOpExecutioner {
} catch (Exception e) { } catch (Exception e) {
throw new RuntimeException("Op [" + name + "] execution failed", e); throw new RuntimeException("Op [" + name + "] execution failed", e);
} }
/*
long st = profilingConfigurableHookIn(op);
CudaContext context =(CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
//AtomicAllocator.getInstance().getFlowController().prepareActionAllWrite(op.outputArguments());
if (extraz.get() == null)
extraz.set(new PointerPointer(32));
PointerPointer extras = extraz.get().put(
new CudaPointer(1),
context.getOldStream(),
context.getBufferScalar(),
context.getBufferReduction());
val outputArgs = op.outputArguments();
val inputArgs = op.inputArguments();
if (outputArgs.length == 0 && !op.isInplaceCall())
throw new ND4JIllegalStateException("You can't execute non-inplace CustomOp without outputs being specified");
val lc = op.opName().toLowerCase();
val hash = op.opHash();
val inputShapes = new PointerPointer<>(inputArgs.length * 2);
val inputBuffers = new PointerPointer<>(inputArgs.length * 2);
int cnt= 0;
for (val in: inputArgs) {
val hp = AtomicAllocator.getInstance().getHostPointer(in.shapeInfoDataBuffer());
inputBuffers.put(cnt, AtomicAllocator.getInstance().getHostPointer(in));
inputShapes.put(cnt, hp);
val dp = AtomicAllocator.getInstance().getPointer(in.shapeInfoDataBuffer(), context);
inputBuffers.put(cnt + inputArgs.length, AtomicAllocator.getInstance().getPointer(in, context));
inputShapes.put(cnt+ inputArgs.length, dp);
if (op.isInplaceCall()) {
val ap = AtomicAllocator.getInstance().getAllocationPoint(in);
if (ap != null)
ap.tickHostWrite();
}
cnt++;
}
val outputShapes = new PointerPointer<>(outputArgs.length * 2);
val outputBuffers = new PointerPointer<>(outputArgs.length * 2);
cnt= 0;
for (val out: outputArgs) {
outputBuffers.put(cnt, AtomicAllocator.getInstance().getHostPointer(out));
outputShapes.put(cnt, AtomicAllocator.getInstance().getHostPointer(out.shapeInfoDataBuffer()));
outputBuffers.put(cnt + outputArgs.length, AtomicAllocator.getInstance().getPointer(out, context));
outputShapes.put(cnt + outputArgs.length, AtomicAllocator.getInstance().getPointer(out.shapeInfoDataBuffer(), context));
val ap = AtomicAllocator.getInstance().getAllocationPoint(out);
if (ap != null)
ap.tickHostWrite();
cnt++;
}
val iArgs = op.iArgs().length > 0 ? new LongPointer(op.iArgs().length) : null;
cnt = 0;
for (val i: op.iArgs())
iArgs.put(cnt++, i);
val tArgs = op.tArgs().length > 0 ? new DoublePointer(op.tArgs().length) : null;
val bArgs = op.bArgs().length > 0 ? new BooleanPointer(op.numBArguments()) : null;
cnt = 0;
for (val t: op.tArgs())
tArgs.put(cnt++, t);
cnt = 0;
for (val b: op.bArgs())
bArgs.put(cnt++, b);
try {
val status = OpStatus.byNumber(nativeOps.execCustomOp(extras, hash, inputBuffers, inputShapes, inputArgs.length, outputBuffers, outputShapes, outputArgs.length, tArgs, op.tArgs().length, iArgs, op.iArgs().length, bArgs, op.numBArguments(), op.isInplaceCall()));
if (status != OpStatus.ND4J_STATUS_OK)
throw new ND4JIllegalStateException("Op execution failed: " + status);
} catch (Exception e) {
throw new RuntimeException("Op [" + op.opName() + "] execution failed");
}
//AtomicAllocator.getInstance().getFlowController().prepareActionAllWrite(op.outputArguments());
profilingConfigurableHookOut(op, st);
return op.outputArguments();
*/
} }
@Override @Override
@ -2341,6 +2280,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
@Override @Override
public void registerGraph(long id, Pointer graph) { public void registerGraph(long id, Pointer graph) {
nativeOps.registerGraph(null, id, graph); nativeOps.registerGraph(null, id, graph);
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
} }
@Override @Override
@ -2368,6 +2310,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
OpaqueVariablesSet result = nativeOps.executeStoredGraph(null, id, ptrBuffers, ptrShapes, ptrIndices, map.size()); OpaqueVariablesSet result = nativeOps.executeStoredGraph(null, id, ptrBuffers, ptrShapes, ptrIndices, map.size());
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
OpStatus status = OpStatus.byNumber(nativeOps.getVariablesSetStatus(result)); OpStatus status = OpStatus.byNumber(nativeOps.getVariablesSetStatus(result));
if (status != OpStatus.ND4J_STATUS_OK) if (status != OpStatus.ND4J_STATUS_OK)
@ -2398,6 +2343,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
newMap.put(nodeName, array); newMap.put(nodeName, array);
} }
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
nativeOps.deleteVariablesSet(result); nativeOps.deleteVariablesSet(result);
return newMap; return newMap;
@ -2406,6 +2354,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
@Override @Override
public void forgetGraph(long id) { public void forgetGraph(long id) {
nativeOps.unregisterGraph(null, id); nativeOps.unregisterGraph(null, id);
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
} }
/** /**
@ -2474,6 +2425,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
null, (LongPointer) AtomicAllocator.getInstance().getHostPointer(tadY.getFirst()), null, AtomicAllocator.getInstance().getPointer(updates, context), (LongPointer) AtomicAllocator.getInstance().getPointer(tadY.getFirst()), (LongPointer) AtomicAllocator.getInstance().getPointer(tadY.getSecond()), null, (LongPointer) AtomicAllocator.getInstance().getHostPointer(tadY.getFirst()), null, AtomicAllocator.getInstance().getPointer(updates, context), (LongPointer) AtomicAllocator.getInstance().getPointer(tadY.getFirst()), (LongPointer) AtomicAllocator.getInstance().getPointer(tadY.getSecond()),
null, (IntPointer) AtomicAllocator.getInstance().getPointer(indices, context)); null, (IntPointer) AtomicAllocator.getInstance().getPointer(indices, context));
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
AtomicAllocator.getInstance().getFlowController().registerAction(context, array, indices, updates); AtomicAllocator.getInstance().getFlowController().registerAction(context, array, indices, updates);
} }
@ -2490,9 +2444,14 @@ public class CudaExecutioner extends DefaultOpExecutioner {
((CudaOpContext) context).setCudaStream(ctx.getOldStream(), ctx.getBufferReduction(), ctx.getBufferAllocation()); ((CudaOpContext) context).setCudaStream(ctx.getOldStream(), ctx.getBufferReduction(), ctx.getBufferAllocation());
val status = nativeOps.execCustomOp2(null, op.opHash(), context.contextPointer()); val status = nativeOps.execCustomOp2(null, op.opHash(), context.contextPointer());
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
if (status != 0) if (status != 0)
throw new RuntimeException("Op [" + op.opName() + "] execution failed"); throw new RuntimeException("Op [" + op.opName() + "] execution failed");
for (val arr:op.outputArguments()) for (val arr:op.outputArguments())
AtomicAllocator.getInstance().registerAction(ctx, arr); AtomicAllocator.getInstance().registerAction(ctx, arr);
@ -2527,6 +2486,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
nativeOps.inspectArray(extras, AtomicAllocator.getInstance().getHostPointer(array), (LongPointer) AtomicAllocator.getInstance().getHostPointer(array.shapeInfoDataBuffer()), AtomicAllocator.getInstance().getPointer(array, ctx), (LongPointer) AtomicAllocator.getInstance().getPointer(array.shapeInfoDataBuffer()), debugInfo); nativeOps.inspectArray(extras, AtomicAllocator.getInstance().getHostPointer(array), (LongPointer) AtomicAllocator.getInstance().getHostPointer(array.shapeInfoDataBuffer()), AtomicAllocator.getInstance().getPointer(array, ctx), (LongPointer) AtomicAllocator.getInstance().getPointer(array.shapeInfoDataBuffer()), debugInfo);
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
return INDArrayStatistics.builder() return INDArrayStatistics.builder()
.minValue(debugInfo._minValue()) .minValue(debugInfo._minValue())
.maxValue(debugInfo._maxValue()) .maxValue(debugInfo._maxValue())
@ -2545,6 +2507,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
public DataBuffer createShapeInfo(long[] shape, long[] stride, long elementWiseStride, char order, DataType dtype, boolean empty) { public DataBuffer createShapeInfo(long[] shape, long[] stride, long elementWiseStride, char order, DataType dtype, boolean empty) {
OpaqueConstantDataBuffer dbf = nativeOps.shapeBuffer(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, empty); OpaqueConstantDataBuffer dbf = nativeOps.shapeBuffer(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, empty);
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
val result = new CudaLongDataBuffer(nativeOps.getConstantDataBufferPrimary(dbf), nativeOps.getConstantDataBufferSpecial(dbf), Shape.shapeInfoLength(shape.length)); val result = new CudaLongDataBuffer(nativeOps.getConstantDataBufferPrimary(dbf), nativeOps.getConstantDataBufferSpecial(dbf), Shape.shapeInfoLength(shape.length));
nativeOps.deleteShapeBuffer(dbf); nativeOps.deleteShapeBuffer(dbf);
@ -2556,6 +2521,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
public TadPack tadShapeInfoAndOffsets(INDArray array, int[] dimension) { public TadPack tadShapeInfoAndOffsets(INDArray array, int[] dimension) {
OpaqueTadPack pack = nativeOps.tadOnlyShapeInfo((LongPointer) array.shapeInfoDataBuffer().addressPointer(), new IntPointer(dimension), dimension.length); OpaqueTadPack pack = nativeOps.tadOnlyShapeInfo((LongPointer) array.shapeInfoDataBuffer().addressPointer(), new IntPointer(dimension), dimension.length);
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
val tadShape = new CudaLongDataBuffer(nativeOps.getPrimaryShapeInfo(pack), nativeOps.getSpecialShapeInfo(pack), nativeOps.getShapeInfoLength(pack)); val tadShape = new CudaLongDataBuffer(nativeOps.getPrimaryShapeInfo(pack), nativeOps.getSpecialShapeInfo(pack), nativeOps.getShapeInfoLength(pack));
val tadOffsets = new CudaLongDataBuffer(nativeOps.getPrimaryOffsets(pack), nativeOps.getSpecialOffsets(pack), nativeOps.getNumberOfTads(pack)); val tadOffsets = new CudaLongDataBuffer(nativeOps.getPrimaryOffsets(pack), nativeOps.getSpecialOffsets(pack), nativeOps.getNumberOfTads(pack));
@ -2568,6 +2536,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
public DataBuffer createConstantBuffer(long[] values, DataType desiredType) { public DataBuffer createConstantBuffer(long[] values, DataType desiredType) {
OpaqueConstantDataBuffer dbf = nativeOps.constantBufferLong(desiredType.toInt(), new LongPointer(values), values.length); OpaqueConstantDataBuffer dbf = nativeOps.constantBufferLong(desiredType.toInt(), new LongPointer(values), values.length);
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
val buffer = Nd4j.createBuffer(nativeOps.getConstantDataBufferPrimary(dbf), nativeOps.getConstantDataBufferSpecial(dbf), values.length, desiredType); val buffer = Nd4j.createBuffer(nativeOps.getConstantDataBufferPrimary(dbf), nativeOps.getConstantDataBufferSpecial(dbf), values.length, desiredType);
buffer.setConstant(true); buffer.setConstant(true);
@ -2578,6 +2549,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
public DataBuffer createConstantBuffer(double[] values, DataType desiredType) { public DataBuffer createConstantBuffer(double[] values, DataType desiredType) {
OpaqueConstantDataBuffer dbf = nativeOps.constantBufferDouble(desiredType.toInt(), new DoublePointer(values), values.length); OpaqueConstantDataBuffer dbf = nativeOps.constantBufferDouble(desiredType.toInt(), new DoublePointer(values), values.length);
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
val buffer = Nd4j.createBuffer(nativeOps.getConstantDataBufferPrimary(dbf), nativeOps.getConstantDataBufferSpecial(dbf), values.length, desiredType); val buffer = Nd4j.createBuffer(nativeOps.getConstantDataBufferPrimary(dbf), nativeOps.getConstantDataBufferSpecial(dbf), values.length, desiredType);
buffer.setConstant(true); buffer.setConstant(true);

View File

@ -449,6 +449,60 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper {
// #endif //DEV_TESTS_TADPACK_H // #endif //DEV_TESTS_TADPACK_H
// Parsed from execution/ErrorReference.h
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
// #ifndef DEV_TESTS_ERRORREFERENCE_H
// #define DEV_TESTS_ERRORREFERENCE_H
// #include <string>
// #include <dll.h>
@Namespace("sd") @NoOffset public static class ErrorReference extends Pointer {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public ErrorReference(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public ErrorReference(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public ErrorReference position(long position) {
return (ErrorReference)super.position(position);
}
public ErrorReference() { super((Pointer)null); allocate(); }
private native void allocate();
public native int errorCode();
public native @Cast("char*") String errorMessage();
public native void setErrorCode(int errorCode);
public native void setErrorMessage(@StdString BytePointer message);
public native void setErrorMessage(@StdString String message);
}
// #endif //DEV_TESTS_ERRORREFERENCE_H
// Parsed from memory/MemoryType.h // Parsed from memory/MemoryType.h
// //
@ -688,6 +742,18 @@ bool verbose = false;
// #include <graph/ResultWrapper.h> // #include <graph/ResultWrapper.h>
// #include <DebugInfo.h> // #include <DebugInfo.h>
/**
* This function returns last error code stored,
* @return non-zero if something bad happened
*/
public native int lastErrorCode();
/**
* This function returns last error message, if last error code > 0
* @return
*/
public native @Cast("char*") String lastErrorMessage();
/** /**
* *
* @param p * @param p
@ -1710,72 +1776,6 @@ public native void execScalarBoolTad(@Cast("Nd4jPointer*") PointerPointer extraP
@Cast("Nd4jLong*") long[] tadShapeInfo, @Cast("Nd4jLong*") long[] tadOffsets, @Cast("Nd4jLong*") long[] tadShapeInfo, @Cast("Nd4jLong*") long[] tadOffsets,
@Cast("Nd4jLong*") long[] tadShapeInfoZ, @Cast("Nd4jLong*") long[] tadOffsetsZ); @Cast("Nd4jLong*") long[] tadShapeInfoZ, @Cast("Nd4jLong*") long[] tadOffsetsZ);
/**
* Append an input array
* to the end of a flat array
* in a particular order
* @param offset the offset of the array to start at
* @param order the order
* @param result the result array
* @param resultShapeInfo the shape info for te array
* @param input the input for the array
* @param inputShapeInfo the shape information for that array
*/
public native void flatten(
@Cast("Nd4jPointer*") PointerPointer extraPointers,
int offset,
char order,
Pointer result, @Cast("Nd4jLong*") LongPointer resultShapeInfo,
Pointer dresult, @Cast("Nd4jLong*") LongPointer dresultShapeInfo,
Pointer input, @Cast("Nd4jLong*") LongPointer inputShapeInfo,
Pointer dinput, @Cast("Nd4jLong*") LongPointer dinputShapeInfo);
public native void flatten(
@Cast("Nd4jPointer*") PointerPointer extraPointers,
int offset,
char order,
Pointer result, @Cast("Nd4jLong*") LongBuffer resultShapeInfo,
Pointer dresult, @Cast("Nd4jLong*") LongBuffer dresultShapeInfo,
Pointer input, @Cast("Nd4jLong*") LongBuffer inputShapeInfo,
Pointer dinput, @Cast("Nd4jLong*") LongBuffer dinputShapeInfo);
public native void flatten(
@Cast("Nd4jPointer*") PointerPointer extraPointers,
int offset,
char order,
Pointer result, @Cast("Nd4jLong*") long[] resultShapeInfo,
Pointer dresult, @Cast("Nd4jLong*") long[] dresultShapeInfo,
Pointer input, @Cast("Nd4jLong*") long[] inputShapeInfo,
Pointer dinput, @Cast("Nd4jLong*") long[] dinputShapeInfo);
public native void concat(
@Cast("Nd4jPointer*") PointerPointer extraPointers,
int dimension,
int numArrays,
@Cast("Nd4jPointer*") PointerPointer data, @Cast("Nd4jPointer*") PointerPointer inputShapeInfo,
@Cast("Nd4jPointer*") PointerPointer ddata, @Cast("Nd4jPointer*") PointerPointer dinputShapeInfo,
Pointer result, @Cast("Nd4jLong*") LongPointer resultShapeInfo,
Pointer dresult, @Cast("Nd4jLong*") LongPointer dresultShapeInfo,
@Cast("Nd4jPointer*") PointerPointer tadPointers, @Cast("Nd4jPointer*") PointerPointer offsetPointers);
public native void concat(
@Cast("Nd4jPointer*") PointerPointer extraPointers,
int dimension,
int numArrays,
@Cast("Nd4jPointer*") PointerPointer data, @Cast("Nd4jPointer*") PointerPointer inputShapeInfo,
@Cast("Nd4jPointer*") PointerPointer ddata, @Cast("Nd4jPointer*") PointerPointer dinputShapeInfo,
Pointer result, @Cast("Nd4jLong*") LongBuffer resultShapeInfo,
Pointer dresult, @Cast("Nd4jLong*") LongBuffer dresultShapeInfo,
@Cast("Nd4jPointer*") PointerPointer tadPointers, @Cast("Nd4jPointer*") PointerPointer offsetPointers);
public native void concat(
@Cast("Nd4jPointer*") PointerPointer extraPointers,
int dimension,
int numArrays,
@Cast("Nd4jPointer*") PointerPointer data, @Cast("Nd4jPointer*") PointerPointer inputShapeInfo,
@Cast("Nd4jPointer*") PointerPointer ddata, @Cast("Nd4jPointer*") PointerPointer dinputShapeInfo,
Pointer result, @Cast("Nd4jLong*") long[] resultShapeInfo,
Pointer dresult, @Cast("Nd4jLong*") long[] dresultShapeInfo,
@Cast("Nd4jPointer*") PointerPointer tadPointers, @Cast("Nd4jPointer*") PointerPointer offsetPointers);
public native void specialConcat( public native void specialConcat(
@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer*") PointerPointer extraPointers,
int dimension, int dimension,
@ -9950,6 +9950,7 @@ public static final int PREALLOC_SIZE = 33554432;
// #include <dll.h> // #include <dll.h>
// #include <pointercast.h> // #include <pointercast.h>
// #include <execution/ErrorReference.h>
@Namespace("nd4j") @NoOffset public static class ContextBuffers extends Pointer { @Namespace("nd4j") @NoOffset public static class ContextBuffers extends Pointer {
static { Loader.load(); } static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
@ -9985,6 +9986,8 @@ public static final int PREALLOC_SIZE = 33554432;
public native void setScalarBuffer(Pointer pointer); public native void setScalarBuffer(Pointer pointer);
public native void setAllocationBuffer(Pointer pointer); public native void setAllocationBuffer(Pointer pointer);
public native ErrorReference errorReference();
public native void triggerOwnership(@Cast("bool") boolean isOwner); public native void triggerOwnership(@Cast("bool") boolean isOwner);
public native int deviceId(); public native int deviceId();
@ -10038,6 +10041,7 @@ public static final int PREALLOC_SIZE = 33554432;
// #include <vector> // #include <vector>
// #include <mutex> // #include <mutex>
// #include <execution/ContextBuffers.h> // #include <execution/ContextBuffers.h>
// #include <execution/ErrorReference.h>
@Namespace("nd4j") @NoOffset public static class LaunchContext extends Pointer { @Namespace("nd4j") @NoOffset public static class LaunchContext extends Pointer {
static { Loader.load(); } static { Loader.load(); }
@ -10067,9 +10071,12 @@ public static final int PREALLOC_SIZE = 33554432;
public native int getDeviceID(); public native int getDeviceID();
public native void setDeviceID(int deviceID); public native void setDeviceID(int deviceID);
public native ErrorReference errorReference();
public static native @Cast("bool") boolean isInitialized(); public static native @Cast("bool") boolean isInitialized();
public static native void releaseBuffers(); public static native void releaseBuffers();
public static native LaunchContext defaultContext(); public static native LaunchContext defaultContext();

View File

@ -32,6 +32,7 @@ import org.bytedeco.javacpp.tools.InfoMapper;
"array/ConstantDescriptor.h", "array/ConstantDescriptor.h",
"array/ConstantDataBuffer.h", "array/ConstantDataBuffer.h",
"array/TadPack.h", "array/TadPack.h",
"execution/ErrorReference.h",
"memory/MemoryType.h", "memory/MemoryType.h",
"Environment.h", "Environment.h",
"types/utf8string.h", "types/utf8string.h",

View File

@ -106,6 +106,9 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory {
functions.put(8, Loader.addressof("LAPACKE_sgesdd")); functions.put(8, Loader.addressof("LAPACKE_sgesdd"));
functions.put(9, Loader.addressof("LAPACKE_dgesdd")); functions.put(9, Loader.addressof("LAPACKE_dgesdd"));
nativeOps.initializeFunctions(functions); nativeOps.initializeFunctions(functions);
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
} }
@Override @Override
@ -489,32 +492,7 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory {
@Override @Override
public INDArray toFlattened(char order, Collection<INDArray> matrices) { public INDArray toFlattened(char order, Collection<INDArray> matrices) {
Preconditions.checkArgument(matrices.size() > 0, "toFlattened expects > 0 operands"); Preconditions.checkArgument(matrices.size() > 0, "toFlattened expects > 0 operands");
/*
int length = 0;
val list = new ArrayList<INDArray>(matrices);
val t = list.get(0).dataType();
for (INDArray m : matrices) {
length += m.length();
Preconditions.checkArgument(m.dataType() == t, "All operands must have same data type");
}
INDArray ret = Nd4j.create(t, new long[] {length}, order);
int linearIndex = 0;
PointerPointer dummy = new PointerPointer(new Pointer[] {null});
for (INDArray m : matrices) {
Nd4j.getCompressor().autoDecompress(m);
nativeOps.flatten(dummy, linearIndex, order,
ret.data().addressPointer(), (LongPointer) ret.shapeInfoDataBuffer().addressPointer(),
null, null,
m.data().addressPointer(),
(LongPointer) m.shapeInfoDataBuffer().addressPointer(),
null, null);
linearIndex += m.length();
}
return ret;
*/
return Nd4j.exec(new Flatten(order, matrices.toArray(new INDArray[matrices.size()])))[0]; return Nd4j.exec(new Flatten(order, matrices.toArray(new INDArray[matrices.size()])))[0];
} }
@ -555,6 +533,9 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory {
new LongPointerWrapper(tadBuffers.getSecond().pointer()) new LongPointerWrapper(tadBuffers.getSecond().pointer())
); );
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
return result; return result;
} }
@ -574,65 +555,6 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory {
return toConcat[0]; return toConcat[0];
return Nd4j.exec(new Concat(dimension, toConcat))[0]; return Nd4j.exec(new Concat(dimension, toConcat))[0];
// legacy implementation
/*
// if reusable var wasn't created for this thread, or is smaller then needed - set it to new value
if (extrazA.get() == null || extrazB.get() == null || extrazSize.get() == null || extrazSize.get() < toConcat.length) {
extrazA.set(new PointerPointer(toConcat.length));
extrazB.set(new PointerPointer(toConcat.length));
extrazSize.set(toConcat.length);
}
PointerPointer shapeInfoPointers = extrazA.get();
PointerPointer dataPointers = extrazB.get();
int sumAlongDim = 0;
long[] outputShape = ArrayUtil.copy(toConcat[0].shape());
boolean allScalars = true;
for (int i = 0; i < toConcat.length; i++) {
Preconditions.checkState(toConcat[i].rank() == outputShape.length, "Encountered different array ranks for concat: input[0].shape()=%ndShape, input[%s].shape()=%ndShape",
toConcat[0], i, toConcat[i]);
if (toConcat[i].isCompressed())
Nd4j.getCompressor().decompressi(toConcat[i]);
Preconditions.checkArgument(toConcat[i].dataType() == toConcat[0].dataType(), "All operands must have same data type: input 0 has type %s, input %s has type %s",
toConcat[0].dataType(), i, toConcat[i].dataType());
allScalars &= toConcat[i].rank() == 0;
shapeInfoPointers.put(i, toConcat[i].shapeInfoDataBuffer().addressPointer());
dataPointers.put(i, toConcat[i].data().addressPointer());
sumAlongDim += toConcat[i].size(dimension);
for (int j = 0; j < toConcat[i].rank(); j++) {
if (j != dimension && toConcat[i].size(j) != outputShape[j]) {
throw new IllegalArgumentException(
"Illegal concatenation at array " + i + " and shape element " + j);
}
}
}
if (allScalars) {
outputShape = new long[]{sumAlongDim};
} else {
outputShape[dimension] = sumAlongDim;
}
INDArray ret = Nd4j.createUninitialized(toConcat[0].dataType(), outputShape, Nd4j.order());
nativeOps.concat(null, dimension, toConcat.length,
dataPointers, shapeInfoPointers,
null, null,
ret.data().addressPointer(), (LongPointer) ret.shapeInfoDataBuffer().addressPointer(),
null, null,
null, null);
return ret;
*/
} }
@ -757,6 +679,8 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory {
(LongPointer) zTadShapeInfo, (LongPointer) zTadShapeInfo,
new LongPointerWrapper(zTadOffsets)); new LongPointerWrapper(zTadOffsets));
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
return ret; return ret;
} }
@ -794,6 +718,9 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory {
arrays.length, arrays.length,
len); len);
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
return target; return target;
} }
@ -846,6 +773,9 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory {
len, len,
true); true);
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
return target; return target;
} }
@ -983,6 +913,8 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory {
arrays.size(), arrays.size(),
ptrMap, tadPointers, offsetPointers); ptrMap, tadPointers, offsetPointers);
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
dataPointers.address(); dataPointers.address();
shapePointers.address(); shapePointers.address();
@ -990,84 +922,6 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory {
offsetPointers.address(); offsetPointers.address();
} }
/**
* This method converts Half-precision databuffer to current dType buffer.
*
* @param buffer
* @return
*/
/*
@Override
public DataBuffer restoreFromHalfs(DataBuffer buffer) {
if (buffer.dataType() != DataType.COMPRESSED)
throw new IllegalStateException("DataBuffer contains wrong data: " + buffer.dataType());
CompressedDataBuffer comp = (CompressedDataBuffer) buffer;
CompressionDescriptor descriptor = comp.getCompressionDescriptor();
DataBuffer targetBuffer = Nd4j.createBuffer(descriptor.getCompressedLength() / 2);
if (Nd4j.dataType() == DataType.DOUBLE) {
nativeOps.convertHalfsToDoubles(
null,
comp.addressPointer(),
(int) descriptor.getCompressedLength() / 2,
targetBuffer.addressPointer()
);
} else if (Nd4j.dataType() == DataType.FLOAT) {
nativeOps.convertHalfsToFloats(
null,
comp.addressPointer(),
(int) descriptor.getCompressedLength() / 2,
targetBuffer.addressPointer()
);
} else {
throw new UnsupportedOperationException("Target dtype isn't supported: " + Nd4j.dataType());
}
return targetBuffer;
}
*/
/**
* This method converts Single/Double precision databuffer to Half-precision databuffer
*
* @param buffer
* @return
*/
/*@Override
public DataBuffer convertToHalfs(DataBuffer buffer) {
// we allocate pointer
ShortPointer pointer = new ShortPointer(buffer.length());
if (buffer.dataType() == DataType.DOUBLE) {
nativeOps.convertDoublesToHalfs(
null,
buffer.addressPointer(),
(int) buffer.length(),
pointer
);
} else if (buffer.dataType() == DataType.FLOAT) {
nativeOps.convertFloatsToHalfs(
null,
buffer.addressPointer(),
(int) buffer.length(),
pointer
);
} else {
throw new UnsupportedOperationException("Source dtype isn't supported: " + buffer.dataType());
}
CompressionDescriptor descriptor = new CompressionDescriptor(buffer, new Float16());
descriptor.setCompressedLength(buffer.length() * 2);
CompressedDataBuffer result = new CompressedDataBuffer(pointer, descriptor);
return result;
}
*/
/** /**
* This method converts Single/Double precision databuffer to Half-precision databuffer * This method converts Single/Double precision databuffer to Half-precision databuffer
* *
@ -1081,6 +935,9 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory {
throw new UnsupportedOperationException("Impossible to compress View. Consider using dup() before. "); throw new UnsupportedOperationException("Impossible to compress View. Consider using dup() before. ");
DataBuffer buffer = convertDataEx(typeSrc, source.data(), typeDst); DataBuffer buffer = convertDataEx(typeSrc, source.data(), typeDst);
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
source.setData(buffer); source.setData(buffer);
if (buffer instanceof CompressedDataBuffer) if (buffer instanceof CompressedDataBuffer)
@ -1125,6 +982,9 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory {
convertDataEx(typeSrc, source, typeDst, buffer); convertDataEx(typeSrc, source, typeDst, buffer);
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
return buffer; return buffer;
} }
@ -1132,6 +992,9 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory {
public void convertDataEx(DataTypeEx typeSrc, Pointer source, DataTypeEx typeDst, Pointer target, public void convertDataEx(DataTypeEx typeSrc, Pointer source, DataTypeEx typeDst, Pointer target,
long length) { long length) {
nativeOps.convertTypes(null, typeSrc.ordinal(), source, length, typeDst.ordinal(), target); nativeOps.convertTypes(null, typeSrc.ordinal(), source, length, typeDst.ordinal(), target);
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
} }
@Override @Override

View File

@ -234,6 +234,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
null); null);
} }
if (loop.lastErrorCode() != 0)
throw new RuntimeException(loop.lastErrorMessage());
profilingConfigurableHookOut(op, st); profilingConfigurableHookOut(op, st);
return op.z(); return op.z();
} }
@ -563,6 +566,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
} }
} }
if (loop.lastErrorCode() != 0)
throw new RuntimeException(loop.lastErrorMessage());
return ret; return ret;
} }
@ -644,6 +650,8 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
if (loop.lastErrorCode() != 0)
throw new RuntimeException(loop.lastErrorMessage());
} }
public INDArray exec(ScalarOp op) { public INDArray exec(ScalarOp op) {
@ -690,6 +698,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
throw new ND4JIllegalStateException("Unknown op type: [" + op.getOpType() +"]"); throw new ND4JIllegalStateException("Unknown op type: [" + op.getOpType() +"]");
} }
if (loop.lastErrorCode() != 0)
throw new RuntimeException(loop.lastErrorMessage());
profilingConfigurableHookOut(op, st); profilingConfigurableHookOut(op, st);
return op.z(); return op.z();
@ -886,6 +897,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
} }
if (loop.lastErrorCode() != 0)
throw new RuntimeException(loop.lastErrorMessage());
profilingConfigurableHookOut(op, st); profilingConfigurableHookOut(op, st);
} }
@ -962,6 +976,8 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
throw new UnsupportedOperationException("Unknown operation type: [" + op.getOpType() + "]"); throw new UnsupportedOperationException("Unknown operation type: [" + op.getOpType() + "]");
} }
if (loop.lastErrorCode() != 0)
throw new RuntimeException(loop.lastErrorMessage());
return op.z(); return op.z();
} }
@ -1091,6 +1107,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
batch.getSample().maxIntArrays(), batch.getSample().maxIntArraySize(), batch.getSample().maxIntArrays(), batch.getSample().maxIntArraySize(),
batch.getSample().maxIndexArguments(), batch.getSample().maxRealArguments(), pointer, FlatBuffersMapper.getDataTypeAsByte(dataType)); batch.getSample().maxIndexArguments(), batch.getSample().maxRealArguments(), pointer, FlatBuffersMapper.getDataTypeAsByte(dataType));
if (loop.lastErrorCode() != 0)
throw new RuntimeException(loop.lastErrorMessage());
} }
/** /**
@ -1197,6 +1216,8 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
numIndexArguments, intArrays, numIntArrays, block.getRealArgumentsPointer(), numIndexArguments, intArrays, numIntArrays, block.getRealArgumentsPointer(),
numRealArguments, FlatBuffersMapper.getDataTypeAsByte(dataType)); numRealArguments, FlatBuffersMapper.getDataTypeAsByte(dataType));
if (loop.lastErrorCode() != 0)
throw new RuntimeException(loop.lastErrorMessage());
} }
/** /**
@ -1284,6 +1305,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
op.extraArgsDataBuff(op.z().dataType()).addressPointer()); op.extraArgsDataBuff(op.z().dataType()).addressPointer());
} }
if (loop.lastErrorCode() != 0)
throw new RuntimeException(loop.lastErrorMessage());
profilingConfigurableHookOut(op, st); profilingConfigurableHookOut(op, st);
return op.z(); return op.z();
@ -1370,6 +1394,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
(float) threshold); (float) threshold);
//long t2 = System.currentTimeMillis(); //long t2 = System.currentTimeMillis();
if (loop.lastErrorCode() != 0)
throw new RuntimeException(loop.lastErrorMessage());
if (cntAbs < 2) if (cntAbs < 2)
return null; return null;
@ -1429,6 +1456,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
loop.convertTypes(null, DataTypeEx.THRESHOLD.ordinal(), buffer.addressPointer(), target.length(), typeDst.ordinal(), target.data().addressPointer()); loop.convertTypes(null, DataTypeEx.THRESHOLD.ordinal(), buffer.addressPointer(), target.length(), typeDst.ordinal(), target.data().addressPointer());
if (loop.lastErrorCode() != 0)
throw new RuntimeException(loop.lastErrorMessage());
return target; return target;
} }
@ -1460,6 +1490,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
(IntPointer) buffer.addressPointer(), (IntPointer) buffer.addressPointer(),
(float) threshold); (float) threshold);
if (loop.lastErrorCode() != 0)
throw new RuntimeException(loop.lastErrorMessage());
return affected; return affected;
} }
@ -1473,6 +1506,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
(LongPointer) target.shapeInfoDataBuffer().addressPointer() (LongPointer) target.shapeInfoDataBuffer().addressPointer()
); );
if (loop.lastErrorCode() != 0)
throw new RuntimeException(loop.lastErrorMessage());
return target; return target;
} }
@ -1673,136 +1709,6 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
} catch (Exception e) { } catch (Exception e) {
throw new RuntimeException("Op [" + name + "] execution failed", e); throw new RuntimeException("Op [" + name + "] execution failed", e);
} }
/*
val name = op.opName().toLowerCase();
val hash = op.opHash();
if (name.equals("noop")) {
return op.outputArguments();
}
val inputShapes = getInputShapes(op.numInputArguments());
val inputBuffers = getInputBuffers(op.numInputArguments());
int cnt= 0;
val inputArgs = op.inputArguments();
for (val in: inputArgs) {
if(in == null)
throw new NullPointerException("Input argument is null for op " + op.getClass().getName());
if (!in.isEmpty())
inputBuffers.put(cnt, in.data().addressPointer());
inputShapes.put(cnt++, in.shapeInfoDataBuffer().addressPointer());
}
val outputArgs = op.outputArguments();
for(int i = 0; i < outputArgs.length; i++) {
if(outputArgs[i] == null)
throw new ND4JIllegalStateException("Op output arguments must not be null! Op " + op.getClass().getName());
}
val outputShapes = getOutputShapes(op.numOutputArguments());
val outputBuffers = getOutputBuffers(op.numOutputArguments());
cnt= 0;
for (val out: outputArgs) {
if(out.isEmpty()){
outputBuffers.put(cnt, null);
} else {
outputBuffers.put(cnt, out.data().addressPointer());
}
outputShapes.put(cnt++, out.shapeInfoDataBuffer().addressPointer());
}
val iArgs = op.numIArguments() > 0 ? getLongPointerFrom(iArgsPointer,op.numIArguments()) : null;
val tArgs = op.numTArguments() > 0 ? getDoublePointerFrom(tArgsPointer,op.numTArguments()) : null;
val bArgs = op.numBArguments() > 0 ? getBooleanPointerFrom(bArgsPointer,op.numBArguments()) : null;
cnt = 0;
val iArgs1 = op.iArgs();
for (val i: iArgs1)
iArgs.put(cnt++, i);
cnt = 0;
val bArgs1 = op.bArgs();
for (val b: bArgs1)
bArgs.put(cnt++, b);
cnt = 0;
val tArgs1 = op.tArgs();
for (val t: tArgs1)
tArgs.put(cnt++, t);
val t = op.numInputArguments();
OpStatus status = OpStatus.ND4J_STATUS_OK;
try {
val code = loop.execCustomOp(
null,
hash,
inputBuffers,
inputShapes,
op.numInputArguments(),
outputBuffers,
outputShapes,
op.numOutputArguments(),
tArgs, op.numTArguments(),
iArgs, op.numIArguments(),
bArgs, op.numBArguments(),
op.isInplaceCall());
status = OpStatus.byNumber(code);
if (status != OpStatus.ND4J_STATUS_OK)
throw new ND4JIllegalStateException("Failed to execute op [" + name + "] with error code [" + status +"]");
}catch(Exception e) {
val sb = new StringBuilder();
sb.append("Inputs: [(");
for( int i=0; i<inputArgs.length; i++ ){
if(i > 0)
sb.append("), (");
sb.append(Shape.shapeToStringShort(inputArgs[i]));
}
sb.append(")]. Outputs: [(");
for( int i=0; i<outputArgs.length; i++){
if(i > 0)
sb.append("), (");
sb.append(Shape.shapeToStringShort(outputArgs[i]));
}
sb.append(")]. tArgs: ");
if(op.numTArguments() > 0){
sb.append(Arrays.toString(op.tArgs()));
} else {
sb.append("-");
}
sb.append(". iArgs: ");
if(op.numIArguments() > 0){
sb.append(Arrays.toString(op.iArgs()));
} else {
sb.append("-");
}
if(op instanceof DifferentialFunction){
String n = ((DifferentialFunction) op).getOwnName();
if(n != null && !n.equals(op.opName())){
sb.append(". Op own name: \"").append(n).append("\"");
}
}
log.error("Failed to execute op " + op.opName() + ". Attempted to execute with " +
String.valueOf(op.numInputArguments()) + " inputs, " +
String.valueOf(op.numOutputArguments()) + " outputs, "+
String.valueOf(op.numTArguments()) + " targs and " +
String.valueOf(op.numIArguments()) + " iargs. " +
sb.toString() +
" - Please see above message (printed out from c++) for a possible cause of error.");
throw e;
}
profilingConfigurableHookOut(op, st);
return op.outputArguments();
*/
} }
protected LongShapeDescriptor getShapeFromPointer(LongPointer ptr) { protected LongShapeDescriptor getShapeFromPointer(LongPointer ptr) {
@ -1870,6 +1776,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
ptrptr = loop.calculateOutputShapes2(null, ptrptr = loop.calculateOutputShapes2(null,
hash, inputBuffers, inputShapes, op.numInputArguments(), tArgs, hash, inputBuffers, inputShapes, op.numInputArguments(), tArgs,
op.numTArguments(), iArgs, op.numIArguments(), bArgs, op.numBArguments()); op.numTArguments(), iArgs, op.numIArguments(), bArgs, op.numBArguments());
if (loop.lastErrorCode() != 0)
throw new RuntimeException(loop.lastErrorMessage());
} catch (Throwable t){ } catch (Throwable t){
StringBuilder sb = new StringBuilder(); StringBuilder sb = new StringBuilder();
sb.append("Inputs: [("); sb.append("Inputs: [(");
@ -1893,6 +1802,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
throw t; throw t;
} }
if (loop.lastErrorCode() != 0)
throw new RuntimeException(loop.lastErrorMessage());
if (ptrptr == null) if (ptrptr == null)
throw new RuntimeException(); throw new RuntimeException();
@ -1929,6 +1841,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
@Override @Override
public void registerGraph(long id, Pointer graph) { public void registerGraph(long id, Pointer graph) {
loop.registerGraph(null, id, graph); loop.registerGraph(null, id, graph);
if (loop.lastErrorCode() != 0)
throw new RuntimeException(loop.lastErrorMessage());
} }
@Override @Override
@ -1952,7 +1867,10 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
val newMap = new LinkedHashMap<String, INDArray>(); val newMap = new LinkedHashMap<String, INDArray>();
OpaqueVariablesSet result = loop.executeStoredGraph(null, id, ptrBuffers, ptrShapes, ptrIndices, map.size()); OpaqueVariablesSet result = loop.executeStoredGraph(null, id, ptrBuffers, ptrShapes, ptrIndices, map.size());
if (loop.lastErrorCode() != 0)
throw new RuntimeException(loop.lastErrorMessage());
OpStatus status = OpStatus.byNumber(loop.getVariablesSetStatus(result)); OpStatus status = OpStatus.byNumber(loop.getVariablesSetStatus(result));
@ -1996,6 +1914,8 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
@Override @Override
public void forgetGraph(long id) { public void forgetGraph(long id) {
loop.unregisterGraph(null, id); loop.unregisterGraph(null, id);
if (loop.lastErrorCode() != 0)
throw new RuntimeException(loop.lastErrorMessage());
} }
/** /**
@ -2055,6 +1975,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
array.data().addressPointer(), (LongPointer) tadX.getFirst().addressPointer(), (LongPointer) tadX.getSecond().addressPointer(), null, null, null, array.data().addressPointer(), (LongPointer) tadX.getFirst().addressPointer(), (LongPointer) tadX.getSecond().addressPointer(), null, null, null,
updates.data().addressPointer(), (LongPointer) tadY.getFirst().addressPointer(), (LongPointer) tadY.getSecond().addressPointer(), null, null, null, updates.data().addressPointer(), (LongPointer) tadY.getFirst().addressPointer(), (LongPointer) tadY.getSecond().addressPointer(), null, null, null,
(IntPointer) indices.data().addressPointer(), null); (IntPointer) indices.data().addressPointer(), null);
if (loop.lastErrorCode() != 0)
throw new RuntimeException(loop.lastErrorMessage());
} }
@Override @Override
@ -2078,6 +2001,10 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
val status = loop.execCustomOp2(null, op.opHash(), context.contextPointer()); val status = loop.execCustomOp2(null, op.opHash(), context.contextPointer());
if (loop.lastErrorCode() != 0)
throw new RuntimeException(loop.lastErrorMessage());
if (status != 0) if (status != 0)
throw new RuntimeException("Op [" + op.opName() + "] execution failed"); throw new RuntimeException("Op [" + op.opName() + "] execution failed");
@ -2155,6 +2082,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
loop.inspectArray(null, array.data().addressPointer(), (LongPointer) array.shapeInfoDataBuffer().addressPointer(), null, null, debugInfo); loop.inspectArray(null, array.data().addressPointer(), (LongPointer) array.shapeInfoDataBuffer().addressPointer(), null, null, debugInfo);
if (loop.lastErrorCode() != 0)
throw new RuntimeException(loop.lastErrorMessage());
return INDArrayStatistics.builder() return INDArrayStatistics.builder()
.minValue(debugInfo._minValue()) .minValue(debugInfo._minValue())
.maxValue(debugInfo._maxValue()) .maxValue(debugInfo._maxValue())
@ -2171,6 +2101,8 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
@Override @Override
public DataBuffer createShapeInfo(long[] shape, long[] stride, long elementWiseStride, char order, DataType dtype, boolean empty) { public DataBuffer createShapeInfo(long[] shape, long[] stride, long elementWiseStride, char order, DataType dtype, boolean empty) {
OpaqueConstantDataBuffer dbf = loop.shapeBuffer(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, empty); OpaqueConstantDataBuffer dbf = loop.shapeBuffer(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, empty);
if (loop.lastErrorCode() != 0)
throw new RuntimeException(loop.lastErrorMessage());
val result = new LongBuffer(loop.getConstantDataBufferPrimary(dbf), Shape.shapeInfoLength(shape.length)); val result = new LongBuffer(loop.getConstantDataBufferPrimary(dbf), Shape.shapeInfoLength(shape.length));
@ -2183,6 +2115,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
public TadPack tadShapeInfoAndOffsets(INDArray array, int[] dimension) { public TadPack tadShapeInfoAndOffsets(INDArray array, int[] dimension) {
OpaqueTadPack pack = loop.tadOnlyShapeInfo((LongPointer) array.shapeInfoDataBuffer().addressPointer(), new IntPointer(dimension), dimension.length); OpaqueTadPack pack = loop.tadOnlyShapeInfo((LongPointer) array.shapeInfoDataBuffer().addressPointer(), new IntPointer(dimension), dimension.length);
if (loop.lastErrorCode() != 0)
throw new RuntimeException(loop.lastErrorMessage());
val tadShape = new LongBuffer(loop.getPrimaryShapeInfo(pack), loop.getShapeInfoLength(pack)); val tadShape = new LongBuffer(loop.getPrimaryShapeInfo(pack), loop.getShapeInfoLength(pack));
val tadOffsets = new LongBuffer(loop.getPrimaryOffsets(pack), loop.getNumberOfTads(pack)); val tadOffsets = new LongBuffer(loop.getPrimaryOffsets(pack), loop.getNumberOfTads(pack));
@ -2205,11 +2140,19 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
@Override @Override
public String runLightBenchmarkSuit(boolean printOut) { public String runLightBenchmarkSuit(boolean printOut) {
return loop.runLightBenchmarkSuit(printOut); val s = loop.runLightBenchmarkSuit(printOut);
if (loop.lastErrorCode() != 0)
throw new RuntimeException(loop.lastErrorMessage());
return s;
} }
@Override @Override
public String runFullBenchmarkSuit(boolean printOut) { public String runFullBenchmarkSuit(boolean printOut) {
return loop.runFullBenchmarkSuit(printOut); val s = loop.runFullBenchmarkSuit(printOut);
if (loop.lastErrorCode() != 0)
throw new RuntimeException(loop.lastErrorMessage());
return s;
} }
} }

View File

@ -467,6 +467,60 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper {
// #endif //DEV_TESTS_TADPACK_H // #endif //DEV_TESTS_TADPACK_H
// Parsed from execution/ErrorReference.h
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
// #ifndef DEV_TESTS_ERRORREFERENCE_H
// #define DEV_TESTS_ERRORREFERENCE_H
// #include <string>
// #include <dll.h>
@Namespace("sd") @NoOffset public static class ErrorReference extends Pointer {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public ErrorReference(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public ErrorReference(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public ErrorReference position(long position) {
return (ErrorReference)super.position(position);
}
public ErrorReference() { super((Pointer)null); allocate(); }
private native void allocate();
public native int errorCode();
public native @Cast("char*") String errorMessage();
public native void setErrorCode(int errorCode);
public native void setErrorMessage(@StdString BytePointer message);
public native void setErrorMessage(@StdString String message);
}
// #endif //DEV_TESTS_ERRORREFERENCE_H
// Parsed from Environment.h // Parsed from Environment.h
/******************************************************************************* /*******************************************************************************
@ -688,6 +742,18 @@ bool verbose = false;
// #include <graph/ResultWrapper.h> // #include <graph/ResultWrapper.h>
// #include <DebugInfo.h> // #include <DebugInfo.h>
/**
* This function returns last error code stored,
* @return non-zero if something bad happened
*/
public native int lastErrorCode();
/**
* This function returns last error message, if last error code > 0
* @return
*/
public native @Cast("char*") String lastErrorMessage();
/** /**
* *
* @param p * @param p
@ -1710,72 +1776,6 @@ public native void execScalarBoolTad(@Cast("Nd4jPointer*") PointerPointer extraP
@Cast("Nd4jLong*") long[] tadShapeInfo, @Cast("Nd4jLong*") long[] tadOffsets, @Cast("Nd4jLong*") long[] tadShapeInfo, @Cast("Nd4jLong*") long[] tadOffsets,
@Cast("Nd4jLong*") long[] tadShapeInfoZ, @Cast("Nd4jLong*") long[] tadOffsetsZ); @Cast("Nd4jLong*") long[] tadShapeInfoZ, @Cast("Nd4jLong*") long[] tadOffsetsZ);
/**
* Append an input array
* to the end of a flat array
* in a particular order
* @param offset the offset of the array to start at
* @param order the order
* @param result the result array
* @param resultShapeInfo the shape info for te array
* @param input the input for the array
* @param inputShapeInfo the shape information for that array
*/
public native void flatten(
@Cast("Nd4jPointer*") PointerPointer extraPointers,
int offset,
char order,
Pointer result, @Cast("Nd4jLong*") LongPointer resultShapeInfo,
Pointer dresult, @Cast("Nd4jLong*") LongPointer dresultShapeInfo,
Pointer input, @Cast("Nd4jLong*") LongPointer inputShapeInfo,
Pointer dinput, @Cast("Nd4jLong*") LongPointer dinputShapeInfo);
public native void flatten(
@Cast("Nd4jPointer*") PointerPointer extraPointers,
int offset,
char order,
Pointer result, @Cast("Nd4jLong*") LongBuffer resultShapeInfo,
Pointer dresult, @Cast("Nd4jLong*") LongBuffer dresultShapeInfo,
Pointer input, @Cast("Nd4jLong*") LongBuffer inputShapeInfo,
Pointer dinput, @Cast("Nd4jLong*") LongBuffer dinputShapeInfo);
public native void flatten(
@Cast("Nd4jPointer*") PointerPointer extraPointers,
int offset,
char order,
Pointer result, @Cast("Nd4jLong*") long[] resultShapeInfo,
Pointer dresult, @Cast("Nd4jLong*") long[] dresultShapeInfo,
Pointer input, @Cast("Nd4jLong*") long[] inputShapeInfo,
Pointer dinput, @Cast("Nd4jLong*") long[] dinputShapeInfo);
public native void concat(
@Cast("Nd4jPointer*") PointerPointer extraPointers,
int dimension,
int numArrays,
@Cast("Nd4jPointer*") PointerPointer data, @Cast("Nd4jPointer*") PointerPointer inputShapeInfo,
@Cast("Nd4jPointer*") PointerPointer ddata, @Cast("Nd4jPointer*") PointerPointer dinputShapeInfo,
Pointer result, @Cast("Nd4jLong*") LongPointer resultShapeInfo,
Pointer dresult, @Cast("Nd4jLong*") LongPointer dresultShapeInfo,
@Cast("Nd4jPointer*") PointerPointer tadPointers, @Cast("Nd4jPointer*") PointerPointer offsetPointers);
public native void concat(
@Cast("Nd4jPointer*") PointerPointer extraPointers,
int dimension,
int numArrays,
@Cast("Nd4jPointer*") PointerPointer data, @Cast("Nd4jPointer*") PointerPointer inputShapeInfo,
@Cast("Nd4jPointer*") PointerPointer ddata, @Cast("Nd4jPointer*") PointerPointer dinputShapeInfo,
Pointer result, @Cast("Nd4jLong*") LongBuffer resultShapeInfo,
Pointer dresult, @Cast("Nd4jLong*") LongBuffer dresultShapeInfo,
@Cast("Nd4jPointer*") PointerPointer tadPointers, @Cast("Nd4jPointer*") PointerPointer offsetPointers);
public native void concat(
@Cast("Nd4jPointer*") PointerPointer extraPointers,
int dimension,
int numArrays,
@Cast("Nd4jPointer*") PointerPointer data, @Cast("Nd4jPointer*") PointerPointer inputShapeInfo,
@Cast("Nd4jPointer*") PointerPointer ddata, @Cast("Nd4jPointer*") PointerPointer dinputShapeInfo,
Pointer result, @Cast("Nd4jLong*") long[] resultShapeInfo,
Pointer dresult, @Cast("Nd4jLong*") long[] dresultShapeInfo,
@Cast("Nd4jPointer*") PointerPointer tadPointers, @Cast("Nd4jPointer*") PointerPointer offsetPointers);
public native void specialConcat( public native void specialConcat(
@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer*") PointerPointer extraPointers,
int dimension, int dimension,
@ -22877,6 +22877,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
// #include <dll.h> // #include <dll.h>
// #include <pointercast.h> // #include <pointercast.h>
// #include <execution/ErrorReference.h>
@Namespace("nd4j") @NoOffset public static class ContextBuffers extends Pointer { @Namespace("nd4j") @NoOffset public static class ContextBuffers extends Pointer {
static { Loader.load(); } static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
@ -22912,6 +22913,8 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
public native void setScalarBuffer(Pointer pointer); public native void setScalarBuffer(Pointer pointer);
public native void setAllocationBuffer(Pointer pointer); public native void setAllocationBuffer(Pointer pointer);
public native ErrorReference errorReference();
public native void triggerOwnership(@Cast("bool") boolean isOwner); public native void triggerOwnership(@Cast("bool") boolean isOwner);
public native int deviceId(); public native int deviceId();
@ -22961,6 +22964,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
// #include <vector> // #include <vector>
// #include <mutex> // #include <mutex>
// #include <execution/ContextBuffers.h> // #include <execution/ContextBuffers.h>
// #include <execution/ErrorReference.h>
@Namespace("nd4j") @NoOffset public static class LaunchContext extends Pointer { @Namespace("nd4j") @NoOffset public static class LaunchContext extends Pointer {
static { Loader.load(); } static { Loader.load(); }
@ -22985,9 +22989,12 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
public native int getDeviceID(); public native int getDeviceID();
public native void setDeviceID(int deviceID); public native void setDeviceID(int deviceID);
public native ErrorReference errorReference();
public static native @Cast("bool") boolean isInitialized(); public static native @Cast("bool") boolean isInitialized();
public static native void releaseBuffers(); public static native void releaseBuffers();
public static native LaunchContext defaultContext(); public static native LaunchContext defaultContext();

View File

@ -38,6 +38,7 @@ import java.util.Scanner;
"array/ConstantDataBuffer.h", "array/ConstantDataBuffer.h",
"array/ConstantDescriptor.h", "array/ConstantDescriptor.h",
"array/TadPack.h", "array/TadPack.h",
"execution/ErrorReference.h",
"Environment.h", "Environment.h",
"types/utf8string.h", "types/utf8string.h",
"NativeOps.h", "NativeOps.h",

View File

@ -5216,6 +5216,8 @@ public class Nd4jTestsC extends BaseNd4jTest {
INDArray array = Nd4j.create(new double[] {10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); INDArray array = Nd4j.create(new double[] {10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
INDArray exp = Nd4j.create(new double[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); INDArray exp = Nd4j.create(new double[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
log.info("Array shapeInfo: {}", array.shapeInfoJava());
INDArray rev = Nd4j.reverse(array); INDArray rev = Nd4j.reverse(array);
assertEquals(exp, rev); assertEquals(exp, rev);
@ -5226,7 +5228,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
INDArray array = Nd4j.create(new double[] {9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); INDArray array = Nd4j.create(new double[] {9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
INDArray exp = Nd4j.create(new double[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); INDArray exp = Nd4j.create(new double[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
INDArray rev = Nd4j.getExecutioner().exec(new Reverse(array, Nd4j.createUninitialized(array.length())))[0]; INDArray rev = Nd4j.getExecutioner().exec(new Reverse(array, array.ulike()))[0];
assertEquals(exp, rev); assertEquals(exp, rev);
} }
@ -5236,7 +5238,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
INDArray array = Nd4j.create(new double[] {10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}); INDArray array = Nd4j.create(new double[] {10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
INDArray exp = Nd4j.create(new double[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); INDArray exp = Nd4j.create(new double[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
INDArray rev = Nd4j.getExecutioner().exec(new Reverse(array, Nd4j.createUninitialized(array.length())))[0]; INDArray rev = Nd4j.getExecutioner().exec(new Reverse(array,array.ulike()))[0];
assertEquals(exp, rev); assertEquals(exp, rev);
} }
@ -5335,11 +5337,103 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertNotNull(lsd); //Fails here on CUDA, OK on native/cpu assertNotNull(lsd); //Fails here on CUDA, OK on native/cpu
} }
@Test
public void testReverseSmall_1() {
val array = Nd4j.linspace(1, 10, 10, DataType.INT);
val exp = array.dup(array.ordering());
Transforms.reverse(array, false);
Transforms.reverse(array, false);
val jexp = exp.data().asInt();
val jarr = array.data().asInt();
assertArrayEquals(jexp, jarr);
assertEquals(exp, array);
}
@Test
public void testReverseSmall_2() {
val array = Nd4j.linspace(1, 10, 10, DataType.INT);
val exp = array.dup(array.ordering());
val reversed = Transforms.reverse(array, true);
val rereversed = Transforms.reverse(reversed, true);
val jexp = exp.data().asInt();
val jarr = rereversed.data().asInt();
assertArrayEquals(jexp, jarr);
assertEquals(exp, rereversed);
}
@Test
public void testReverseSmall_3() {
val array = Nd4j.linspace(1, 11, 11, DataType.INT);
val exp = array.dup(array.ordering());
Transforms.reverse(array, false);
log.info("Reversed shapeInfo: {}", array.shapeInfoJava());
log.info("Reversed: {}", array);
Transforms.reverse(array, false);
val jexp = exp.data().asInt();
val jarr = array.data().asInt();
assertArrayEquals(jexp, jarr);
assertEquals(exp, array);
}
@Test
public void testReverseSmall_4() {
val array = Nd4j.linspace(1, 11, 11, DataType.INT);
val exp = array.dup(array.ordering());
val reversed = Transforms.reverse(array, true);
log.info("Reversed: {}", reversed);
val rereversed = Transforms.reverse(reversed, true);
val jexp = exp.data().asInt();
val jarr = rereversed.data().asInt();
assertArrayEquals(jexp, jarr);
assertEquals(exp, rereversed);
}
@Test
public void testReverse_1() {
val array = Nd4j.linspace(1, 2017152, 2017152, DataType.INT);
val exp = array.dup(array.ordering());
Transforms.reverse(array, false);
Transforms.reverse(array, false);
val jexp = exp.data().asInt();
val jarr = array.data().asInt();
assertArrayEquals(jexp, jarr);
assertEquals(exp, array);
}
@Test
public void testReverse_2() {
val array = Nd4j.linspace(1, 2017152, 2017152, DataType.INT);
val exp = array.dup(array.ordering());
val reversed = Transforms.reverse(array, true);
val rereversed = Transforms.reverse(reversed, true);
val jexp = exp.data().asInt();
val jarr = rereversed.data().asInt();
assertArrayEquals(jexp, jarr);
assertEquals(exp, rereversed);
}
@Test @Test
public void testNativeSort3_1() { public void testNativeSort3_1() {
INDArray array = Nd4j.linspace(1, 2017152, 2017152, DataType.DOUBLE).reshape(1, -1); INDArray array = Nd4j.linspace(1, 2017152, 2017152, DataType.DOUBLE).reshape(1, -1);
INDArray exp = array.dup(); INDArray exp = array.dup();
Transforms.reverse(array, false); Transforms.reverse(array, false);
log.info("Reverse: {}", array);
long time1 = System.currentTimeMillis(); long time1 = System.currentTimeMillis();