[WIP] Error handling (#169)
* CUDA reverse rewrite + couple of tests Signed-off-by: raver119 <raver119@gmail.com> * don't throw exception on invalid pointer Signed-off-by: raver119 <raver119@gmail.com> * data types validation for fastpath exec mode + 2 tests Signed-off-by: raver119 <raver119@gmail.com> * data types validation for fastpath exec mode + 2 tests Signed-off-by: raver119 <raver119@gmail.com> * ismax allowed dtypes tweak Signed-off-by: raver119 <raver119@gmail.com> * lastErrorCode + lastErrorMessage for native exceptions handling Signed-off-by: raver119 <raver119@gmail.com> * exportable ErrorReference Signed-off-by: raver119 <raver119@gmail.com> * check error codes in java Signed-off-by: raver119 <raver119@gmail.com> * - consume lastErrorCode - fast_in dtype validation fix Signed-off-by: raver119 <raver119@gmail.com> * - sg/cb allowed output type change - minor logging fix for data type validation Signed-off-by: raver119 <raver119@gmail.com>master
parent
bb5fc36e5e
commit
25e5c23eae
|
@ -79,6 +79,18 @@ bool verbose = false;
|
|||
|
||||
extern "C" {
|
||||
|
||||
/**
|
||||
* This function returns last error code stored,
|
||||
* @return non-zero if something bad happened
|
||||
*/
|
||||
ND4J_EXPORT int lastErrorCode();
|
||||
|
||||
/**
|
||||
* This function returns last error message, if last error code > 0
|
||||
* @return
|
||||
*/
|
||||
ND4J_EXPORT const char* lastErrorMessage();
|
||||
|
||||
/**
|
||||
*
|
||||
* @param p
|
||||
|
@ -557,38 +569,6 @@ ND4J_EXPORT void execScalarBoolTad(Nd4jPointer *extraPointers,
|
|||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
||||
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ);
|
||||
|
||||
|
||||
/**
|
||||
* Append an input array
|
||||
* to the end of a flat array
|
||||
* in a particular order
|
||||
* @param offset the offset of the array to start at
|
||||
* @param order the order
|
||||
* @param result the result array
|
||||
* @param resultShapeInfo the shape info for te array
|
||||
* @param input the input for the array
|
||||
* @param inputShapeInfo the shape information for that array
|
||||
*/
|
||||
ND4J_EXPORT void flatten(
|
||||
Nd4jPointer *extraPointers,
|
||||
int offset,
|
||||
char order,
|
||||
void *result, Nd4jLong *resultShapeInfo,
|
||||
void *dresult, Nd4jLong *dresultShapeInfo,
|
||||
void *input, Nd4jLong *inputShapeInfo,
|
||||
void *dinput, Nd4jLong *dinputShapeInfo);
|
||||
|
||||
ND4J_EXPORT void concat(
|
||||
Nd4jPointer *extraPointers,
|
||||
int dimension,
|
||||
int numArrays,
|
||||
Nd4jPointer *data, Nd4jPointer *inputShapeInfo,
|
||||
Nd4jPointer *ddata, Nd4jPointer *dinputShapeInfo,
|
||||
void *result, Nd4jLong *resultShapeInfo,
|
||||
void *dresult, Nd4jLong *dresultShapeInfo,
|
||||
Nd4jPointer *tadPointers, Nd4jPointer *offsetPointers);
|
||||
|
||||
|
||||
ND4J_EXPORT void specialConcat (
|
||||
Nd4jPointer *extraPointers,
|
||||
int dimension,
|
||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -23,6 +23,7 @@
|
|||
|
||||
#include <dll.h>
|
||||
#include <pointercast.h>
|
||||
#include <execution/ErrorReference.h>
|
||||
|
||||
namespace nd4j {
|
||||
class ND4J_EXPORT ContextBuffers {
|
||||
|
@ -32,6 +33,7 @@ namespace nd4j {
|
|||
void* _allocationPointer = nullptr;
|
||||
void* _execStream = nullptr;
|
||||
void* _specialStream = nullptr;
|
||||
sd::ErrorReference _errorReference;
|
||||
bool _allocated = false;
|
||||
bool _initialized = false;
|
||||
|
||||
|
@ -60,6 +62,8 @@ namespace nd4j {
|
|||
void setScalarBuffer(void* pointer);
|
||||
void setAllocationBuffer(void* pointer);
|
||||
|
||||
sd::ErrorReference* errorReference();
|
||||
|
||||
void triggerOwnership(bool isOwner);
|
||||
|
||||
int deviceId();
|
||||
|
|
|
@ -15,32 +15,32 @@
|
|||
******************************************************************************/
|
||||
|
||||
//
|
||||
// Created by Yurii Shyrma on 27.01.2018
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#ifndef LIBND4J_PROVIDERRNG_H
|
||||
#define LIBND4J_PROVIDERRNG_H
|
||||
#ifndef DEV_TESTS_ERRORREFERENCE_H
|
||||
#define DEV_TESTS_ERRORREFERENCE_H
|
||||
|
||||
#include <helpers/helper_random.h>
|
||||
#include <mutex>
|
||||
|
||||
namespace nd4j {
|
||||
|
||||
class ProviderRNG {
|
||||
|
||||
protected:
|
||||
random::RandomBuffer* _rng;
|
||||
static std::mutex _mutex;
|
||||
ProviderRNG();
|
||||
#include <string>
|
||||
#include <dll.h>
|
||||
|
||||
namespace sd {
|
||||
class ND4J_EXPORT ErrorReference {
|
||||
private:
|
||||
int _errorCode = 0;
|
||||
std::string _errorMessage;
|
||||
public:
|
||||
ProviderRNG(const ProviderRNG&) = delete;
|
||||
void operator=(const ProviderRNG&) = delete;
|
||||
random::RandomBuffer* getRNG() const;
|
||||
static ProviderRNG& getInstance();
|
||||
};
|
||||
ErrorReference() = default;
|
||||
~ErrorReference() = default;
|
||||
|
||||
int errorCode();
|
||||
const char* errorMessage();
|
||||
|
||||
void setErrorCode(int errorCode);
|
||||
void setErrorMessage(std::string message);
|
||||
void setErrorMessage(const char* message);
|
||||
};
|
||||
}
|
||||
|
||||
#endif //LIBND4J_PROVIDERRNG_H
|
||||
|
||||
#endif //DEV_TESTS_ERRORREFERENCE_H
|
|
@ -37,6 +37,7 @@
|
|||
#include <vector>
|
||||
#include <mutex>
|
||||
#include <execution/ContextBuffers.h>
|
||||
#include <execution/ErrorReference.h>
|
||||
|
||||
|
||||
|
||||
|
@ -97,9 +98,12 @@ class ND4J_EXPORT LaunchContext {
|
|||
|
||||
int getDeviceID() const {return _deviceID;}
|
||||
void setDeviceID(int deviceID) { _deviceID = deviceID; }
|
||||
sd::ErrorReference* errorReference();
|
||||
|
||||
static bool isInitialized();
|
||||
static void releaseBuffers();
|
||||
|
||||
|
||||
static LaunchContext* defaultContext();
|
||||
|
||||
|
||||
|
|
|
@ -99,4 +99,8 @@ namespace nd4j {
|
|||
ContextBuffers& ContextBuffers::operator=(ContextBuffers&& other) {
|
||||
return *this;
|
||||
}
|
||||
|
||||
sd::ErrorReference* ContextBuffers::errorReference() {
|
||||
return &_errorReference;
|
||||
}
|
||||
}
|
|
@ -23,7 +23,11 @@
|
|||
#include <exceptions/cuda_exception.h>
|
||||
#include <thread>
|
||||
|
||||
#ifdef IOS_BUILD
|
||||
nd4j::ContextBuffers contextBuffers = nd4j::ContextBuffers();
|
||||
#else
|
||||
thread_local nd4j::ContextBuffers contextBuffers = nd4j::ContextBuffers();
|
||||
#endif
|
||||
|
||||
namespace nd4j {
|
||||
|
||||
|
@ -65,4 +69,8 @@ namespace nd4j {
|
|||
void LaunchContext::releaseBuffers() {
|
||||
//
|
||||
}
|
||||
|
||||
sd::ErrorReference* LaunchContext::errorReference() {
|
||||
return contextBuffers.errorReference();
|
||||
}
|
||||
}
|
|
@ -220,5 +220,9 @@ namespace nd4j {
|
|||
bool ContextBuffers::isInitialized() {
|
||||
return _initialized;
|
||||
}
|
||||
|
||||
sd::ErrorReference* ContextBuffers::errorReference() {
|
||||
return &_errorReference;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -168,4 +168,8 @@ LaunchContext::LaunchContext() {
|
|||
bool LaunchContext::isInitialized() {
|
||||
return contextBuffers.isInitialized();
|
||||
}
|
||||
|
||||
sd::ErrorReference* LaunchContext::errorReference() {
|
||||
return contextBuffers.errorReference();
|
||||
}
|
||||
}
|
|
@ -15,37 +15,32 @@
|
|||
******************************************************************************/
|
||||
|
||||
//
|
||||
// Created by Yurii Shyrma on 27.01.2018
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include <helpers/ProviderRNG.h>
|
||||
#include <NativeOps.h>
|
||||
|
||||
#include <execution/ErrorReference.h>
|
||||
|
||||
namespace nd4j {
|
||||
|
||||
ProviderRNG::ProviderRNG() {
|
||||
namespace sd {
|
||||
int ErrorReference::errorCode() {
|
||||
return _errorCode;
|
||||
}
|
||||
|
||||
Nd4jLong *buffer = new Nd4jLong[100000];
|
||||
std::lock_guard<std::mutex> lock(_mutex);
|
||||
#ifndef __CUDABLAS__
|
||||
// at this moment we don't have streams etc, so let's just skip this for now
|
||||
_rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, 123, 100000, (Nd4jPointer) buffer);
|
||||
#endif
|
||||
// if(_rng != nullptr)
|
||||
}
|
||||
|
||||
ProviderRNG& ProviderRNG::getInstance() {
|
||||
|
||||
static ProviderRNG instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
random::RandomBuffer* ProviderRNG::getRNG() const {
|
||||
|
||||
return _rng;
|
||||
}
|
||||
|
||||
std::mutex ProviderRNG::_mutex;
|
||||
|
||||
const char* ErrorReference::errorMessage() {
|
||||
// since we're fetching error message - error code will be assumed consumed & nullified
|
||||
_errorCode = 0;
|
||||
return _errorMessage.c_str();
|
||||
}
|
||||
|
||||
void ErrorReference::setErrorCode(int errorCode) {
|
||||
_errorCode = errorCode;
|
||||
}
|
||||
|
||||
void ErrorReference::setErrorMessage(std::string message) {
|
||||
_errorMessage = message;
|
||||
}
|
||||
|
||||
void ErrorReference::setErrorMessage(const char* message) {
|
||||
_errorMessage = std::string(message);
|
||||
}
|
||||
}
|
|
@ -45,7 +45,7 @@ DECLARE_SYN(IsMax, ismax);
|
|||
DECLARE_TYPES(ismax) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(0, DataType::ANY)
|
||||
->setAllowedOutputTypes(0, DataType::BOOL);
|
||||
->setAllowedOutputTypes(0, DataType::ANY);
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -84,7 +84,8 @@ namespace nd4j {
|
|||
->setAllowedInputTypes(11, nd4j::DataType::INT64)
|
||||
->setAllowedInputTypes(12, nd4j::DataType::INT32)
|
||||
->setAllowedInputTypes(13, nd4j::DataType::INT32)
|
||||
->setAllowedInputTypes(14, {ALL_FLOATS});
|
||||
->setAllowedInputTypes(14, {ALL_FLOATS})
|
||||
->setAllowedOutputTypes(nd4j::DataType::ANY);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -79,7 +79,7 @@ namespace nd4j {
|
|||
->setAllowedInputTypes(9, {ALL_FLOATS})
|
||||
->setAllowedInputTypes(10, nd4j::DataType::INT64)
|
||||
->setAllowedInputTypes(11, {ALL_FLOATS})
|
||||
->setAllowedOutputTypes(nd4j::DataType::INT8);
|
||||
->setAllowedOutputTypes(nd4j::DataType::ANY);
|
||||
}
|
||||
|
||||
/*
|
||||
|
|
|
@ -70,7 +70,7 @@ CONFIGURABLE_OP_IMPL(softmax_bp, 2, 1, true, 0, 0) {
|
|||
|
||||
DECLARE_TYPES(softmax_bp) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(DataType::ANY)
|
||||
->setAllowedInputTypes({ALL_FLOATS})
|
||||
->setAllowedOutputTypes({ALL_FLOATS});
|
||||
}
|
||||
|
||||
|
|
|
@ -30,51 +30,9 @@ namespace nd4j {
|
|||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
template <typename T>
|
||||
inline void __device__ indexSwap(T* arr, Nd4jLong idx1, Nd4jLong idx2) {
|
||||
T tmp = arr[idx1];
|
||||
arr[idx1] = arr[idx2];
|
||||
arr[idx2] = tmp;
|
||||
}
|
||||
// template <typename T>
|
||||
// void reverseArray(nd4j::LaunchContext * context, void* inArr, Nd4jLong *inShapeBuffer, void *result, Nd4jLong *zShapeBuffer, int numOfElemsToReverse = 0);
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
static __global__ void reverseArrayInplaceKernel(void *input, Nd4jLong *inputShape, Nd4jLong numOfElemsToReverse) {
|
||||
const auto tid = blockIdx.x * gridDim.x + threadIdx.x;
|
||||
const auto step = gridDim.x * blockDim.x;
|
||||
__shared__ Nd4jLong length;
|
||||
__shared__ int linearStatus;
|
||||
__shared__ T* inputArr;
|
||||
if (threadIdx.x == 0) {
|
||||
length = shape::length(inputShape);
|
||||
linearStatus = shape::elementWiseStride(inputShape);
|
||||
inputArr = reinterpret_cast<T*>(input);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (Nd4jLong e = tid; e < numOfElemsToReverse / 2; e += step) {
|
||||
if (linearStatus == 1) {
|
||||
auto idx = numOfElemsToReverse - e - 1;
|
||||
indexSwap(inputArr, e, idx);
|
||||
}
|
||||
else if (linearStatus > 1) {
|
||||
auto idx1 = (numOfElemsToReverse - e - 1) * linearStatus;
|
||||
Nd4jLong idx2 = e * linearStatus;
|
||||
indexSwap(inputArr, idx1, idx2);
|
||||
}
|
||||
else {
|
||||
auto inOffset = shape::getIndexOffset(e, inputShape, length);
|
||||
auto outOffset = shape::getIndexOffset(numOfElemsToReverse - e - 1, inputShape, length);
|
||||
indexSwap(inputArr, inOffset, outOffset);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static __global__ void reverseArrayKernel(void* input, Nd4jLong *inputShape, void* output, Nd4jLong *outputShape, Nd4jLong numOfElemsToReverse) {
|
||||
const auto tid = blockIdx.x * gridDim.x + threadIdx.x;
|
||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const auto step = gridDim.x * blockDim.x;
|
||||
__shared__ Nd4jLong length;
|
||||
__shared__ int linearStatus;
|
||||
|
@ -93,51 +51,47 @@ namespace helpers {
|
|||
}
|
||||
__syncthreads();
|
||||
|
||||
for (Nd4jLong e = tid; e < length; e += step) {
|
||||
if (e < numOfElemsToReverse ) {
|
||||
if (linearStatus == 1) {
|
||||
auto idx = numOfElemsToReverse - e - 1;
|
||||
outputArr[idx] = inputArr[e];
|
||||
} else if (linearStatus > 1) {
|
||||
auto idx1 = (numOfElemsToReverse - e - 1) * linearStatus;
|
||||
Nd4jLong idx2 = e * linearStatus;
|
||||
outputArr[idx1] = inputArr[idx2];
|
||||
} else {
|
||||
auto inOffset = shape::getIndexOffset(e, inputShape, length);
|
||||
auto outOffset = shape::getIndexOffset(numOfElemsToReverse - e - 1, outputShape, length);
|
||||
outputArr[outOffset] = inputArr[inOffset];
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (linearStatus == 1) {
|
||||
outputArr[e] = inputArr[e];
|
||||
} else if (linearStatus > 1) {
|
||||
auto idx1 = e * linearStatus;
|
||||
Nd4jLong idx2 = e * linearStatus;
|
||||
outputArr[idx1] = inputArr[idx2];
|
||||
} else {
|
||||
auto inOffset = shape::getIndexOffset(e, inputShape, length);
|
||||
auto outOffset = shape::getIndexOffset(e, outputShape, length);
|
||||
outputArr[outOffset] = inputArr[inOffset];
|
||||
}
|
||||
}
|
||||
auto odd = length % 2 != 0;
|
||||
auto limit = length / 2;
|
||||
|
||||
for (Nd4jLong e = tid; e < limit; e += step) {
|
||||
// we're calculating offsets within input array
|
||||
auto fOffset = shape::getIndexOffset(e, inputShape, length);
|
||||
auto lOffset = shape::getIndexOffset(numOfElemsToReverse - e - 1, inputShape, length);
|
||||
|
||||
// now we're storing input values
|
||||
auto v1 = inputArr[fOffset];
|
||||
auto v2 = inputArr[lOffset];
|
||||
|
||||
// now we're calculating offsets within output array
|
||||
auto zfOffset = shape::getIndexOffset(e, outputShape, length);
|
||||
auto zlOffset = shape::getIndexOffset(numOfElemsToReverse - e - 1, outputShape, length);
|
||||
|
||||
// and saving values to output arrays
|
||||
outputArr[zfOffset] = v2;
|
||||
outputArr[zlOffset] = v1;
|
||||
|
||||
//printf("TID: %i; E: %lld; z[%lld], z[%lld] = x[%lld], x[%lld];\n", tid, e, zfOffset, zlOffset, lOffset, fOffset);
|
||||
}
|
||||
|
||||
//printf("\n");
|
||||
// in case of odd array we'll have to move middle value
|
||||
if (odd && tid == 0) {
|
||||
auto xOffset = shape::getIndexOffset(limit, inputShape, length);
|
||||
auto zOffset = shape::getIndexOffset(limit, outputShape, length);
|
||||
|
||||
outputArr[zOffset] = inputArr[xOffset];
|
||||
//printf("TID: %i; E: %lld; z[%lld] = x[%lld];\n", tid, limit, zOffset, xOffset);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void reverseArray(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int numOfElemsToReverse) {
|
||||
static void reverseArray(nd4j::LaunchContext * context, NDArray* input, NDArray* output, Nd4jLong numOfElemsToReverse) {
|
||||
auto stream = context->getCudaStream();
|
||||
Nd4jLong numOfReverse = numOfElemsToReverse;
|
||||
if (numOfElemsToReverse == 0)
|
||||
numOfReverse = input->lengthOf();
|
||||
if (input == output) {
|
||||
reverseArrayInplaceKernel<T><<<256, 512, 8192, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), numOfReverse);
|
||||
}
|
||||
else {
|
||||
reverseArrayKernel<T><<<256, 512, 8192, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), numOfReverse);
|
||||
}
|
||||
|
||||
reverseArrayKernel<T><<<256, 512, 8192, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), numOfReverse);
|
||||
}
|
||||
|
||||
|
||||
|
@ -221,7 +175,7 @@ namespace helpers {
|
|||
delete listIn;
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void reverseArray, (nd4j::LaunchContext * context, NDArray *inArr, NDArray *outArr, int numOfElemsToReverse), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template void reverseArray, (nd4j::LaunchContext * context, NDArray *inArr, NDArray *outArr, Nd4jLong numOfElemsToReverse), LIBND4J_TYPES);
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,7 +19,6 @@
|
|||
//
|
||||
|
||||
#include <ops/declarable/DeclarableOp.h>
|
||||
#include <helpers/ProviderRNG.h>
|
||||
#include <Status.h>
|
||||
#include <helpers/ShapeUtils.h>
|
||||
#include <NDArrayFactory.h>
|
||||
|
@ -190,32 +189,6 @@ namespace nd4j {
|
|||
auto outSha = this->calculateOutputShape(&inSha, ctx);
|
||||
results = outSha->size();
|
||||
|
||||
// we must "validate" our output shapes
|
||||
/*
|
||||
for (int e = 0; e < results; e++) {
|
||||
auto ptr = outSha->at(e);
|
||||
|
||||
// checking for the same pointer used twice
|
||||
for (int i = 0; i < results; i++){
|
||||
if (i == e)
|
||||
continue;
|
||||
|
||||
auto com = outSha->at(i);
|
||||
|
||||
if (ptr == com)
|
||||
throw std::runtime_error("ShapeFunction returned same shape instance twice [" + *_descriptor->getOpName() + "]");
|
||||
}
|
||||
|
||||
// checking for input pointer returned back
|
||||
for (int i = 0; i < inSha.size(); i++){
|
||||
auto com = inSha.at(i);
|
||||
|
||||
if (ptr == com)
|
||||
throw std::runtime_error("ShapeFunction returned input shape instance as output [" + *_descriptor->getOpName() + "]");
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
// optionally saving shapeTime
|
||||
if (Environment::getInstance()->isProfiling() && node != nullptr) {
|
||||
shapeEnd = std::chrono::system_clock::now();
|
||||
|
@ -355,75 +328,139 @@ namespace nd4j {
|
|||
// rolling over inputs first
|
||||
int cnt = 0, inT = 0;
|
||||
std::vector<nd4j::DataType> inputTypes(block.width());
|
||||
for (auto &p: *(block.inputs())) {
|
||||
auto var = block.variable(p);
|
||||
|
||||
// we're not checking validity, if ANY types were explicitly allowed
|
||||
//if (block.dataType(cnt) == nd4j::DataType::ANY)
|
||||
// continue;
|
||||
|
||||
// only validating non-null variables
|
||||
if (var != nullptr && var->hasNDArray()) {
|
||||
auto array = var->getNDArray();
|
||||
|
||||
if (block.isFastPath()) {
|
||||
for (auto array: block.fastpath_in()) {
|
||||
inputTypes[inT++] = array->dataType();
|
||||
if (!_descriptor->checkInputMatch(cnt, array->dataType())) {
|
||||
auto ctype = DataTypeUtils::asString(array->dataType());
|
||||
nd4j_printf("Op [%s] failed check for input [%i], DataType: [%s]\n", _descriptor->getOpName()->data(), cnt, ctype.c_str());
|
||||
nd4j_printf("Op [%s] failed check for input [%i], DataType: [%s]\n",
|
||||
_descriptor->getOpName()->data(), cnt, ctype.c_str());
|
||||
return ND4J_STATUS_BAD_ARGUMENTS;
|
||||
}
|
||||
cnt++;
|
||||
}
|
||||
} else {
|
||||
for (auto &p: *(block.inputs())) {
|
||||
auto var = block.variable(p);
|
||||
|
||||
cnt++;
|
||||
}
|
||||
|
||||
// checking optionally available outputs
|
||||
auto varSpace = block.getVariableSpace();
|
||||
for (int index = 0; index < DataTypeUtils::max<int>(); index++) {
|
||||
if (varSpace != nullptr && varSpace->hasVariable(block.nodeId(), index)) {
|
||||
auto var = block.variable(block.nodeId(), index);
|
||||
// we're not checking validity, if ANY types were explicitly allowed
|
||||
//if (block.dataType(cnt) == nd4j::DataType::ANY)
|
||||
// continue;
|
||||
|
||||
// only validating non-null variables
|
||||
if (var != nullptr && var->hasNDArray()) {
|
||||
auto array = var->getNDArray();
|
||||
auto cType = array->dataType();
|
||||
|
||||
if (_descriptor->isSameMode()) {
|
||||
|
||||
if (index >= block.width()) {
|
||||
auto iv = block.variable(0);
|
||||
|
||||
if (iv->getNDArray()->dataType() != cType) {
|
||||
auto t = DataTypeUtils::asString(cType);
|
||||
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s]\n", _descriptor->getOpName()->data(), index, t.c_str());
|
||||
return ND4J_STATUS_BAD_ARGUMENTS;
|
||||
}
|
||||
} else {
|
||||
// for same mode, output type must be the same as input type
|
||||
auto iv = block.variable(index);
|
||||
|
||||
if (iv->getNDArray()->dataType() != cType) {
|
||||
auto t = DataTypeUtils::asString(cType);
|
||||
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s]\n", _descriptor->getOpName()->data(), index, t.c_str());
|
||||
return ND4J_STATUS_BAD_ARGUMENTS;
|
||||
}
|
||||
}
|
||||
} else if (_descriptor->isInherit(index)) {
|
||||
// in inherit mode, output type must be the same as one of input types
|
||||
if (std::find(inputTypes.begin(), inputTypes.end(), cType) == inputTypes.end()) {
|
||||
auto t = DataTypeUtils::asString(cType);
|
||||
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s].\n", _descriptor->getOpName()->data(), index, t.c_str());
|
||||
return ND4J_STATUS_BAD_ARGUMENTS;
|
||||
}
|
||||
|
||||
} else if (!_descriptor->checkOutputMatch(index, cType)) {
|
||||
auto t = DataTypeUtils::asString(cType);
|
||||
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%i];\n", _descriptor->getOpName()->data(), index, t.c_str());
|
||||
inputTypes[inT++] = array->dataType();
|
||||
if (!_descriptor->checkInputMatch(cnt, array->dataType())) {
|
||||
auto ctype = DataTypeUtils::asString(array->dataType());
|
||||
nd4j_printf("Op [%s] failed check for input [%i], DataType: [%s]\n",
|
||||
_descriptor->getOpName()->data(), cnt, ctype.c_str());
|
||||
return ND4J_STATUS_BAD_ARGUMENTS;
|
||||
}
|
||||
}
|
||||
} else
|
||||
break;
|
||||
|
||||
cnt++;
|
||||
}
|
||||
}
|
||||
|
||||
if (block.isFastPath()) {
|
||||
int index = 0;
|
||||
for (auto array: block.fastpath_out()) {
|
||||
auto cType = array->dataType();
|
||||
|
||||
if (_descriptor->isSameMode()) {
|
||||
|
||||
if (index >= block.width()) {
|
||||
auto ia = block.fastpath_in()[0];
|
||||
|
||||
if (ia->dataType() != cType) {
|
||||
auto t = DataTypeUtils::asString(cType);
|
||||
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s]\n",
|
||||
_descriptor->getOpName()->data(), index, t.c_str());
|
||||
return ND4J_STATUS_BAD_ARGUMENTS;
|
||||
}
|
||||
} else {
|
||||
// for same mode, output type must be the same as input type
|
||||
auto ia = block.fastpath_in()[index];
|
||||
|
||||
if (ia->dataType() != cType) {
|
||||
auto t = DataTypeUtils::asString(cType);
|
||||
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s]\n",
|
||||
_descriptor->getOpName()->data(), index, t.c_str());
|
||||
return ND4J_STATUS_BAD_ARGUMENTS;
|
||||
}
|
||||
}
|
||||
} else if (_descriptor->isInherit(index)) {
|
||||
// in inherit mode, output type must be the same as one of input types
|
||||
if (std::find(inputTypes.begin(), inputTypes.end(), cType) == inputTypes.end()) {
|
||||
auto t = DataTypeUtils::asString(cType);
|
||||
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s].\n",
|
||||
_descriptor->getOpName()->data(), index, t.c_str());
|
||||
return ND4J_STATUS_BAD_ARGUMENTS;
|
||||
}
|
||||
|
||||
} else if (!_descriptor->checkOutputMatch(index, cType)) {
|
||||
auto t = DataTypeUtils::asString(cType);
|
||||
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s];\n",
|
||||
_descriptor->getOpName()->data(), index, t.c_str());
|
||||
return ND4J_STATUS_BAD_ARGUMENTS;
|
||||
}
|
||||
index++;
|
||||
}
|
||||
} else {
|
||||
// checking optionally available outputs
|
||||
auto varSpace = block.getVariableSpace();
|
||||
for (int index = 0; index < DataTypeUtils::max<int>(); index++) {
|
||||
if (varSpace != nullptr && varSpace->hasVariable(block.nodeId(), index)) {
|
||||
auto var = block.variable(block.nodeId(), index);
|
||||
|
||||
// only validating non-null variables
|
||||
if (var != nullptr && var->hasNDArray()) {
|
||||
auto array = var->getNDArray();
|
||||
auto cType = array->dataType();
|
||||
|
||||
if (_descriptor->isSameMode()) {
|
||||
|
||||
if (index >= block.width()) {
|
||||
auto iv = block.variable(0);
|
||||
|
||||
if (iv->getNDArray()->dataType() != cType) {
|
||||
auto t = DataTypeUtils::asString(cType);
|
||||
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s]\n",
|
||||
_descriptor->getOpName()->data(), index, t.c_str());
|
||||
return ND4J_STATUS_BAD_ARGUMENTS;
|
||||
}
|
||||
} else {
|
||||
// for same mode, output type must be the same as input type
|
||||
auto iv = block.variable(index);
|
||||
|
||||
if (iv->getNDArray()->dataType() != cType) {
|
||||
auto t = DataTypeUtils::asString(cType);
|
||||
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s]\n",
|
||||
_descriptor->getOpName()->data(), index, t.c_str());
|
||||
return ND4J_STATUS_BAD_ARGUMENTS;
|
||||
}
|
||||
}
|
||||
} else if (_descriptor->isInherit(index)) {
|
||||
// in inherit mode, output type must be the same as one of input types
|
||||
if (std::find(inputTypes.begin(), inputTypes.end(), cType) == inputTypes.end()) {
|
||||
auto t = DataTypeUtils::asString(cType);
|
||||
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s].\n",
|
||||
_descriptor->getOpName()->data(), index, t.c_str());
|
||||
return ND4J_STATUS_BAD_ARGUMENTS;
|
||||
}
|
||||
|
||||
} else if (!_descriptor->checkOutputMatch(index, cType)) {
|
||||
auto t = DataTypeUtils::asString(cType);
|
||||
nd4j_printf("Op [%s] failed check for output [%i], DataType: [%s];\n",
|
||||
_descriptor->getOpName()->data(), index, t.c_str());
|
||||
return ND4J_STATUS_BAD_ARGUMENTS;
|
||||
}
|
||||
}
|
||||
} else
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -400,6 +400,32 @@ TEST_F(JavaInteropTests, Test_Synonyms_3) {
|
|||
ASSERT_EQ(nameRef, name);
|
||||
}
|
||||
|
||||
TEST_F(JavaInteropTests, Test_FastPath_Validation_1) {
|
||||
auto x = NDArrayFactory::create<int>('c', {4}, {1, 2, 3, 4});
|
||||
auto z = NDArrayFactory::create<int>('c', {4}, {1, 2, 3, 4});
|
||||
|
||||
Context ctx(1);
|
||||
ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo());
|
||||
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
||||
|
||||
nd4j::ops::softmax op;
|
||||
auto status = op.execute(&ctx);
|
||||
ASSERT_NE(Status::OK(), status);
|
||||
}
|
||||
|
||||
TEST_F(JavaInteropTests, Test_FastPath_Validation_2) {
|
||||
auto x = NDArrayFactory::create<float>('c', {4}, {1.f, 2.f, 3.f, 4.f});
|
||||
auto z = NDArrayFactory::create<int>('c', {4}, {1, 2, 3, 4});
|
||||
|
||||
Context ctx(1);
|
||||
ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo());
|
||||
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
||||
|
||||
nd4j::ops::softmax op;
|
||||
auto status = op.execute(&ctx);
|
||||
ASSERT_NE(Status::OK(), status);
|
||||
}
|
||||
|
||||
/*
|
||||
TEST_F(JavaInteropTests, test_avgpooling_edge_1) {
|
||||
int inOutH = 35;
|
||||
|
|
|
@ -992,81 +992,6 @@ TEST_F(NativeOpsTests, ScalarTadTest_2) {
|
|||
ASSERT_TRUE(exp.e<bool>(5) == z.e<bool>(5) && exp.e<bool>(15));
|
||||
}
|
||||
|
||||
TEST_F(NativeOpsTests, FlattenTest_1) {
|
||||
auto x = NDArrayFactory::create<float>('c', {5, 5});
|
||||
auto y = NDArrayFactory::create<float>('c', {5, 5});
|
||||
auto exp = NDArrayFactory::create<float>('c', {2, 5,5});
|
||||
auto z = NDArrayFactory::create<float>('c', {2, 5,5});
|
||||
|
||||
Nd4jPointer extra[6];
|
||||
#ifdef __CUDABLAS__
|
||||
extra[1] = x.getContext()->getCudaStream();
|
||||
extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr;
|
||||
x.syncToHost();
|
||||
y.syncToHost();
|
||||
printf("Unsupported for CUDA platform yet.\n");
|
||||
return;
|
||||
#endif
|
||||
x.linspace(1.0,2);
|
||||
y.linspace(2,2);
|
||||
|
||||
//y.assign(2.);
|
||||
x.syncToDevice();
|
||||
z.syncToDevice();
|
||||
auto dimension = NDArrayFactory::create<int>({0, 1});
|
||||
auto dimensions = reinterpret_cast<int*>(dimension.buffer());
|
||||
auto tadPackX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), dimensions, dimension.lengthOf());
|
||||
auto tadPackZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(z.shapeInfo(), dimensions, dimension.lengthOf());
|
||||
exp(1, {0}).linspace(1,2);
|
||||
::flatten(extra,
|
||||
25, 'c', z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
||||
x.buffer(), x.shapeInfo(),
|
||||
x.specialBuffer(), x.specialShapeInfo());
|
||||
|
||||
// exp.printIndexedBuffer("Exp");
|
||||
// z.printIndexedBuffer("Flatten");
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
}
|
||||
|
||||
TEST_F(NativeOpsTests, ConcatTest_1) {
|
||||
auto x = NDArrayFactory::create<float>('c', {5, 5});
|
||||
auto y = NDArrayFactory::create<float>('c', {5, 5});
|
||||
auto exp = NDArrayFactory::create<float>('c', {10,5});
|
||||
auto z = NDArrayFactory::create<float>('c', {10,5});
|
||||
|
||||
Nd4jPointer extra[6];
|
||||
#ifdef __CUDABLAS__
|
||||
extra[1] = x.getContext()->getCudaStream();
|
||||
extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr;
|
||||
x.syncToHost();
|
||||
y.syncToHost();
|
||||
printf("Unsupported for CUDA platform yet.\n");
|
||||
return;
|
||||
#endif
|
||||
x.linspace(1.0);
|
||||
y.linspace(26);
|
||||
|
||||
//y.assign(2.);
|
||||
x.syncToDevice();
|
||||
z.syncToDevice();
|
||||
int d = 0;
|
||||
auto dimension = NDArrayFactory::create<int>('c', {1}, {d});
|
||||
auto dimensions = reinterpret_cast<int*>(dimension.buffer());
|
||||
//auto tadPackX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), dimensions, dimension.lengthOf());
|
||||
auto tadPackZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(z.shapeInfo(), dimensions, dimension.lengthOf());
|
||||
exp.linspace(1);
|
||||
Nd4jPointer datas[] = {x.buffer(), y.buffer()};
|
||||
Nd4jPointer shapes[] = {x.shapeInfo(), y.shapeInfo()};
|
||||
|
||||
::concat(extra,
|
||||
0, 2, datas, shapes, nullptr, nullptr, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
||||
nullptr, nullptr);
|
||||
|
||||
// exp.printIndexedBuffer("Exp");
|
||||
// z.printIndexedBuffer("Concat");
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
}
|
||||
|
||||
TEST_F(NativeOpsTests, ConcatTest_2) {
|
||||
auto x = NDArrayFactory::create<float>('c', {5, 5});
|
||||
auto y = NDArrayFactory::create<float>('c', {5, 5});
|
||||
|
|
|
@ -557,43 +557,6 @@ public interface NativeOps {
|
|||
@Cast("Nd4jLong *") LongPointer tadShapeInfo, @Cast("Nd4jLong *") LongPointer tadOffsets,
|
||||
@Cast("Nd4jLong *") LongPointer tadShapeInfoZ, @Cast("Nd4jLong *") LongPointer tadOffsetsZ);
|
||||
|
||||
/**
|
||||
* @param extraPointers
|
||||
* @param offset
|
||||
* @param order
|
||||
* @param results
|
||||
* @param resultShapeInfo
|
||||
* @param input
|
||||
* @param inputShapeInfo
|
||||
*/
|
||||
void flatten(PointerPointer extraPointers,
|
||||
int offset,
|
||||
char order,
|
||||
Pointer results, @Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
||||
Pointer dresults, @Cast("Nd4jLong *") LongPointer dresultShapeInfo,
|
||||
Pointer input, @Cast("Nd4jLong *") LongPointer inputShapeInfo,
|
||||
Pointer dinput, @Cast("Nd4jLong *") LongPointer dinputShapeInfo);
|
||||
|
||||
/**
|
||||
* @param extraPointers
|
||||
* @param dimension
|
||||
* @param numArrays
|
||||
* @param data
|
||||
* @param inputShapeInfo
|
||||
* @param results
|
||||
* @param resultShapeInfo
|
||||
* @param tadPointers
|
||||
* @param tadOffsets
|
||||
*/
|
||||
void concat(PointerPointer extraPointers,
|
||||
int dimension,
|
||||
int numArrays,
|
||||
PointerPointer data, PointerPointer inputShapeInfo,
|
||||
PointerPointer ddata, PointerPointer dinputShapeInfo,
|
||||
Pointer results, @Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
||||
Pointer dresults, @Cast("Nd4jLong *") LongPointer dresultShapeInfo,
|
||||
PointerPointer tadPointers,
|
||||
PointerPointer tadOffsets);
|
||||
|
||||
void specialConcat(PointerPointer extraPointers,
|
||||
int dimension,
|
||||
|
@ -1185,4 +1148,7 @@ public interface NativeOps {
|
|||
Pointer lcCopyStream(OpaqueLaunchContext lc);
|
||||
Pointer lcBlasHandle(OpaqueLaunchContext lc);
|
||||
Pointer lcSolverHandle(OpaqueLaunchContext lc);
|
||||
|
||||
int lastErrorCode();
|
||||
String lastErrorMessage();
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@ import org.bytedeco.javacpp.Pointer;
|
|||
import org.nd4j.jita.allocator.pointers.CudaPointer;
|
||||
import org.nd4j.linalg.exception.ND4JException;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.nativeblas.NativeOps;
|
||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||
|
||||
/**
|
||||
|
@ -67,14 +68,18 @@ public class cudaEvent_t extends CudaPointer {
|
|||
int res = NativeOpsHolder.getInstance().getDeviceNativeOps().eventSynchronize(this);
|
||||
if (res == 0)
|
||||
throw new ND4JException("CUDA exception happened. Terminating. Last op: [" + Nd4j.getExecutioner().getLastOp() +"]");
|
||||
|
||||
if (NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode() != 0)
|
||||
throw new RuntimeException(NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage());
|
||||
}
|
||||
}
|
||||
|
||||
public void register(cudaStream_t stream) {
|
||||
if (!isDestroyed()) {
|
||||
int res = NativeOpsHolder.getInstance().getDeviceNativeOps().registerEvent(this, stream);
|
||||
if (res == 0)
|
||||
throw new ND4JException("CUDA exception happened. Terminating. Last op: [" + Nd4j.getExecutioner().getLastOp() +"]");
|
||||
|
||||
if (NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode() != 0)
|
||||
throw new RuntimeException(NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -36,8 +36,9 @@ public class cudaStream_t extends CudaPointer {
|
|||
public int synchronize() {
|
||||
NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
|
||||
int res = nativeOps.streamSynchronize(this);
|
||||
if (res == 0)
|
||||
throw new ND4JException("CUDA exception happened. Terminating. Last op: [" + Nd4j.getExecutioner().getLastOp() +"]");
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
return res;
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@ import org.nd4j.linalg.api.buffer.DataType;
|
|||
import org.nd4j.linalg.api.buffer.DataTypeEx;
|
||||
import org.nd4j.linalg.api.buffer.Utf8Buffer;
|
||||
import org.nd4j.linalg.api.memory.enums.MemoryKind;
|
||||
import org.nd4j.linalg.api.ops.custom.Flatten;
|
||||
import org.nd4j.linalg.api.ops.impl.shape.Concat;
|
||||
import org.nd4j.linalg.api.ops.performance.PerformanceTracker;
|
||||
import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper;
|
||||
|
@ -104,6 +105,9 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
|||
functions.put(11, Loader.addressof("cusolverDnSgesvd"));
|
||||
functions.put(12, Loader.addressof("cusolverDnDgesvd"));
|
||||
nativeOps.initializeFunctions(functions);
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -335,75 +339,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
|||
if (Nd4j.getExecutioner() instanceof GridExecutioner)
|
||||
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
|
||||
|
||||
int length = 0;
|
||||
DataType t = null;
|
||||
for (INDArray m : matrices) {
|
||||
length += m.length();
|
||||
if (t == null)
|
||||
t = m.dataType();
|
||||
|
||||
Preconditions.checkArgument(t == m.dataType(), "Arrays must have same data type");
|
||||
}
|
||||
|
||||
INDArray ret = Nd4j.create(t, new long[] {length}, order);
|
||||
int linearIndex = 0;
|
||||
|
||||
AtomicAllocator allocator = AtomicAllocator.getInstance();
|
||||
|
||||
|
||||
for (INDArray m : matrices) {
|
||||
if (m.isEmpty())
|
||||
continue;
|
||||
|
||||
CudaContext context = allocator.getFlowController().prepareAction(ret, m);
|
||||
|
||||
if (m.ordering() == order && ret.elementWiseStride() == m.elementWiseStride()
|
||||
&& ret.elementWiseStride() == 1) {
|
||||
// do memcpy in proper direction and forget about that
|
||||
// FIXME: get rid of this
|
||||
((BaseCudaDataBuffer) m.data()).lazyAllocateHostPointer();
|
||||
allocator.memcpyAsync(ret.data(), new CudaPointer(allocator.getHostPointer(m).address()),
|
||||
AllocationUtils.getRequiredMemory(AllocationUtils.buildAllocationShape(m)),
|
||||
linearIndex * (m.data().dataType() == DataType.DOUBLE ? 8
|
||||
: m.data().dataType() == DataType.FLOAT ? 4 : 2));
|
||||
linearIndex += m.length();
|
||||
} else {
|
||||
Pointer hostYShapeInfo = AddressRetriever.retrieveHostPointer(m.shapeInfoDataBuffer());
|
||||
|
||||
PointerPointer extras = new PointerPointer(
|
||||
AddressRetriever.retrieveHostPointer(ret.shapeInfoDataBuffer()), context.getOldStream(),
|
||||
allocator.getDeviceIdPointer(), null,
|
||||
context.getBufferReduction(), context.getBufferScalar(), null,
|
||||
hostYShapeInfo, AddressRetriever.retrieveHostPointer(ret.shapeInfoDataBuffer()));
|
||||
|
||||
|
||||
nativeOps.flatten(extras, linearIndex, order,
|
||||
null,
|
||||
(LongPointer) allocator.getHostPointer(ret.shapeInfoDataBuffer()),
|
||||
allocator.getPointer(ret, context),
|
||||
(LongPointer) allocator.getPointer(ret.shapeInfoDataBuffer(), context),
|
||||
null,
|
||||
(LongPointer) allocator.getHostPointer(m.shapeInfoDataBuffer()),
|
||||
allocator.getPointer(m, context),
|
||||
(LongPointer) allocator.getPointer(m.shapeInfoDataBuffer(), context));
|
||||
|
||||
|
||||
|
||||
|
||||
//Works for all cases...
|
||||
|
||||
/* NdIndexIterator iter = new NdIndexIterator(order, m.shape());
|
||||
while (iter.hasNext()) {
|
||||
ret.putScalar(linearIndex++, m.getDouble(iter.next()));
|
||||
}*/
|
||||
|
||||
linearIndex += m.length();
|
||||
}
|
||||
|
||||
if (ret != null)
|
||||
allocator.registerAction(context, ret, m);
|
||||
}
|
||||
return ret;
|
||||
return Nd4j.exec(new Flatten(order, matrices.toArray(new INDArray[0])))[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -412,131 +348,6 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
|||
((GridExecutioner) Nd4j.getExecutioner()).flushQueue();
|
||||
|
||||
return Nd4j.exec(new Concat(dimension, toConcat))[0];
|
||||
|
||||
// legacy implementation
|
||||
/*
|
||||
boolean allScalars = true;
|
||||
|
||||
var outputShape = ArrayUtil.copy(toConcat[0].shape());
|
||||
|
||||
if (toConcat.length == 1)
|
||||
return toConcat[0];
|
||||
|
||||
int sumAlongDim = 0;
|
||||
for (int i = 0; i < toConcat.length; i++) {
|
||||
if (toConcat[i].isCompressed())
|
||||
Nd4j.getCompressor().decompressi(toConcat[i]);
|
||||
|
||||
allScalars &= toConcat[i].rank() == 0;
|
||||
|
||||
sumAlongDim += toConcat[i].size(dimension);
|
||||
}
|
||||
|
||||
if (allScalars) {
|
||||
outputShape = new long[]{sumAlongDim};
|
||||
} else {
|
||||
outputShape[dimension] = sumAlongDim;
|
||||
}
|
||||
|
||||
INDArray ret = Nd4j.createUninitialized(toConcat[0].dataType(), outputShape, Nd4j.order());
|
||||
|
||||
AtomicAllocator allocator = AtomicAllocator.getInstance();
|
||||
|
||||
CudaContext context = allocator.getFlowController().prepareAction(ret, toConcat);
|
||||
|
||||
val shapeInfoPointers = new long[toConcat.length];
|
||||
val dataPointers = new long[toConcat.length];
|
||||
val tadPointers = new long[toConcat.length];
|
||||
val offsetsPointers = new long[toConcat.length];
|
||||
val hostShapeInfoPointers = new long[toConcat.length];
|
||||
|
||||
TADManager tadManager = Nd4j.getExecutioner().getTADManager();
|
||||
for (int i = 0; i < toConcat.length; i++) {
|
||||
shapeInfoPointers[i] = AddressRetriever.retrieveDeviceAddress(toConcat[i].shapeInfoDataBuffer(), context);
|
||||
dataPointers[i] = AtomicAllocator.getInstance().getPointer(toConcat[i], context).address();
|
||||
hostShapeInfoPointers[i] = AtomicAllocator.getInstance().getHostPointer(toConcat[i].shapeInfoDataBuffer()).address();
|
||||
|
||||
sumAlongDim += toConcat[i].size(dimension);
|
||||
for (int j = 0; j < toConcat[i].rank(); j++)
|
||||
if (j != dimension && toConcat[i].size(j) != outputShape[j]) {
|
||||
throw new IllegalArgumentException(
|
||||
"Illegal concatenation at array " + i + " and shape element " + j);
|
||||
}
|
||||
|
||||
if (!allScalars) {
|
||||
val tadBuffers = tadManager.getTADOnlyShapeInfo(toConcat[i], new int[]{dimension});
|
||||
|
||||
long devTadShapeInfo = AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context).address();
|
||||
|
||||
val offsets = tadBuffers.getSecond();
|
||||
long devTadOffsets = AtomicAllocator.getInstance().getPointer(offsets, context).address();
|
||||
|
||||
tadPointers[i] = devTadShapeInfo;
|
||||
offsetsPointers[i] = devTadOffsets;
|
||||
}
|
||||
}
|
||||
|
||||
// getting tadOnlyShape for result
|
||||
val zBuffers = tadManager.getTADOnlyShapeInfo(ret, new int[] {dimension});
|
||||
val hostPointers = new LongPointer(hostShapeInfoPointers);
|
||||
val hosthost = new PointerPointerWrapper(hostPointers);
|
||||
|
||||
//System.out.println("shapePointers: " + Arrays.toString(shapeInfoPointers));
|
||||
|
||||
val dZ = AtomicAllocator.getInstance().getPointer(ret, context);
|
||||
val dZShapeInfo = AddressRetriever.retrieveDevicePointer(ret.shapeInfoDataBuffer(), context);
|
||||
|
||||
|
||||
|
||||
//val tempData = new CudaDoubleDataBuffer(toConcat.length);
|
||||
//val tempShapes = new CudaDoubleDataBuffer(toConcat.length);
|
||||
//val tempTAD = new CudaDoubleDataBuffer(toConcat.length);
|
||||
//val tempOffsets = new CudaDoubleDataBuffer(toConcat.length);
|
||||
|
||||
//AtomicAllocator.getInstance().memcpyBlocking(tempData, new LongPointer(dataPointers), dataPointers.length * 8,0);
|
||||
//AtomicAllocator.getInstance().memcpyBlocking(tempShapes, new LongPointer(shapeInfoPointers), shapeInfoPointers.length * 8, 0);
|
||||
//AtomicAllocator.getInstance().memcpyBlocking(tempTAD, new LongPointer(tadPointers), tadPointers.length * 8, 0);
|
||||
//AtomicAllocator.getInstance().memcpyBlocking(tempOffsets, new LongPointer(offsetsPointers), offsetsPointers.length * 8, 0);
|
||||
|
||||
val dataPointer = new PointerPointerWrapper(new LongPointer(dataPointers)); //AtomicAllocator.getInstance().getPointer(tempData, context);
|
||||
val shapesPointer = new PointerPointerWrapper(new LongPointer(shapeInfoPointers));//AtomicAllocator.getInstance().getPointer(tempShapes, context);
|
||||
//val tadPointer = AtomicAllocator.getInstance().getPointer(tempTAD, context);
|
||||
//val offsetPointer = AtomicAllocator.getInstance().getPointer(tempOffsets, context);
|
||||
|
||||
|
||||
// System.out.println("ShapesPointer after conversion: " + shapesPointer);
|
||||
|
||||
val extras = new PointerPointer(AddressRetriever.retrieveHostPointer(ret.shapeInfoDataBuffer()),
|
||||
context.getOldStream(), allocator.getDeviceIdPointer(), null,
|
||||
context.getBufferReduction(), context.getBufferScalar(), null,
|
||||
AddressRetriever.retrieveHostPointer(toConcat[0].shapeInfoDataBuffer()),
|
||||
AddressRetriever.retrieveHostPointer(ret.shapeInfoDataBuffer()),
|
||||
new LongPointer(hostShapeInfoPointers),
|
||||
AtomicAllocator.getInstance().getPointer(zBuffers.getFirst(), context), // getting zTADShape
|
||||
AtomicAllocator.getInstance().getPointer(zBuffers.getSecond(), context) // getting zOffset
|
||||
);
|
||||
|
||||
|
||||
nativeOps.concat(extras,
|
||||
dimension,
|
||||
toConcat.length,
|
||||
null,
|
||||
hosthost,
|
||||
dataPointer,
|
||||
shapesPointer,
|
||||
null,
|
||||
(LongPointer) ret.shapeInfoDataBuffer().addressPointer(),
|
||||
dZ,
|
||||
(LongPointer) dZShapeInfo,
|
||||
null,
|
||||
null);
|
||||
|
||||
|
||||
allocator.registerAction(context, ret, toConcat);
|
||||
|
||||
return ret;
|
||||
//return super.concat(dimension, toConcat);
|
||||
*/
|
||||
}
|
||||
|
||||
|
||||
|
@ -590,6 +401,8 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
|||
(LongPointer) ret.shapeInfoDataBuffer().addressPointer(),
|
||||
null, null);
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
AllocationPoint point = allocator.getAllocationPoint(ret);
|
||||
|
||||
|
@ -598,6 +411,9 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
|||
nativeOps.memcpyAsync(point.getDevicePointer(), point.getHostPointer(), ret.lengthLong() * Nd4j.sizeOfDataType(ret.data().dataType()), CudaConstants.cudaMemcpyHostToDevice, context.getSpecialStream());
|
||||
context.getSpecialStream().synchronize();
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
PerformanceTracker.getInstance().helperRegisterTransaction(point.getDeviceId(), perfD, point.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE);
|
||||
|
||||
point.tickHostRead();
|
||||
|
@ -729,6 +545,8 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
|||
(LongPointer) zTadShapeInfo,
|
||||
new LongPointerWrapper(zTadOffsets));
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
allocator.registerAction(context, ret, source);
|
||||
|
||||
|
@ -743,7 +561,6 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
|||
return target.assign(arrays[0]);
|
||||
|
||||
// we do averaging on GPU only if ALL devices have p2p links
|
||||
//if (CudaEnvironment.getInstance().getConfiguration().isCrossDeviceAccessAllowed() && nativeOps.isP2PAvailable()) {
|
||||
if (true) {
|
||||
Nd4j.getExecutioner().push();
|
||||
|
||||
|
@ -781,6 +598,9 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
|||
|
||||
nativeOps.accumulate(extras, null, (LongPointer) arrays[0].shapeInfoDataBuffer().addressPointer(), x, null, null, (LongPointer) allocator.getHostPointer(target.shapeInfoDataBuffer()) , z, (LongPointer) allocator.getPointer(target.shapeInfoDataBuffer()), arrays.length, len);
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
allocator.getFlowController().registerAction(context, target, arrays);
|
||||
|
||||
return target;
|
||||
|
@ -824,6 +644,8 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
|||
arrays.length,
|
||||
len);
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
AtomicAllocator.getInstance().getAllocationPoint(target).tickHostWrite();
|
||||
|
||||
|
@ -895,6 +717,9 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
|||
arrays.length,
|
||||
len, true);
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
allocator.getFlowController().registerAction(context, target, arrays);
|
||||
|
||||
return target;
|
||||
|
@ -940,6 +765,9 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
|||
arrays.length,
|
||||
len, true);
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
if (target != null)
|
||||
AtomicAllocator.getInstance().getAllocationPoint(target).tickHostWrite();
|
||||
|
||||
|
@ -1115,6 +943,9 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
|||
(IntPointer) shuffleMap, new PointerPointer(allocator.getPointer(tempTAD, context)),
|
||||
new PointerPointer(allocator.getPointer(tempOffsets, context)));
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
for (int f = 0; f < arrays.size(); f++) {
|
||||
allocator.getFlowController().registerAction(context, arrays.get(f));
|
||||
}
|
||||
|
@ -1260,6 +1091,9 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
|||
val p = new PointerPointer<>(new Pointer[]{null, stream});
|
||||
|
||||
nativeOps.convertTypes(p, typeSrc.ordinal(), source, length, typeDst.ordinal(), target);
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -1277,7 +1111,13 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
|||
srcPtr = nativeOps.mallocDevice(ssize, 0, 0);
|
||||
dstPtr = nativeOps.mallocDevice(size, 0, 0);
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
nativeOps.memcpyAsync(srcPtr, source, ssize, CudaConstants.cudaMemcpyHostToDevice, stream);
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
} else {
|
||||
// decompressing
|
||||
throw new UnsupportedOperationException();
|
||||
|
@ -1288,9 +1128,15 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
|||
|
||||
stream.synchronize();
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
if (buffer instanceof CompressedDataBuffer) {
|
||||
nativeOps.freeDevice(srcPtr, 0);
|
||||
nativeOps.freeDevice(dstPtr, 0);
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1309,13 +1155,15 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
|||
val size = ((CompressedDataBuffer) source).getCompressionDescriptor().getCompressedLength();
|
||||
srcPtr = ws.alloc(size, MemoryKind.DEVICE, DataType.HALF, false);
|
||||
nativeOps.memcpyAsync(srcPtr, source.addressPointer(), size, CudaConstants.cudaMemcpyHostToHost, stream);
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
}
|
||||
|
||||
// if true - we're compressing into host memory
|
||||
if (target instanceof CompressedDataBuffer) {
|
||||
val size = ((CompressedDataBuffer) target).getCompressionDescriptor().getCompressedLength();
|
||||
dstPtr = ws.alloc(size, MemoryKind.DEVICE, DataType.HALF, false);
|
||||
//nativeOps.memcpyAsync(dstPtr, target.addressPointer(), size, CudaConstants.cudaMemcpyHostToHost, stream);
|
||||
}
|
||||
} else {
|
||||
// if true - we're decompressing from host memory
|
||||
|
@ -1325,6 +1173,9 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
|||
srcPtr = nativeOps.mallocDevice(size, 0, 0);
|
||||
nativeOps.memcpyAsync(srcPtr, source.addressPointer(), size, CudaConstants.cudaMemcpyHostToHost, stream);
|
||||
stream.synchronize();
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
} else
|
||||
srcPtr = AtomicAllocator.getInstance().getPointer(source);
|
||||
|
||||
|
@ -1333,8 +1184,9 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
|||
log.info("Replacing target ptr");
|
||||
val size = ((CompressedDataBuffer) target).getCompressionDescriptor().getCompressedLength();
|
||||
dstPtr = nativeOps.mallocDevice(size, 0, 0);
|
||||
//nativeOps.memcpyAsync(dstPtr, source.addressPointer(), size, CudaConstants.cudaMemcpyHostToHost, stream);
|
||||
//stream.synchronize();
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
} else
|
||||
dstPtr = AtomicAllocator.getInstance().getPointer(target);
|
||||
}
|
||||
|
@ -1342,6 +1194,9 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
|||
|
||||
convertDataEx(typeSrc, srcPtr, typeDst, dstPtr, target.length());
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
Nd4j.getExecutioner().commit();
|
||||
|
||||
|
||||
|
@ -1364,6 +1219,9 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
|||
|
||||
}
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
Nd4j.getExecutioner().commit();
|
||||
}
|
||||
|
||||
|
@ -1462,6 +1320,9 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
|||
new LongPointerWrapper(AtomicAllocator.getInstance().getPointer(tadBuffers.getSecond(), context))
|
||||
);
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
AtomicAllocator.getInstance().getFlowController().registerActionAllWrite(context, result);
|
||||
AtomicAllocator.getInstance().getFlowController().registerAction(context,null, result);
|
||||
|
||||
|
@ -1517,6 +1378,8 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
|||
descending
|
||||
);
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
AtomicAllocator.getInstance().getFlowController().registerAction(context, x);
|
||||
|
||||
|
@ -1565,6 +1428,8 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
|||
descending
|
||||
);
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
AtomicAllocator.getInstance().getFlowController().registerAction(context, x);
|
||||
|
||||
|
|
|
@ -207,6 +207,10 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
throw new UnsupportedOperationException("Unknown op type: " + op.getOpType());
|
||||
}
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
|
||||
AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
|
||||
|
||||
profilingConfigurableHookOut(op, st);
|
||||
|
@ -461,6 +465,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
}
|
||||
}
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
profilingConfigurableHookOut(op, st);
|
||||
|
||||
return op.z();
|
||||
|
@ -619,7 +626,8 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
AtomicAllocator.getInstance().getPointer(op.dimensions(), context),
|
||||
null);
|
||||
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
|
||||
|
||||
|
@ -777,6 +785,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
throw new UnsupportedOperationException("Unknown opType: " + op.getOpType());
|
||||
}
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
|
||||
|
||||
profilingConfigurableHookOut(op, st);
|
||||
|
@ -868,6 +879,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
AtomicAllocator.getInstance().registerAction(context, null, op.x(), op.y());
|
||||
}
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
profilingConfigurableHookOut(op, st);
|
||||
|
||||
return null;
|
||||
|
@ -1105,6 +1119,8 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
|
||||
}
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
profilingConfigurableHookOut(op, st);
|
||||
|
||||
|
@ -1194,6 +1210,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
AtomicAllocator.getInstance().getFlowController().registerAction(context, op.z(), op.x(), op.y());
|
||||
|
||||
profilingConfigurableHookOut(op, st);
|
||||
|
@ -1268,6 +1287,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
throw new UnsupportedOperationException("Unknown op type: " + op.getOpType());
|
||||
}
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.scalar());
|
||||
|
||||
profilingConfigurableHookOut(op, st);
|
||||
|
@ -1423,6 +1445,8 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
}
|
||||
}
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
|
||||
|
||||
|
@ -1582,6 +1606,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
batch.getSample().maxIndexArguments(), batch.getSample().maxRealArguments(),
|
||||
AtomicAllocator.getInstance().getPointer(surfaceBuffer, context), FlatBuffersMapper.getDataTypeAsByte(dataType));
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
surfacePoint.tickHostWrite();
|
||||
}
|
||||
|
||||
|
@ -1676,6 +1703,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
numIndexArguments, iPtr, numIntArrays,
|
||||
AtomicAllocator.getInstance().getPointer(realsBuffer.data(), context),
|
||||
numRealArguments, FlatBuffersMapper.getDataTypeAsByte(dataType));
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -1739,6 +1769,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.z().dataType()), context));
|
||||
}
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
AtomicAllocator.getInstance().getFlowController().registerAction(context, op.z(), op.x(), op.y());
|
||||
|
||||
profilingConfigurableHookOut(op, st);
|
||||
|
@ -1969,6 +2002,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
|
||||
nativeOps.decodeThreshold(extras, AtomicAllocator.getInstance().getPointer(buffer), compressedLength, AtomicAllocator.getInstance().getPointer(result), (LongPointer) AtomicAllocator.getInstance().getHostPointer(target.shapeInfoDataBuffer()));
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
AtomicAllocator.getInstance().getAllocationPoint(result).tickDeviceWrite();
|
||||
|
||||
return target;
|
||||
|
@ -2013,7 +2049,8 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
(IntPointer) AtomicAllocator.getInstance().getPointer(buffer, context),
|
||||
(float) threshold);
|
||||
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
AtomicAllocator.getInstance().getFlowController().registerAction(context, indArray);
|
||||
|
||||
|
@ -2039,6 +2076,8 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
|
||||
nativeOps.decodeBitmap(extras, AtomicAllocator.getInstance().getPointer(encoded.data(), context), target.lengthLong(), AtomicAllocator.getInstance().getPointer(target, context), (LongPointer) AtomicAllocator.getInstance().getHostPointer(target.shapeInfoDataBuffer()));
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
AtomicAllocator.getInstance().getFlowController().registerAction(context, target);
|
||||
|
||||
|
@ -2151,6 +2190,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
|
||||
OpaqueShapeList ptrptr = nativeOps.calculateOutputShapes2(null, hash, inputBuffers, inputShapes, op.inputArguments().length, tArgs, op.tArgs().length, iArgs, op.iArgs().length, bArgs, op.numBArguments());
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
if (ptrptr == null)
|
||||
throw new RuntimeException();
|
||||
|
||||
|
@ -2221,109 +2263,6 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
} catch (Exception e) {
|
||||
throw new RuntimeException("Op [" + name + "] execution failed", e);
|
||||
}
|
||||
|
||||
/*
|
||||
long st = profilingConfigurableHookIn(op);
|
||||
|
||||
CudaContext context =(CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
|
||||
//AtomicAllocator.getInstance().getFlowController().prepareActionAllWrite(op.outputArguments());
|
||||
|
||||
if (extraz.get() == null)
|
||||
extraz.set(new PointerPointer(32));
|
||||
|
||||
|
||||
PointerPointer extras = extraz.get().put(
|
||||
new CudaPointer(1),
|
||||
context.getOldStream(),
|
||||
context.getBufferScalar(),
|
||||
context.getBufferReduction());
|
||||
|
||||
val outputArgs = op.outputArguments();
|
||||
val inputArgs = op.inputArguments();
|
||||
|
||||
if (outputArgs.length == 0 && !op.isInplaceCall())
|
||||
throw new ND4JIllegalStateException("You can't execute non-inplace CustomOp without outputs being specified");
|
||||
|
||||
val lc = op.opName().toLowerCase();
|
||||
val hash = op.opHash();
|
||||
|
||||
|
||||
val inputShapes = new PointerPointer<>(inputArgs.length * 2);
|
||||
val inputBuffers = new PointerPointer<>(inputArgs.length * 2);
|
||||
|
||||
int cnt= 0;
|
||||
for (val in: inputArgs) {
|
||||
val hp = AtomicAllocator.getInstance().getHostPointer(in.shapeInfoDataBuffer());
|
||||
inputBuffers.put(cnt, AtomicAllocator.getInstance().getHostPointer(in));
|
||||
inputShapes.put(cnt, hp);
|
||||
|
||||
|
||||
val dp = AtomicAllocator.getInstance().getPointer(in.shapeInfoDataBuffer(), context);
|
||||
|
||||
inputBuffers.put(cnt + inputArgs.length, AtomicAllocator.getInstance().getPointer(in, context));
|
||||
inputShapes.put(cnt+ inputArgs.length, dp);
|
||||
|
||||
if (op.isInplaceCall()) {
|
||||
val ap = AtomicAllocator.getInstance().getAllocationPoint(in);
|
||||
if (ap != null)
|
||||
ap.tickHostWrite();
|
||||
}
|
||||
|
||||
cnt++;
|
||||
}
|
||||
|
||||
|
||||
val outputShapes = new PointerPointer<>(outputArgs.length * 2);
|
||||
val outputBuffers = new PointerPointer<>(outputArgs.length * 2);
|
||||
|
||||
cnt= 0;
|
||||
for (val out: outputArgs) {
|
||||
outputBuffers.put(cnt, AtomicAllocator.getInstance().getHostPointer(out));
|
||||
outputShapes.put(cnt, AtomicAllocator.getInstance().getHostPointer(out.shapeInfoDataBuffer()));
|
||||
|
||||
outputBuffers.put(cnt + outputArgs.length, AtomicAllocator.getInstance().getPointer(out, context));
|
||||
outputShapes.put(cnt + outputArgs.length, AtomicAllocator.getInstance().getPointer(out.shapeInfoDataBuffer(), context));
|
||||
|
||||
val ap = AtomicAllocator.getInstance().getAllocationPoint(out);
|
||||
|
||||
if (ap != null)
|
||||
ap.tickHostWrite();
|
||||
|
||||
cnt++;
|
||||
}
|
||||
|
||||
val iArgs = op.iArgs().length > 0 ? new LongPointer(op.iArgs().length) : null;
|
||||
|
||||
cnt = 0;
|
||||
for (val i: op.iArgs())
|
||||
iArgs.put(cnt++, i);
|
||||
|
||||
|
||||
val tArgs = op.tArgs().length > 0 ? new DoublePointer(op.tArgs().length) : null;
|
||||
|
||||
val bArgs = op.bArgs().length > 0 ? new BooleanPointer(op.numBArguments()) : null;
|
||||
|
||||
cnt = 0;
|
||||
for (val t: op.tArgs())
|
||||
tArgs.put(cnt++, t);
|
||||
|
||||
cnt = 0;
|
||||
for (val b: op.bArgs())
|
||||
bArgs.put(cnt++, b);
|
||||
|
||||
try {
|
||||
val status = OpStatus.byNumber(nativeOps.execCustomOp(extras, hash, inputBuffers, inputShapes, inputArgs.length, outputBuffers, outputShapes, outputArgs.length, tArgs, op.tArgs().length, iArgs, op.iArgs().length, bArgs, op.numBArguments(), op.isInplaceCall()));
|
||||
if (status != OpStatus.ND4J_STATUS_OK)
|
||||
throw new ND4JIllegalStateException("Op execution failed: " + status);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Op [" + op.opName() + "] execution failed");
|
||||
}
|
||||
|
||||
//AtomicAllocator.getInstance().getFlowController().prepareActionAllWrite(op.outputArguments());
|
||||
|
||||
profilingConfigurableHookOut(op, st);
|
||||
return op.outputArguments();
|
||||
*/
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -2341,6 +2280,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
@Override
|
||||
public void registerGraph(long id, Pointer graph) {
|
||||
nativeOps.registerGraph(null, id, graph);
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -2368,6 +2310,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
|
||||
OpaqueVariablesSet result = nativeOps.executeStoredGraph(null, id, ptrBuffers, ptrShapes, ptrIndices, map.size());
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
OpStatus status = OpStatus.byNumber(nativeOps.getVariablesSetStatus(result));
|
||||
|
||||
if (status != OpStatus.ND4J_STATUS_OK)
|
||||
|
@ -2398,6 +2343,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
newMap.put(nodeName, array);
|
||||
}
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
nativeOps.deleteVariablesSet(result);
|
||||
|
||||
return newMap;
|
||||
|
@ -2406,6 +2354,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
@Override
|
||||
public void forgetGraph(long id) {
|
||||
nativeOps.unregisterGraph(null, id);
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -2474,6 +2425,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
null, (LongPointer) AtomicAllocator.getInstance().getHostPointer(tadY.getFirst()), null, AtomicAllocator.getInstance().getPointer(updates, context), (LongPointer) AtomicAllocator.getInstance().getPointer(tadY.getFirst()), (LongPointer) AtomicAllocator.getInstance().getPointer(tadY.getSecond()),
|
||||
null, (IntPointer) AtomicAllocator.getInstance().getPointer(indices, context));
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
AtomicAllocator.getInstance().getFlowController().registerAction(context, array, indices, updates);
|
||||
}
|
||||
|
||||
|
@ -2490,9 +2444,14 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
((CudaOpContext) context).setCudaStream(ctx.getOldStream(), ctx.getBufferReduction(), ctx.getBufferAllocation());
|
||||
|
||||
val status = nativeOps.execCustomOp2(null, op.opHash(), context.contextPointer());
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
if (status != 0)
|
||||
throw new RuntimeException("Op [" + op.opName() + "] execution failed");
|
||||
|
||||
|
||||
|
||||
for (val arr:op.outputArguments())
|
||||
AtomicAllocator.getInstance().registerAction(ctx, arr);
|
||||
|
||||
|
@ -2527,6 +2486,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
|
||||
nativeOps.inspectArray(extras, AtomicAllocator.getInstance().getHostPointer(array), (LongPointer) AtomicAllocator.getInstance().getHostPointer(array.shapeInfoDataBuffer()), AtomicAllocator.getInstance().getPointer(array, ctx), (LongPointer) AtomicAllocator.getInstance().getPointer(array.shapeInfoDataBuffer()), debugInfo);
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
return INDArrayStatistics.builder()
|
||||
.minValue(debugInfo._minValue())
|
||||
.maxValue(debugInfo._maxValue())
|
||||
|
@ -2545,6 +2507,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
public DataBuffer createShapeInfo(long[] shape, long[] stride, long elementWiseStride, char order, DataType dtype, boolean empty) {
|
||||
OpaqueConstantDataBuffer dbf = nativeOps.shapeBuffer(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, empty);
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
val result = new CudaLongDataBuffer(nativeOps.getConstantDataBufferPrimary(dbf), nativeOps.getConstantDataBufferSpecial(dbf), Shape.shapeInfoLength(shape.length));
|
||||
|
||||
nativeOps.deleteShapeBuffer(dbf);
|
||||
|
@ -2556,6 +2521,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
public TadPack tadShapeInfoAndOffsets(INDArray array, int[] dimension) {
|
||||
OpaqueTadPack pack = nativeOps.tadOnlyShapeInfo((LongPointer) array.shapeInfoDataBuffer().addressPointer(), new IntPointer(dimension), dimension.length);
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
val tadShape = new CudaLongDataBuffer(nativeOps.getPrimaryShapeInfo(pack), nativeOps.getSpecialShapeInfo(pack), nativeOps.getShapeInfoLength(pack));
|
||||
val tadOffsets = new CudaLongDataBuffer(nativeOps.getPrimaryOffsets(pack), nativeOps.getSpecialOffsets(pack), nativeOps.getNumberOfTads(pack));
|
||||
|
||||
|
@ -2568,6 +2536,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
public DataBuffer createConstantBuffer(long[] values, DataType desiredType) {
|
||||
OpaqueConstantDataBuffer dbf = nativeOps.constantBufferLong(desiredType.toInt(), new LongPointer(values), values.length);
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
val buffer = Nd4j.createBuffer(nativeOps.getConstantDataBufferPrimary(dbf), nativeOps.getConstantDataBufferSpecial(dbf), values.length, desiredType);
|
||||
buffer.setConstant(true);
|
||||
|
||||
|
@ -2578,6 +2549,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
public DataBuffer createConstantBuffer(double[] values, DataType desiredType) {
|
||||
OpaqueConstantDataBuffer dbf = nativeOps.constantBufferDouble(desiredType.toInt(), new DoublePointer(values), values.length);
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
val buffer = Nd4j.createBuffer(nativeOps.getConstantDataBufferPrimary(dbf), nativeOps.getConstantDataBufferSpecial(dbf), values.length, desiredType);
|
||||
buffer.setConstant(true);
|
||||
|
||||
|
|
|
@ -449,6 +449,60 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper {
|
|||
// #endif //DEV_TESTS_TADPACK_H
|
||||
|
||||
|
||||
// Parsed from execution/ErrorReference.h
|
||||
|
||||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
// #ifndef DEV_TESTS_ERRORREFERENCE_H
|
||||
// #define DEV_TESTS_ERRORREFERENCE_H
|
||||
|
||||
// #include <string>
|
||||
// #include <dll.h>
|
||||
@Namespace("sd") @NoOffset public static class ErrorReference extends Pointer {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public ErrorReference(Pointer p) { super(p); }
|
||||
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||
public ErrorReference(long size) { super((Pointer)null); allocateArray(size); }
|
||||
private native void allocateArray(long size);
|
||||
@Override public ErrorReference position(long position) {
|
||||
return (ErrorReference)super.position(position);
|
||||
}
|
||||
|
||||
public ErrorReference() { super((Pointer)null); allocate(); }
|
||||
private native void allocate();
|
||||
|
||||
public native int errorCode();
|
||||
public native @Cast("char*") String errorMessage();
|
||||
|
||||
public native void setErrorCode(int errorCode);
|
||||
public native void setErrorMessage(@StdString BytePointer message);
|
||||
public native void setErrorMessage(@StdString String message);
|
||||
}
|
||||
|
||||
|
||||
|
||||
// #endif //DEV_TESTS_ERRORREFERENCE_H
|
||||
|
||||
|
||||
// Parsed from memory/MemoryType.h
|
||||
|
||||
//
|
||||
|
@ -688,6 +742,18 @@ bool verbose = false;
|
|||
// #include <graph/ResultWrapper.h>
|
||||
// #include <DebugInfo.h>
|
||||
|
||||
/**
|
||||
* This function returns last error code stored,
|
||||
* @return non-zero if something bad happened
|
||||
*/
|
||||
public native int lastErrorCode();
|
||||
|
||||
/**
|
||||
* This function returns last error message, if last error code > 0
|
||||
* @return
|
||||
*/
|
||||
public native @Cast("char*") String lastErrorMessage();
|
||||
|
||||
/**
|
||||
*
|
||||
* @param p
|
||||
|
@ -1710,72 +1776,6 @@ public native void execScalarBoolTad(@Cast("Nd4jPointer*") PointerPointer extraP
|
|||
@Cast("Nd4jLong*") long[] tadShapeInfo, @Cast("Nd4jLong*") long[] tadOffsets,
|
||||
@Cast("Nd4jLong*") long[] tadShapeInfoZ, @Cast("Nd4jLong*") long[] tadOffsetsZ);
|
||||
|
||||
|
||||
/**
|
||||
* Append an input array
|
||||
* to the end of a flat array
|
||||
* in a particular order
|
||||
* @param offset the offset of the array to start at
|
||||
* @param order the order
|
||||
* @param result the result array
|
||||
* @param resultShapeInfo the shape info for te array
|
||||
* @param input the input for the array
|
||||
* @param inputShapeInfo the shape information for that array
|
||||
*/
|
||||
public native void flatten(
|
||||
@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||
int offset,
|
||||
char order,
|
||||
Pointer result, @Cast("Nd4jLong*") LongPointer resultShapeInfo,
|
||||
Pointer dresult, @Cast("Nd4jLong*") LongPointer dresultShapeInfo,
|
||||
Pointer input, @Cast("Nd4jLong*") LongPointer inputShapeInfo,
|
||||
Pointer dinput, @Cast("Nd4jLong*") LongPointer dinputShapeInfo);
|
||||
public native void flatten(
|
||||
@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||
int offset,
|
||||
char order,
|
||||
Pointer result, @Cast("Nd4jLong*") LongBuffer resultShapeInfo,
|
||||
Pointer dresult, @Cast("Nd4jLong*") LongBuffer dresultShapeInfo,
|
||||
Pointer input, @Cast("Nd4jLong*") LongBuffer inputShapeInfo,
|
||||
Pointer dinput, @Cast("Nd4jLong*") LongBuffer dinputShapeInfo);
|
||||
public native void flatten(
|
||||
@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||
int offset,
|
||||
char order,
|
||||
Pointer result, @Cast("Nd4jLong*") long[] resultShapeInfo,
|
||||
Pointer dresult, @Cast("Nd4jLong*") long[] dresultShapeInfo,
|
||||
Pointer input, @Cast("Nd4jLong*") long[] inputShapeInfo,
|
||||
Pointer dinput, @Cast("Nd4jLong*") long[] dinputShapeInfo);
|
||||
|
||||
public native void concat(
|
||||
@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||
int dimension,
|
||||
int numArrays,
|
||||
@Cast("Nd4jPointer*") PointerPointer data, @Cast("Nd4jPointer*") PointerPointer inputShapeInfo,
|
||||
@Cast("Nd4jPointer*") PointerPointer ddata, @Cast("Nd4jPointer*") PointerPointer dinputShapeInfo,
|
||||
Pointer result, @Cast("Nd4jLong*") LongPointer resultShapeInfo,
|
||||
Pointer dresult, @Cast("Nd4jLong*") LongPointer dresultShapeInfo,
|
||||
@Cast("Nd4jPointer*") PointerPointer tadPointers, @Cast("Nd4jPointer*") PointerPointer offsetPointers);
|
||||
public native void concat(
|
||||
@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||
int dimension,
|
||||
int numArrays,
|
||||
@Cast("Nd4jPointer*") PointerPointer data, @Cast("Nd4jPointer*") PointerPointer inputShapeInfo,
|
||||
@Cast("Nd4jPointer*") PointerPointer ddata, @Cast("Nd4jPointer*") PointerPointer dinputShapeInfo,
|
||||
Pointer result, @Cast("Nd4jLong*") LongBuffer resultShapeInfo,
|
||||
Pointer dresult, @Cast("Nd4jLong*") LongBuffer dresultShapeInfo,
|
||||
@Cast("Nd4jPointer*") PointerPointer tadPointers, @Cast("Nd4jPointer*") PointerPointer offsetPointers);
|
||||
public native void concat(
|
||||
@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||
int dimension,
|
||||
int numArrays,
|
||||
@Cast("Nd4jPointer*") PointerPointer data, @Cast("Nd4jPointer*") PointerPointer inputShapeInfo,
|
||||
@Cast("Nd4jPointer*") PointerPointer ddata, @Cast("Nd4jPointer*") PointerPointer dinputShapeInfo,
|
||||
Pointer result, @Cast("Nd4jLong*") long[] resultShapeInfo,
|
||||
Pointer dresult, @Cast("Nd4jLong*") long[] dresultShapeInfo,
|
||||
@Cast("Nd4jPointer*") PointerPointer tadPointers, @Cast("Nd4jPointer*") PointerPointer offsetPointers);
|
||||
|
||||
|
||||
public native void specialConcat(
|
||||
@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||
int dimension,
|
||||
|
@ -9950,6 +9950,7 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
|
||||
// #include <dll.h>
|
||||
// #include <pointercast.h>
|
||||
// #include <execution/ErrorReference.h>
|
||||
@Namespace("nd4j") @NoOffset public static class ContextBuffers extends Pointer {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
|
@ -9985,6 +9986,8 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
public native void setScalarBuffer(Pointer pointer);
|
||||
public native void setAllocationBuffer(Pointer pointer);
|
||||
|
||||
public native ErrorReference errorReference();
|
||||
|
||||
public native void triggerOwnership(@Cast("bool") boolean isOwner);
|
||||
|
||||
public native int deviceId();
|
||||
|
@ -10038,6 +10041,7 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
// #include <vector>
|
||||
// #include <mutex>
|
||||
// #include <execution/ContextBuffers.h>
|
||||
// #include <execution/ErrorReference.h>
|
||||
|
||||
@Namespace("nd4j") @NoOffset public static class LaunchContext extends Pointer {
|
||||
static { Loader.load(); }
|
||||
|
@ -10067,9 +10071,12 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
|
||||
public native int getDeviceID();
|
||||
public native void setDeviceID(int deviceID);
|
||||
public native ErrorReference errorReference();
|
||||
|
||||
public static native @Cast("bool") boolean isInitialized();
|
||||
public static native void releaseBuffers();
|
||||
|
||||
|
||||
public static native LaunchContext defaultContext();
|
||||
|
||||
|
||||
|
|
|
@ -32,6 +32,7 @@ import org.bytedeco.javacpp.tools.InfoMapper;
|
|||
"array/ConstantDescriptor.h",
|
||||
"array/ConstantDataBuffer.h",
|
||||
"array/TadPack.h",
|
||||
"execution/ErrorReference.h",
|
||||
"memory/MemoryType.h",
|
||||
"Environment.h",
|
||||
"types/utf8string.h",
|
||||
|
|
|
@ -106,6 +106,9 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory {
|
|||
functions.put(8, Loader.addressof("LAPACKE_sgesdd"));
|
||||
functions.put(9, Loader.addressof("LAPACKE_dgesdd"));
|
||||
nativeOps.initializeFunctions(functions);
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -489,32 +492,7 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory {
|
|||
@Override
|
||||
public INDArray toFlattened(char order, Collection<INDArray> matrices) {
|
||||
Preconditions.checkArgument(matrices.size() > 0, "toFlattened expects > 0 operands");
|
||||
/*
|
||||
int length = 0;
|
||||
val list = new ArrayList<INDArray>(matrices);
|
||||
val t = list.get(0).dataType();
|
||||
for (INDArray m : matrices) {
|
||||
length += m.length();
|
||||
Preconditions.checkArgument(m.dataType() == t, "All operands must have same data type");
|
||||
}
|
||||
|
||||
INDArray ret = Nd4j.create(t, new long[] {length}, order);
|
||||
int linearIndex = 0;
|
||||
PointerPointer dummy = new PointerPointer(new Pointer[] {null});
|
||||
for (INDArray m : matrices) {
|
||||
Nd4j.getCompressor().autoDecompress(m);
|
||||
|
||||
nativeOps.flatten(dummy, linearIndex, order,
|
||||
ret.data().addressPointer(), (LongPointer) ret.shapeInfoDataBuffer().addressPointer(),
|
||||
null, null,
|
||||
m.data().addressPointer(),
|
||||
(LongPointer) m.shapeInfoDataBuffer().addressPointer(),
|
||||
null, null);
|
||||
|
||||
linearIndex += m.length();
|
||||
}
|
||||
return ret;
|
||||
*/
|
||||
return Nd4j.exec(new Flatten(order, matrices.toArray(new INDArray[matrices.size()])))[0];
|
||||
}
|
||||
|
||||
|
@ -555,6 +533,9 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory {
|
|||
new LongPointerWrapper(tadBuffers.getSecond().pointer())
|
||||
);
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
@ -574,65 +555,6 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory {
|
|||
return toConcat[0];
|
||||
|
||||
return Nd4j.exec(new Concat(dimension, toConcat))[0];
|
||||
|
||||
// legacy implementation
|
||||
/*
|
||||
// if reusable var wasn't created for this thread, or is smaller then needed - set it to new value
|
||||
if (extrazA.get() == null || extrazB.get() == null || extrazSize.get() == null || extrazSize.get() < toConcat.length) {
|
||||
extrazA.set(new PointerPointer(toConcat.length));
|
||||
extrazB.set(new PointerPointer(toConcat.length));
|
||||
extrazSize.set(toConcat.length);
|
||||
}
|
||||
|
||||
PointerPointer shapeInfoPointers = extrazA.get();
|
||||
PointerPointer dataPointers = extrazB.get();
|
||||
int sumAlongDim = 0;
|
||||
|
||||
long[] outputShape = ArrayUtil.copy(toConcat[0].shape());
|
||||
|
||||
boolean allScalars = true;
|
||||
|
||||
for (int i = 0; i < toConcat.length; i++) {
|
||||
Preconditions.checkState(toConcat[i].rank() == outputShape.length, "Encountered different array ranks for concat: input[0].shape()=%ndShape, input[%s].shape()=%ndShape",
|
||||
toConcat[0], i, toConcat[i]);
|
||||
|
||||
if (toConcat[i].isCompressed())
|
||||
Nd4j.getCompressor().decompressi(toConcat[i]);
|
||||
|
||||
Preconditions.checkArgument(toConcat[i].dataType() == toConcat[0].dataType(), "All operands must have same data type: input 0 has type %s, input %s has type %s",
|
||||
toConcat[0].dataType(), i, toConcat[i].dataType());
|
||||
|
||||
allScalars &= toConcat[i].rank() == 0;
|
||||
|
||||
shapeInfoPointers.put(i, toConcat[i].shapeInfoDataBuffer().addressPointer());
|
||||
dataPointers.put(i, toConcat[i].data().addressPointer());
|
||||
sumAlongDim += toConcat[i].size(dimension);
|
||||
for (int j = 0; j < toConcat[i].rank(); j++) {
|
||||
|
||||
if (j != dimension && toConcat[i].size(j) != outputShape[j]) {
|
||||
throw new IllegalArgumentException(
|
||||
"Illegal concatenation at array " + i + " and shape element " + j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (allScalars) {
|
||||
outputShape = new long[]{sumAlongDim};
|
||||
} else {
|
||||
outputShape[dimension] = sumAlongDim;
|
||||
}
|
||||
|
||||
INDArray ret = Nd4j.createUninitialized(toConcat[0].dataType(), outputShape, Nd4j.order());
|
||||
|
||||
nativeOps.concat(null, dimension, toConcat.length,
|
||||
dataPointers, shapeInfoPointers,
|
||||
null, null,
|
||||
ret.data().addressPointer(), (LongPointer) ret.shapeInfoDataBuffer().addressPointer(),
|
||||
null, null,
|
||||
null, null);
|
||||
|
||||
return ret;
|
||||
*/
|
||||
}
|
||||
|
||||
|
||||
|
@ -757,6 +679,8 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory {
|
|||
(LongPointer) zTadShapeInfo,
|
||||
new LongPointerWrapper(zTadOffsets));
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
@ -794,6 +718,9 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory {
|
|||
arrays.length,
|
||||
len);
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
return target;
|
||||
}
|
||||
|
||||
|
@ -846,6 +773,9 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory {
|
|||
len,
|
||||
true);
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
return target;
|
||||
}
|
||||
|
||||
|
@ -983,6 +913,8 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory {
|
|||
arrays.size(),
|
||||
ptrMap, tadPointers, offsetPointers);
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
dataPointers.address();
|
||||
shapePointers.address();
|
||||
|
@ -990,84 +922,6 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory {
|
|||
offsetPointers.address();
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* This method converts Half-precision databuffer to current dType buffer.
|
||||
*
|
||||
* @param buffer
|
||||
* @return
|
||||
*/
|
||||
/*
|
||||
@Override
|
||||
public DataBuffer restoreFromHalfs(DataBuffer buffer) {
|
||||
if (buffer.dataType() != DataType.COMPRESSED)
|
||||
throw new IllegalStateException("DataBuffer contains wrong data: " + buffer.dataType());
|
||||
|
||||
CompressedDataBuffer comp = (CompressedDataBuffer) buffer;
|
||||
CompressionDescriptor descriptor = comp.getCompressionDescriptor();
|
||||
|
||||
DataBuffer targetBuffer = Nd4j.createBuffer(descriptor.getCompressedLength() / 2);
|
||||
|
||||
if (Nd4j.dataType() == DataType.DOUBLE) {
|
||||
nativeOps.convertHalfsToDoubles(
|
||||
null,
|
||||
comp.addressPointer(),
|
||||
(int) descriptor.getCompressedLength() / 2,
|
||||
targetBuffer.addressPointer()
|
||||
);
|
||||
} else if (Nd4j.dataType() == DataType.FLOAT) {
|
||||
nativeOps.convertHalfsToFloats(
|
||||
null,
|
||||
comp.addressPointer(),
|
||||
(int) descriptor.getCompressedLength() / 2,
|
||||
targetBuffer.addressPointer()
|
||||
);
|
||||
} else {
|
||||
throw new UnsupportedOperationException("Target dtype isn't supported: " + Nd4j.dataType());
|
||||
}
|
||||
|
||||
return targetBuffer;
|
||||
}
|
||||
*/
|
||||
|
||||
/**
|
||||
* This method converts Single/Double precision databuffer to Half-precision databuffer
|
||||
*
|
||||
* @param buffer
|
||||
* @return
|
||||
*/
|
||||
/*@Override
|
||||
public DataBuffer convertToHalfs(DataBuffer buffer) {
|
||||
// we allocate pointer
|
||||
ShortPointer pointer = new ShortPointer(buffer.length());
|
||||
|
||||
if (buffer.dataType() == DataType.DOUBLE) {
|
||||
nativeOps.convertDoublesToHalfs(
|
||||
null,
|
||||
buffer.addressPointer(),
|
||||
(int) buffer.length(),
|
||||
pointer
|
||||
);
|
||||
} else if (buffer.dataType() == DataType.FLOAT) {
|
||||
nativeOps.convertFloatsToHalfs(
|
||||
null,
|
||||
buffer.addressPointer(),
|
||||
(int) buffer.length(),
|
||||
pointer
|
||||
);
|
||||
} else {
|
||||
throw new UnsupportedOperationException("Source dtype isn't supported: " + buffer.dataType());
|
||||
}
|
||||
|
||||
CompressionDescriptor descriptor = new CompressionDescriptor(buffer, new Float16());
|
||||
descriptor.setCompressedLength(buffer.length() * 2);
|
||||
|
||||
|
||||
CompressedDataBuffer result = new CompressedDataBuffer(pointer, descriptor);
|
||||
return result;
|
||||
}
|
||||
*/
|
||||
|
||||
/**
|
||||
* This method converts Single/Double precision databuffer to Half-precision databuffer
|
||||
*
|
||||
|
@ -1081,6 +935,9 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory {
|
|||
throw new UnsupportedOperationException("Impossible to compress View. Consider using dup() before. ");
|
||||
|
||||
DataBuffer buffer = convertDataEx(typeSrc, source.data(), typeDst);
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
source.setData(buffer);
|
||||
|
||||
if (buffer instanceof CompressedDataBuffer)
|
||||
|
@ -1125,6 +982,9 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory {
|
|||
|
||||
convertDataEx(typeSrc, source, typeDst, buffer);
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
|
||||
return buffer;
|
||||
}
|
||||
|
||||
|
@ -1132,6 +992,9 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory {
|
|||
public void convertDataEx(DataTypeEx typeSrc, Pointer source, DataTypeEx typeDst, Pointer target,
|
||||
long length) {
|
||||
nativeOps.convertTypes(null, typeSrc.ordinal(), source, length, typeDst.ordinal(), target);
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -234,6 +234,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
null);
|
||||
}
|
||||
|
||||
if (loop.lastErrorCode() != 0)
|
||||
throw new RuntimeException(loop.lastErrorMessage());
|
||||
|
||||
profilingConfigurableHookOut(op, st);
|
||||
return op.z();
|
||||
}
|
||||
|
@ -563,6 +566,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
}
|
||||
}
|
||||
|
||||
if (loop.lastErrorCode() != 0)
|
||||
throw new RuntimeException(loop.lastErrorMessage());
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
@ -644,6 +650,8 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
if (loop.lastErrorCode() != 0)
|
||||
throw new RuntimeException(loop.lastErrorMessage());
|
||||
}
|
||||
|
||||
public INDArray exec(ScalarOp op) {
|
||||
|
@ -690,6 +698,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
throw new ND4JIllegalStateException("Unknown op type: [" + op.getOpType() +"]");
|
||||
}
|
||||
|
||||
if (loop.lastErrorCode() != 0)
|
||||
throw new RuntimeException(loop.lastErrorMessage());
|
||||
|
||||
profilingConfigurableHookOut(op, st);
|
||||
|
||||
return op.z();
|
||||
|
@ -886,6 +897,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
|
||||
}
|
||||
|
||||
if (loop.lastErrorCode() != 0)
|
||||
throw new RuntimeException(loop.lastErrorMessage());
|
||||
|
||||
profilingConfigurableHookOut(op, st);
|
||||
}
|
||||
|
||||
|
@ -962,6 +976,8 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
throw new UnsupportedOperationException("Unknown operation type: [" + op.getOpType() + "]");
|
||||
}
|
||||
|
||||
if (loop.lastErrorCode() != 0)
|
||||
throw new RuntimeException(loop.lastErrorMessage());
|
||||
|
||||
return op.z();
|
||||
}
|
||||
|
@ -1091,6 +1107,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
batch.getSample().maxIntArrays(), batch.getSample().maxIntArraySize(),
|
||||
batch.getSample().maxIndexArguments(), batch.getSample().maxRealArguments(), pointer, FlatBuffersMapper.getDataTypeAsByte(dataType));
|
||||
|
||||
if (loop.lastErrorCode() != 0)
|
||||
throw new RuntimeException(loop.lastErrorMessage());
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -1197,6 +1216,8 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
numIndexArguments, intArrays, numIntArrays, block.getRealArgumentsPointer(),
|
||||
numRealArguments, FlatBuffersMapper.getDataTypeAsByte(dataType));
|
||||
|
||||
if (loop.lastErrorCode() != 0)
|
||||
throw new RuntimeException(loop.lastErrorMessage());
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -1284,6 +1305,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
op.extraArgsDataBuff(op.z().dataType()).addressPointer());
|
||||
}
|
||||
|
||||
if (loop.lastErrorCode() != 0)
|
||||
throw new RuntimeException(loop.lastErrorMessage());
|
||||
|
||||
profilingConfigurableHookOut(op, st);
|
||||
|
||||
return op.z();
|
||||
|
@ -1370,6 +1394,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
(float) threshold);
|
||||
//long t2 = System.currentTimeMillis();
|
||||
|
||||
if (loop.lastErrorCode() != 0)
|
||||
throw new RuntimeException(loop.lastErrorMessage());
|
||||
|
||||
if (cntAbs < 2)
|
||||
return null;
|
||||
|
||||
|
@ -1429,6 +1456,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
|
||||
loop.convertTypes(null, DataTypeEx.THRESHOLD.ordinal(), buffer.addressPointer(), target.length(), typeDst.ordinal(), target.data().addressPointer());
|
||||
|
||||
if (loop.lastErrorCode() != 0)
|
||||
throw new RuntimeException(loop.lastErrorMessage());
|
||||
|
||||
return target;
|
||||
}
|
||||
|
||||
|
@ -1460,6 +1490,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
(IntPointer) buffer.addressPointer(),
|
||||
(float) threshold);
|
||||
|
||||
if (loop.lastErrorCode() != 0)
|
||||
throw new RuntimeException(loop.lastErrorMessage());
|
||||
|
||||
return affected;
|
||||
}
|
||||
|
||||
|
@ -1473,6 +1506,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
(LongPointer) target.shapeInfoDataBuffer().addressPointer()
|
||||
);
|
||||
|
||||
if (loop.lastErrorCode() != 0)
|
||||
throw new RuntimeException(loop.lastErrorMessage());
|
||||
|
||||
return target;
|
||||
}
|
||||
|
||||
|
@ -1673,136 +1709,6 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
} catch (Exception e) {
|
||||
throw new RuntimeException("Op [" + name + "] execution failed", e);
|
||||
}
|
||||
/*
|
||||
val name = op.opName().toLowerCase();
|
||||
val hash = op.opHash();
|
||||
|
||||
if (name.equals("noop")) {
|
||||
return op.outputArguments();
|
||||
}
|
||||
|
||||
val inputShapes = getInputShapes(op.numInputArguments());
|
||||
val inputBuffers = getInputBuffers(op.numInputArguments());
|
||||
|
||||
int cnt= 0;
|
||||
val inputArgs = op.inputArguments();
|
||||
for (val in: inputArgs) {
|
||||
if(in == null)
|
||||
throw new NullPointerException("Input argument is null for op " + op.getClass().getName());
|
||||
|
||||
if (!in.isEmpty())
|
||||
inputBuffers.put(cnt, in.data().addressPointer());
|
||||
|
||||
inputShapes.put(cnt++, in.shapeInfoDataBuffer().addressPointer());
|
||||
}
|
||||
|
||||
val outputArgs = op.outputArguments();
|
||||
for(int i = 0; i < outputArgs.length; i++) {
|
||||
if(outputArgs[i] == null)
|
||||
throw new ND4JIllegalStateException("Op output arguments must not be null! Op " + op.getClass().getName());
|
||||
}
|
||||
|
||||
|
||||
val outputShapes = getOutputShapes(op.numOutputArguments());
|
||||
val outputBuffers = getOutputBuffers(op.numOutputArguments());
|
||||
|
||||
cnt= 0;
|
||||
for (val out: outputArgs) {
|
||||
if(out.isEmpty()){
|
||||
outputBuffers.put(cnt, null);
|
||||
} else {
|
||||
outputBuffers.put(cnt, out.data().addressPointer());
|
||||
}
|
||||
outputShapes.put(cnt++, out.shapeInfoDataBuffer().addressPointer());
|
||||
}
|
||||
|
||||
val iArgs = op.numIArguments() > 0 ? getLongPointerFrom(iArgsPointer,op.numIArguments()) : null;
|
||||
val tArgs = op.numTArguments() > 0 ? getDoublePointerFrom(tArgsPointer,op.numTArguments()) : null;
|
||||
val bArgs = op.numBArguments() > 0 ? getBooleanPointerFrom(bArgsPointer,op.numBArguments()) : null;
|
||||
|
||||
cnt = 0;
|
||||
val iArgs1 = op.iArgs();
|
||||
for (val i: iArgs1)
|
||||
iArgs.put(cnt++, i);
|
||||
|
||||
cnt = 0;
|
||||
val bArgs1 = op.bArgs();
|
||||
for (val b: bArgs1)
|
||||
bArgs.put(cnt++, b);
|
||||
|
||||
cnt = 0;
|
||||
val tArgs1 = op.tArgs();
|
||||
for (val t: tArgs1)
|
||||
tArgs.put(cnt++, t);
|
||||
|
||||
val t = op.numInputArguments();
|
||||
|
||||
OpStatus status = OpStatus.ND4J_STATUS_OK;
|
||||
try {
|
||||
val code = loop.execCustomOp(
|
||||
null,
|
||||
hash,
|
||||
inputBuffers,
|
||||
inputShapes,
|
||||
op.numInputArguments(),
|
||||
outputBuffers,
|
||||
outputShapes,
|
||||
op.numOutputArguments(),
|
||||
tArgs, op.numTArguments(),
|
||||
iArgs, op.numIArguments(),
|
||||
bArgs, op.numBArguments(),
|
||||
op.isInplaceCall());
|
||||
|
||||
status = OpStatus.byNumber(code);
|
||||
|
||||
if (status != OpStatus.ND4J_STATUS_OK)
|
||||
throw new ND4JIllegalStateException("Failed to execute op [" + name + "] with error code [" + status +"]");
|
||||
}catch(Exception e) {
|
||||
val sb = new StringBuilder();
|
||||
sb.append("Inputs: [(");
|
||||
for( int i=0; i<inputArgs.length; i++ ){
|
||||
if(i > 0)
|
||||
sb.append("), (");
|
||||
sb.append(Shape.shapeToStringShort(inputArgs[i]));
|
||||
}
|
||||
sb.append(")]. Outputs: [(");
|
||||
for( int i=0; i<outputArgs.length; i++){
|
||||
if(i > 0)
|
||||
sb.append("), (");
|
||||
sb.append(Shape.shapeToStringShort(outputArgs[i]));
|
||||
}
|
||||
sb.append(")]. tArgs: ");
|
||||
if(op.numTArguments() > 0){
|
||||
sb.append(Arrays.toString(op.tArgs()));
|
||||
} else {
|
||||
sb.append("-");
|
||||
}
|
||||
sb.append(". iArgs: ");
|
||||
if(op.numIArguments() > 0){
|
||||
sb.append(Arrays.toString(op.iArgs()));
|
||||
} else {
|
||||
sb.append("-");
|
||||
}
|
||||
if(op instanceof DifferentialFunction){
|
||||
String n = ((DifferentialFunction) op).getOwnName();
|
||||
if(n != null && !n.equals(op.opName())){
|
||||
sb.append(". Op own name: \"").append(n).append("\"");
|
||||
}
|
||||
}
|
||||
log.error("Failed to execute op " + op.opName() + ". Attempted to execute with " +
|
||||
String.valueOf(op.numInputArguments()) + " inputs, " +
|
||||
String.valueOf(op.numOutputArguments()) + " outputs, "+
|
||||
String.valueOf(op.numTArguments()) + " targs and " +
|
||||
String.valueOf(op.numIArguments()) + " iargs. " +
|
||||
sb.toString() +
|
||||
" - Please see above message (printed out from c++) for a possible cause of error.");
|
||||
throw e;
|
||||
}
|
||||
|
||||
profilingConfigurableHookOut(op, st);
|
||||
|
||||
return op.outputArguments();
|
||||
*/
|
||||
}
|
||||
|
||||
protected LongShapeDescriptor getShapeFromPointer(LongPointer ptr) {
|
||||
|
@ -1870,6 +1776,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
ptrptr = loop.calculateOutputShapes2(null,
|
||||
hash, inputBuffers, inputShapes, op.numInputArguments(), tArgs,
|
||||
op.numTArguments(), iArgs, op.numIArguments(), bArgs, op.numBArguments());
|
||||
|
||||
if (loop.lastErrorCode() != 0)
|
||||
throw new RuntimeException(loop.lastErrorMessage());
|
||||
} catch (Throwable t){
|
||||
StringBuilder sb = new StringBuilder();
|
||||
sb.append("Inputs: [(");
|
||||
|
@ -1893,6 +1802,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
throw t;
|
||||
}
|
||||
|
||||
if (loop.lastErrorCode() != 0)
|
||||
throw new RuntimeException(loop.lastErrorMessage());
|
||||
|
||||
if (ptrptr == null)
|
||||
throw new RuntimeException();
|
||||
|
||||
|
@ -1929,6 +1841,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
@Override
|
||||
public void registerGraph(long id, Pointer graph) {
|
||||
loop.registerGraph(null, id, graph);
|
||||
|
||||
if (loop.lastErrorCode() != 0)
|
||||
throw new RuntimeException(loop.lastErrorMessage());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -1952,7 +1867,10 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
|
||||
val newMap = new LinkedHashMap<String, INDArray>();
|
||||
|
||||
OpaqueVariablesSet result = loop.executeStoredGraph(null, id, ptrBuffers, ptrShapes, ptrIndices, map.size());
|
||||
OpaqueVariablesSet result = loop.executeStoredGraph(null, id, ptrBuffers, ptrShapes, ptrIndices, map.size());
|
||||
|
||||
if (loop.lastErrorCode() != 0)
|
||||
throw new RuntimeException(loop.lastErrorMessage());
|
||||
|
||||
OpStatus status = OpStatus.byNumber(loop.getVariablesSetStatus(result));
|
||||
|
||||
|
@ -1996,6 +1914,8 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
@Override
|
||||
public void forgetGraph(long id) {
|
||||
loop.unregisterGraph(null, id);
|
||||
if (loop.lastErrorCode() != 0)
|
||||
throw new RuntimeException(loop.lastErrorMessage());
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -2055,6 +1975,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
array.data().addressPointer(), (LongPointer) tadX.getFirst().addressPointer(), (LongPointer) tadX.getSecond().addressPointer(), null, null, null,
|
||||
updates.data().addressPointer(), (LongPointer) tadY.getFirst().addressPointer(), (LongPointer) tadY.getSecond().addressPointer(), null, null, null,
|
||||
(IntPointer) indices.data().addressPointer(), null);
|
||||
|
||||
if (loop.lastErrorCode() != 0)
|
||||
throw new RuntimeException(loop.lastErrorMessage());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -2078,6 +2001,10 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
|
||||
|
||||
val status = loop.execCustomOp2(null, op.opHash(), context.contextPointer());
|
||||
|
||||
if (loop.lastErrorCode() != 0)
|
||||
throw new RuntimeException(loop.lastErrorMessage());
|
||||
|
||||
if (status != 0)
|
||||
throw new RuntimeException("Op [" + op.opName() + "] execution failed");
|
||||
|
||||
|
@ -2155,6 +2082,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
|
||||
loop.inspectArray(null, array.data().addressPointer(), (LongPointer) array.shapeInfoDataBuffer().addressPointer(), null, null, debugInfo);
|
||||
|
||||
if (loop.lastErrorCode() != 0)
|
||||
throw new RuntimeException(loop.lastErrorMessage());
|
||||
|
||||
return INDArrayStatistics.builder()
|
||||
.minValue(debugInfo._minValue())
|
||||
.maxValue(debugInfo._maxValue())
|
||||
|
@ -2171,6 +2101,8 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
@Override
|
||||
public DataBuffer createShapeInfo(long[] shape, long[] stride, long elementWiseStride, char order, DataType dtype, boolean empty) {
|
||||
OpaqueConstantDataBuffer dbf = loop.shapeBuffer(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, empty);
|
||||
if (loop.lastErrorCode() != 0)
|
||||
throw new RuntimeException(loop.lastErrorMessage());
|
||||
|
||||
val result = new LongBuffer(loop.getConstantDataBufferPrimary(dbf), Shape.shapeInfoLength(shape.length));
|
||||
|
||||
|
@ -2183,6 +2115,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
public TadPack tadShapeInfoAndOffsets(INDArray array, int[] dimension) {
|
||||
OpaqueTadPack pack = loop.tadOnlyShapeInfo((LongPointer) array.shapeInfoDataBuffer().addressPointer(), new IntPointer(dimension), dimension.length);
|
||||
|
||||
if (loop.lastErrorCode() != 0)
|
||||
throw new RuntimeException(loop.lastErrorMessage());
|
||||
|
||||
val tadShape = new LongBuffer(loop.getPrimaryShapeInfo(pack), loop.getShapeInfoLength(pack));
|
||||
val tadOffsets = new LongBuffer(loop.getPrimaryOffsets(pack), loop.getNumberOfTads(pack));
|
||||
|
||||
|
@ -2205,11 +2140,19 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
|
||||
@Override
|
||||
public String runLightBenchmarkSuit(boolean printOut) {
|
||||
return loop.runLightBenchmarkSuit(printOut);
|
||||
val s = loop.runLightBenchmarkSuit(printOut);
|
||||
if (loop.lastErrorCode() != 0)
|
||||
throw new RuntimeException(loop.lastErrorMessage());
|
||||
|
||||
return s;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String runFullBenchmarkSuit(boolean printOut) {
|
||||
return loop.runFullBenchmarkSuit(printOut);
|
||||
val s = loop.runFullBenchmarkSuit(printOut);
|
||||
if (loop.lastErrorCode() != 0)
|
||||
throw new RuntimeException(loop.lastErrorMessage());
|
||||
|
||||
return s;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -467,6 +467,60 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper {
|
|||
// #endif //DEV_TESTS_TADPACK_H
|
||||
|
||||
|
||||
// Parsed from execution/ErrorReference.h
|
||||
|
||||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
// #ifndef DEV_TESTS_ERRORREFERENCE_H
|
||||
// #define DEV_TESTS_ERRORREFERENCE_H
|
||||
|
||||
// #include <string>
|
||||
// #include <dll.h>
|
||||
@Namespace("sd") @NoOffset public static class ErrorReference extends Pointer {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public ErrorReference(Pointer p) { super(p); }
|
||||
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||
public ErrorReference(long size) { super((Pointer)null); allocateArray(size); }
|
||||
private native void allocateArray(long size);
|
||||
@Override public ErrorReference position(long position) {
|
||||
return (ErrorReference)super.position(position);
|
||||
}
|
||||
|
||||
public ErrorReference() { super((Pointer)null); allocate(); }
|
||||
private native void allocate();
|
||||
|
||||
public native int errorCode();
|
||||
public native @Cast("char*") String errorMessage();
|
||||
|
||||
public native void setErrorCode(int errorCode);
|
||||
public native void setErrorMessage(@StdString BytePointer message);
|
||||
public native void setErrorMessage(@StdString String message);
|
||||
}
|
||||
|
||||
|
||||
|
||||
// #endif //DEV_TESTS_ERRORREFERENCE_H
|
||||
|
||||
|
||||
// Parsed from Environment.h
|
||||
|
||||
/*******************************************************************************
|
||||
|
@ -688,6 +742,18 @@ bool verbose = false;
|
|||
// #include <graph/ResultWrapper.h>
|
||||
// #include <DebugInfo.h>
|
||||
|
||||
/**
|
||||
* This function returns last error code stored,
|
||||
* @return non-zero if something bad happened
|
||||
*/
|
||||
public native int lastErrorCode();
|
||||
|
||||
/**
|
||||
* This function returns last error message, if last error code > 0
|
||||
* @return
|
||||
*/
|
||||
public native @Cast("char*") String lastErrorMessage();
|
||||
|
||||
/**
|
||||
*
|
||||
* @param p
|
||||
|
@ -1710,72 +1776,6 @@ public native void execScalarBoolTad(@Cast("Nd4jPointer*") PointerPointer extraP
|
|||
@Cast("Nd4jLong*") long[] tadShapeInfo, @Cast("Nd4jLong*") long[] tadOffsets,
|
||||
@Cast("Nd4jLong*") long[] tadShapeInfoZ, @Cast("Nd4jLong*") long[] tadOffsetsZ);
|
||||
|
||||
|
||||
/**
|
||||
* Append an input array
|
||||
* to the end of a flat array
|
||||
* in a particular order
|
||||
* @param offset the offset of the array to start at
|
||||
* @param order the order
|
||||
* @param result the result array
|
||||
* @param resultShapeInfo the shape info for te array
|
||||
* @param input the input for the array
|
||||
* @param inputShapeInfo the shape information for that array
|
||||
*/
|
||||
public native void flatten(
|
||||
@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||
int offset,
|
||||
char order,
|
||||
Pointer result, @Cast("Nd4jLong*") LongPointer resultShapeInfo,
|
||||
Pointer dresult, @Cast("Nd4jLong*") LongPointer dresultShapeInfo,
|
||||
Pointer input, @Cast("Nd4jLong*") LongPointer inputShapeInfo,
|
||||
Pointer dinput, @Cast("Nd4jLong*") LongPointer dinputShapeInfo);
|
||||
public native void flatten(
|
||||
@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||
int offset,
|
||||
char order,
|
||||
Pointer result, @Cast("Nd4jLong*") LongBuffer resultShapeInfo,
|
||||
Pointer dresult, @Cast("Nd4jLong*") LongBuffer dresultShapeInfo,
|
||||
Pointer input, @Cast("Nd4jLong*") LongBuffer inputShapeInfo,
|
||||
Pointer dinput, @Cast("Nd4jLong*") LongBuffer dinputShapeInfo);
|
||||
public native void flatten(
|
||||
@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||
int offset,
|
||||
char order,
|
||||
Pointer result, @Cast("Nd4jLong*") long[] resultShapeInfo,
|
||||
Pointer dresult, @Cast("Nd4jLong*") long[] dresultShapeInfo,
|
||||
Pointer input, @Cast("Nd4jLong*") long[] inputShapeInfo,
|
||||
Pointer dinput, @Cast("Nd4jLong*") long[] dinputShapeInfo);
|
||||
|
||||
public native void concat(
|
||||
@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||
int dimension,
|
||||
int numArrays,
|
||||
@Cast("Nd4jPointer*") PointerPointer data, @Cast("Nd4jPointer*") PointerPointer inputShapeInfo,
|
||||
@Cast("Nd4jPointer*") PointerPointer ddata, @Cast("Nd4jPointer*") PointerPointer dinputShapeInfo,
|
||||
Pointer result, @Cast("Nd4jLong*") LongPointer resultShapeInfo,
|
||||
Pointer dresult, @Cast("Nd4jLong*") LongPointer dresultShapeInfo,
|
||||
@Cast("Nd4jPointer*") PointerPointer tadPointers, @Cast("Nd4jPointer*") PointerPointer offsetPointers);
|
||||
public native void concat(
|
||||
@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||
int dimension,
|
||||
int numArrays,
|
||||
@Cast("Nd4jPointer*") PointerPointer data, @Cast("Nd4jPointer*") PointerPointer inputShapeInfo,
|
||||
@Cast("Nd4jPointer*") PointerPointer ddata, @Cast("Nd4jPointer*") PointerPointer dinputShapeInfo,
|
||||
Pointer result, @Cast("Nd4jLong*") LongBuffer resultShapeInfo,
|
||||
Pointer dresult, @Cast("Nd4jLong*") LongBuffer dresultShapeInfo,
|
||||
@Cast("Nd4jPointer*") PointerPointer tadPointers, @Cast("Nd4jPointer*") PointerPointer offsetPointers);
|
||||
public native void concat(
|
||||
@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||
int dimension,
|
||||
int numArrays,
|
||||
@Cast("Nd4jPointer*") PointerPointer data, @Cast("Nd4jPointer*") PointerPointer inputShapeInfo,
|
||||
@Cast("Nd4jPointer*") PointerPointer ddata, @Cast("Nd4jPointer*") PointerPointer dinputShapeInfo,
|
||||
Pointer result, @Cast("Nd4jLong*") long[] resultShapeInfo,
|
||||
Pointer dresult, @Cast("Nd4jLong*") long[] dresultShapeInfo,
|
||||
@Cast("Nd4jPointer*") PointerPointer tadPointers, @Cast("Nd4jPointer*") PointerPointer offsetPointers);
|
||||
|
||||
|
||||
public native void specialConcat(
|
||||
@Cast("Nd4jPointer*") PointerPointer extraPointers,
|
||||
int dimension,
|
||||
|
@ -22877,6 +22877,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
|
||||
// #include <dll.h>
|
||||
// #include <pointercast.h>
|
||||
// #include <execution/ErrorReference.h>
|
||||
@Namespace("nd4j") @NoOffset public static class ContextBuffers extends Pointer {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
|
@ -22912,6 +22913,8 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
public native void setScalarBuffer(Pointer pointer);
|
||||
public native void setAllocationBuffer(Pointer pointer);
|
||||
|
||||
public native ErrorReference errorReference();
|
||||
|
||||
public native void triggerOwnership(@Cast("bool") boolean isOwner);
|
||||
|
||||
public native int deviceId();
|
||||
|
@ -22961,6 +22964,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
// #include <vector>
|
||||
// #include <mutex>
|
||||
// #include <execution/ContextBuffers.h>
|
||||
// #include <execution/ErrorReference.h>
|
||||
|
||||
@Namespace("nd4j") @NoOffset public static class LaunchContext extends Pointer {
|
||||
static { Loader.load(); }
|
||||
|
@ -22985,9 +22989,12 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
|
||||
public native int getDeviceID();
|
||||
public native void setDeviceID(int deviceID);
|
||||
public native ErrorReference errorReference();
|
||||
|
||||
public static native @Cast("bool") boolean isInitialized();
|
||||
public static native void releaseBuffers();
|
||||
|
||||
|
||||
public static native LaunchContext defaultContext();
|
||||
|
||||
|
||||
|
|
|
@ -38,6 +38,7 @@ import java.util.Scanner;
|
|||
"array/ConstantDataBuffer.h",
|
||||
"array/ConstantDescriptor.h",
|
||||
"array/TadPack.h",
|
||||
"execution/ErrorReference.h",
|
||||
"Environment.h",
|
||||
"types/utf8string.h",
|
||||
"NativeOps.h",
|
||||
|
|
|
@ -5216,6 +5216,8 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
|||
INDArray array = Nd4j.create(new double[] {10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
|
||||
INDArray exp = Nd4j.create(new double[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
|
||||
|
||||
log.info("Array shapeInfo: {}", array.shapeInfoJava());
|
||||
|
||||
INDArray rev = Nd4j.reverse(array);
|
||||
|
||||
assertEquals(exp, rev);
|
||||
|
@ -5226,7 +5228,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
|||
INDArray array = Nd4j.create(new double[] {9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
|
||||
INDArray exp = Nd4j.create(new double[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
|
||||
|
||||
INDArray rev = Nd4j.getExecutioner().exec(new Reverse(array, Nd4j.createUninitialized(array.length())))[0];
|
||||
INDArray rev = Nd4j.getExecutioner().exec(new Reverse(array, array.ulike()))[0];
|
||||
|
||||
assertEquals(exp, rev);
|
||||
}
|
||||
|
@ -5236,7 +5238,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
|||
INDArray array = Nd4j.create(new double[] {10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
|
||||
INDArray exp = Nd4j.create(new double[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
|
||||
|
||||
INDArray rev = Nd4j.getExecutioner().exec(new Reverse(array, Nd4j.createUninitialized(array.length())))[0];
|
||||
INDArray rev = Nd4j.getExecutioner().exec(new Reverse(array,array.ulike()))[0];
|
||||
|
||||
assertEquals(exp, rev);
|
||||
}
|
||||
|
@ -5335,11 +5337,103 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
|||
assertNotNull(lsd); //Fails here on CUDA, OK on native/cpu
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testReverseSmall_1() {
|
||||
val array = Nd4j.linspace(1, 10, 10, DataType.INT);
|
||||
val exp = array.dup(array.ordering());
|
||||
|
||||
Transforms.reverse(array, false);
|
||||
Transforms.reverse(array, false);
|
||||
|
||||
val jexp = exp.data().asInt();
|
||||
val jarr = array.data().asInt();
|
||||
assertArrayEquals(jexp, jarr);
|
||||
assertEquals(exp, array);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testReverseSmall_2() {
|
||||
val array = Nd4j.linspace(1, 10, 10, DataType.INT);
|
||||
val exp = array.dup(array.ordering());
|
||||
|
||||
val reversed = Transforms.reverse(array, true);
|
||||
val rereversed = Transforms.reverse(reversed, true);
|
||||
|
||||
val jexp = exp.data().asInt();
|
||||
val jarr = rereversed.data().asInt();
|
||||
assertArrayEquals(jexp, jarr);
|
||||
assertEquals(exp, rereversed);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testReverseSmall_3() {
|
||||
val array = Nd4j.linspace(1, 11, 11, DataType.INT);
|
||||
val exp = array.dup(array.ordering());
|
||||
|
||||
Transforms.reverse(array, false);
|
||||
|
||||
log.info("Reversed shapeInfo: {}", array.shapeInfoJava());
|
||||
log.info("Reversed: {}", array);
|
||||
|
||||
Transforms.reverse(array, false);
|
||||
|
||||
val jexp = exp.data().asInt();
|
||||
val jarr = array.data().asInt();
|
||||
assertArrayEquals(jexp, jarr);
|
||||
assertEquals(exp, array);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testReverseSmall_4() {
|
||||
val array = Nd4j.linspace(1, 11, 11, DataType.INT);
|
||||
val exp = array.dup(array.ordering());
|
||||
|
||||
val reversed = Transforms.reverse(array, true);
|
||||
|
||||
log.info("Reversed: {}", reversed);
|
||||
|
||||
val rereversed = Transforms.reverse(reversed, true);
|
||||
|
||||
val jexp = exp.data().asInt();
|
||||
val jarr = rereversed.data().asInt();
|
||||
assertArrayEquals(jexp, jarr);
|
||||
assertEquals(exp, rereversed);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testReverse_1() {
|
||||
val array = Nd4j.linspace(1, 2017152, 2017152, DataType.INT);
|
||||
val exp = array.dup(array.ordering());
|
||||
|
||||
Transforms.reverse(array, false);
|
||||
Transforms.reverse(array, false);
|
||||
|
||||
val jexp = exp.data().asInt();
|
||||
val jarr = array.data().asInt();
|
||||
assertArrayEquals(jexp, jarr);
|
||||
assertEquals(exp, array);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testReverse_2() {
|
||||
val array = Nd4j.linspace(1, 2017152, 2017152, DataType.INT);
|
||||
val exp = array.dup(array.ordering());
|
||||
|
||||
val reversed = Transforms.reverse(array, true);
|
||||
val rereversed = Transforms.reverse(reversed, true);
|
||||
|
||||
val jexp = exp.data().asInt();
|
||||
val jarr = rereversed.data().asInt();
|
||||
assertArrayEquals(jexp, jarr);
|
||||
assertEquals(exp, rereversed);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNativeSort3_1() {
|
||||
INDArray array = Nd4j.linspace(1, 2017152, 2017152, DataType.DOUBLE).reshape(1, -1);
|
||||
INDArray exp = array.dup();
|
||||
Transforms.reverse(array, false);
|
||||
log.info("Reverse: {}", array);
|
||||
|
||||
|
||||
long time1 = System.currentTimeMillis();
|
||||
|
|
Loading…
Reference in New Issue