[WIP] CUDA tweaks (#60)

* special cpu concat

Signed-off-by: raver119 <raver119@gmail.com>

* special concat fix

Signed-off-by: raver119 <raver119@gmail.com>

* OpProfiler tweak for absent host pointers

Signed-off-by: raver119 <raver119@gmail.com>

* minor test tweak to see orders

Signed-off-by: raver119 <raver119@gmail.com>

* CUDA broadcasting diff orders fix

Signed-off-by: raver119 <raver119@gmail.com>

* faster iterations

Signed-off-by: raver119 <raver119@gmail.com>

* OldSoftMax/OldLogSoftMax gone

Signed-off-by: raver119 <raver119@gmail.com>

* RandomLauncher tweaks

Signed-off-by: raver119 <raver119@gmail.com>

* additional check int randomtests

Signed-off-by: raver119 <raver119@gmail.com>

* skip prepare/register action for empty arrays

Signed-off-by: raver119 <raver119@gmail.com>

* npz float16 fix

Signed-off-by: raver119 <raver119@gmail.com>

* empty reduction cuda fixes

Signed-off-by: raver119 <raver119@gmail.com>

* ShapeBufferTests tweaks

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-07-15 16:36:35 +03:00 committed by AlexDBlack
parent 6ce458e949
commit 9cf28ea6c9
43 changed files with 391 additions and 599 deletions

View File

@ -21,26 +21,27 @@
#include <NDArray.h>
#include <helpers/helper_random.h>
#include <graph/RandomGenerator.h>
#include <execution/LaunchContext.h>
namespace nd4j {
class RandomLauncher {
public:
static void applyDropOut(nd4j::graph::RandomGenerator& rng, NDArray *array, double retainProb, NDArray* z = nullptr);
static void applyInvertedDropOut(nd4j::graph::RandomGenerator& rng, NDArray *array, double retainProb, NDArray* z = nullptr);
static void applyAlphaDropOut(nd4j::graph::RandomGenerator& rng, NDArray *array, double retainProb, double alpha, double beta, double alphaPrime, NDArray* z = nullptr);
static void applyDropOut(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray *array, double retainProb, NDArray* z = nullptr);
static void applyInvertedDropOut(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray *array, double retainProb, NDArray* z = nullptr);
static void applyAlphaDropOut(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray *array, double retainProb, double alpha, double beta, double alphaPrime, NDArray* z = nullptr);
static void fillUniform(nd4j::graph::RandomGenerator& rng, NDArray* array, double from, double to);
static void fillUniform(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray* array, double from, double to);
static void fillGaussian(nd4j::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev);
static void fillGaussian(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev);
static void fillExponential(nd4j::graph::RandomGenerator& rng, NDArray* array, double lambda);
static void fillExponential(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray* array, double lambda);
static void fillLogNormal(nd4j::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev);
static void fillLogNormal(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev);
static void fillTruncatedNormal(nd4j::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev);
static void fillTruncatedNormal(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev);
static void fillBinomial(nd4j::graph::RandomGenerator& rng, NDArray* array, int trials, double prob);
static void fillBinomial(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray* array, int trials, double prob);
static void fillBernoulli(nd4j::graph::RandomGenerator& rng, NDArray* array, double prob);
static void fillBernoulli(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray* array, double prob);
};
}

View File

@ -23,76 +23,97 @@
#include <helpers/RandomLauncher.h>
#include <graph/RandomGenerator.h>
#include <ops/declarable/CustomOperations.h>
#include <helpers/PointersManager.h>
namespace nd4j {
// FIXME: implement this
void RandomLauncher::applyDropOut(nd4j::graph::RandomGenerator& rng, NDArray *array, double retainProb, NDArray* z) {
void RandomLauncher::applyDropOut(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray *array, double retainProb, NDArray* z) {
if (z == nullptr)
z = array;
ExtraArguments arguments({retainProb});
PointersManager pm(context, "applyDropOut");
NativeOpExecutioner::execRandom(nullptr, random::DropOut, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), arguments.argumentsAsT(z->dataType()));
NativeOpExecutioner::execRandom(context, random::DropOut, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), arguments.argumentsAsT(z->dataType()));
pm.synchronize();
}
void RandomLauncher::applyInvertedDropOut(nd4j::graph::RandomGenerator& rng, NDArray *array, double retainProb, NDArray* z) {
void RandomLauncher::applyInvertedDropOut(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray *array, double retainProb, NDArray* z) {
if (z == nullptr)
z = array;
ExtraArguments arguments({retainProb});
PointersManager pm(context, "applyInvertedDropOut");
NativeOpExecutioner::execRandom(nullptr, random::DropOutInverted, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), arguments.argumentsAsT(z->dataType()));
NativeOpExecutioner::execRandom(context, random::DropOutInverted, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), arguments.argumentsAsT(z->dataType()));
pm.synchronize();
}
void RandomLauncher::applyAlphaDropOut(nd4j::graph::RandomGenerator& rng, NDArray *array, double retainProb, double alpha, double beta, double alphaPrime, NDArray* z) {
void RandomLauncher::applyAlphaDropOut(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray *array, double retainProb, double alpha, double beta, double alphaPrime, NDArray* z) {
if (z == nullptr)
z = array;
ExtraArguments arguments({retainProb, alpha, beta, alphaPrime});
PointersManager pm(context, "applyAlphaDropOut");
NativeOpExecutioner::execRandom(nullptr, random::AlphaDropOut, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), arguments.argumentsAsT(z->dataType()));
NativeOpExecutioner::execRandom(context, random::AlphaDropOut, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), arguments.argumentsAsT(z->dataType()));
pm.synchronize();
}
void RandomLauncher::fillBernoulli(nd4j::graph::RandomGenerator& rng, NDArray* array, double prob) {
void RandomLauncher::fillBernoulli(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray* array, double prob) {
ExtraArguments arguments({prob});
PointersManager pm(context, "fillBernoulli");
NativeOpExecutioner::execRandom(nullptr, random::BernoulliDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType()));
NativeOpExecutioner::execRandom(context, random::BernoulliDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType()));
pm.synchronize();
}
void RandomLauncher::fillUniform(nd4j::graph::RandomGenerator& rng, NDArray* array, double from, double to) {
void RandomLauncher::fillUniform(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray* array, double from, double to) {
ExtraArguments arguments({from, to});
PointersManager pm(context, "fillUniform");
NativeOpExecutioner::execRandom(nullptr, random::UniformDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType()));
NativeOpExecutioner::execRandom(context, random::UniformDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType()));
pm.synchronize();
}
void RandomLauncher::fillGaussian(nd4j::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev) {
void RandomLauncher::fillGaussian(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev) {
ExtraArguments arguments({mean, stdev});
PointersManager pm(context, "fillGaussian");
NativeOpExecutioner::execRandom(nullptr, random::GaussianDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType()));
NativeOpExecutioner::execRandom(context, random::GaussianDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType()));
pm.synchronize();
}
void RandomLauncher::fillExponential(nd4j::graph::RandomGenerator& rng, NDArray* array, double lambda) {
void RandomLauncher::fillExponential(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray* array, double lambda) {
ExtraArguments arguments({lambda});
PointersManager pm(context, "fillExponential");
NativeOpExecutioner::execRandom(nullptr, random::ExponentialDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType()));
NativeOpExecutioner::execRandom(context, random::ExponentialDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType()));
pm.synchronize();
}
void RandomLauncher::fillLogNormal(nd4j::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev) {
void RandomLauncher::fillLogNormal(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev) {
ExtraArguments arguments({mean, stdev});
PointersManager pm(context, "fillLogNormal");
NativeOpExecutioner::execRandom(nullptr, random::GaussianDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType()));
NativeOpExecutioner::execRandom(context, random::GaussianDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType()));
pm.synchronize();
}
void RandomLauncher::fillTruncatedNormal(nd4j::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev) {
void RandomLauncher::fillTruncatedNormal(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev) {
ExtraArguments arguments({mean, stdev});
PointersManager pm(context, "fillTruncatedNormal");
NativeOpExecutioner::execRandom(nullptr, random::TruncatedNormalDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType()));
NativeOpExecutioner::execRandom(context, random::TruncatedNormalDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType()));
pm.synchronize();
}
void RandomLauncher::fillBinomial(nd4j::graph::RandomGenerator& rng, NDArray* array, int trials, double prob) {
void RandomLauncher::fillBinomial(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray* array, int trials, double prob) {
ExtraArguments arguments({(double) trials, prob});
PointersManager pm(context, "fillBinomial");
NativeOpExecutioner::execRandom(nullptr, random::BinomialDistributionEx, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType()));
NativeOpExecutioner::execRandom(context, random::BinomialDistributionEx, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType()));
pm.synchronize();
}
}

View File

@ -128,6 +128,10 @@ namespace functions {
}
__syncthreads();
auto xOrder = shape::order(xShapeInfo);
auto yOrder = shape::order(tadOnlyShapeInfo);
auto zOrder = shape::order(tadOnlyShapeInfoZ);
for (int r = blockIdx.x; r < numTads; r += gridDim.x) {
@ -135,7 +139,7 @@ namespace functions {
auto rZ = z + tadOffsetsZ[r];
if(tadEWS > 0 && zEWS > 0 && xEWS > 0 && dimensionLength == 1) {
if(tadEWS > 0 && zEWS > 0 && xEWS > 0 && dimensionLength == 1 && xOrder == yOrder && xOrder == zOrder) {
for (int i = threadIdx.x; i < tadLength; i+= blockDim.x)
rZ[i * zEWS] = OpType::op(x[i * xEWS], rY[i * tadEWS]);
}
@ -190,6 +194,9 @@ namespace functions {
}
__syncthreads();
auto xOrder = shape::order(tadOnlyShapeInfo);
auto yOrder = shape::order(yShapeInfo);
auto zOrder = shape::order(tadOnlyShapeInfoZ);
for (int r = blockIdx.x; r < numTads; r += gridDim.x) {
@ -197,7 +204,7 @@ namespace functions {
auto rZ = z + tadOffsetsZ[r];
if(tadEWS > 0 && zEWS > 0 && yEWS > 0) {
if(tadEWS > 0 && zEWS > 0 && yEWS > 0 && xOrder == yOrder && xOrder == zOrder) {
for (int i = threadIdx.x; i < tadLength; i+= blockDim.x)
rZ[i * zEWS] = OpType::op(rX[i * tadEWS], y[i * yEWS]);
}

View File

@ -44,7 +44,7 @@ namespace nd4j {
auto z = OUTPUT_VARIABLE(0);
auto f = T_ARG(0);
RandomLauncher::fillBernoulli(rng, z, f);
RandomLauncher::fillBernoulli(block.launchContext(), rng, z, f);
return Status::OK();
}

View File

@ -53,7 +53,7 @@ namespace nd4j {
auto z = OUTPUT_VARIABLE(0);
auto lambda = T_ARG(0);
RandomLauncher::fillExponential(rng, z, lambda);
RandomLauncher::fillExponential(block.launchContext(), rng, z, lambda);
return Status::OK();
}

View File

@ -39,7 +39,7 @@ namespace nd4j {
functions::random::RandomFunction<T>::template execTransform<randomOps::GaussianDistribution<T>>(block.getRNG(), z->getBuffer(), z->getShapeInfo(), z->getBuffer(), z->getShapeInfo(), z->getBuffer(), z->getShapeInfo(), block.getTArguments()->data());
*/
RandomLauncher::fillGaussian(rng, OUTPUT_VARIABLE(0), T_ARG(0), T_ARG(1));
RandomLauncher::fillGaussian(block.launchContext(), rng, OUTPUT_VARIABLE(0), T_ARG(0), T_ARG(1));
return Status::OK();
}

View File

@ -53,7 +53,7 @@ namespace nd4j {
*/
REQUIRE_TRUE(block.numT() > 1, 0, "RandomUniform: to/from must be set");
RandomLauncher::fillUniform(rng, OUTPUT_VARIABLE(0), T_ARG(0), T_ARG(1));
RandomLauncher::fillUniform(block.launchContext(), rng, OUTPUT_VARIABLE(0), T_ARG(0), T_ARG(1));
return Status::OK();
}

View File

@ -1203,71 +1203,7 @@ static void mirrorPad_(const NDArray& input, const NDArray& paddings, NDArray& o
//////////////////////////////////////////////////////////////////////////
template<typename T>
static void concat_(const std::vector<NDArray*>& inArrs, NDArray& output, const int axis) {
const uint numOfArrs = inArrs.size();
int outDim;
const bool isOutputVector = output.isCommonVector(outDim);
if(isOutputVector || (axis == 0 && output.ordering() == 'c')) {
bool allVectorsOrScalars = true;
const uint outEws = isOutputVector ? output.stridesOf()[outDim] : output.ews();
std::vector<int> nonUnityDim(numOfArrs);
std::vector<Nd4jLong> zOffset(numOfArrs);
for(int i = 0; i < numOfArrs; i++) {
allVectorsOrScalars &= (inArrs[i]->lengthOf() == 1 || inArrs[i]->isCommonVector(nonUnityDim[i]));
if(!allVectorsOrScalars)
break;
if(i == 0) zOffset[0] = 0;
else zOffset[i] = zOffset[i - 1] + outEws * inArrs[i - 1]->lengthOf();
}
if(allVectorsOrScalars) {
T* outBuff = output.bufferAsT<T>();
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (uint r = 0; r < numOfArrs; r++) {
const uint arrLen = inArrs[r]->lengthOf();
const uint xEws = (arrLen == 1) ? 1 : inArrs[r]->stridesOf()[nonUnityDim[r]];
T *z = outBuff + zOffset[r];
T *x = inArrs[r]->bufferAsT<T>();
if(outEws == 1 && xEws == 1)
for (uint e = 0; e < arrLen; e++)
z[e] = x[e];
else
for (uint e = 0; e < arrLen; e++)
z[e * outEws] = x[e * xEws];
}
return;
}
}
const int rank = inArrs[0]->rankOf();
const int rank2 = 2*rank;
std::vector<std::vector<Nd4jLong>> indices(numOfArrs, std::vector<Nd4jLong>(rank2,0));
// take into account indices for first array
indices[0][2 * axis + 1] = inArrs[0]->sizeAt(axis);
// loop through the rest of input arrays
for(int i = 1; i < numOfArrs; ++i) {
indices[i][2 * axis] = indices[i-1][2 * axis + 1]; // index start from
indices[i][2 * axis + 1] = indices[i-1][2 * axis + 1] + inArrs[i]->sizeAt(axis); // index end with (excluding)
}
PRAGMA_OMP_PARALLEL_FOR_SIMD
for(int i = 0; i < numOfArrs; ++i) {
auto temp = output(indices[i], true);
nd4j::TransformLoops<T,T,T>::template loopTransform<simdOps::Assign<T,T>, false>(inArrs[i]->bufferAsT<T>(), inArrs[i]->getShapeInfo(), temp.bufferAsT<T>(), temp.getShapeInfo(), nullptr);
// temp.assign(inArrs[i]);
}
nd4j::SpecialMethods<T>::concatCpuGeneric(inArrs, output, axis);
}
void concat(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output, const int axis) {

View File

@ -81,7 +81,7 @@ namespace nd4j {
auto z = OUTPUT_VARIABLE(0); //NDArrayFactory::create_<T>('c', shape, block.getWorkspace());
RandomLauncher::fillUniform(block.randomGenerator(), z, from, to);
RandomLauncher::fillUniform(block.launchContext(), block.randomGenerator(), z, from, to);
// FIXME:
//OVERWRITE_RESULT(z);
@ -105,7 +105,7 @@ namespace nd4j {
if (!block.isInplace())
z->assign(input);
RandomLauncher::applyDropOut(block.randomGenerator(), z, prob);
RandomLauncher::applyDropOut(block.launchContext(), block.randomGenerator(), z, prob);
}
break;
case nd4j::random::DropOutInverted: {
@ -140,7 +140,7 @@ namespace nd4j {
auto z = OUTPUT_VARIABLE(0);//NDArrayFactory::create_<T>('c', shape, block.getWorkspace());
RandomLauncher::fillGaussian(block.randomGenerator(), z, mean, stdev);
RandomLauncher::fillGaussian(block.launchContext(), block.randomGenerator(), z, mean, stdev);
// FIXME: !!
//OVERWRITE_RESULT(z);
@ -168,7 +168,7 @@ namespace nd4j {
auto z = OUTPUT_VARIABLE(0); // NDArrayFactory::create_<T>('c', shape, block.getWorkspace());
RandomLauncher::fillBernoulli(block.randomGenerator(), z, prob);
RandomLauncher::fillBernoulli(block.launchContext(), block.randomGenerator(), z, prob);
// FIXME:
//OVERWRITE_RESULT(z);
@ -201,7 +201,7 @@ namespace nd4j {
auto z = OUTPUT_VARIABLE(0);//NDArrayFactory::create_<T>('c', shape, block.getWorkspace());
RandomLauncher::fillBinomial(block.randomGenerator(), z, trials, prob);
RandomLauncher::fillBinomial(block.launchContext(), block.randomGenerator(), z, trials, prob);
// FIXME: !!!
//OVERWRITE_RESULT(z);
@ -233,7 +233,7 @@ namespace nd4j {
auto z = OUTPUT_VARIABLE(0);//NDArrayFactory::create_<T>('c', shape, block.getWorkspace());
RandomLauncher::fillLogNormal(block.randomGenerator(), z, mean, stdev);
RandomLauncher::fillLogNormal(block.launchContext(), block.randomGenerator(), z, mean, stdev);
// FIXME: !!
//OVERWRITE_RESULT(z);
@ -265,7 +265,7 @@ namespace nd4j {
auto z = OUTPUT_VARIABLE(0); // NDArrayFactory::create_<T>('c', shape, block.getWorkspace());
RandomLauncher::fillTruncatedNormal(block.randomGenerator(), z, mean, stdev);
RandomLauncher::fillTruncatedNormal(block.launchContext(), block.randomGenerator(), z, mean, stdev);
// FIXME: !!!
//OVERWRITE_RESULT(z);
@ -301,7 +301,7 @@ namespace nd4j {
if (!block.isInplace())
z->assign(input);
RandomLauncher::applyAlphaDropOut(block.randomGenerator(), z, prob, a, b, pa);
RandomLauncher::applyAlphaDropOut(block.launchContext(), block.randomGenerator(), z, prob, a, b, pa);
}
break;
case nd4j::random::Linspace: {

View File

@ -28,9 +28,81 @@
#include <NDArray.h>
#include <ops/declarable/CustomOperations.h>
#include <types/types.h>
#include <helpers/Loops.h>
namespace nd4j {
/**
* Concatneate multi array of the same shape together
* along a particular dimension
*/
template <typename T>
void SpecialMethods<T>::concatCpuGeneric(const std::vector<NDArray*>& inArrs, NDArray& output, const int axis) {
const uint numOfArrs = inArrs.size();
int outDim;
const bool isOutputVector = output.isCommonVector(outDim);
if(isOutputVector || (axis == 0 && output.ordering() == 'c')) {
bool allVectorsOrScalars = true;
const uint outEws = isOutputVector ? output.stridesOf()[outDim] : output.ews();
std::vector<int> nonUnityDim(numOfArrs);
std::vector<Nd4jLong> zOffset(numOfArrs);
for(int i = 0; i < numOfArrs; i++) {
allVectorsOrScalars &= (inArrs[i]->lengthOf() == 1 || inArrs[i]->isCommonVector(nonUnityDim[i]));
if(!allVectorsOrScalars)
break;
if(i == 0) zOffset[0] = 0;
else zOffset[i] = zOffset[i - 1] + outEws * inArrs[i - 1]->lengthOf();
}
if(allVectorsOrScalars) {
T* outBuff = output.bufferAsT<T>();
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (uint r = 0; r < numOfArrs; r++) {
const uint arrLen = inArrs[r]->lengthOf();
const uint xEws = (arrLen == 1) ? 1 : inArrs[r]->stridesOf()[nonUnityDim[r]];
T *z = outBuff + zOffset[r];
T *x = inArrs[r]->bufferAsT<T>();
if(outEws == 1 && xEws == 1)
for (uint e = 0; e < arrLen; e++)
z[e] = x[e];
else
for (uint e = 0; e < arrLen; e++)
z[e * outEws] = x[e * xEws];
}
return;
}
}
const int rank = inArrs[0]->rankOf();
const int rank2 = 2*rank;
std::vector<std::vector<Nd4jLong>> indices(numOfArrs, std::vector<Nd4jLong>(rank2,0));
// take into account indices for first array
indices[0][2 * axis + 1] = inArrs[0]->sizeAt(axis);
// loop through the rest of input arrays
for(int i = 1; i < numOfArrs; ++i) {
indices[i][2 * axis] = indices[i-1][2 * axis + 1]; // index start from
indices[i][2 * axis + 1] = indices[i-1][2 * axis + 1] + inArrs[i]->sizeAt(axis); // index end with (excluding)
}
PRAGMA_OMP_PARALLEL_FOR_SIMD
for(int i = 0; i < numOfArrs; ++i) {
auto temp = output(indices[i], true);
nd4j::TransformLoops<T,T,T>::template loopTransform<simdOps::Assign<T,T>, false>(inArrs[i]->bufferAsT<T>(), inArrs[i]->getShapeInfo(), temp.bufferAsT<T>(), temp.getShapeInfo(), nullptr);
}
}
/**
* Concatneate multi array of the same shape together
* along a particular dimension
@ -38,24 +110,14 @@ namespace nd4j {
template <typename T>
void SpecialMethods<T>::concatCpuGeneric(int dimension, int numArrays, Nd4jPointer *data, Nd4jPointer *inputShapeInfo, void *vresult, Nd4jLong *resultShapeInfo) {
auto result = reinterpret_cast<T *>(vresult);
std::vector<Nd4jLong> iArgs = {dimension};
std::vector<double> tArgs;
std::vector<bool> bArgsEmpty;
std::vector<NDArray*> inputs(numArrays);
std::vector<NDArray*> outputs(1);
outputs[0] = new NDArray(static_cast<void*>(result), static_cast<Nd4jLong*>(resultShapeInfo));
NDArray output(static_cast<void*>(result), static_cast<Nd4jLong*>(resultShapeInfo));
for(int i = 0; i < numArrays; ++i)
inputs[i] = new NDArray(static_cast<void *>(data[i]), static_cast<Nd4jLong*>(inputShapeInfo[i]));
nd4j::ops::concat op;
auto status = op.execute(inputs, outputs, tArgs, iArgs, bArgsEmpty);
if(status != Status::OK())
throw std::runtime_error("concatCpuGeneric fails to be executed !");
delete outputs[0];
nd4j::SpecialMethods<T>::concatCpuGeneric(inputs, output, dimension);
for(int i = 0; i < numArrays; ++i)
delete inputs[i];

View File

@ -30,6 +30,8 @@
#include <pointercast.h>
namespace nd4j {
class NDArray;
//FIXME: get rid of this redefinition
typedef union
{
@ -47,6 +49,7 @@ namespace nd4j {
template <typename T>
class ND4J_EXPORT SpecialMethods {
public:
static void concatCpuGeneric(const std::vector<NDArray*>& inArrs, NDArray& output, const int axis);
static void concatCpuGeneric(int dimension, int numArrays, Nd4jPointer *data, Nd4jPointer *inputShapeInfo, void *result, Nd4jLong *resultShapeInfo);
static void accumulateGeneric(void **x, void *z, Nd4jLong *zShapeInfo, int n, const Nd4jLong length);
static void averageGeneric(void **x, void *z, Nd4jLong *zShapeInfo, int n, const Nd4jLong length, bool propagate);

View File

@ -541,8 +541,6 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.transforms.strict.Log.class,
org.nd4j.linalg.api.ops.impl.transforms.strict.Log1p.class,
org.nd4j.linalg.api.ops.impl.transforms.strict.LogSigmoid.class,
org.nd4j.linalg.api.ops.impl.transforms.strict.OldLogSoftMax.class,
org.nd4j.linalg.api.ops.impl.transforms.strict.OldSoftMax.class,
org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELU.class,
org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELUDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.strict.RationalTanh.class,

View File

@ -23,6 +23,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarSet;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.same.OldIdentity;
import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear;
import org.nd4j.linalg.api.ops.impl.scalar.Step;
@ -161,109 +162,4 @@ public enum Activation {
throw new UnsupportedOperationException("Activation function not yet supported: " + this);
}
}
/**
* Get the Activation function as an ND4J Transform, applied on either the input or a copy of the input
*
* @param in Input to apply the activation function op to
* @param dup If true: duplicate the array before applying the transform. If false: don't duplicate
* @return The transform op (execute using {@code Nd4j.getExecutioner().exec(op)}
*/
public Op asTransform(INDArray in, boolean dup) {
if (dup) {
in = in.dup();
}
switch (this) {
case CUBE:
return new Cube(in);
case ELU:
return new ELU(in);
case HARDSIGMOID:
return new HardSigmoid(in);
case HARDTANH:
return new HardTanh(in);
case IDENTITY:
return new OldIdentity(in);
case LEAKYRELU:
return new LeakyReLU(in);
case RATIONALTANH:
return new RationalTanh(in);
case RELU:
return new RectifiedLinear(in);
case SIGMOID:
return new Sigmoid(in);
case SOFTMAX:
return new OldSoftMax(in);
case SOFTPLUS:
return new SoftPlus(in);
case SOFTSIGN:
return new SoftSign(in);
case TANH:
return new Tanh(in);
case RECTIFIEDTANH:
return new RectifiedTanh(in);
case SELU:
return new SELU(in);
case SWISH:
return new Swish(in);
case GELU:
return new GELU(in);
case RRELU:
default:
throw new UnsupportedOperationException("Not supported via this method: " + this);
}
}
/**
* Get the Activation function <i>derivative</i> (i.e., dOut/dIn) as an ND4J Transform, applied on either the input
* or a copy of the input
*
* @param in Input to apply the activation function derivative op to
* @param dup If true: duplicate the array before applying the transform. If false: don't duplicate
* @return The op (execute using {@code Nd4j.getExecutioner().exec(op)}
*/
public Op asTransformDerivative(INDArray in, boolean dup) {
if (dup) {
in = in.dup();
}
switch (this) {
case CUBE:
return new CubeDerivative(in);
case ELU:
return new ELUDerivative(in);
case HARDSIGMOID:
return new HardSigmoidDerivative(in);
case HARDTANH:
return new HardTanhDerivative(in);
case LEAKYRELU:
return new LeakyReLUDerivative(in);
case RATIONALTANH:
return new RationalTanhDerivative(in);
case SIGMOID:
return new SigmoidDerivative(in);
case SOFTPLUS:
return new Sigmoid(in);
case SOFTSIGN:
return new SoftSignDerivative(in);
case TANH:
return new TanhDerivative(in);
case RECTIFIEDTANH:
return new RectifiedTanhDerivative(in);
case SELU:
return new SELUDerivative(in);
case SWISH:
return new SwishDerivative(in);
case SOFTMAX:
return new SoftMaxDerivative(in);
case IDENTITY:
return new ScalarSet(in, 1.0);
case RELU:
return new Step(in);
case GELU:
return new GELUDerivative(in);
case RRELU:
default:
throw new UnsupportedOperationException("Not supported via this method: " + this);
}
}
}

View File

@ -20,7 +20,8 @@ import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.strict.OldSoftMax;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
@ -34,14 +35,14 @@ public class ActivationSoftmax extends BaseActivationFunction {
@Override
public INDArray getActivation(INDArray in, boolean training) {
Nd4j.getExecutioner().execAndReturn(new OldSoftMax(in));
Nd4j.getExecutioner().execAndReturn((CustomOp) new SoftMax(in, in));
return in;
}
@Override
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
assertShape(in, epsilon);
INDArray out = Nd4j.getExecutioner().exec(new OldSoftMax(in));
INDArray out = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(in, in.ulike()))[0];
INDArray x = out.mul(epsilon).sum(1);
INDArray dLdz = out.mul(epsilon.subColumnVector(x));
return new Pair<>(dLdz, null);

View File

@ -21,14 +21,9 @@ import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseTransformOp;
import org.nd4j.linalg.api.ops.BaseTransformStrictOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.OldSoftMax;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

View File

@ -19,10 +19,14 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
import java.nio.Buffer;
import java.util.Collections;
import java.util.List;
@ -53,10 +57,6 @@ public class SoftMax extends BaseDynamicTransformOp {
super(sameDiff, args, inPlace);
}
public SoftMax(INDArray input, INDArray result){
super(new INDArray[]{input}, new INDArray[]{result});
}
public SoftMax(SameDiff sameDiff, SDVariable[] args, int dimension) {
super(sameDiff, args, false);
this.dimension = dimension;
@ -75,13 +75,19 @@ public class SoftMax extends BaseDynamicTransformOp {
addIArgument(dimension);
}
public SoftMax(INDArray input){
this(input, input);
}
public SoftMax(INDArray input, INDArray result){
this(input, result, -1);
}
@Override
public String opName() {
return "softmax";
}
@Override
public String onnxName() {
return "Softmax";

View File

@ -1,87 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.strict;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseTransformStrictOp;
import java.util.Collections;
import java.util.List;
/**
* Old LogSoftMax function
*
* @author Adam Gibson
*/
public class OldLogSoftMax extends BaseTransformStrictOp {
public OldLogSoftMax(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
super(sameDiff, i_v1, i_v2);
}
public OldLogSoftMax(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
super(sameDiff, i_v1, i_v2, inPlace);
}
public OldLogSoftMax(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
super(sameDiff, i_v, inPlace);
}
public OldLogSoftMax() {
}
public OldLogSoftMax(INDArray x){
this(x,x);
}
public OldLogSoftMax(INDArray x, INDArray z) {
super(x, z);
Preconditions.checkArgument(x != null && x.rank() == 2, "OldSoftMax op supports rank 2 (2d) arrays only. Got x (source) array with shape: %ndShape", x);
Preconditions.checkArgument(z != null && z.rank() == 2, "OldSoftMax op supports rank 2 (2d) arrays only. Got z (result) array with shape: %ndShape", z);
}
@Override
public int opNum() {
return 2;
}
@Override
public String opName() {
return "old_logsoftmax";
}
@Override
public String onnxName() {
return "old_LogSoftmax";
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
}
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
SDVariable ret = f().logSoftmaxDerivative(arg(), i_v.get(0));
return Collections.singletonList(ret);
}
}

View File

@ -1,94 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.strict;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseTransformStrictOp;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Collections;
import java.util.List;
/**
* Soft max function
* row_maxes is a row vector (max for each row)
* row_maxes = rowmaxes(input)
* diff = exp(input - max) / diff.rowSums()
* Outputs a probability distribution.
* Note that this is a parameterized model and requires
* the sum and max for the vector being calculated
*
* @author Adam Gibson
*/
public class OldSoftMax extends BaseTransformStrictOp {
public OldSoftMax(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
super(sameDiff, i_v1, i_v2);
}
public OldSoftMax(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
super(sameDiff, i_v1, i_v2, inPlace);
}
public OldSoftMax(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
super(sameDiff, i_v, inPlace);
}
public OldSoftMax() {
}
public OldSoftMax(INDArray x){
this(x,x);
}
public OldSoftMax(INDArray x, INDArray z) {
super(x, z);
Preconditions.checkArgument(x != null && x.rank() == 2, "OldSoftMax op supports rank 2 (2d) arrays only. Got x (source) array with shape: %ndShape", x);
Preconditions.checkArgument(z != null && z.rank() == 2, "OldSoftMax op supports rank 2 (2d) arrays only. Got z (result) array with shape: %ndShape", z);
}
@Override
public int opNum() {
return 0;
}
@Override
public String opName() {
return "old_softmax";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName());
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
}
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
SDVariable ret = f().softmaxDerivative(arg(), i_v.get(0), 1);
return Collections.singletonList(ret);
}
}

View File

@ -19,20 +19,27 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import java.nio.Buffer;
/**
* Softmax derivative
*
* @author Adam Gibson
*/
public class SoftMaxDerivative extends OldSoftMax {
public class SoftMaxDerivative extends SoftMax {
public SoftMaxDerivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
super(sameDiff, i_v1, i_v2);
super(sameDiff, new SDVariable[]{i_v1, i_v2});
}
public SoftMaxDerivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
super(sameDiff, i_v1, i_v2, inPlace);
super(sameDiff, new SDVariable[]{ i_v1, i_v2}, inPlace);
}
public SoftMaxDerivative(INDArray x, INDArray z) {
@ -40,11 +47,13 @@ public class SoftMaxDerivative extends OldSoftMax {
}
public SoftMaxDerivative(INDArray x) {
super(x);
super(x, x);
}
public SoftMaxDerivative() {}
@Override
public int opNum() {
return 1;

View File

@ -1334,6 +1334,17 @@ public class Shape {
}
}
public static boolean isVector(LongBuffer shapeInfo) {
int rank = Shape.rank(shapeInfo);
if (rank > 2 || rank < 1)
return false;
else {
long len = Shape.length(shapeInfo);
val shape = Shape.shapeOf(shapeInfo);
return shape.get(0) == len || shape.get(1) == len;
}
}
/**
* Returns whether the given shape is a vector
*
@ -2498,6 +2509,14 @@ public class Shape {
return ret;
}
public static int length(LongBuffer buffer) {
int ret = 1;
val shape = Shape.shapeOf(buffer);
int rank = Shape.rank(buffer);
for (int i = 0; i < rank; i++)
ret *= shape.get(i);
return ret;
}
/**
* Gets the rank given the shape info buffer
@ -2764,8 +2783,8 @@ public class Shape {
* @param rank the rank to get the length for
* @return rank * 2 + 4
*/
public static int shapeInfoLength(int rank) {
return rank * 2 + 4;
public static int shapeInfoLength(long rank) {
return (int) rank * 2 + 4;
}
public static int shapeInfoLength(long[] shape) {
@ -3072,6 +3091,11 @@ public class Shape {
return buffer.get(length2 - 2);
}
public static long elementWiseStride(LongBuffer buffer) {
int length2 = shapeInfoLength(buffer.get(0));
return buffer.get(length2 - 2);
}
/**
* Get the element wise stride for the
* shape info buffer
@ -3179,40 +3203,6 @@ public class Shape {
throw new RuntimeException("setOrder called");
}
/**
* Creates the shape information buffer
* given the shape,stride
* @param shape the shape for the buffer
* @param stride the stride for the buffer
* @param offset the offset for the buffer
* @param elementWiseStride the element wise stride for the buffer
* @param order the order for the buffer
* @return the shape information buffer given the parameters
*/
public static DataBuffer createShapeInformation(int[] shape, int[] stride, long offset, int elementWiseStride, char order) {
if (shape.length != stride.length)
throw new IllegalStateException("Shape and stride must be the same length");
int rank = shape.length;
int shapeBuffer[] = new int[rank * 2 + 4];
shapeBuffer[0] = rank;
int count = 1;
for (int e = 0; e < shape.length; e++)
shapeBuffer[count++] = shape[e];
for (int e = 0; e < stride.length; e++)
shapeBuffer[count++] = stride[e];
shapeBuffer[count++] = (int) offset;
shapeBuffer[count++] = elementWiseStride;
shapeBuffer[count] = (int) order;
DataBuffer ret = Nd4j.createBufferDetached(shapeBuffer);
ret.setConstant(true);
return ret;
}
public static DataBuffer createShapeInformation(long[] shape, long[] stride, long elementWiseStride, char order, DataType dataType, boolean empty) {
boolean isEmpty = empty;
if (!empty)
@ -3438,9 +3428,20 @@ public class Shape {
public static boolean contentEquals(long[] arr, IntBuffer other) {
for (int i = 0; i < arr.length; i++) {
Buffer buffer2 = (Buffer) other;
buffer2.position(i);
if (arr[i] != other.get()) {
val t = arr[i];
val o = other.get(i);
if (t != o) {
return false;
}
}
return true;
}
public static boolean contentEquals(long[] arr, LongBuffer other) {
for (int i = 0; i < arr.length; i++) {
val t = arr[i];
val o = other.get(i);
if (t != o) {
return false;
}
}

View File

@ -25,8 +25,8 @@ import org.nd4j.linalg.activations.impl.ActivationSoftmax;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.same.TimesOneMinus;
import org.nd4j.linalg.api.ops.impl.transforms.strict.OldSoftMax;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.LossUtil;
@ -121,7 +121,7 @@ public class LossBinaryXENT implements ILossFunction {
INDArray scoreArr;
if (activationFn instanceof ActivationSoftmax) {
//TODO Post GPU support for custom ops: Use LogSoftMax op to avoid numerical issues when calculating score
INDArray logsoftmax = Nd4j.getExecutioner().exec(new OldSoftMax(preOutput.dup()));
INDArray logsoftmax = Nd4j.exec((CustomOp) new SoftMax(preOutput, preOutput.ulike(), -1))[0];
Transforms.log(logsoftmax, false);
scoreArr = logsoftmax.muli(labels);

View File

@ -20,7 +20,8 @@ import lombok.Data;
import lombok.EqualsAndHashCode;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.strict.OldSoftMax;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
@ -139,7 +140,7 @@ public class LossMixtureDensity implements ILossFunction {
// Alpha is a softmax because
// the alpha should all sum to 1 for a given gaussian mixture.
mdc.alpha = Nd4j.getExecutioner().exec(new OldSoftMax(mdc.alpha));
mdc.alpha = Nd4j.exec((CustomOp) new SoftMax(mdc.alpha, mdc.alpha, -1))[0];
// Mu comes directly from the network as an unmolested value.
// Note that this effectively means that the output layer of

View File

@ -21,6 +21,7 @@ import lombok.val;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.impl.reduce3.*;
@ -29,6 +30,7 @@ import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarNot;
import org.nd4j.linalg.api.ops.impl.shape.Cross;
import org.nd4j.linalg.api.ops.impl.transforms.bool.BooleanNot;
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.floating.*;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.*;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.ELUDerivative;
@ -512,7 +514,7 @@ public class Transforms {
* @return
*/
public static INDArray softmax(INDArray in, boolean copy) {
return Nd4j.getExecutioner().exec(new OldSoftMax(in, (copy ? in.ulike() : in)));
return Nd4j.getExecutioner().exec((CustomOp) new SoftMax(in, (copy ? in.ulike() : in), -1))[0];
}
/**

View File

@ -240,7 +240,7 @@ public class OpProfiler {
String opClass = getOpClass(op);
classCounter.incrementCount(opClass);
if(op.x() == null || (op.x() != null && op.x().data().address() == lastZ && op.z() == op.x() && op.y() == null)) {
if(op.x() == null || (op.x() != null && op.x().data().platformAddress() == lastZ && op.z() == op.x() && op.y() == null)) {
// we have possible shift here
matchingCounter.incrementCount(prevOpMatching + " -> " + opClass);
matchingCounterDetailed.incrementCount(prevOpMatchingDetailed + " -> " + opClass + " " + op.opName());
@ -254,7 +254,7 @@ public class OpProfiler {
}
}
lastZ = op.z() != null ? op.z().data().address() : 0L;
lastZ = op.z() != null ? op.z().data().platformAddress() : 0L;
prevOpMatching = opClass;
prevOpMatchingDetailed = opClass + " " + op.opName();
prevOpMatchingInverted = opClass + " " + op.opName();

View File

@ -610,6 +610,7 @@ public abstract class BaseNativeNDArrayFactory extends BaseNDArrayFactory {
bb2.put((byte)((s >> 8) & 0xff));
bb2.put((byte)(s & 0xff));
}
Nd4j.getAffinityManager().tagLocation(arr, AffinityManager.Location.HOST);
map.put(fName, arr.reshape(order, shape));
} else if(dt == DataType.LONG){
long[] d = new long[(int)size];

View File

@ -72,7 +72,7 @@ public class SynchronousFlowController implements FlowController {
public void synchronizeToHost(AllocationPoint point) {
if (!point.isActualOnHostSide()) {
CudaContext context = (CudaContext) allocator.getDeviceContext().getContext();
val context = (CudaContext) allocator.getDeviceContext().getContext();
if (!point.isConstant())
waitTillFinished(point);
@ -102,7 +102,7 @@ public class SynchronousFlowController implements FlowController {
if (!point.isActualOnDeviceSide()) {
if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
CudaContext context = (CudaContext) allocator.getDeviceContext().getContext();
val context = (CudaContext) allocator.getDeviceContext().getContext();
long perfD = PerformanceTracker.getInstance().helperStartTransaction();
@ -135,17 +135,17 @@ public class SynchronousFlowController implements FlowController {
@Override
public CudaContext prepareActionAllWrite(INDArray... operands) {
CudaContext context = (CudaContext) allocator.getDeviceContext().getContext();
int cId = allocator.getDeviceId();
val context = (CudaContext) allocator.getDeviceContext().getContext();
val cId = allocator.getDeviceId();
for (INDArray operand : operands) {
if (operand == null)
if (operand == null || operand.isEmpty())
continue;
Nd4j.getCompressor().autoDecompress(operand);
AllocationPoint pointData = allocator.getAllocationPoint(operand);
AllocationPoint pointShape = allocator.getAllocationPoint(operand.shapeInfoDataBuffer());
val pointData = allocator.getAllocationPoint(operand);
val pointShape = allocator.getAllocationPoint(operand.shapeInfoDataBuffer());
pointData.acquireLock();
@ -168,15 +168,15 @@ public class SynchronousFlowController implements FlowController {
@Override
public CudaContext prepareAction(INDArray result, INDArray... operands) {
CudaContext context = (CudaContext) allocator.getDeviceContext().getContext();
int cId = allocator.getDeviceId();
val context = (CudaContext) allocator.getDeviceContext().getContext();
val cId = allocator.getDeviceId();
if (result != null) {
if (result != null && !result.isEmpty()) {
Nd4j.getCompressor().autoDecompress(result);
prepareDelayedMemory(result);
AllocationPoint pointData = allocator.getAllocationPoint(result);
AllocationPoint pointShape = allocator.getAllocationPoint(result.shapeInfoDataBuffer());
val pointData = allocator.getAllocationPoint(result);
val pointShape = allocator.getAllocationPoint(result.shapeInfoDataBuffer());
pointData.acquireLock();
@ -196,13 +196,13 @@ public class SynchronousFlowController implements FlowController {
}
for (INDArray operand : operands) {
if (operand == null)
if (operand == null || operand.isEmpty())
continue;
Nd4j.getCompressor().autoDecompress(operand);
AllocationPoint pointData = allocator.getAllocationPoint(operand);
AllocationPoint pointShape = allocator.getAllocationPoint(operand.shapeInfoDataBuffer());
val pointData = allocator.getAllocationPoint(operand);
val pointShape = allocator.getAllocationPoint(operand.shapeInfoDataBuffer());
pointData.acquireLock();
@ -256,7 +256,7 @@ public class SynchronousFlowController implements FlowController {
if (operand == null)
continue;
AllocationPoint pointOperand = allocator.getAllocationPoint(operand);
val pointOperand = allocator.getAllocationPoint(operand);
pointOperand.tickDeviceWrite();
eventsProvider.storeEvent(pointOperand.getLastWriteEvent());
pointOperand.setLastWriteEvent(eventsProvider.getEvent());
@ -266,9 +266,10 @@ public class SynchronousFlowController implements FlowController {
}
public void registerAction(CudaContext context, INDArray result, INDArray... operands) {
if (result == null)
if (result == null || result.isEmpty())
return;
AllocationPoint point = allocator.getAllocationPoint(result);
val point = allocator.getAllocationPoint(result);
point.tickDeviceWrite();
eventsProvider.storeEvent(point.getLastWriteEvent());
point.setLastWriteEvent(eventsProvider.getEvent());
@ -276,10 +277,10 @@ public class SynchronousFlowController implements FlowController {
point.releaseLock();
for (INDArray operand : operands) {
if (operand == null)
if (operand == null || operand.isEmpty())
continue;
AllocationPoint pointOperand = allocator.getAllocationPoint(operand);
val pointOperand = allocator.getAllocationPoint(operand);
pointOperand.releaseLock();
eventsProvider.storeEvent(pointOperand.getLastReadEvent());
pointOperand.setLastReadEvent(eventsProvider.getEvent());
@ -289,7 +290,7 @@ public class SynchronousFlowController implements FlowController {
@Override
public CudaContext prepareAction(AllocationPoint result, AllocationPoint... operands) {
CudaContext context = (CudaContext) allocator.getDeviceContext().getContext();
val context = (CudaContext) allocator.getDeviceContext().getContext();
if (result != null) {
result.acquireLock();
@ -299,6 +300,7 @@ public class SynchronousFlowController implements FlowController {
for (AllocationPoint operand : operands) {
if (operand == null)
continue;
operand.acquireLock();
operand.setCurrentContext(context);
}
@ -313,15 +315,16 @@ public class SynchronousFlowController implements FlowController {
protected void prepareDelayedMemory(INDArray array) {
if (configuration.getMemoryModel() == Configuration.MemoryModel.DELAYED) {
AllocationPoint pointData = allocator.getAllocationPoint(array.shapeInfoDataBuffer());
AllocationPoint pointShape = allocator.getAllocationPoint(array.shapeInfoDataBuffer());
val pointData = allocator.getAllocationPoint(array.shapeInfoDataBuffer());
val pointShape = allocator.getAllocationPoint(array.shapeInfoDataBuffer());
if (pointData.getAllocationStatus() != AllocationStatus.DEVICE)
prepareDelayedMemory(array.data());
if (pointShape.getAllocationStatus() == AllocationStatus.HOST) {
DataBuffer oShape = array.shapeInfoDataBuffer();
DataBuffer nShape = Nd4j.getConstantHandler().relocateConstantSpace(oShape);
val oShape = array.shapeInfoDataBuffer();
val nShape = Nd4j.getConstantHandler().relocateConstantSpace(oShape);
if (nShape == oShape)
Nd4j.getConstantHandler().moveToConstantSpace(nShape);
((JCublasNDArray) array).setShapeInfoDataBuffer(nShape);

View File

@ -567,6 +567,11 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
return allocationPoint.getPointers().getHostPointer().address();
}
@Override
public long platformAddress() {
return allocationPoint.getPointers().getDevicePointer().address();
}
@Override
public Pointer pointer() {
// FIXME: very bad thing,

View File

@ -26,6 +26,7 @@ import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.indexer.LongIndexer;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
import org.nd4j.base.Preconditions;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.CudaPointer;
@ -515,7 +516,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
}
// in case of regular accumulation we don't care about array state before op
ret = Nd4j.createUninitialized(dtype, retShape);
ret = Nd4j.create(dtype, retShape);
}
op.setZ(ret);
} else {
@ -536,11 +537,16 @@ public class CudaExecutioner extends DefaultOpExecutioner {
@Override
public INDArray exec(IndexAccumulation op) {
val dimension = Shape.normalizeAxis(op.x().rank(), op.dimensions().toIntVector());
if (op.z() == null) {
long[] retShape = Shape.reductionShape(op.x(), dimension, true, op.isKeepDims());
INDArray ret = Nd4j.createUninitialized(DataType.LONG, retShape);
op.setZ(ret);
if (op.x().isEmpty()) {
for (val d:dimension) {
Preconditions.checkArgument(op.x().shape()[d] != 0, "IndexReduce can't be issued along axis with 0 in shape");
}
}
if (op.z() == null) {
val retShape = Shape.reductionShape(op.x(), dimension, true, op.isKeepDims());
op.setZ(Nd4j.createUninitialized(DataType.LONG, retShape));
}
long st = profilingConfigurableHookIn(op);
@ -556,10 +562,13 @@ public class CudaExecutioner extends DefaultOpExecutioner {
return op.x();
}
if (op.z().isEmpty())
return op.z();
if (CudaEnvironment.getInstance().getConfiguration().isDebug())
lastOp.set(op.opName());
CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
val context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
val hostXShapeInfo =
op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer());

View File

@ -619,9 +619,6 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory {
"Illegal concatenation at array " + i + " and shape element " + j);
}
}
//log.info("Shape[{}]: {}", i, Arrays.toString(toConcat[i].shapeInfoDataBuffer().asInt()));
}
if (allScalars) {
@ -630,8 +627,6 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory {
outputShape[dimension] = sumAlongDim;
}
//PointerPointer dummy = new PointerPointer(new Pointer[] {null});
INDArray ret = Nd4j.createUninitialized(toConcat[0].dataType(), outputShape, Nd4j.order());
nativeOps.concat(null, dimension, toConcat.length,
@ -639,11 +634,9 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory {
null, null,
ret.data().addressPointer(), (LongPointer) ret.shapeInfoDataBuffer().addressPointer(),
null, null,
//new PointerPointer(new Pointer[] {null}), new PointerPointer(new Pointer[] {null}));
null, null);
return ret;
// return super.concat(dimension,toConcat);
}

View File

@ -31,6 +31,7 @@ import org.nd4j.autodiff.validation.OpValidation;
import org.nd4j.autodiff.validation.TestCase;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.indexaccum.IAMax;
import org.nd4j.linalg.api.ops.impl.indexaccum.IAMin;
@ -1083,7 +1084,7 @@ public class ReductionOpValidation extends BaseOpValidation {
final INDArray exec = Nd4j.matmul(keys, query, true, false, false)
.divi(Math.sqrt(keys.size(1)));
Nd4j.exec(new SoftMax(exec, exec, 1));
Nd4j.exec((CustomOp) new SoftMax(exec, exec, 1));
final INDArray finalOut = Nd4j.matmul(values, exec).norm1();
SameDiff sd = SameDiff.create();
@ -1111,7 +1112,7 @@ public class ReductionOpValidation extends BaseOpValidation {
final INDArray exec = Nd4j.matmul(keys, query, true, false, false)
.divi(Math.sqrt(keys.size(1)));
exec.addi(mask.reshape(10, 3, 1).sub(1).muli(1e9));
Nd4j.exec(new SoftMax(exec, exec, 1));
Nd4j.exec((CustomOp) new SoftMax(exec, exec, 1));
final INDArray finalOut = Nd4j.matmul(values, exec).norm1();
SameDiff sd = SameDiff.create();
@ -1141,7 +1142,7 @@ public class ReductionOpValidation extends BaseOpValidation {
final INDArray exec = Nd4j.matmul(keys, query, true, false, false)
.divi(Math.sqrt(keys.size(-2)));
exec.addi(Nd4j.tile(mask.reshape(2, 1, 3, 1), 1, 5, 1, 2).sub(1).muli(1e9));
Nd4j.exec(new SoftMax(exec, exec, -2));
Nd4j.exec((CustomOp) new SoftMax(exec, exec, -2));
final INDArray finalOut = Nd4j.matmul(values, exec).norm1();
SameDiff sd = SameDiff.create();
@ -1169,7 +1170,7 @@ public class ReductionOpValidation extends BaseOpValidation {
final INDArray exec = Nd4j.matmul(keys, query, true, false, false)
.divi(Math.sqrt(keys.size(-2)));
Nd4j.exec(new SoftMax(exec, exec, -2));
Nd4j.exec((CustomOp) new SoftMax(exec, exec, -2));
final INDArray finalOut = Nd4j.matmul(values, exec).norm1();
SameDiff sd = SameDiff.create();
@ -1249,7 +1250,7 @@ public class ReductionOpValidation extends BaseOpValidation {
final INDArray exec = Nd4j.matmul(keys, query, true, false, false)
.divi(Math.sqrt(keys.size(1)));
exec.addi(mask.reshape(10, 3, 1).sub(1).muli(1e9));
Nd4j.exec(new SoftMax(exec, exec, 1));
Nd4j.exec((CustomOp) new SoftMax(exec, exec, 1));
final INDArray finalOut = Nd4j.matmul(values, exec).norm1();
for (char queryOrder : new char[]{'f', 'c'}) {

View File

@ -42,6 +42,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMax;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMin;
import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize;
import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt;
import org.nd4j.linalg.api.ops.impl.transforms.strict.*;
@ -671,7 +672,7 @@ public class TransformOpValidation extends BaseOpValidation {
//TODO SHOULDN'T THIS HAVE A DIMENSION ARG???
t = sd.nn().softmax(in);
ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut);
tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new OldSoftMax(ia.dup())));
tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new SoftMax(ia.dup()))[0]);
break;
case 24:
t = sd.math().sqrt(in);

View File

@ -25,7 +25,7 @@ import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.strict.OldSoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.checkutil.NDArrayCreationUtil;
@ -60,7 +60,7 @@ public class LoneTest extends BaseNd4jTest {
System.out.println("Input transpose " + Shape.shapeToString(input.shapeInfo()));
INDArray output = Nd4j.create(DataType.DOUBLE, 10, 1);
System.out.println("Element wise stride of output " + output.elementWiseStride());
Nd4j.getExecutioner().exec(new OldSoftMax(input, output));
Nd4j.getExecutioner().exec(new SoftMax(input, output));
}
@Override

View File

@ -41,9 +41,9 @@ import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.custom.Flatten;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.impl.broadcast.*;
@ -72,13 +72,13 @@ import org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.Eps;
import org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpace;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.BinaryRelativeError;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.Set;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.Axpy;
import org.nd4j.linalg.api.ops.impl.transforms.same.OldReverse;
import org.nd4j.linalg.api.ops.impl.transforms.same.Sign;
import org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh;
import org.nd4j.linalg.api.ops.impl.transforms.strict.OldSoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftMaxDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh;
import org.nd4j.linalg.api.shape.Shape;
@ -94,8 +94,6 @@ import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.profiler.OpProfiler;
import org.nd4j.linalg.profiler.ProfilerConfig;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.linalg.util.MathUtils;
@ -2919,7 +2917,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
System.out.println("Input transpose " + Shape.shapeToString(input.shapeInfo()));
INDArray output = Nd4j.create(10, 1);
System.out.println("Element wise stride of output " + output.elementWiseStride());
Nd4j.getExecutioner().exec(new OldSoftMax(input, output));
Nd4j.getExecutioner().exec(new SoftMax(input, output));
}
@Test
@ -3134,7 +3132,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
public void testSoftmaxRow() {
for (int i = 0; i < 20; i++) {
INDArray arr1 = Nd4j.zeros(1, 100);
Nd4j.getExecutioner().execAndReturn(new OldSoftMax(arr1));
Nd4j.getExecutioner().execAndReturn(new SoftMax(arr1));
System.out.println(Arrays.toString(arr1.data().asFloat()));
}
}
@ -3779,7 +3777,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
for (int i = 0; i < 3; i++) {
INDArray subset = result12.tensorAlongDimension(i, 1, 2);//result12.get(NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.all());
assertEquals("Failed for subset " + i, bc12, subset);
assertEquals("Failed for subset [" + i + "] orders [" + orderArr + "/" + orderbc + "]", bc12, subset);
}
}
}
@ -5725,9 +5723,9 @@ public class Nd4jTestsC extends BaseNd4jTest {
val reference = original.dup(original.ordering());
val expected = original.dup(original.ordering());
Nd4j.getExecutioner().execAndReturn(new OldSoftMax(expected));
Nd4j.getExecutioner().execAndReturn((CustomOp) new SoftMax(expected, expected, -1));
val result = Nd4j.getExecutioner().exec(new OldSoftMax(original, original.dup(original.ordering())));
val result = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(original, original.dup(original.ordering())))[0];
assertEquals(reference, original);
assertEquals(expected, result);

View File

@ -23,12 +23,12 @@ import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.impl.indexaccum.IMax;
import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt;
import org.nd4j.linalg.api.ops.impl.transforms.strict.OldLogSoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.strict.OldSoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftMaxDerivative;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
@ -158,9 +158,9 @@ public class CrashTest extends BaseNd4jTest {
// logisoftmax, softmax & softmax derivative
Nd4j.getExecutioner().exec(new OldSoftMax(x));
Nd4j.getExecutioner().exec(new SoftMaxDerivative(x));
Nd4j.getExecutioner().exec(new OldLogSoftMax(x));
Nd4j.getExecutioner().exec((CustomOp) new SoftMax(x));
Nd4j.getExecutioner().exec((CustomOp) new SoftMaxDerivative(x));
Nd4j.getExecutioner().exec((CustomOp) new LogSoftMax(x));
// BooleanIndexing

View File

@ -30,12 +30,13 @@ import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.memory.enums.MirroringPolicy;
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.impl.reduce.bool.IsInf;
import org.nd4j.linalg.api.ops.impl.reduce.bool.IsNaN;
import org.nd4j.linalg.api.ops.impl.reduce.longer.CountNonZero;
import org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldEqualTo;
import org.nd4j.linalg.api.ops.impl.transforms.strict.OldSoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
@ -334,7 +335,7 @@ public class MixedDataTypesTests extends BaseNd4jTest {
public void testTypesValidation_3() {
val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT);
val result = Nd4j.getExecutioner().exec(new OldSoftMax(arrayX));
val result = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(arrayX, arrayX, -1));
}
public void testTypesValidation_4() {

View File

@ -25,7 +25,9 @@ import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.impl.scalar.Step;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.*;
import org.nd4j.linalg.api.ops.impl.transforms.strict.*;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SigmoidDerivative;
@ -217,8 +219,8 @@ public class DerivativeTests extends BaseNd4jTest {
}
}
INDArray sm = Nd4j.getExecutioner().exec(new OldSoftMax(z.dup()));
INDArray zPrime = Nd4j.getExecutioner().exec(new SoftMaxDerivative(z));
INDArray sm = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(z.dup()))[0];
INDArray zPrime = Nd4j.getExecutioner().exec((CustomOp) new SoftMaxDerivative(z))[0];
System.out.println(Arrays.toString(sm.data().asDouble()));
System.out.println(Arrays.toString(zPrime.data().asDouble()));
assertNotEquals(sm, zPrime);
@ -396,7 +398,7 @@ public class DerivativeTests extends BaseNd4jTest {
//random array represeting preout
INDArray X = Nd4j.rand(1, 2);
//preout transformed to y_hat with softmax
INDArray YHat = Nd4j.getExecutioner().exec(new OldSoftMax(X.dup()));
INDArray YHat = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(X.dup()))[0];
//hard coding something to construct a function with, using MSE
INDArray Y = Nd4j.create(new double[][] {{0.123, 1 - 0.123}});
@ -404,7 +406,7 @@ public class DerivativeTests extends BaseNd4jTest {
//This is the MSE now
double lossHere = Transforms.pow(Y.sub(YHat), 2).sumNumber().doubleValue();
INDArray softmaxDer = Nd4j.getExecutioner().exec(new SoftMaxDerivative(X.dup()));
INDArray softmaxDer = Nd4j.getExecutioner().exec((CustomOp) new SoftMaxDerivative(X.dup()))[0];
//the way we apply the chain rule now is 2*(y-yhat)*softmaxder
INDArray dLdY = Y.sub(YHat).mul(-2);
@ -444,13 +446,13 @@ public class DerivativeTests extends BaseNd4jTest {
double x = X.getDouble(0, i);
Xiplus = X.dup();
Xiplus.put(0, i, x + epsilon);
YHatplus = Nd4j.getExecutioner().exec(new OldSoftMax(Xiplus.dup()));
YHatplus = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(Xiplus.dup()))[0];
lossplus = Transforms.pow(Y.sub(YHatplus), 2).sumNumber().doubleValue();
// -epsilon
Ximinus = X.dup();
Ximinus.put(0, i, x - epsilon);
YHatminus = Nd4j.getExecutioner().exec(new OldSoftMax(Ximinus.dup()));
YHatminus = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(Ximinus.dup()))[0];
lossminus = Transforms.pow(Y.sub(YHatminus), 2).sumNumber().doubleValue();
double gradienti = (lossplus - lossminus) / (2 * epsilon);
@ -479,16 +481,16 @@ public class DerivativeTests extends BaseNd4jTest {
INDArray X = Nd4j.rand(1, someLength);
//preout transformed to y_hat with softmax
INDArray YHat = Nd4j.getExecutioner().exec(new OldSoftMax(X.dup()));
INDArray YHat = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(X.dup()))[0];
//hard coding something to construct a function with, using MSE
INDArray temp = Nd4j.rand(1, someLength);
INDArray Y = Nd4j.getExecutioner().exec(new OldSoftMax(temp));
INDArray Y = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(temp))[0];
//This is the MSE now
double lossHere = Transforms.pow(Y.sub(YHat), 2).sumNumber().doubleValue();
INDArray softmaxDer = Nd4j.getExecutioner().exec(new SoftMaxDerivative(X.dup()));
INDArray softmaxDer = Nd4j.getExecutioner().exec((CustomOp) new SoftMaxDerivative(X.dup()))[0];
//the way we apply the chain rule now is 2*(y-yhat)*softmaxder
INDArray dLdY = Y.sub(YHat).mul(-2);
@ -511,13 +513,13 @@ public class DerivativeTests extends BaseNd4jTest {
double x = X.getDouble(0, i);
Xiplus = X.dup();
Xiplus.put(0, i, x + epsilon);
YHatplus = Nd4j.getExecutioner().exec(new OldSoftMax(Xiplus.dup()));
YHatplus = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(Xiplus.dup()))[0];
lossplus = Transforms.pow(Y.sub(YHatplus), 2).sumNumber().doubleValue();
// -epsilon
Ximinus = X.dup();
Ximinus.put(0, i, x - epsilon);
YHatminus = Nd4j.getExecutioner().exec(new OldSoftMax(Ximinus.dup()));
YHatminus = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(Ximinus.dup()))[0];
lossminus = Transforms.pow(Y.sub(YHatminus), 2).sumNumber().doubleValue();
double gradienti = (lossplus - lossminus) / (2 * epsilon);
@ -538,14 +540,14 @@ public class DerivativeTests extends BaseNd4jTest {
// this is only for X a row vector
// should return rank 2 matrix diagonal elements are pi*(1-pi)
//rest are -pi*pj
INDArray p = Nd4j.getExecutioner().exec(new OldSoftMax(X.dup()));
INDArray p = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(X.dup()))[0];
INDArray pCol = p.dup().transpose();
INDArray pipj = pCol.mmul(p);
pipj.muli(-1);
//so now pipj is correct except for the diagonal elements
// which by the way is what our current softmax der gives us
INDArray diagp = Nd4j.getExecutioner().exec(new SoftMaxDerivative(X.dup()));
INDArray diagp = Nd4j.getExecutioner().exec((CustomOp) new SoftMaxDerivative(X.dup()))[0];
//ugly for loop to correct diag elements

View File

@ -24,6 +24,7 @@ import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.impl.indexaccum.IAMax;
import org.nd4j.linalg.api.ops.impl.indexaccum.IMax;
@ -42,11 +43,11 @@ import org.nd4j.linalg.api.ops.impl.scalar.ScalarMax;
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThan;
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThan;
import org.nd4j.linalg.api.ops.impl.summarystats.Variance;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Exp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Log;
import org.nd4j.linalg.api.ops.impl.transforms.strict.OldSoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SetRange;
import org.nd4j.linalg.api.ops.random.impl.DropOut;
import org.nd4j.linalg.api.ops.random.impl.DropOutInverted;
@ -304,11 +305,11 @@ public class OpExecutionerTests extends BaseNd4jTest {
@Test
public void testRowSoftmax() {
OpExecutioner opExecutioner = Nd4j.getExecutioner();
INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1);
OldSoftMax softMax = new OldSoftMax(arr);
opExecutioner.exec(softMax);
assertEquals(getFailureMessage(), 1.0, softMax.z().sumNumber().doubleValue(), 1e-1);
val opExecutioner = Nd4j.getExecutioner();
val arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1);
val softMax = new SoftMax(arr);
opExecutioner.exec((CustomOp) softMax);
assertEquals(getFailureMessage(), 1.0, softMax.outputArguments()[0].sumNumber().doubleValue(), 1e-1);
}
@ -373,7 +374,7 @@ public class OpExecutionerTests extends BaseNd4jTest {
public void testSoftmax() {
INDArray vec = Nd4j.linspace(1, 6, 6, DataType.DOUBLE);
INDArray matrix = vec.dup().reshape('f', 2, 3);
Nd4j.getExecutioner().exec(new OldSoftMax(matrix));
Nd4j.getExecutioner().exec((CustomOp) new SoftMax(matrix));
INDArray matrixAssertion = Nd4j.create(
new double[] {0.015876241, 0.015876241, 0.11731043, 0.11731043, 0.86681336, 0.86681336},
new int[] {2, 3}, 'f');
@ -384,7 +385,7 @@ public class OpExecutionerTests extends BaseNd4jTest {
public void testOtherSoftmax() {
INDArray vec = Nd4j.linspace(1, 18, 18, DataType.DOUBLE);
INDArray matrix = vec.dup().reshape('f', 3, 6);
Nd4j.getExecutioner().exec(new OldSoftMax(matrix));
Nd4j.getExecutioner().exec((CustomOp) new SoftMax(matrix));
INDArray assertion = Nd4j.create(new double[] {2.9067235E-7, 2.9067235E-7, 2.9067235E-7, 5.8383102E-6,
5.8383102E-6, 5.8383102E-6, 1.1726559E-4, 1.1726559E-4, 1.1726559E-4, 0.0023553425,
0.0023553425, 0.0023553425, 0.047308315, 0.047308315, 0.047308315, 0.95021296, 0.95021296,
@ -517,9 +518,9 @@ public class OpExecutionerTests extends BaseNd4jTest {
0.3049033, 0.29277474, 0.29136384, 0.30316526, 0.2807459}, new int[] {150, 3}, 'f');
System.out.println("Data:" + input.data().length());
OldSoftMax softMax = new OldSoftMax(input);
Nd4j.getExecutioner().exec(softMax);
assertEquals(assertion, softMax.z());
val softMax = new SoftMax(input);
Nd4j.getExecutioner().exec((CustomOp) softMax);
assertEquals(assertion, softMax.outputArguments()[0]);
}
@ -557,9 +558,9 @@ public class OpExecutionerTests extends BaseNd4jTest {
public void testSoftMax() {
OpExecutioner opExecutioner = Nd4j.getExecutioner();
INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1);
OldSoftMax softMax = new OldSoftMax(arr);
opExecutioner.exec(softMax);
assertEquals(getFailureMessage(), 1.0, softMax.z().sumNumber().doubleValue(), 1e-1);
val softMax = new SoftMax(arr);
opExecutioner.exec((CustomOp) softMax);
assertEquals(getFailureMessage(), 1.0, softMax.outputArguments()[0].sumNumber().doubleValue(), 1e-1);
}
@Test

View File

@ -29,6 +29,7 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
import org.nd4j.linalg.api.ops.impl.indexaccum.IMax;
@ -50,12 +51,12 @@ import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThan;
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThan;
import org.nd4j.linalg.api.ops.impl.summarystats.Variance;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.floating.Histogram;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Exp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Log;
import org.nd4j.linalg.api.ops.impl.transforms.strict.OldSoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SetRange;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
@ -96,9 +97,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest {
public void testSoftmaxReference() {
INDArray input = Nd4j.linspace(1,4,4, DataType.FLOAT).reshape(2,2);
INDArray dup = input.dup();
Nd4j.getExecutioner().exec(new OldSoftMax(dup));
Nd4j.getExecutioner().exec((CustomOp) new SoftMax(dup));
INDArray result = Nd4j.zeros(DataType.FLOAT, 2,2);
Nd4j.getExecutioner().exec(new OldSoftMax(input,result));
Nd4j.getExecutioner().exec((CustomOp) new SoftMax(input,result));
assertEquals(dup,result);
@ -322,9 +323,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest {
public void testRowSoftmax() {
OpExecutioner opExecutioner = Nd4j.getExecutioner();
INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1);
OldSoftMax softMax = new OldSoftMax(arr);
opExecutioner.exec(softMax);
assertEquals(getFailureMessage(), 1.0, softMax.z().sumNumber().doubleValue(), 1e-1);
val softMax = new SoftMax(arr);
opExecutioner.exec((CustomOp) softMax);
assertEquals(getFailureMessage(), 1.0, softMax.outputArguments()[0].sumNumber().doubleValue(), 1e-1);
}
@Test
@ -422,23 +423,23 @@ public class OpExecutionerTestsC extends BaseNd4jTest {
public void testSoftMax() {
OpExecutioner opExecutioner = Nd4j.getExecutioner();
INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1);
OldSoftMax softMax = new OldSoftMax(arr);
opExecutioner.exec(softMax);
assertEquals(getFailureMessage(), 1.0, softMax.z().sumNumber().doubleValue(), 1e-1);
val softMax = new SoftMax(arr);
opExecutioner.exec((CustomOp) softMax);
assertEquals(getFailureMessage(), 1.0, softMax.outputArguments()[0].sumNumber().doubleValue(), 1e-1);
INDArray linspace = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3);
OldSoftMax softmax = new OldSoftMax(linspace.dup());
Nd4j.getExecutioner().exec(softmax);
assertEquals(linspace.rows(), softmax.z().sumNumber().doubleValue(), 1e-1);
val softmax = new SoftMax(linspace.dup());
Nd4j.getExecutioner().exec((CustomOp) softmax);
assertEquals(linspace.rows(), softmax.outputArguments()[0].sumNumber().doubleValue(), 1e-1);
}
@Test
public void testDimensionSoftMax() {
INDArray linspace = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3);
OldSoftMax max = new OldSoftMax(linspace);
Nd4j.getExecutioner().exec(max);
linspace.assign(max.z());
val max = new SoftMax(linspace);
Nd4j.getExecutioner().exec((CustomOp) max);
linspace.assign(max.outputArguments()[0]);
assertEquals(getFailureMessage(), linspace.getRow(0).sumNumber().doubleValue(), 1.0, 1e-1);
}
@ -782,7 +783,7 @@ public class OpExecutionerTestsC extends BaseNd4jTest {
public void testSoftmax() {
INDArray vec = Nd4j.linspace(1, 18, 18, DataType.DOUBLE);
INDArray matrix = vec.dup().reshape(3, 6);
Nd4j.getExecutioner().exec(new OldSoftMax(matrix));
Nd4j.getExecutioner().exec((CustomOp) new SoftMax(matrix));
INDArray assertion = Nd4j.create(
new double[] {0.0042697787, 0.011606461, 0.031549633, 0.085760795, 0.23312202, 0.6336913,
0.0042697787, 0.011606461, 0.031549633, 0.085760795, 0.23312202, 0.6336913,

View File

@ -381,30 +381,32 @@ public class OperationProfilerTests extends BaseNd4jTest {
INDArray x = Nd4j.create(1000, 1000).assign(1.0);
INDArray y = Nd4j.create(1000, 1000).assign(1.0);
for (int e = 0; e < 10000; e++) {
int iterations = 100;
for (int e = 0; e < iterations; e++) {
x.addi(y);
}
Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.SCOPE_PANIC);
val nanosC = System.nanoTime();
for (int e = 0; e < 10000; e++) {
for (int e = 0; e < iterations; e++) {
x.addi(y);
}
val nanosD = System.nanoTime();
val avgB = (nanosD - nanosC) / 10000;
val avgB = (nanosD - nanosC) / iterations;
Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.DISABLED);
val nanosA = System.nanoTime();
for (int e = 0; e < 10000; e++) {
for (int e = 0; e < iterations; e++) {
x.addi(y);
}
val nanosB = System.nanoTime();
val avgA = (nanosB - nanosA) / 10000;
val avgA = (nanosB - nanosA) / iterations;
log.info("A: {}; B: {}", avgA, avgB);

View File

@ -1429,15 +1429,17 @@ public class RandomTests extends BaseNd4jTest {
@Test
public void testRngRepeatabilityUniform(){
val nexp = Nd4j.create(DataType.FLOAT, 10);
Nd4j.getRandom().setSeed(12345);
INDArray out1 = Nd4j.create(DataType.FLOAT, 10);
val out1 = Nd4j.create(DataType.FLOAT, 10);
Nd4j.exec(new DistributionUniform(Nd4j.createFromArray(10L), out1, 0.0, 1.0));
Nd4j.getRandom().setSeed(12345);
INDArray out2 = Nd4j.create(DataType.FLOAT, 10);
val out2 = Nd4j.create(DataType.FLOAT, 10);
Nd4j.exec(new DistributionUniform(Nd4j.createFromArray(10L), out2, 0.0, 1.0));
assertEquals(out1, out2);
assertNotEquals(nexp, out1);
}
@Test

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.shape;
import lombok.val;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@ -48,18 +49,17 @@ public class ShapeBufferTests extends BaseNd4jTest {
@Test
public void testRank() {
int[] shape = {2, 4};
int[] stride = {1, 2};
IntBuffer buff = Shape.createShapeInformation(shape, stride, 0, 1, 'c').asNioInt();
int rank = 2;
assertEquals(rank, Shape.rank(buff));
long[] shape = {2, 4};
long[] stride = {1, 2};
val shapeInfoBuffer = Shape.createShapeInformation(shape, stride, 1, 'c', DataType.DOUBLE, false);
val buff = shapeInfoBuffer.asNioLong();
assertEquals(2, Shape.rank(buff));
}
@Test
public void testArrCreationShape() {
INDArray arr = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
val arr = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
for (int i = 0; i < 2; i++)
assertEquals(2, arr.size(i));
int[] stride = ArrayUtil.calcStrides(new int[] {2, 2});
@ -70,12 +70,13 @@ public class ShapeBufferTests extends BaseNd4jTest {
@Test
public void testShape() {
int[] shape = {2, 4};
int[] stride = {1, 2};
IntBuffer buff = Shape.createShapeInformation(shape, stride, 0, 1, 'c').asNioInt();
IntBuffer shapeView = Shape.shapeOf(buff);
long[] shape = {2, 4};
long[] stride = {1, 2};
val shapeInfoBuffer = Shape.createShapeInformation(shape, stride, 1, 'c', DataType.DOUBLE, false);
val buff = shapeInfoBuffer.asNioLong();
val shapeView = Shape.shapeOf(buff);
assertTrue(Shape.contentEquals(shape, shapeView));
IntBuffer strideView = Shape.stride(buff);
val strideView = Shape.stride(buff);
assertTrue(Shape.contentEquals(stride, strideView));
assertEquals('c', Shape.order(buff));
assertEquals(1, Shape.elementWiseStride(buff));
@ -86,9 +87,9 @@ public class ShapeBufferTests extends BaseNd4jTest {
@Test
public void testBuff() {
int[] shape = {1, 2};
int[] stride = {1, 2};
IntBuffer buff = Shape.createShapeInformation(shape, stride, 0, 1, 'c').asNioInt();
long[] shape = {1, 2};
long[] stride = {1, 2};
val buff = Shape.createShapeInformation(shape, stride, 1, 'c', DataType.DOUBLE, false).asNioLong();
assertTrue(Shape.isVector(buff));
}

View File

@ -2661,4 +2661,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
this.indexer = null;
this.pointer = null;
}
@Override
public long platformAddress() {
return address();
}
}

View File

@ -84,6 +84,14 @@ public interface DataBuffer extends Serializable, AutoCloseable {
*/
long address();
/**
* Returns the address of platform-specific pointer:
* - for native backend that'll be host pointer
* - for cuda backend that'll be device pointer
* @return
*/
long platformAddress();
/**
* Returns true if the underlying data source
* is the same for both buffers (referential equals)