From 9cf28ea6c9dbcb2333191c487ed159a9c743da10 Mon Sep 17 00:00:00 2001 From: raver119 Date: Mon, 15 Jul 2019 16:36:35 +0300 Subject: [PATCH] [WIP] CUDA tweaks (#60) * special cpu concat Signed-off-by: raver119 * special concat fix Signed-off-by: raver119 * OpProfiler tweak for absent host pointers Signed-off-by: raver119 * minor test tweak to see orders Signed-off-by: raver119 * CUDA broadcasting diff orders fix Signed-off-by: raver119 * faster iterations Signed-off-by: raver119 * OldSoftMax/OldLogSoftMax gone Signed-off-by: raver119 * RandomLauncher tweaks Signed-off-by: raver119 * additional check int randomtests Signed-off-by: raver119 * skip prepare/register action for empty arrays Signed-off-by: raver119 * npz float16 fix Signed-off-by: raver119 * empty reduction cuda fixes Signed-off-by: raver119 * ShapeBufferTests tweaks Signed-off-by: raver119 --- libnd4j/include/helpers/RandomLauncher.h | 21 ++-- .../include/helpers/impl/RandomLauncher.cpp | 61 ++++++---- libnd4j/include/loops/cuda/broadcasting.cu | 11 +- .../declarable/generic/random/bernoulli.cpp | 2 +- .../declarable/generic/random/exponential.cpp | 2 +- .../ops/declarable/generic/random/normal.cpp | 2 +- .../ops/declarable/generic/random/uniform.cpp | 2 +- .../ops/declarable/helpers/cpu/transforms.cpp | 66 +---------- .../ops/declarable/impl/LegacyRandomOp.cpp | 16 +-- libnd4j/include/ops/impl/specials.cpp | 86 ++++++++++++-- libnd4j/include/ops/specials.h | 3 + .../converters/ImportClassMapping.java | 2 - .../nd4j/linalg/activations/Activation.java | 106 +----------------- .../activations/impl/ActivationSoftmax.java | 7 +- .../impl/transforms/custom/LogSoftMax.java | 5 - .../ops/impl/transforms/custom/SoftMax.java | 16 ++- .../impl/transforms/strict/OldLogSoftMax.java | 87 -------------- .../impl/transforms/strict/OldSoftMax.java | 94 ---------------- .../transforms/strict/SoftMaxDerivative.java | 17 ++- .../java/org/nd4j/linalg/api/shape/Shape.java | 79 ++++++------- .../lossfunctions/impl/LossBinaryXENT.java | 4 +- .../impl/LossMixtureDensity.java | 5 +- .../linalg/ops/transforms/Transforms.java | 4 +- .../org/nd4j/linalg/profiler/OpProfiler.java | 4 +- .../nativeblas/BaseNativeNDArrayFactory.java | 1 + .../flow/impl/SynchronousFlowController.java | 53 ++++----- .../jcublas/buffer/BaseCudaDataBuffer.java | 5 + .../ops/executioner/CudaExecutioner.java | 21 +++- .../cpu/nativecpu/CpuNDArrayFactory.java | 7 -- .../opvalidation/ReductionOpValidation.java | 11 +- .../opvalidation/TransformOpValidation.java | 3 +- .../test/java/org/nd4j/linalg/LoneTest.java | 4 +- .../test/java/org/nd4j/linalg/Nd4jTestsC.java | 16 ++- .../java/org/nd4j/linalg/crash/CrashTest.java | 10 +- .../linalg/mixed/MixedDataTypesTests.java | 5 +- .../org/nd4j/linalg/ops/DerivativeTests.java | 28 ++--- .../nd4j/linalg/ops/OpExecutionerTests.java | 29 ++--- .../nd4j/linalg/ops/OpExecutionerTestsC.java | 33 +++--- .../profiling/OperationProfilerTests.java | 12 +- .../java/org/nd4j/linalg/rng/RandomTests.java | 6 +- .../nd4j/linalg/shape/ShapeBufferTests.java | 31 ++--- .../linalg/api/buffer/BaseDataBuffer.java | 5 + .../nd4j/linalg/api/buffer/DataBuffer.java | 8 ++ 43 files changed, 391 insertions(+), 599 deletions(-) delete mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/OldLogSoftMax.java delete mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/OldSoftMax.java diff --git a/libnd4j/include/helpers/RandomLauncher.h b/libnd4j/include/helpers/RandomLauncher.h index 6e9f13042..24921dc21 100644 --- a/libnd4j/include/helpers/RandomLauncher.h +++ b/libnd4j/include/helpers/RandomLauncher.h @@ -21,26 +21,27 @@ #include #include #include +#include namespace nd4j { class RandomLauncher { public: - static void applyDropOut(nd4j::graph::RandomGenerator& rng, NDArray *array, double retainProb, NDArray* z = nullptr); - static void applyInvertedDropOut(nd4j::graph::RandomGenerator& rng, NDArray *array, double retainProb, NDArray* z = nullptr); - static void applyAlphaDropOut(nd4j::graph::RandomGenerator& rng, NDArray *array, double retainProb, double alpha, double beta, double alphaPrime, NDArray* z = nullptr); + static void applyDropOut(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray *array, double retainProb, NDArray* z = nullptr); + static void applyInvertedDropOut(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray *array, double retainProb, NDArray* z = nullptr); + static void applyAlphaDropOut(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray *array, double retainProb, double alpha, double beta, double alphaPrime, NDArray* z = nullptr); - static void fillUniform(nd4j::graph::RandomGenerator& rng, NDArray* array, double from, double to); + static void fillUniform(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray* array, double from, double to); - static void fillGaussian(nd4j::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev); + static void fillGaussian(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev); - static void fillExponential(nd4j::graph::RandomGenerator& rng, NDArray* array, double lambda); + static void fillExponential(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray* array, double lambda); - static void fillLogNormal(nd4j::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev); + static void fillLogNormal(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev); - static void fillTruncatedNormal(nd4j::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev); + static void fillTruncatedNormal(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev); - static void fillBinomial(nd4j::graph::RandomGenerator& rng, NDArray* array, int trials, double prob); + static void fillBinomial(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray* array, int trials, double prob); - static void fillBernoulli(nd4j::graph::RandomGenerator& rng, NDArray* array, double prob); + static void fillBernoulli(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray* array, double prob); }; } \ No newline at end of file diff --git a/libnd4j/include/helpers/impl/RandomLauncher.cpp b/libnd4j/include/helpers/impl/RandomLauncher.cpp index 99294784d..a3fc86020 100644 --- a/libnd4j/include/helpers/impl/RandomLauncher.cpp +++ b/libnd4j/include/helpers/impl/RandomLauncher.cpp @@ -23,76 +23,97 @@ #include #include #include +#include namespace nd4j { // FIXME: implement this - void RandomLauncher::applyDropOut(nd4j::graph::RandomGenerator& rng, NDArray *array, double retainProb, NDArray* z) { + void RandomLauncher::applyDropOut(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray *array, double retainProb, NDArray* z) { if (z == nullptr) z = array; ExtraArguments arguments({retainProb}); + PointersManager pm(context, "applyDropOut"); - NativeOpExecutioner::execRandom(nullptr, random::DropOut, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), arguments.argumentsAsT(z->dataType())); + NativeOpExecutioner::execRandom(context, random::DropOut, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), arguments.argumentsAsT(z->dataType())); + pm.synchronize(); } - void RandomLauncher::applyInvertedDropOut(nd4j::graph::RandomGenerator& rng, NDArray *array, double retainProb, NDArray* z) { + void RandomLauncher::applyInvertedDropOut(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray *array, double retainProb, NDArray* z) { if (z == nullptr) z = array; ExtraArguments arguments({retainProb}); + PointersManager pm(context, "applyInvertedDropOut"); - NativeOpExecutioner::execRandom(nullptr, random::DropOutInverted, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), arguments.argumentsAsT(z->dataType())); + NativeOpExecutioner::execRandom(context, random::DropOutInverted, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), arguments.argumentsAsT(z->dataType())); + pm.synchronize(); } - void RandomLauncher::applyAlphaDropOut(nd4j::graph::RandomGenerator& rng, NDArray *array, double retainProb, double alpha, double beta, double alphaPrime, NDArray* z) { + void RandomLauncher::applyAlphaDropOut(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray *array, double retainProb, double alpha, double beta, double alphaPrime, NDArray* z) { if (z == nullptr) z = array; ExtraArguments arguments({retainProb, alpha, beta, alphaPrime}); + PointersManager pm(context, "applyAlphaDropOut"); - NativeOpExecutioner::execRandom(nullptr, random::AlphaDropOut, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), arguments.argumentsAsT(z->dataType())); + NativeOpExecutioner::execRandom(context, random::AlphaDropOut, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), arguments.argumentsAsT(z->dataType())); + pm.synchronize(); } - void RandomLauncher::fillBernoulli(nd4j::graph::RandomGenerator& rng, NDArray* array, double prob) { + void RandomLauncher::fillBernoulli(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray* array, double prob) { ExtraArguments arguments({prob}); + PointersManager pm(context, "fillBernoulli"); - NativeOpExecutioner::execRandom(nullptr, random::BernoulliDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); + NativeOpExecutioner::execRandom(context, random::BernoulliDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); + pm.synchronize(); } - void RandomLauncher::fillUniform(nd4j::graph::RandomGenerator& rng, NDArray* array, double from, double to) { + void RandomLauncher::fillUniform(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray* array, double from, double to) { ExtraArguments arguments({from, to}); + PointersManager pm(context, "fillUniform"); - NativeOpExecutioner::execRandom(nullptr, random::UniformDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); + NativeOpExecutioner::execRandom(context, random::UniformDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); + pm.synchronize(); } - void RandomLauncher::fillGaussian(nd4j::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev) { + void RandomLauncher::fillGaussian(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev) { ExtraArguments arguments({mean, stdev}); + PointersManager pm(context, "fillGaussian"); - NativeOpExecutioner::execRandom(nullptr, random::GaussianDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); + NativeOpExecutioner::execRandom(context, random::GaussianDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); + pm.synchronize(); } - void RandomLauncher::fillExponential(nd4j::graph::RandomGenerator& rng, NDArray* array, double lambda) { + void RandomLauncher::fillExponential(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray* array, double lambda) { ExtraArguments arguments({lambda}); + PointersManager pm(context, "fillExponential"); - NativeOpExecutioner::execRandom(nullptr, random::ExponentialDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); + NativeOpExecutioner::execRandom(context, random::ExponentialDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); + pm.synchronize(); } - void RandomLauncher::fillLogNormal(nd4j::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev) { + void RandomLauncher::fillLogNormal(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev) { ExtraArguments arguments({mean, stdev}); + PointersManager pm(context, "fillLogNormal"); - NativeOpExecutioner::execRandom(nullptr, random::GaussianDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); + NativeOpExecutioner::execRandom(context, random::GaussianDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); + pm.synchronize(); } - void RandomLauncher::fillTruncatedNormal(nd4j::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev) { + void RandomLauncher::fillTruncatedNormal(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev) { ExtraArguments arguments({mean, stdev}); + PointersManager pm(context, "fillTruncatedNormal"); - NativeOpExecutioner::execRandom(nullptr, random::TruncatedNormalDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); + NativeOpExecutioner::execRandom(context, random::TruncatedNormalDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); + pm.synchronize(); } - void RandomLauncher::fillBinomial(nd4j::graph::RandomGenerator& rng, NDArray* array, int trials, double prob) { + void RandomLauncher::fillBinomial(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray* array, int trials, double prob) { ExtraArguments arguments({(double) trials, prob}); + PointersManager pm(context, "fillBinomial"); - NativeOpExecutioner::execRandom(nullptr, random::BinomialDistributionEx, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); + NativeOpExecutioner::execRandom(context, random::BinomialDistributionEx, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); + pm.synchronize(); } } diff --git a/libnd4j/include/loops/cuda/broadcasting.cu b/libnd4j/include/loops/cuda/broadcasting.cu index b61f4f019..b6a7b1830 100644 --- a/libnd4j/include/loops/cuda/broadcasting.cu +++ b/libnd4j/include/loops/cuda/broadcasting.cu @@ -128,6 +128,10 @@ namespace functions { } __syncthreads(); + auto xOrder = shape::order(xShapeInfo); + auto yOrder = shape::order(tadOnlyShapeInfo); + auto zOrder = shape::order(tadOnlyShapeInfoZ); + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { @@ -135,7 +139,7 @@ namespace functions { auto rZ = z + tadOffsetsZ[r]; - if(tadEWS > 0 && zEWS > 0 && xEWS > 0 && dimensionLength == 1) { + if(tadEWS > 0 && zEWS > 0 && xEWS > 0 && dimensionLength == 1 && xOrder == yOrder && xOrder == zOrder) { for (int i = threadIdx.x; i < tadLength; i+= blockDim.x) rZ[i * zEWS] = OpType::op(x[i * xEWS], rY[i * tadEWS]); } @@ -190,6 +194,9 @@ namespace functions { } __syncthreads(); + auto xOrder = shape::order(tadOnlyShapeInfo); + auto yOrder = shape::order(yShapeInfo); + auto zOrder = shape::order(tadOnlyShapeInfoZ); for (int r = blockIdx.x; r < numTads; r += gridDim.x) { @@ -197,7 +204,7 @@ namespace functions { auto rZ = z + tadOffsetsZ[r]; - if(tadEWS > 0 && zEWS > 0 && yEWS > 0) { + if(tadEWS > 0 && zEWS > 0 && yEWS > 0 && xOrder == yOrder && xOrder == zOrder) { for (int i = threadIdx.x; i < tadLength; i+= blockDim.x) rZ[i * zEWS] = OpType::op(rX[i * tadEWS], y[i * yEWS]); } diff --git a/libnd4j/include/ops/declarable/generic/random/bernoulli.cpp b/libnd4j/include/ops/declarable/generic/random/bernoulli.cpp index 06a95128c..338760921 100644 --- a/libnd4j/include/ops/declarable/generic/random/bernoulli.cpp +++ b/libnd4j/include/ops/declarable/generic/random/bernoulli.cpp @@ -44,7 +44,7 @@ namespace nd4j { auto z = OUTPUT_VARIABLE(0); auto f = T_ARG(0); - RandomLauncher::fillBernoulli(rng, z, f); + RandomLauncher::fillBernoulli(block.launchContext(), rng, z, f); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/random/exponential.cpp b/libnd4j/include/ops/declarable/generic/random/exponential.cpp index 44a1014b7..bc942fc9b 100644 --- a/libnd4j/include/ops/declarable/generic/random/exponential.cpp +++ b/libnd4j/include/ops/declarable/generic/random/exponential.cpp @@ -53,7 +53,7 @@ namespace nd4j { auto z = OUTPUT_VARIABLE(0); auto lambda = T_ARG(0); - RandomLauncher::fillExponential(rng, z, lambda); + RandomLauncher::fillExponential(block.launchContext(), rng, z, lambda); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/random/normal.cpp b/libnd4j/include/ops/declarable/generic/random/normal.cpp index 64c3a0d14..781d495f0 100644 --- a/libnd4j/include/ops/declarable/generic/random/normal.cpp +++ b/libnd4j/include/ops/declarable/generic/random/normal.cpp @@ -39,7 +39,7 @@ namespace nd4j { functions::random::RandomFunction::template execTransform>(block.getRNG(), z->getBuffer(), z->getShapeInfo(), z->getBuffer(), z->getShapeInfo(), z->getBuffer(), z->getShapeInfo(), block.getTArguments()->data()); */ - RandomLauncher::fillGaussian(rng, OUTPUT_VARIABLE(0), T_ARG(0), T_ARG(1)); + RandomLauncher::fillGaussian(block.launchContext(), rng, OUTPUT_VARIABLE(0), T_ARG(0), T_ARG(1)); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/random/uniform.cpp b/libnd4j/include/ops/declarable/generic/random/uniform.cpp index 32f1cb1b2..b678852fb 100644 --- a/libnd4j/include/ops/declarable/generic/random/uniform.cpp +++ b/libnd4j/include/ops/declarable/generic/random/uniform.cpp @@ -53,7 +53,7 @@ namespace nd4j { */ REQUIRE_TRUE(block.numT() > 1, 0, "RandomUniform: to/from must be set"); - RandomLauncher::fillUniform(rng, OUTPUT_VARIABLE(0), T_ARG(0), T_ARG(1)); + RandomLauncher::fillUniform(block.launchContext(), rng, OUTPUT_VARIABLE(0), T_ARG(0), T_ARG(1)); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/transforms.cpp b/libnd4j/include/ops/declarable/helpers/cpu/transforms.cpp index ef0ffec5f..5a132a5c6 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/transforms.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/transforms.cpp @@ -1203,71 +1203,7 @@ static void mirrorPad_(const NDArray& input, const NDArray& paddings, NDArray& o ////////////////////////////////////////////////////////////////////////// template static void concat_(const std::vector& inArrs, NDArray& output, const int axis) { - - const uint numOfArrs = inArrs.size(); - - int outDim; - const bool isOutputVector = output.isCommonVector(outDim); - - if(isOutputVector || (axis == 0 && output.ordering() == 'c')) { - - bool allVectorsOrScalars = true; - const uint outEws = isOutputVector ? output.stridesOf()[outDim] : output.ews(); - - std::vector nonUnityDim(numOfArrs); - std::vector zOffset(numOfArrs); - - for(int i = 0; i < numOfArrs; i++) { - allVectorsOrScalars &= (inArrs[i]->lengthOf() == 1 || inArrs[i]->isCommonVector(nonUnityDim[i])); - if(!allVectorsOrScalars) - break; - if(i == 0) zOffset[0] = 0; - else zOffset[i] = zOffset[i - 1] + outEws * inArrs[i - 1]->lengthOf(); - } - - if(allVectorsOrScalars) { - - T* outBuff = output.bufferAsT(); - - PRAGMA_OMP_PARALLEL_FOR_SIMD - for (uint r = 0; r < numOfArrs; r++) { - - const uint arrLen = inArrs[r]->lengthOf(); - const uint xEws = (arrLen == 1) ? 1 : inArrs[r]->stridesOf()[nonUnityDim[r]]; - - T *z = outBuff + zOffset[r]; - T *x = inArrs[r]->bufferAsT(); - - if(outEws == 1 && xEws == 1) - for (uint e = 0; e < arrLen; e++) - z[e] = x[e]; - else - for (uint e = 0; e < arrLen; e++) - z[e * outEws] = x[e * xEws]; - } - return; - } - } - - const int rank = inArrs[0]->rankOf(); - const int rank2 = 2*rank; - std::vector> indices(numOfArrs, std::vector(rank2,0)); - - // take into account indices for first array - indices[0][2 * axis + 1] = inArrs[0]->sizeAt(axis); - - // loop through the rest of input arrays - for(int i = 1; i < numOfArrs; ++i) { - indices[i][2 * axis] = indices[i-1][2 * axis + 1]; // index start from - indices[i][2 * axis + 1] = indices[i-1][2 * axis + 1] + inArrs[i]->sizeAt(axis); // index end with (excluding) - } - - PRAGMA_OMP_PARALLEL_FOR_SIMD - for(int i = 0; i < numOfArrs; ++i) { - auto temp = output(indices[i], true); - nd4j::TransformLoops::template loopTransform, false>(inArrs[i]->bufferAsT(), inArrs[i]->getShapeInfo(), temp.bufferAsT(), temp.getShapeInfo(), nullptr); - // temp.assign(inArrs[i]); - } + nd4j::SpecialMethods::concatCpuGeneric(inArrs, output, axis); } void concat(nd4j::LaunchContext * context, const std::vector& inArrs, NDArray& output, const int axis) { diff --git a/libnd4j/include/ops/declarable/impl/LegacyRandomOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyRandomOp.cpp index c3e7dcec5..1ce00f44a 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyRandomOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyRandomOp.cpp @@ -81,7 +81,7 @@ namespace nd4j { auto z = OUTPUT_VARIABLE(0); //NDArrayFactory::create_('c', shape, block.getWorkspace()); - RandomLauncher::fillUniform(block.randomGenerator(), z, from, to); + RandomLauncher::fillUniform(block.launchContext(), block.randomGenerator(), z, from, to); // FIXME: //OVERWRITE_RESULT(z); @@ -105,7 +105,7 @@ namespace nd4j { if (!block.isInplace()) z->assign(input); - RandomLauncher::applyDropOut(block.randomGenerator(), z, prob); + RandomLauncher::applyDropOut(block.launchContext(), block.randomGenerator(), z, prob); } break; case nd4j::random::DropOutInverted: { @@ -140,7 +140,7 @@ namespace nd4j { auto z = OUTPUT_VARIABLE(0);//NDArrayFactory::create_('c', shape, block.getWorkspace()); - RandomLauncher::fillGaussian(block.randomGenerator(), z, mean, stdev); + RandomLauncher::fillGaussian(block.launchContext(), block.randomGenerator(), z, mean, stdev); // FIXME: !! //OVERWRITE_RESULT(z); @@ -168,7 +168,7 @@ namespace nd4j { auto z = OUTPUT_VARIABLE(0); // NDArrayFactory::create_('c', shape, block.getWorkspace()); - RandomLauncher::fillBernoulli(block.randomGenerator(), z, prob); + RandomLauncher::fillBernoulli(block.launchContext(), block.randomGenerator(), z, prob); // FIXME: //OVERWRITE_RESULT(z); @@ -201,7 +201,7 @@ namespace nd4j { auto z = OUTPUT_VARIABLE(0);//NDArrayFactory::create_('c', shape, block.getWorkspace()); - RandomLauncher::fillBinomial(block.randomGenerator(), z, trials, prob); + RandomLauncher::fillBinomial(block.launchContext(), block.randomGenerator(), z, trials, prob); // FIXME: !!! //OVERWRITE_RESULT(z); @@ -233,7 +233,7 @@ namespace nd4j { auto z = OUTPUT_VARIABLE(0);//NDArrayFactory::create_('c', shape, block.getWorkspace()); - RandomLauncher::fillLogNormal(block.randomGenerator(), z, mean, stdev); + RandomLauncher::fillLogNormal(block.launchContext(), block.randomGenerator(), z, mean, stdev); // FIXME: !! //OVERWRITE_RESULT(z); @@ -265,7 +265,7 @@ namespace nd4j { auto z = OUTPUT_VARIABLE(0); // NDArrayFactory::create_('c', shape, block.getWorkspace()); - RandomLauncher::fillTruncatedNormal(block.randomGenerator(), z, mean, stdev); + RandomLauncher::fillTruncatedNormal(block.launchContext(), block.randomGenerator(), z, mean, stdev); // FIXME: !!! //OVERWRITE_RESULT(z); @@ -301,7 +301,7 @@ namespace nd4j { if (!block.isInplace()) z->assign(input); - RandomLauncher::applyAlphaDropOut(block.randomGenerator(), z, prob, a, b, pa); + RandomLauncher::applyAlphaDropOut(block.launchContext(), block.randomGenerator(), z, prob, a, b, pa); } break; case nd4j::random::Linspace: { diff --git a/libnd4j/include/ops/impl/specials.cpp b/libnd4j/include/ops/impl/specials.cpp index 3a07ba1e2..8c7c04f60 100644 --- a/libnd4j/include/ops/impl/specials.cpp +++ b/libnd4j/include/ops/impl/specials.cpp @@ -28,9 +28,81 @@ #include #include #include +#include namespace nd4j { +/** +* Concatneate multi array of the same shape together +* along a particular dimension +*/ +template +void SpecialMethods::concatCpuGeneric(const std::vector& inArrs, NDArray& output, const int axis) { + const uint numOfArrs = inArrs.size(); + + int outDim; + const bool isOutputVector = output.isCommonVector(outDim); + + if(isOutputVector || (axis == 0 && output.ordering() == 'c')) { + + bool allVectorsOrScalars = true; + const uint outEws = isOutputVector ? output.stridesOf()[outDim] : output.ews(); + + std::vector nonUnityDim(numOfArrs); + std::vector zOffset(numOfArrs); + + for(int i = 0; i < numOfArrs; i++) { + allVectorsOrScalars &= (inArrs[i]->lengthOf() == 1 || inArrs[i]->isCommonVector(nonUnityDim[i])); + if(!allVectorsOrScalars) + break; + if(i == 0) zOffset[0] = 0; + else zOffset[i] = zOffset[i - 1] + outEws * inArrs[i - 1]->lengthOf(); + } + + if(allVectorsOrScalars) { + + T* outBuff = output.bufferAsT(); + + PRAGMA_OMP_PARALLEL_FOR_SIMD + for (uint r = 0; r < numOfArrs; r++) { + + const uint arrLen = inArrs[r]->lengthOf(); + const uint xEws = (arrLen == 1) ? 1 : inArrs[r]->stridesOf()[nonUnityDim[r]]; + + T *z = outBuff + zOffset[r]; + T *x = inArrs[r]->bufferAsT(); + + if(outEws == 1 && xEws == 1) + for (uint e = 0; e < arrLen; e++) + z[e] = x[e]; + else + for (uint e = 0; e < arrLen; e++) + z[e * outEws] = x[e * xEws]; + } + return; + } + } + + const int rank = inArrs[0]->rankOf(); + const int rank2 = 2*rank; + std::vector> indices(numOfArrs, std::vector(rank2,0)); + + // take into account indices for first array + indices[0][2 * axis + 1] = inArrs[0]->sizeAt(axis); + + // loop through the rest of input arrays + for(int i = 1; i < numOfArrs; ++i) { + indices[i][2 * axis] = indices[i-1][2 * axis + 1]; // index start from + indices[i][2 * axis + 1] = indices[i-1][2 * axis + 1] + inArrs[i]->sizeAt(axis); // index end with (excluding) + } + + PRAGMA_OMP_PARALLEL_FOR_SIMD + for(int i = 0; i < numOfArrs; ++i) { + auto temp = output(indices[i], true); + nd4j::TransformLoops::template loopTransform, false>(inArrs[i]->bufferAsT(), inArrs[i]->getShapeInfo(), temp.bufferAsT(), temp.getShapeInfo(), nullptr); + } +} + /** * Concatneate multi array of the same shape together * along a particular dimension @@ -38,24 +110,14 @@ namespace nd4j { template void SpecialMethods::concatCpuGeneric(int dimension, int numArrays, Nd4jPointer *data, Nd4jPointer *inputShapeInfo, void *vresult, Nd4jLong *resultShapeInfo) { auto result = reinterpret_cast(vresult); - - std::vector iArgs = {dimension}; - std::vector tArgs; - std::vector bArgsEmpty; std::vector inputs(numArrays); - std::vector outputs(1); - outputs[0] = new NDArray(static_cast(result), static_cast(resultShapeInfo)); + NDArray output(static_cast(result), static_cast(resultShapeInfo)); for(int i = 0; i < numArrays; ++i) inputs[i] = new NDArray(static_cast(data[i]), static_cast(inputShapeInfo[i])); - nd4j::ops::concat op; - auto status = op.execute(inputs, outputs, tArgs, iArgs, bArgsEmpty); - if(status != Status::OK()) - throw std::runtime_error("concatCpuGeneric fails to be executed !"); - - delete outputs[0]; + nd4j::SpecialMethods::concatCpuGeneric(inputs, output, dimension); for(int i = 0; i < numArrays; ++i) delete inputs[i]; diff --git a/libnd4j/include/ops/specials.h b/libnd4j/include/ops/specials.h index 4d2c384fa..6919aa38d 100644 --- a/libnd4j/include/ops/specials.h +++ b/libnd4j/include/ops/specials.h @@ -30,6 +30,8 @@ #include namespace nd4j { + class NDArray; + //FIXME: get rid of this redefinition typedef union { @@ -47,6 +49,7 @@ namespace nd4j { template class ND4J_EXPORT SpecialMethods { public: + static void concatCpuGeneric(const std::vector& inArrs, NDArray& output, const int axis); static void concatCpuGeneric(int dimension, int numArrays, Nd4jPointer *data, Nd4jPointer *inputShapeInfo, void *result, Nd4jLong *resultShapeInfo); static void accumulateGeneric(void **x, void *z, Nd4jLong *zShapeInfo, int n, const Nd4jLong length); static void averageGeneric(void **x, void *z, Nd4jLong *zShapeInfo, int n, const Nd4jLong length, bool propagate); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index f1921cfc7..5f45a4a8a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -541,8 +541,6 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.transforms.strict.Log.class, org.nd4j.linalg.api.ops.impl.transforms.strict.Log1p.class, org.nd4j.linalg.api.ops.impl.transforms.strict.LogSigmoid.class, - org.nd4j.linalg.api.ops.impl.transforms.strict.OldLogSoftMax.class, - org.nd4j.linalg.api.ops.impl.transforms.strict.OldSoftMax.class, org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELU.class, org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELUDerivative.class, org.nd4j.linalg.api.ops.impl.transforms.strict.RationalTanh.class, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/Activation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/Activation.java index 53c1531ea..145baf6e0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/Activation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/Activation.java @@ -23,6 +23,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU; import org.nd4j.linalg.api.ops.impl.scalar.ScalarSet; +import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; import org.nd4j.linalg.api.ops.impl.transforms.same.OldIdentity; import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear; import org.nd4j.linalg.api.ops.impl.scalar.Step; @@ -161,109 +162,4 @@ public enum Activation { throw new UnsupportedOperationException("Activation function not yet supported: " + this); } } - - /** - * Get the Activation function as an ND4J Transform, applied on either the input or a copy of the input - * - * @param in Input to apply the activation function op to - * @param dup If true: duplicate the array before applying the transform. If false: don't duplicate - * @return The transform op (execute using {@code Nd4j.getExecutioner().exec(op)} - */ - public Op asTransform(INDArray in, boolean dup) { - if (dup) { - in = in.dup(); - } - switch (this) { - case CUBE: - return new Cube(in); - case ELU: - return new ELU(in); - case HARDSIGMOID: - return new HardSigmoid(in); - case HARDTANH: - return new HardTanh(in); - case IDENTITY: - return new OldIdentity(in); - case LEAKYRELU: - return new LeakyReLU(in); - case RATIONALTANH: - return new RationalTanh(in); - case RELU: - return new RectifiedLinear(in); - case SIGMOID: - return new Sigmoid(in); - case SOFTMAX: - return new OldSoftMax(in); - case SOFTPLUS: - return new SoftPlus(in); - case SOFTSIGN: - return new SoftSign(in); - case TANH: - return new Tanh(in); - case RECTIFIEDTANH: - return new RectifiedTanh(in); - case SELU: - return new SELU(in); - case SWISH: - return new Swish(in); - case GELU: - return new GELU(in); - case RRELU: - default: - throw new UnsupportedOperationException("Not supported via this method: " + this); - } - } - - /** - * Get the Activation function derivative (i.e., dOut/dIn) as an ND4J Transform, applied on either the input - * or a copy of the input - * - * @param in Input to apply the activation function derivative op to - * @param dup If true: duplicate the array before applying the transform. If false: don't duplicate - * @return The op (execute using {@code Nd4j.getExecutioner().exec(op)} - */ - public Op asTransformDerivative(INDArray in, boolean dup) { - if (dup) { - in = in.dup(); - } - switch (this) { - case CUBE: - return new CubeDerivative(in); - case ELU: - return new ELUDerivative(in); - case HARDSIGMOID: - return new HardSigmoidDerivative(in); - case HARDTANH: - return new HardTanhDerivative(in); - case LEAKYRELU: - return new LeakyReLUDerivative(in); - case RATIONALTANH: - return new RationalTanhDerivative(in); - case SIGMOID: - return new SigmoidDerivative(in); - case SOFTPLUS: - return new Sigmoid(in); - case SOFTSIGN: - return new SoftSignDerivative(in); - case TANH: - return new TanhDerivative(in); - case RECTIFIEDTANH: - return new RectifiedTanhDerivative(in); - case SELU: - return new SELUDerivative(in); - case SWISH: - return new SwishDerivative(in); - case SOFTMAX: - return new SoftMaxDerivative(in); - case IDENTITY: - return new ScalarSet(in, 1.0); - case RELU: - return new Step(in); - case GELU: - return new GELUDerivative(in); - case RRELU: - default: - throw new UnsupportedOperationException("Not supported via this method: " + this); - } - } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSoftmax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSoftmax.java index 912a7ed18..2fd8b439a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSoftmax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSoftmax.java @@ -20,7 +20,8 @@ import lombok.EqualsAndHashCode; import lombok.Getter; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.impl.transforms.strict.OldSoftMax; +import org.nd4j.linalg.api.ops.CustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.primitives.Pair; @@ -34,14 +35,14 @@ public class ActivationSoftmax extends BaseActivationFunction { @Override public INDArray getActivation(INDArray in, boolean training) { - Nd4j.getExecutioner().execAndReturn(new OldSoftMax(in)); + Nd4j.getExecutioner().execAndReturn((CustomOp) new SoftMax(in, in)); return in; } @Override public Pair backprop(INDArray in, INDArray epsilon) { assertShape(in, epsilon); - INDArray out = Nd4j.getExecutioner().exec(new OldSoftMax(in)); + INDArray out = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(in, in.ulike()))[0]; INDArray x = out.mul(epsilon).sum(1); INDArray dLdz = out.mul(epsilon.subColumnVector(x)); return new Pair<>(dLdz, null); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LogSoftMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LogSoftMax.java index b13870606..09a4823e0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LogSoftMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LogSoftMax.java @@ -21,14 +21,9 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformOp; -import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import org.nd4j.linalg.api.ops.DynamicCustomOp; -import org.nd4j.linalg.api.ops.impl.transforms.strict.OldSoftMax; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.ops.transforms.Transforms; -import java.util.Arrays; import java.util.Collections; import java.util.List; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SoftMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SoftMax.java index 47a70aff9..bfa1c27c1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SoftMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SoftMax.java @@ -19,10 +19,14 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.CustomOp; +import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; +import java.nio.Buffer; import java.util.Collections; import java.util.List; @@ -53,10 +57,6 @@ public class SoftMax extends BaseDynamicTransformOp { super(sameDiff, args, inPlace); } - public SoftMax(INDArray input, INDArray result){ - super(new INDArray[]{input}, new INDArray[]{result}); - } - public SoftMax(SameDiff sameDiff, SDVariable[] args, int dimension) { super(sameDiff, args, false); this.dimension = dimension; @@ -75,13 +75,19 @@ public class SoftMax extends BaseDynamicTransformOp { addIArgument(dimension); } + public SoftMax(INDArray input){ + this(input, input); + } + + public SoftMax(INDArray input, INDArray result){ + this(input, result, -1); + } @Override public String opName() { return "softmax"; } - @Override public String onnxName() { return "Softmax"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/OldLogSoftMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/OldLogSoftMax.java deleted file mode 100644 index 4b1c0c378..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/OldLogSoftMax.java +++ /dev/null @@ -1,87 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.transforms.strict; - -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.base.Preconditions; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformStrictOp; - -import java.util.Collections; -import java.util.List; - -/** - * Old LogSoftMax function - * - * @author Adam Gibson - */ - -public class OldLogSoftMax extends BaseTransformStrictOp { - public OldLogSoftMax(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) { - super(sameDiff, i_v1, i_v2); - } - - public OldLogSoftMax(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) { - super(sameDiff, i_v1, i_v2, inPlace); - } - - public OldLogSoftMax(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { - super(sameDiff, i_v, inPlace); - } - - public OldLogSoftMax() { - } - - public OldLogSoftMax(INDArray x){ - this(x,x); - } - - public OldLogSoftMax(INDArray x, INDArray z) { - super(x, z); - Preconditions.checkArgument(x != null && x.rank() == 2, "OldSoftMax op supports rank 2 (2d) arrays only. Got x (source) array with shape: %ndShape", x); - Preconditions.checkArgument(z != null && z.rank() == 2, "OldSoftMax op supports rank 2 (2d) arrays only. Got z (result) array with shape: %ndShape", z); - } - - @Override - public int opNum() { - return 2; - } - - @Override - public String opName() { - return "old_logsoftmax"; - } - - - @Override - public String onnxName() { - return "old_LogSoftmax"; - } - - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("No tensorflow op opName found for " + opName()); - } - - @Override - public List doDiff(List i_v) { - SDVariable ret = f().logSoftmaxDerivative(arg(), i_v.get(0)); - return Collections.singletonList(ret); - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/OldSoftMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/OldSoftMax.java deleted file mode 100644 index 15b54f9a2..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/OldSoftMax.java +++ /dev/null @@ -1,94 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.transforms.strict; - -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.base.Preconditions; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformStrictOp; -import org.nd4j.linalg.factory.Nd4j; - -import java.util.Collections; -import java.util.List; - -/** - * Soft max function - * row_maxes is a row vector (max for each row) - * row_maxes = rowmaxes(input) - * diff = exp(input - max) / diff.rowSums() - * Outputs a probability distribution. - * Note that this is a parameterized model and requires - * the sum and max for the vector being calculated - * - * @author Adam Gibson - */ - -public class OldSoftMax extends BaseTransformStrictOp { - public OldSoftMax(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) { - super(sameDiff, i_v1, i_v2); - } - - public OldSoftMax(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) { - super(sameDiff, i_v1, i_v2, inPlace); - } - - public OldSoftMax(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { - super(sameDiff, i_v, inPlace); - } - - public OldSoftMax() { - } - - public OldSoftMax(INDArray x){ - this(x,x); - } - - public OldSoftMax(INDArray x, INDArray z) { - super(x, z); - Preconditions.checkArgument(x != null && x.rank() == 2, "OldSoftMax op supports rank 2 (2d) arrays only. Got x (source) array with shape: %ndShape", x); - Preconditions.checkArgument(z != null && z.rank() == 2, "OldSoftMax op supports rank 2 (2d) arrays only. Got z (result) array with shape: %ndShape", z); - } - - @Override - public int opNum() { - return 0; - } - - @Override - public String opName() { - return "old_softmax"; - } - - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName()); - } - - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("No tensorflow op opName found for " + opName()); - } - - @Override - public List doDiff(List i_v) { - SDVariable ret = f().softmaxDerivative(arg(), i_v.get(0), 1); - return Collections.singletonList(ret); - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftMaxDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftMaxDerivative.java index 0e4cd02c6..b9bf8d2fe 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftMaxDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftMaxDerivative.java @@ -19,20 +19,27 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; +import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.CustomOp; +import org.nd4j.linalg.api.ops.Op; +import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; + +import java.nio.Buffer; /** * Softmax derivative * * @author Adam Gibson */ -public class SoftMaxDerivative extends OldSoftMax { +public class SoftMaxDerivative extends SoftMax { public SoftMaxDerivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) { - super(sameDiff, i_v1, i_v2); + super(sameDiff, new SDVariable[]{i_v1, i_v2}); } public SoftMaxDerivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) { - super(sameDiff, i_v1, i_v2, inPlace); + super(sameDiff, new SDVariable[]{ i_v1, i_v2}, inPlace); } public SoftMaxDerivative(INDArray x, INDArray z) { @@ -40,11 +47,13 @@ public class SoftMaxDerivative extends OldSoftMax { } public SoftMaxDerivative(INDArray x) { - super(x); + super(x, x); } public SoftMaxDerivative() {} + + @Override public int opNum() { return 1; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java index 89e8cbec4..a35fe8f43 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java @@ -1334,6 +1334,17 @@ public class Shape { } } + public static boolean isVector(LongBuffer shapeInfo) { + int rank = Shape.rank(shapeInfo); + if (rank > 2 || rank < 1) + return false; + else { + long len = Shape.length(shapeInfo); + val shape = Shape.shapeOf(shapeInfo); + return shape.get(0) == len || shape.get(1) == len; + } + } + /** * Returns whether the given shape is a vector * @@ -2498,6 +2509,14 @@ public class Shape { return ret; } + public static int length(LongBuffer buffer) { + int ret = 1; + val shape = Shape.shapeOf(buffer); + int rank = Shape.rank(buffer); + for (int i = 0; i < rank; i++) + ret *= shape.get(i); + return ret; + } /** * Gets the rank given the shape info buffer @@ -2764,8 +2783,8 @@ public class Shape { * @param rank the rank to get the length for * @return rank * 2 + 4 */ - public static int shapeInfoLength(int rank) { - return rank * 2 + 4; + public static int shapeInfoLength(long rank) { + return (int) rank * 2 + 4; } public static int shapeInfoLength(long[] shape) { @@ -3072,6 +3091,11 @@ public class Shape { return buffer.get(length2 - 2); } + public static long elementWiseStride(LongBuffer buffer) { + int length2 = shapeInfoLength(buffer.get(0)); + return buffer.get(length2 - 2); + } + /** * Get the element wise stride for the * shape info buffer @@ -3179,40 +3203,6 @@ public class Shape { throw new RuntimeException("setOrder called"); } - /** - * Creates the shape information buffer - * given the shape,stride - * @param shape the shape for the buffer - * @param stride the stride for the buffer - * @param offset the offset for the buffer - * @param elementWiseStride the element wise stride for the buffer - * @param order the order for the buffer - * @return the shape information buffer given the parameters - */ - public static DataBuffer createShapeInformation(int[] shape, int[] stride, long offset, int elementWiseStride, char order) { - if (shape.length != stride.length) - throw new IllegalStateException("Shape and stride must be the same length"); - - int rank = shape.length; - int shapeBuffer[] = new int[rank * 2 + 4]; - shapeBuffer[0] = rank; - int count = 1; - for (int e = 0; e < shape.length; e++) - shapeBuffer[count++] = shape[e]; - - for (int e = 0; e < stride.length; e++) - shapeBuffer[count++] = stride[e]; - - shapeBuffer[count++] = (int) offset; - shapeBuffer[count++] = elementWiseStride; - shapeBuffer[count] = (int) order; - - DataBuffer ret = Nd4j.createBufferDetached(shapeBuffer); - ret.setConstant(true); - - return ret; - } - public static DataBuffer createShapeInformation(long[] shape, long[] stride, long elementWiseStride, char order, DataType dataType, boolean empty) { boolean isEmpty = empty; if (!empty) @@ -3438,9 +3428,20 @@ public class Shape { public static boolean contentEquals(long[] arr, IntBuffer other) { for (int i = 0; i < arr.length; i++) { - Buffer buffer2 = (Buffer) other; - buffer2.position(i); - if (arr[i] != other.get()) { + val t = arr[i]; + val o = other.get(i); + if (t != o) { + return false; + } + } + return true; + } + + public static boolean contentEquals(long[] arr, LongBuffer other) { + for (int i = 0; i < arr.length; i++) { + val t = arr[i]; + val o = other.get(i); + if (t != o) { return false; } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossBinaryXENT.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossBinaryXENT.java index 379be151f..604585f6c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossBinaryXENT.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossBinaryXENT.java @@ -25,8 +25,8 @@ import org.nd4j.linalg.activations.impl.ActivationSoftmax; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; import org.nd4j.linalg.api.ops.impl.transforms.same.TimesOneMinus; -import org.nd4j.linalg.api.ops.impl.transforms.strict.OldSoftMax; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.lossfunctions.ILossFunction; import org.nd4j.linalg.lossfunctions.LossUtil; @@ -121,7 +121,7 @@ public class LossBinaryXENT implements ILossFunction { INDArray scoreArr; if (activationFn instanceof ActivationSoftmax) { //TODO Post GPU support for custom ops: Use LogSoftMax op to avoid numerical issues when calculating score - INDArray logsoftmax = Nd4j.getExecutioner().exec(new OldSoftMax(preOutput.dup())); + INDArray logsoftmax = Nd4j.exec((CustomOp) new SoftMax(preOutput, preOutput.ulike(), -1))[0]; Transforms.log(logsoftmax, false); scoreArr = logsoftmax.muli(labels); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMixtureDensity.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMixtureDensity.java index af96748fd..53f87265a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMixtureDensity.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMixtureDensity.java @@ -20,7 +20,8 @@ import lombok.Data; import lombok.EqualsAndHashCode; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.impl.transforms.strict.OldSoftMax; +import org.nd4j.linalg.api.ops.CustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex; @@ -139,7 +140,7 @@ public class LossMixtureDensity implements ILossFunction { // Alpha is a softmax because // the alpha should all sum to 1 for a given gaussian mixture. - mdc.alpha = Nd4j.getExecutioner().exec(new OldSoftMax(mdc.alpha)); + mdc.alpha = Nd4j.exec((CustomOp) new SoftMax(mdc.alpha, mdc.alpha, -1))[0]; // Mu comes directly from the network as an unmolested value. // Note that this effectively means that the output layer of diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java index 43df4a430..e676197ee 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java @@ -21,6 +21,7 @@ import lombok.val; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.ScalarOp; import org.nd4j.linalg.api.ops.TransformOp; import org.nd4j.linalg.api.ops.impl.reduce3.*; @@ -29,6 +30,7 @@ import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarNot; import org.nd4j.linalg.api.ops.impl.shape.Cross; import org.nd4j.linalg.api.ops.impl.transforms.bool.BooleanNot; import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax; +import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; import org.nd4j.linalg.api.ops.impl.transforms.floating.*; import org.nd4j.linalg.api.ops.impl.transforms.comparison.*; import org.nd4j.linalg.api.ops.impl.transforms.gradient.ELUDerivative; @@ -512,7 +514,7 @@ public class Transforms { * @return */ public static INDArray softmax(INDArray in, boolean copy) { - return Nd4j.getExecutioner().exec(new OldSoftMax(in, (copy ? in.ulike() : in))); + return Nd4j.getExecutioner().exec((CustomOp) new SoftMax(in, (copy ? in.ulike() : in), -1))[0]; } /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/profiler/OpProfiler.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/profiler/OpProfiler.java index 6519f89b4..02d71b54d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/profiler/OpProfiler.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/profiler/OpProfiler.java @@ -240,7 +240,7 @@ public class OpProfiler { String opClass = getOpClass(op); classCounter.incrementCount(opClass); - if(op.x() == null || (op.x() != null && op.x().data().address() == lastZ && op.z() == op.x() && op.y() == null)) { + if(op.x() == null || (op.x() != null && op.x().data().platformAddress() == lastZ && op.z() == op.x() && op.y() == null)) { // we have possible shift here matchingCounter.incrementCount(prevOpMatching + " -> " + opClass); matchingCounterDetailed.incrementCount(prevOpMatchingDetailed + " -> " + opClass + " " + op.opName()); @@ -254,7 +254,7 @@ public class OpProfiler { } } - lastZ = op.z() != null ? op.z().data().address() : 0L; + lastZ = op.z() != null ? op.z().data().platformAddress() : 0L; prevOpMatching = opClass; prevOpMatchingDetailed = opClass + " " + op.opName(); prevOpMatchingInverted = opClass + " " + op.opName(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/BaseNativeNDArrayFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/BaseNativeNDArrayFactory.java index 3b7ac83c0..e7bd7404c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/BaseNativeNDArrayFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/BaseNativeNDArrayFactory.java @@ -610,6 +610,7 @@ public abstract class BaseNativeNDArrayFactory extends BaseNDArrayFactory { bb2.put((byte)((s >> 8) & 0xff)); bb2.put((byte)(s & 0xff)); } + Nd4j.getAffinityManager().tagLocation(arr, AffinityManager.Location.HOST); map.put(fName, arr.reshape(order, shape)); } else if(dt == DataType.LONG){ long[] d = new long[(int)size]; diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/flow/impl/SynchronousFlowController.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/flow/impl/SynchronousFlowController.java index 2f84fa54d..9a8feeb0b 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/flow/impl/SynchronousFlowController.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/flow/impl/SynchronousFlowController.java @@ -72,7 +72,7 @@ public class SynchronousFlowController implements FlowController { public void synchronizeToHost(AllocationPoint point) { if (!point.isActualOnHostSide()) { - CudaContext context = (CudaContext) allocator.getDeviceContext().getContext(); + val context = (CudaContext) allocator.getDeviceContext().getContext(); if (!point.isConstant()) waitTillFinished(point); @@ -102,7 +102,7 @@ public class SynchronousFlowController implements FlowController { if (!point.isActualOnDeviceSide()) { if (point.getAllocationStatus() == AllocationStatus.DEVICE) { - CudaContext context = (CudaContext) allocator.getDeviceContext().getContext(); + val context = (CudaContext) allocator.getDeviceContext().getContext(); long perfD = PerformanceTracker.getInstance().helperStartTransaction(); @@ -135,17 +135,17 @@ public class SynchronousFlowController implements FlowController { @Override public CudaContext prepareActionAllWrite(INDArray... operands) { - CudaContext context = (CudaContext) allocator.getDeviceContext().getContext(); - int cId = allocator.getDeviceId(); + val context = (CudaContext) allocator.getDeviceContext().getContext(); + val cId = allocator.getDeviceId(); for (INDArray operand : operands) { - if (operand == null) + if (operand == null || operand.isEmpty()) continue; Nd4j.getCompressor().autoDecompress(operand); - AllocationPoint pointData = allocator.getAllocationPoint(operand); - AllocationPoint pointShape = allocator.getAllocationPoint(operand.shapeInfoDataBuffer()); + val pointData = allocator.getAllocationPoint(operand); + val pointShape = allocator.getAllocationPoint(operand.shapeInfoDataBuffer()); pointData.acquireLock(); @@ -168,15 +168,15 @@ public class SynchronousFlowController implements FlowController { @Override public CudaContext prepareAction(INDArray result, INDArray... operands) { - CudaContext context = (CudaContext) allocator.getDeviceContext().getContext(); - int cId = allocator.getDeviceId(); + val context = (CudaContext) allocator.getDeviceContext().getContext(); + val cId = allocator.getDeviceId(); - if (result != null) { + if (result != null && !result.isEmpty()) { Nd4j.getCompressor().autoDecompress(result); prepareDelayedMemory(result); - AllocationPoint pointData = allocator.getAllocationPoint(result); - AllocationPoint pointShape = allocator.getAllocationPoint(result.shapeInfoDataBuffer()); + val pointData = allocator.getAllocationPoint(result); + val pointShape = allocator.getAllocationPoint(result.shapeInfoDataBuffer()); pointData.acquireLock(); @@ -196,13 +196,13 @@ public class SynchronousFlowController implements FlowController { } for (INDArray operand : operands) { - if (operand == null) + if (operand == null || operand.isEmpty()) continue; Nd4j.getCompressor().autoDecompress(operand); - AllocationPoint pointData = allocator.getAllocationPoint(operand); - AllocationPoint pointShape = allocator.getAllocationPoint(operand.shapeInfoDataBuffer()); + val pointData = allocator.getAllocationPoint(operand); + val pointShape = allocator.getAllocationPoint(operand.shapeInfoDataBuffer()); pointData.acquireLock(); @@ -256,7 +256,7 @@ public class SynchronousFlowController implements FlowController { if (operand == null) continue; - AllocationPoint pointOperand = allocator.getAllocationPoint(operand); + val pointOperand = allocator.getAllocationPoint(operand); pointOperand.tickDeviceWrite(); eventsProvider.storeEvent(pointOperand.getLastWriteEvent()); pointOperand.setLastWriteEvent(eventsProvider.getEvent()); @@ -266,9 +266,10 @@ public class SynchronousFlowController implements FlowController { } public void registerAction(CudaContext context, INDArray result, INDArray... operands) { - if (result == null) + if (result == null || result.isEmpty()) return; - AllocationPoint point = allocator.getAllocationPoint(result); + + val point = allocator.getAllocationPoint(result); point.tickDeviceWrite(); eventsProvider.storeEvent(point.getLastWriteEvent()); point.setLastWriteEvent(eventsProvider.getEvent()); @@ -276,10 +277,10 @@ public class SynchronousFlowController implements FlowController { point.releaseLock(); for (INDArray operand : operands) { - if (operand == null) + if (operand == null || operand.isEmpty()) continue; - AllocationPoint pointOperand = allocator.getAllocationPoint(operand); + val pointOperand = allocator.getAllocationPoint(operand); pointOperand.releaseLock(); eventsProvider.storeEvent(pointOperand.getLastReadEvent()); pointOperand.setLastReadEvent(eventsProvider.getEvent()); @@ -289,7 +290,7 @@ public class SynchronousFlowController implements FlowController { @Override public CudaContext prepareAction(AllocationPoint result, AllocationPoint... operands) { - CudaContext context = (CudaContext) allocator.getDeviceContext().getContext(); + val context = (CudaContext) allocator.getDeviceContext().getContext(); if (result != null) { result.acquireLock(); @@ -299,6 +300,7 @@ public class SynchronousFlowController implements FlowController { for (AllocationPoint operand : operands) { if (operand == null) continue; + operand.acquireLock(); operand.setCurrentContext(context); } @@ -313,15 +315,16 @@ public class SynchronousFlowController implements FlowController { protected void prepareDelayedMemory(INDArray array) { if (configuration.getMemoryModel() == Configuration.MemoryModel.DELAYED) { - AllocationPoint pointData = allocator.getAllocationPoint(array.shapeInfoDataBuffer()); - AllocationPoint pointShape = allocator.getAllocationPoint(array.shapeInfoDataBuffer()); + val pointData = allocator.getAllocationPoint(array.shapeInfoDataBuffer()); + val pointShape = allocator.getAllocationPoint(array.shapeInfoDataBuffer()); if (pointData.getAllocationStatus() != AllocationStatus.DEVICE) prepareDelayedMemory(array.data()); if (pointShape.getAllocationStatus() == AllocationStatus.HOST) { - DataBuffer oShape = array.shapeInfoDataBuffer(); - DataBuffer nShape = Nd4j.getConstantHandler().relocateConstantSpace(oShape); + val oShape = array.shapeInfoDataBuffer(); + val nShape = Nd4j.getConstantHandler().relocateConstantSpace(oShape); + if (nShape == oShape) Nd4j.getConstantHandler().moveToConstantSpace(nShape); ((JCublasNDArray) array).setShapeInfoDataBuffer(nShape); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java index 79196d15b..2c5b7afa8 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java @@ -567,6 +567,11 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda return allocationPoint.getPointers().getHostPointer().address(); } + @Override + public long platformAddress() { + return allocationPoint.getPointers().getDevicePointer().address(); + } + @Override public Pointer pointer() { // FIXME: very bad thing, diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java index 90bc487df..723c7d8d7 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java @@ -26,6 +26,7 @@ import org.bytedeco.javacpp.*; import org.bytedeco.javacpp.indexer.LongIndexer; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper; +import org.nd4j.base.Preconditions; import org.nd4j.jita.allocator.impl.AllocationPoint; import org.nd4j.jita.allocator.impl.AtomicAllocator; import org.nd4j.jita.allocator.pointers.CudaPointer; @@ -515,7 +516,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { } // in case of regular accumulation we don't care about array state before op - ret = Nd4j.createUninitialized(dtype, retShape); + ret = Nd4j.create(dtype, retShape); } op.setZ(ret); } else { @@ -536,11 +537,16 @@ public class CudaExecutioner extends DefaultOpExecutioner { @Override public INDArray exec(IndexAccumulation op) { val dimension = Shape.normalizeAxis(op.x().rank(), op.dimensions().toIntVector()); - if (op.z() == null) { - long[] retShape = Shape.reductionShape(op.x(), dimension, true, op.isKeepDims()); - INDArray ret = Nd4j.createUninitialized(DataType.LONG, retShape); - op.setZ(ret); + if (op.x().isEmpty()) { + for (val d:dimension) { + Preconditions.checkArgument(op.x().shape()[d] != 0, "IndexReduce can't be issued along axis with 0 in shape"); + } + } + + if (op.z() == null) { + val retShape = Shape.reductionShape(op.x(), dimension, true, op.isKeepDims()); + op.setZ(Nd4j.createUninitialized(DataType.LONG, retShape)); } long st = profilingConfigurableHookIn(op); @@ -556,10 +562,13 @@ public class CudaExecutioner extends DefaultOpExecutioner { return op.x(); } + if (op.z().isEmpty()) + return op.z(); + if (CudaEnvironment.getInstance().getConfiguration().isDebug()) lastOp.set(op.opName()); - CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y()); + val context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y()); val hostXShapeInfo = op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java index b08c0ccf8..5f6fd2e5f 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java @@ -619,9 +619,6 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory { "Illegal concatenation at array " + i + " and shape element " + j); } } - - - //log.info("Shape[{}]: {}", i, Arrays.toString(toConcat[i].shapeInfoDataBuffer().asInt())); } if (allScalars) { @@ -630,8 +627,6 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory { outputShape[dimension] = sumAlongDim; } - //PointerPointer dummy = new PointerPointer(new Pointer[] {null}); - INDArray ret = Nd4j.createUninitialized(toConcat[0].dataType(), outputShape, Nd4j.order()); nativeOps.concat(null, dimension, toConcat.length, @@ -639,11 +634,9 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory { null, null, ret.data().addressPointer(), (LongPointer) ret.shapeInfoDataBuffer().addressPointer(), null, null, - //new PointerPointer(new Pointer[] {null}), new PointerPointer(new Pointer[] {null})); null, null); return ret; - // return super.concat(dimension,toConcat); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java index b36fba903..cc67f00d8 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java @@ -31,6 +31,7 @@ import org.nd4j.autodiff.validation.OpValidation; import org.nd4j.autodiff.validation.TestCase; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.indexaccum.IAMax; import org.nd4j.linalg.api.ops.impl.indexaccum.IAMin; @@ -1083,7 +1084,7 @@ public class ReductionOpValidation extends BaseOpValidation { final INDArray exec = Nd4j.matmul(keys, query, true, false, false) .divi(Math.sqrt(keys.size(1))); - Nd4j.exec(new SoftMax(exec, exec, 1)); + Nd4j.exec((CustomOp) new SoftMax(exec, exec, 1)); final INDArray finalOut = Nd4j.matmul(values, exec).norm1(); SameDiff sd = SameDiff.create(); @@ -1111,7 +1112,7 @@ public class ReductionOpValidation extends BaseOpValidation { final INDArray exec = Nd4j.matmul(keys, query, true, false, false) .divi(Math.sqrt(keys.size(1))); exec.addi(mask.reshape(10, 3, 1).sub(1).muli(1e9)); - Nd4j.exec(new SoftMax(exec, exec, 1)); + Nd4j.exec((CustomOp) new SoftMax(exec, exec, 1)); final INDArray finalOut = Nd4j.matmul(values, exec).norm1(); SameDiff sd = SameDiff.create(); @@ -1141,7 +1142,7 @@ public class ReductionOpValidation extends BaseOpValidation { final INDArray exec = Nd4j.matmul(keys, query, true, false, false) .divi(Math.sqrt(keys.size(-2))); exec.addi(Nd4j.tile(mask.reshape(2, 1, 3, 1), 1, 5, 1, 2).sub(1).muli(1e9)); - Nd4j.exec(new SoftMax(exec, exec, -2)); + Nd4j.exec((CustomOp) new SoftMax(exec, exec, -2)); final INDArray finalOut = Nd4j.matmul(values, exec).norm1(); SameDiff sd = SameDiff.create(); @@ -1169,7 +1170,7 @@ public class ReductionOpValidation extends BaseOpValidation { final INDArray exec = Nd4j.matmul(keys, query, true, false, false) .divi(Math.sqrt(keys.size(-2))); - Nd4j.exec(new SoftMax(exec, exec, -2)); + Nd4j.exec((CustomOp) new SoftMax(exec, exec, -2)); final INDArray finalOut = Nd4j.matmul(values, exec).norm1(); SameDiff sd = SameDiff.create(); @@ -1249,7 +1250,7 @@ public class ReductionOpValidation extends BaseOpValidation { final INDArray exec = Nd4j.matmul(keys, query, true, false, false) .divi(Math.sqrt(keys.size(1))); exec.addi(mask.reshape(10, 3, 1).sub(1).muli(1e9)); - Nd4j.exec(new SoftMax(exec, exec, 1)); + Nd4j.exec((CustomOp) new SoftMax(exec, exec, 1)); final INDArray finalOut = Nd4j.matmul(values, exec).norm1(); for (char queryOrder : new char[]{'f', 'c'}) { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java index e4f3bd9ac..53a040670 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java @@ -42,6 +42,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMax; import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMin; import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual; import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual; +import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; import org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize; import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt; import org.nd4j.linalg.api.ops.impl.transforms.strict.*; @@ -671,7 +672,7 @@ public class TransformOpValidation extends BaseOpValidation { //TODO SHOULDN'T THIS HAVE A DIMENSION ARG??? t = sd.nn().softmax(in); ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut); - tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new OldSoftMax(ia.dup()))); + tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new SoftMax(ia.dup()))[0]); break; case 24: t = sd.math().sqrt(in); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/LoneTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/LoneTest.java index e676761fa..577e19ecb 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/LoneTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/LoneTest.java @@ -25,7 +25,7 @@ import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.impl.transforms.strict.OldSoftMax; +import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; import org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.checkutil.NDArrayCreationUtil; @@ -60,7 +60,7 @@ public class LoneTest extends BaseNd4jTest { System.out.println("Input transpose " + Shape.shapeToString(input.shapeInfo())); INDArray output = Nd4j.create(DataType.DOUBLE, 10, 1); System.out.println("Element wise stride of output " + output.elementWiseStride()); - Nd4j.getExecutioner().exec(new OldSoftMax(input, output)); + Nd4j.getExecutioner().exec(new SoftMax(input, output)); } @Override diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index f9ac99423..339e3e6b4 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -41,9 +41,9 @@ import org.nd4j.linalg.api.memory.enums.LearningPolicy; import org.nd4j.linalg.api.memory.enums.SpillPolicy; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BroadcastOp; +import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.Op; -import org.nd4j.linalg.api.ops.custom.Flatten; import org.nd4j.linalg.api.ops.executioner.GridExecutioner; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.api.ops.impl.broadcast.*; @@ -72,13 +72,13 @@ import org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform; import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet; import org.nd4j.linalg.api.ops.impl.transforms.comparison.Eps; import org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpace; +import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.BinaryRelativeError; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.Set; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.Axpy; import org.nd4j.linalg.api.ops.impl.transforms.same.OldReverse; import org.nd4j.linalg.api.ops.impl.transforms.same.Sign; import org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh; -import org.nd4j.linalg.api.ops.impl.transforms.strict.OldSoftMax; import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftMaxDerivative; import org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh; import org.nd4j.linalg.api.shape.Shape; @@ -94,8 +94,6 @@ import org.nd4j.linalg.indexing.conditions.Conditions; import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.linalg.primitives.Pair; -import org.nd4j.linalg.profiler.OpProfiler; -import org.nd4j.linalg.profiler.ProfilerConfig; import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.linalg.util.MathUtils; @@ -2919,7 +2917,7 @@ public class Nd4jTestsC extends BaseNd4jTest { System.out.println("Input transpose " + Shape.shapeToString(input.shapeInfo())); INDArray output = Nd4j.create(10, 1); System.out.println("Element wise stride of output " + output.elementWiseStride()); - Nd4j.getExecutioner().exec(new OldSoftMax(input, output)); + Nd4j.getExecutioner().exec(new SoftMax(input, output)); } @Test @@ -3134,7 +3132,7 @@ public class Nd4jTestsC extends BaseNd4jTest { public void testSoftmaxRow() { for (int i = 0; i < 20; i++) { INDArray arr1 = Nd4j.zeros(1, 100); - Nd4j.getExecutioner().execAndReturn(new OldSoftMax(arr1)); + Nd4j.getExecutioner().execAndReturn(new SoftMax(arr1)); System.out.println(Arrays.toString(arr1.data().asFloat())); } } @@ -3779,7 +3777,7 @@ public class Nd4jTestsC extends BaseNd4jTest { for (int i = 0; i < 3; i++) { INDArray subset = result12.tensorAlongDimension(i, 1, 2);//result12.get(NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.all()); - assertEquals("Failed for subset " + i, bc12, subset); + assertEquals("Failed for subset [" + i + "] orders [" + orderArr + "/" + orderbc + "]", bc12, subset); } } } @@ -5725,9 +5723,9 @@ public class Nd4jTestsC extends BaseNd4jTest { val reference = original.dup(original.ordering()); val expected = original.dup(original.ordering()); - Nd4j.getExecutioner().execAndReturn(new OldSoftMax(expected)); + Nd4j.getExecutioner().execAndReturn((CustomOp) new SoftMax(expected, expected, -1)); - val result = Nd4j.getExecutioner().exec(new OldSoftMax(original, original.dup(original.ordering()))); + val result = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(original, original.dup(original.ordering())))[0]; assertEquals(reference, original); assertEquals(expected, result); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/crash/CrashTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/crash/CrashTest.java index b6ba052c8..fe02070fd 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/crash/CrashTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/crash/CrashTest.java @@ -23,12 +23,12 @@ import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.impl.indexaccum.IMax; import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance; import org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax; +import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; import org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt; -import org.nd4j.linalg.api.ops.impl.transforms.strict.OldLogSoftMax; -import org.nd4j.linalg.api.ops.impl.transforms.strict.OldSoftMax; import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftMaxDerivative; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; @@ -158,9 +158,9 @@ public class CrashTest extends BaseNd4jTest { // logisoftmax, softmax & softmax derivative - Nd4j.getExecutioner().exec(new OldSoftMax(x)); - Nd4j.getExecutioner().exec(new SoftMaxDerivative(x)); - Nd4j.getExecutioner().exec(new OldLogSoftMax(x)); + Nd4j.getExecutioner().exec((CustomOp) new SoftMax(x)); + Nd4j.getExecutioner().exec((CustomOp) new SoftMaxDerivative(x)); + Nd4j.getExecutioner().exec((CustomOp) new LogSoftMax(x)); // BooleanIndexing diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/mixed/MixedDataTypesTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/mixed/MixedDataTypesTests.java index 4088a4a73..3dc773981 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/mixed/MixedDataTypesTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/mixed/MixedDataTypesTests.java @@ -30,12 +30,13 @@ import org.nd4j.linalg.api.memory.enums.LearningPolicy; import org.nd4j.linalg.api.memory.enums.MirroringPolicy; import org.nd4j.linalg.api.memory.enums.SpillPolicy; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.impl.reduce.bool.IsInf; import org.nd4j.linalg.api.ops.impl.reduce.bool.IsNaN; import org.nd4j.linalg.api.ops.impl.reduce.longer.CountNonZero; import org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity; import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldEqualTo; -import org.nd4j.linalg.api.ops.impl.transforms.strict.OldSoftMax; +import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; @@ -334,7 +335,7 @@ public class MixedDataTypesTests extends BaseNd4jTest { public void testTypesValidation_3() { val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT); - val result = Nd4j.getExecutioner().exec(new OldSoftMax(arrayX)); + val result = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(arrayX, arrayX, -1)); } public void testTypesValidation_4() { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/DerivativeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/DerivativeTests.java index 58d564f83..c89ee8dfd 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/DerivativeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/DerivativeTests.java @@ -25,7 +25,9 @@ import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.impl.scalar.Step; +import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; import org.nd4j.linalg.api.ops.impl.transforms.gradient.*; import org.nd4j.linalg.api.ops.impl.transforms.strict.*; import org.nd4j.linalg.api.ops.impl.transforms.strict.SigmoidDerivative; @@ -217,8 +219,8 @@ public class DerivativeTests extends BaseNd4jTest { } } - INDArray sm = Nd4j.getExecutioner().exec(new OldSoftMax(z.dup())); - INDArray zPrime = Nd4j.getExecutioner().exec(new SoftMaxDerivative(z)); + INDArray sm = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(z.dup()))[0]; + INDArray zPrime = Nd4j.getExecutioner().exec((CustomOp) new SoftMaxDerivative(z))[0]; System.out.println(Arrays.toString(sm.data().asDouble())); System.out.println(Arrays.toString(zPrime.data().asDouble())); assertNotEquals(sm, zPrime); @@ -396,7 +398,7 @@ public class DerivativeTests extends BaseNd4jTest { //random array represeting preout INDArray X = Nd4j.rand(1, 2); //preout transformed to y_hat with softmax - INDArray YHat = Nd4j.getExecutioner().exec(new OldSoftMax(X.dup())); + INDArray YHat = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(X.dup()))[0]; //hard coding something to construct a function with, using MSE INDArray Y = Nd4j.create(new double[][] {{0.123, 1 - 0.123}}); @@ -404,7 +406,7 @@ public class DerivativeTests extends BaseNd4jTest { //This is the MSE now double lossHere = Transforms.pow(Y.sub(YHat), 2).sumNumber().doubleValue(); - INDArray softmaxDer = Nd4j.getExecutioner().exec(new SoftMaxDerivative(X.dup())); + INDArray softmaxDer = Nd4j.getExecutioner().exec((CustomOp) new SoftMaxDerivative(X.dup()))[0]; //the way we apply the chain rule now is 2*(y-yhat)*softmaxder INDArray dLdY = Y.sub(YHat).mul(-2); @@ -444,13 +446,13 @@ public class DerivativeTests extends BaseNd4jTest { double x = X.getDouble(0, i); Xiplus = X.dup(); Xiplus.put(0, i, x + epsilon); - YHatplus = Nd4j.getExecutioner().exec(new OldSoftMax(Xiplus.dup())); + YHatplus = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(Xiplus.dup()))[0]; lossplus = Transforms.pow(Y.sub(YHatplus), 2).sumNumber().doubleValue(); // -epsilon Ximinus = X.dup(); Ximinus.put(0, i, x - epsilon); - YHatminus = Nd4j.getExecutioner().exec(new OldSoftMax(Ximinus.dup())); + YHatminus = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(Ximinus.dup()))[0]; lossminus = Transforms.pow(Y.sub(YHatminus), 2).sumNumber().doubleValue(); double gradienti = (lossplus - lossminus) / (2 * epsilon); @@ -479,16 +481,16 @@ public class DerivativeTests extends BaseNd4jTest { INDArray X = Nd4j.rand(1, someLength); //preout transformed to y_hat with softmax - INDArray YHat = Nd4j.getExecutioner().exec(new OldSoftMax(X.dup())); + INDArray YHat = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(X.dup()))[0]; //hard coding something to construct a function with, using MSE INDArray temp = Nd4j.rand(1, someLength); - INDArray Y = Nd4j.getExecutioner().exec(new OldSoftMax(temp)); + INDArray Y = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(temp))[0]; //This is the MSE now double lossHere = Transforms.pow(Y.sub(YHat), 2).sumNumber().doubleValue(); - INDArray softmaxDer = Nd4j.getExecutioner().exec(new SoftMaxDerivative(X.dup())); + INDArray softmaxDer = Nd4j.getExecutioner().exec((CustomOp) new SoftMaxDerivative(X.dup()))[0]; //the way we apply the chain rule now is 2*(y-yhat)*softmaxder INDArray dLdY = Y.sub(YHat).mul(-2); @@ -511,13 +513,13 @@ public class DerivativeTests extends BaseNd4jTest { double x = X.getDouble(0, i); Xiplus = X.dup(); Xiplus.put(0, i, x + epsilon); - YHatplus = Nd4j.getExecutioner().exec(new OldSoftMax(Xiplus.dup())); + YHatplus = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(Xiplus.dup()))[0]; lossplus = Transforms.pow(Y.sub(YHatplus), 2).sumNumber().doubleValue(); // -epsilon Ximinus = X.dup(); Ximinus.put(0, i, x - epsilon); - YHatminus = Nd4j.getExecutioner().exec(new OldSoftMax(Ximinus.dup())); + YHatminus = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(Ximinus.dup()))[0]; lossminus = Transforms.pow(Y.sub(YHatminus), 2).sumNumber().doubleValue(); double gradienti = (lossplus - lossminus) / (2 * epsilon); @@ -538,14 +540,14 @@ public class DerivativeTests extends BaseNd4jTest { // this is only for X a row vector // should return rank 2 matrix diagonal elements are pi*(1-pi) //rest are -pi*pj - INDArray p = Nd4j.getExecutioner().exec(new OldSoftMax(X.dup())); + INDArray p = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(X.dup()))[0]; INDArray pCol = p.dup().transpose(); INDArray pipj = pCol.mmul(p); pipj.muli(-1); //so now pipj is correct except for the diagonal elements // which by the way is what our current softmax der gives us - INDArray diagp = Nd4j.getExecutioner().exec(new SoftMaxDerivative(X.dup())); + INDArray diagp = Nd4j.getExecutioner().exec((CustomOp) new SoftMaxDerivative(X.dup()))[0]; //ugly for loop to correct diag elements diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTests.java index ca3d85ebb..c6414002e 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTests.java @@ -24,6 +24,7 @@ import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.api.ops.impl.indexaccum.IAMax; import org.nd4j.linalg.api.ops.impl.indexaccum.IMax; @@ -42,11 +43,11 @@ import org.nd4j.linalg.api.ops.impl.scalar.ScalarMax; import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThan; import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThan; import org.nd4j.linalg.api.ops.impl.summarystats.Variance; +import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp; import org.nd4j.linalg.api.ops.impl.transforms.strict.Exp; import org.nd4j.linalg.api.ops.impl.transforms.strict.Log; -import org.nd4j.linalg.api.ops.impl.transforms.strict.OldSoftMax; import org.nd4j.linalg.api.ops.impl.transforms.strict.SetRange; import org.nd4j.linalg.api.ops.random.impl.DropOut; import org.nd4j.linalg.api.ops.random.impl.DropOutInverted; @@ -304,11 +305,11 @@ public class OpExecutionerTests extends BaseNd4jTest { @Test public void testRowSoftmax() { - OpExecutioner opExecutioner = Nd4j.getExecutioner(); - INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1); - OldSoftMax softMax = new OldSoftMax(arr); - opExecutioner.exec(softMax); - assertEquals(getFailureMessage(), 1.0, softMax.z().sumNumber().doubleValue(), 1e-1); + val opExecutioner = Nd4j.getExecutioner(); + val arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1); + val softMax = new SoftMax(arr); + opExecutioner.exec((CustomOp) softMax); + assertEquals(getFailureMessage(), 1.0, softMax.outputArguments()[0].sumNumber().doubleValue(), 1e-1); } @@ -373,7 +374,7 @@ public class OpExecutionerTests extends BaseNd4jTest { public void testSoftmax() { INDArray vec = Nd4j.linspace(1, 6, 6, DataType.DOUBLE); INDArray matrix = vec.dup().reshape('f', 2, 3); - Nd4j.getExecutioner().exec(new OldSoftMax(matrix)); + Nd4j.getExecutioner().exec((CustomOp) new SoftMax(matrix)); INDArray matrixAssertion = Nd4j.create( new double[] {0.015876241, 0.015876241, 0.11731043, 0.11731043, 0.86681336, 0.86681336}, new int[] {2, 3}, 'f'); @@ -384,7 +385,7 @@ public class OpExecutionerTests extends BaseNd4jTest { public void testOtherSoftmax() { INDArray vec = Nd4j.linspace(1, 18, 18, DataType.DOUBLE); INDArray matrix = vec.dup().reshape('f', 3, 6); - Nd4j.getExecutioner().exec(new OldSoftMax(matrix)); + Nd4j.getExecutioner().exec((CustomOp) new SoftMax(matrix)); INDArray assertion = Nd4j.create(new double[] {2.9067235E-7, 2.9067235E-7, 2.9067235E-7, 5.8383102E-6, 5.8383102E-6, 5.8383102E-6, 1.1726559E-4, 1.1726559E-4, 1.1726559E-4, 0.0023553425, 0.0023553425, 0.0023553425, 0.047308315, 0.047308315, 0.047308315, 0.95021296, 0.95021296, @@ -517,9 +518,9 @@ public class OpExecutionerTests extends BaseNd4jTest { 0.3049033, 0.29277474, 0.29136384, 0.30316526, 0.2807459}, new int[] {150, 3}, 'f'); System.out.println("Data:" + input.data().length()); - OldSoftMax softMax = new OldSoftMax(input); - Nd4j.getExecutioner().exec(softMax); - assertEquals(assertion, softMax.z()); + val softMax = new SoftMax(input); + Nd4j.getExecutioner().exec((CustomOp) softMax); + assertEquals(assertion, softMax.outputArguments()[0]); } @@ -557,9 +558,9 @@ public class OpExecutionerTests extends BaseNd4jTest { public void testSoftMax() { OpExecutioner opExecutioner = Nd4j.getExecutioner(); INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1); - OldSoftMax softMax = new OldSoftMax(arr); - opExecutioner.exec(softMax); - assertEquals(getFailureMessage(), 1.0, softMax.z().sumNumber().doubleValue(), 1e-1); + val softMax = new SoftMax(arr); + opExecutioner.exec((CustomOp) softMax); + assertEquals(getFailureMessage(), 1.0, softMax.outputArguments()[0].sumNumber().doubleValue(), 1e-1); } @Test diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java index 1ce00d896..614044a26 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java @@ -29,6 +29,7 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.util.DataTypeUtil; import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp; import org.nd4j.linalg.api.ops.impl.indexaccum.IMax; @@ -50,12 +51,12 @@ import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThan; import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThan; import org.nd4j.linalg.api.ops.impl.summarystats.Variance; import org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax; +import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; import org.nd4j.linalg.api.ops.impl.transforms.floating.Histogram; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp; import org.nd4j.linalg.api.ops.impl.transforms.strict.Exp; import org.nd4j.linalg.api.ops.impl.transforms.strict.Log; -import org.nd4j.linalg.api.ops.impl.transforms.strict.OldSoftMax; import org.nd4j.linalg.api.ops.impl.transforms.strict.SetRange; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.factory.Nd4j; @@ -96,9 +97,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { public void testSoftmaxReference() { INDArray input = Nd4j.linspace(1,4,4, DataType.FLOAT).reshape(2,2); INDArray dup = input.dup(); - Nd4j.getExecutioner().exec(new OldSoftMax(dup)); + Nd4j.getExecutioner().exec((CustomOp) new SoftMax(dup)); INDArray result = Nd4j.zeros(DataType.FLOAT, 2,2); - Nd4j.getExecutioner().exec(new OldSoftMax(input,result)); + Nd4j.getExecutioner().exec((CustomOp) new SoftMax(input,result)); assertEquals(dup,result); @@ -322,9 +323,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest { public void testRowSoftmax() { OpExecutioner opExecutioner = Nd4j.getExecutioner(); INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1); - OldSoftMax softMax = new OldSoftMax(arr); - opExecutioner.exec(softMax); - assertEquals(getFailureMessage(), 1.0, softMax.z().sumNumber().doubleValue(), 1e-1); + val softMax = new SoftMax(arr); + opExecutioner.exec((CustomOp) softMax); + assertEquals(getFailureMessage(), 1.0, softMax.outputArguments()[0].sumNumber().doubleValue(), 1e-1); } @Test @@ -422,23 +423,23 @@ public class OpExecutionerTestsC extends BaseNd4jTest { public void testSoftMax() { OpExecutioner opExecutioner = Nd4j.getExecutioner(); INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1); - OldSoftMax softMax = new OldSoftMax(arr); - opExecutioner.exec(softMax); - assertEquals(getFailureMessage(), 1.0, softMax.z().sumNumber().doubleValue(), 1e-1); + val softMax = new SoftMax(arr); + opExecutioner.exec((CustomOp) softMax); + assertEquals(getFailureMessage(), 1.0, softMax.outputArguments()[0].sumNumber().doubleValue(), 1e-1); INDArray linspace = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3); - OldSoftMax softmax = new OldSoftMax(linspace.dup()); - Nd4j.getExecutioner().exec(softmax); - assertEquals(linspace.rows(), softmax.z().sumNumber().doubleValue(), 1e-1); + val softmax = new SoftMax(linspace.dup()); + Nd4j.getExecutioner().exec((CustomOp) softmax); + assertEquals(linspace.rows(), softmax.outputArguments()[0].sumNumber().doubleValue(), 1e-1); } @Test public void testDimensionSoftMax() { INDArray linspace = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3); - OldSoftMax max = new OldSoftMax(linspace); - Nd4j.getExecutioner().exec(max); - linspace.assign(max.z()); + val max = new SoftMax(linspace); + Nd4j.getExecutioner().exec((CustomOp) max); + linspace.assign(max.outputArguments()[0]); assertEquals(getFailureMessage(), linspace.getRow(0).sumNumber().doubleValue(), 1.0, 1e-1); } @@ -782,7 +783,7 @@ public class OpExecutionerTestsC extends BaseNd4jTest { public void testSoftmax() { INDArray vec = Nd4j.linspace(1, 18, 18, DataType.DOUBLE); INDArray matrix = vec.dup().reshape(3, 6); - Nd4j.getExecutioner().exec(new OldSoftMax(matrix)); + Nd4j.getExecutioner().exec((CustomOp) new SoftMax(matrix)); INDArray assertion = Nd4j.create( new double[] {0.0042697787, 0.011606461, 0.031549633, 0.085760795, 0.23312202, 0.6336913, 0.0042697787, 0.011606461, 0.031549633, 0.085760795, 0.23312202, 0.6336913, diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/OperationProfilerTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/OperationProfilerTests.java index 4366e99a1..e84e136fe 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/OperationProfilerTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/OperationProfilerTests.java @@ -381,30 +381,32 @@ public class OperationProfilerTests extends BaseNd4jTest { INDArray x = Nd4j.create(1000, 1000).assign(1.0); INDArray y = Nd4j.create(1000, 1000).assign(1.0); - for (int e = 0; e < 10000; e++) { + int iterations = 100; + + for (int e = 0; e < iterations; e++) { x.addi(y); } Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.SCOPE_PANIC); val nanosC = System.nanoTime(); - for (int e = 0; e < 10000; e++) { + for (int e = 0; e < iterations; e++) { x.addi(y); } val nanosD = System.nanoTime(); - val avgB = (nanosD - nanosC) / 10000; + val avgB = (nanosD - nanosC) / iterations; Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.DISABLED); val nanosA = System.nanoTime(); - for (int e = 0; e < 10000; e++) { + for (int e = 0; e < iterations; e++) { x.addi(y); } val nanosB = System.nanoTime(); - val avgA = (nanosB - nanosA) / 10000; + val avgA = (nanosB - nanosA) / iterations; log.info("A: {}; B: {}", avgA, avgB); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java index 66979b95d..e5a567dc1 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java @@ -1429,15 +1429,17 @@ public class RandomTests extends BaseNd4jTest { @Test public void testRngRepeatabilityUniform(){ + val nexp = Nd4j.create(DataType.FLOAT, 10); Nd4j.getRandom().setSeed(12345); - INDArray out1 = Nd4j.create(DataType.FLOAT, 10); + val out1 = Nd4j.create(DataType.FLOAT, 10); Nd4j.exec(new DistributionUniform(Nd4j.createFromArray(10L), out1, 0.0, 1.0)); Nd4j.getRandom().setSeed(12345); - INDArray out2 = Nd4j.create(DataType.FLOAT, 10); + val out2 = Nd4j.create(DataType.FLOAT, 10); Nd4j.exec(new DistributionUniform(Nd4j.createFromArray(10L), out2, 0.0, 1.0)); assertEquals(out1, out2); + assertNotEquals(nexp, out1); } @Test diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeBufferTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeBufferTests.java index 23ab53907..9489cb886 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeBufferTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/ShapeBufferTests.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.shape; +import lombok.val; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -48,18 +49,17 @@ public class ShapeBufferTests extends BaseNd4jTest { @Test public void testRank() { - int[] shape = {2, 4}; - int[] stride = {1, 2}; - IntBuffer buff = Shape.createShapeInformation(shape, stride, 0, 1, 'c').asNioInt(); - int rank = 2; - assertEquals(rank, Shape.rank(buff)); - + long[] shape = {2, 4}; + long[] stride = {1, 2}; + val shapeInfoBuffer = Shape.createShapeInformation(shape, stride, 1, 'c', DataType.DOUBLE, false); + val buff = shapeInfoBuffer.asNioLong(); + assertEquals(2, Shape.rank(buff)); } @Test public void testArrCreationShape() { - INDArray arr = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); + val arr = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2); for (int i = 0; i < 2; i++) assertEquals(2, arr.size(i)); int[] stride = ArrayUtil.calcStrides(new int[] {2, 2}); @@ -70,12 +70,13 @@ public class ShapeBufferTests extends BaseNd4jTest { @Test public void testShape() { - int[] shape = {2, 4}; - int[] stride = {1, 2}; - IntBuffer buff = Shape.createShapeInformation(shape, stride, 0, 1, 'c').asNioInt(); - IntBuffer shapeView = Shape.shapeOf(buff); + long[] shape = {2, 4}; + long[] stride = {1, 2}; + val shapeInfoBuffer = Shape.createShapeInformation(shape, stride, 1, 'c', DataType.DOUBLE, false); + val buff = shapeInfoBuffer.asNioLong(); + val shapeView = Shape.shapeOf(buff); assertTrue(Shape.contentEquals(shape, shapeView)); - IntBuffer strideView = Shape.stride(buff); + val strideView = Shape.stride(buff); assertTrue(Shape.contentEquals(stride, strideView)); assertEquals('c', Shape.order(buff)); assertEquals(1, Shape.elementWiseStride(buff)); @@ -86,9 +87,9 @@ public class ShapeBufferTests extends BaseNd4jTest { @Test public void testBuff() { - int[] shape = {1, 2}; - int[] stride = {1, 2}; - IntBuffer buff = Shape.createShapeInformation(shape, stride, 0, 1, 'c').asNioInt(); + long[] shape = {1, 2}; + long[] stride = {1, 2}; + val buff = Shape.createShapeInformation(shape, stride, 1, 'c', DataType.DOUBLE, false).asNioLong(); assertTrue(Shape.isVector(buff)); } diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java index 2b3bf3875..ee83e82e5 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java +++ b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java @@ -2661,4 +2661,9 @@ public abstract class BaseDataBuffer implements DataBuffer { this.indexer = null; this.pointer = null; } + + @Override + public long platformAddress() { + return address(); + } } diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataBuffer.java b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataBuffer.java index 10147633b..28b04f07f 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataBuffer.java +++ b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataBuffer.java @@ -84,6 +84,14 @@ public interface DataBuffer extends Serializable, AutoCloseable { */ long address(); + /** + * Returns the address of platform-specific pointer: + * - for native backend that'll be host pointer + * - for cuda backend that'll be device pointer + * @return + */ + long platformAddress(); + /** * Returns true if the underlying data source * is the same for both buffers (referential equals)