[WIP] CUDA tweaks (#60)
* special cpu concat Signed-off-by: raver119 <raver119@gmail.com> * special concat fix Signed-off-by: raver119 <raver119@gmail.com> * OpProfiler tweak for absent host pointers Signed-off-by: raver119 <raver119@gmail.com> * minor test tweak to see orders Signed-off-by: raver119 <raver119@gmail.com> * CUDA broadcasting diff orders fix Signed-off-by: raver119 <raver119@gmail.com> * faster iterations Signed-off-by: raver119 <raver119@gmail.com> * OldSoftMax/OldLogSoftMax gone Signed-off-by: raver119 <raver119@gmail.com> * RandomLauncher tweaks Signed-off-by: raver119 <raver119@gmail.com> * additional check int randomtests Signed-off-by: raver119 <raver119@gmail.com> * skip prepare/register action for empty arrays Signed-off-by: raver119 <raver119@gmail.com> * npz float16 fix Signed-off-by: raver119 <raver119@gmail.com> * empty reduction cuda fixes Signed-off-by: raver119 <raver119@gmail.com> * ShapeBufferTests tweaks Signed-off-by: raver119 <raver119@gmail.com>master
parent
6ce458e949
commit
9cf28ea6c9
|
@ -21,26 +21,27 @@
|
||||||
#include <NDArray.h>
|
#include <NDArray.h>
|
||||||
#include <helpers/helper_random.h>
|
#include <helpers/helper_random.h>
|
||||||
#include <graph/RandomGenerator.h>
|
#include <graph/RandomGenerator.h>
|
||||||
|
#include <execution/LaunchContext.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
class RandomLauncher {
|
class RandomLauncher {
|
||||||
public:
|
public:
|
||||||
static void applyDropOut(nd4j::graph::RandomGenerator& rng, NDArray *array, double retainProb, NDArray* z = nullptr);
|
static void applyDropOut(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray *array, double retainProb, NDArray* z = nullptr);
|
||||||
static void applyInvertedDropOut(nd4j::graph::RandomGenerator& rng, NDArray *array, double retainProb, NDArray* z = nullptr);
|
static void applyInvertedDropOut(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray *array, double retainProb, NDArray* z = nullptr);
|
||||||
static void applyAlphaDropOut(nd4j::graph::RandomGenerator& rng, NDArray *array, double retainProb, double alpha, double beta, double alphaPrime, NDArray* z = nullptr);
|
static void applyAlphaDropOut(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray *array, double retainProb, double alpha, double beta, double alphaPrime, NDArray* z = nullptr);
|
||||||
|
|
||||||
static void fillUniform(nd4j::graph::RandomGenerator& rng, NDArray* array, double from, double to);
|
static void fillUniform(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray* array, double from, double to);
|
||||||
|
|
||||||
static void fillGaussian(nd4j::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev);
|
static void fillGaussian(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev);
|
||||||
|
|
||||||
static void fillExponential(nd4j::graph::RandomGenerator& rng, NDArray* array, double lambda);
|
static void fillExponential(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray* array, double lambda);
|
||||||
|
|
||||||
static void fillLogNormal(nd4j::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev);
|
static void fillLogNormal(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev);
|
||||||
|
|
||||||
static void fillTruncatedNormal(nd4j::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev);
|
static void fillTruncatedNormal(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev);
|
||||||
|
|
||||||
static void fillBinomial(nd4j::graph::RandomGenerator& rng, NDArray* array, int trials, double prob);
|
static void fillBinomial(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray* array, int trials, double prob);
|
||||||
|
|
||||||
static void fillBernoulli(nd4j::graph::RandomGenerator& rng, NDArray* array, double prob);
|
static void fillBernoulli(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray* array, double prob);
|
||||||
};
|
};
|
||||||
}
|
}
|
|
@ -23,76 +23,97 @@
|
||||||
#include <helpers/RandomLauncher.h>
|
#include <helpers/RandomLauncher.h>
|
||||||
#include <graph/RandomGenerator.h>
|
#include <graph/RandomGenerator.h>
|
||||||
#include <ops/declarable/CustomOperations.h>
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <helpers/PointersManager.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
// FIXME: implement this
|
// FIXME: implement this
|
||||||
|
|
||||||
void RandomLauncher::applyDropOut(nd4j::graph::RandomGenerator& rng, NDArray *array, double retainProb, NDArray* z) {
|
void RandomLauncher::applyDropOut(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray *array, double retainProb, NDArray* z) {
|
||||||
if (z == nullptr)
|
if (z == nullptr)
|
||||||
z = array;
|
z = array;
|
||||||
|
|
||||||
ExtraArguments arguments({retainProb});
|
ExtraArguments arguments({retainProb});
|
||||||
|
PointersManager pm(context, "applyDropOut");
|
||||||
|
|
||||||
NativeOpExecutioner::execRandom(nullptr, random::DropOut, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), arguments.argumentsAsT(z->dataType()));
|
NativeOpExecutioner::execRandom(context, random::DropOut, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), arguments.argumentsAsT(z->dataType()));
|
||||||
|
pm.synchronize();
|
||||||
}
|
}
|
||||||
|
|
||||||
void RandomLauncher::applyInvertedDropOut(nd4j::graph::RandomGenerator& rng, NDArray *array, double retainProb, NDArray* z) {
|
void RandomLauncher::applyInvertedDropOut(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray *array, double retainProb, NDArray* z) {
|
||||||
if (z == nullptr)
|
if (z == nullptr)
|
||||||
z = array;
|
z = array;
|
||||||
|
|
||||||
ExtraArguments arguments({retainProb});
|
ExtraArguments arguments({retainProb});
|
||||||
|
PointersManager pm(context, "applyInvertedDropOut");
|
||||||
|
|
||||||
NativeOpExecutioner::execRandom(nullptr, random::DropOutInverted, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), arguments.argumentsAsT(z->dataType()));
|
NativeOpExecutioner::execRandom(context, random::DropOutInverted, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), arguments.argumentsAsT(z->dataType()));
|
||||||
|
pm.synchronize();
|
||||||
}
|
}
|
||||||
|
|
||||||
void RandomLauncher::applyAlphaDropOut(nd4j::graph::RandomGenerator& rng, NDArray *array, double retainProb, double alpha, double beta, double alphaPrime, NDArray* z) {
|
void RandomLauncher::applyAlphaDropOut(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray *array, double retainProb, double alpha, double beta, double alphaPrime, NDArray* z) {
|
||||||
if (z == nullptr)
|
if (z == nullptr)
|
||||||
z = array;
|
z = array;
|
||||||
|
|
||||||
ExtraArguments arguments({retainProb, alpha, beta, alphaPrime});
|
ExtraArguments arguments({retainProb, alpha, beta, alphaPrime});
|
||||||
|
PointersManager pm(context, "applyAlphaDropOut");
|
||||||
|
|
||||||
NativeOpExecutioner::execRandom(nullptr, random::AlphaDropOut, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), arguments.argumentsAsT(z->dataType()));
|
NativeOpExecutioner::execRandom(context, random::AlphaDropOut, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), arguments.argumentsAsT(z->dataType()));
|
||||||
|
pm.synchronize();
|
||||||
}
|
}
|
||||||
|
|
||||||
void RandomLauncher::fillBernoulli(nd4j::graph::RandomGenerator& rng, NDArray* array, double prob) {
|
void RandomLauncher::fillBernoulli(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray* array, double prob) {
|
||||||
ExtraArguments arguments({prob});
|
ExtraArguments arguments({prob});
|
||||||
|
PointersManager pm(context, "fillBernoulli");
|
||||||
|
|
||||||
NativeOpExecutioner::execRandom(nullptr, random::BernoulliDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType()));
|
NativeOpExecutioner::execRandom(context, random::BernoulliDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType()));
|
||||||
|
pm.synchronize();
|
||||||
}
|
}
|
||||||
|
|
||||||
void RandomLauncher::fillUniform(nd4j::graph::RandomGenerator& rng, NDArray* array, double from, double to) {
|
void RandomLauncher::fillUniform(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray* array, double from, double to) {
|
||||||
ExtraArguments arguments({from, to});
|
ExtraArguments arguments({from, to});
|
||||||
|
PointersManager pm(context, "fillUniform");
|
||||||
|
|
||||||
NativeOpExecutioner::execRandom(nullptr, random::UniformDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType()));
|
NativeOpExecutioner::execRandom(context, random::UniformDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType()));
|
||||||
|
pm.synchronize();
|
||||||
}
|
}
|
||||||
|
|
||||||
void RandomLauncher::fillGaussian(nd4j::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev) {
|
void RandomLauncher::fillGaussian(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev) {
|
||||||
ExtraArguments arguments({mean, stdev});
|
ExtraArguments arguments({mean, stdev});
|
||||||
|
PointersManager pm(context, "fillGaussian");
|
||||||
|
|
||||||
NativeOpExecutioner::execRandom(nullptr, random::GaussianDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType()));
|
NativeOpExecutioner::execRandom(context, random::GaussianDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType()));
|
||||||
|
pm.synchronize();
|
||||||
}
|
}
|
||||||
|
|
||||||
void RandomLauncher::fillExponential(nd4j::graph::RandomGenerator& rng, NDArray* array, double lambda) {
|
void RandomLauncher::fillExponential(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray* array, double lambda) {
|
||||||
ExtraArguments arguments({lambda});
|
ExtraArguments arguments({lambda});
|
||||||
|
PointersManager pm(context, "fillExponential");
|
||||||
|
|
||||||
NativeOpExecutioner::execRandom(nullptr, random::ExponentialDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType()));
|
NativeOpExecutioner::execRandom(context, random::ExponentialDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType()));
|
||||||
|
pm.synchronize();
|
||||||
}
|
}
|
||||||
|
|
||||||
void RandomLauncher::fillLogNormal(nd4j::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev) {
|
void RandomLauncher::fillLogNormal(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev) {
|
||||||
ExtraArguments arguments({mean, stdev});
|
ExtraArguments arguments({mean, stdev});
|
||||||
|
PointersManager pm(context, "fillLogNormal");
|
||||||
|
|
||||||
NativeOpExecutioner::execRandom(nullptr, random::GaussianDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType()));
|
NativeOpExecutioner::execRandom(context, random::GaussianDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType()));
|
||||||
|
pm.synchronize();
|
||||||
}
|
}
|
||||||
|
|
||||||
void RandomLauncher::fillTruncatedNormal(nd4j::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev) {
|
void RandomLauncher::fillTruncatedNormal(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev) {
|
||||||
ExtraArguments arguments({mean, stdev});
|
ExtraArguments arguments({mean, stdev});
|
||||||
|
PointersManager pm(context, "fillTruncatedNormal");
|
||||||
|
|
||||||
NativeOpExecutioner::execRandom(nullptr, random::TruncatedNormalDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType()));
|
NativeOpExecutioner::execRandom(context, random::TruncatedNormalDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType()));
|
||||||
|
pm.synchronize();
|
||||||
}
|
}
|
||||||
|
|
||||||
void RandomLauncher::fillBinomial(nd4j::graph::RandomGenerator& rng, NDArray* array, int trials, double prob) {
|
void RandomLauncher::fillBinomial(nd4j::LaunchContext *context, nd4j::graph::RandomGenerator& rng, NDArray* array, int trials, double prob) {
|
||||||
ExtraArguments arguments({(double) trials, prob});
|
ExtraArguments arguments({(double) trials, prob});
|
||||||
|
PointersManager pm(context, "fillBinomial");
|
||||||
|
|
||||||
NativeOpExecutioner::execRandom(nullptr, random::BinomialDistributionEx, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType()));
|
NativeOpExecutioner::execRandom(context, random::BinomialDistributionEx, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType()));
|
||||||
|
pm.synchronize();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -128,6 +128,10 @@ namespace functions {
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
|
auto xOrder = shape::order(xShapeInfo);
|
||||||
|
auto yOrder = shape::order(tadOnlyShapeInfo);
|
||||||
|
auto zOrder = shape::order(tadOnlyShapeInfoZ);
|
||||||
|
|
||||||
for (int r = blockIdx.x; r < numTads; r += gridDim.x) {
|
for (int r = blockIdx.x; r < numTads; r += gridDim.x) {
|
||||||
|
|
||||||
|
|
||||||
|
@ -135,7 +139,7 @@ namespace functions {
|
||||||
auto rZ = z + tadOffsetsZ[r];
|
auto rZ = z + tadOffsetsZ[r];
|
||||||
|
|
||||||
|
|
||||||
if(tadEWS > 0 && zEWS > 0 && xEWS > 0 && dimensionLength == 1) {
|
if(tadEWS > 0 && zEWS > 0 && xEWS > 0 && dimensionLength == 1 && xOrder == yOrder && xOrder == zOrder) {
|
||||||
for (int i = threadIdx.x; i < tadLength; i+= blockDim.x)
|
for (int i = threadIdx.x; i < tadLength; i+= blockDim.x)
|
||||||
rZ[i * zEWS] = OpType::op(x[i * xEWS], rY[i * tadEWS]);
|
rZ[i * zEWS] = OpType::op(x[i * xEWS], rY[i * tadEWS]);
|
||||||
}
|
}
|
||||||
|
@ -190,6 +194,9 @@ namespace functions {
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
|
auto xOrder = shape::order(tadOnlyShapeInfo);
|
||||||
|
auto yOrder = shape::order(yShapeInfo);
|
||||||
|
auto zOrder = shape::order(tadOnlyShapeInfoZ);
|
||||||
|
|
||||||
for (int r = blockIdx.x; r < numTads; r += gridDim.x) {
|
for (int r = blockIdx.x; r < numTads; r += gridDim.x) {
|
||||||
|
|
||||||
|
@ -197,7 +204,7 @@ namespace functions {
|
||||||
auto rZ = z + tadOffsetsZ[r];
|
auto rZ = z + tadOffsetsZ[r];
|
||||||
|
|
||||||
|
|
||||||
if(tadEWS > 0 && zEWS > 0 && yEWS > 0) {
|
if(tadEWS > 0 && zEWS > 0 && yEWS > 0 && xOrder == yOrder && xOrder == zOrder) {
|
||||||
for (int i = threadIdx.x; i < tadLength; i+= blockDim.x)
|
for (int i = threadIdx.x; i < tadLength; i+= blockDim.x)
|
||||||
rZ[i * zEWS] = OpType::op(rX[i * tadEWS], y[i * yEWS]);
|
rZ[i * zEWS] = OpType::op(rX[i * tadEWS], y[i * yEWS]);
|
||||||
}
|
}
|
||||||
|
|
|
@ -44,7 +44,7 @@ namespace nd4j {
|
||||||
auto z = OUTPUT_VARIABLE(0);
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
auto f = T_ARG(0);
|
auto f = T_ARG(0);
|
||||||
|
|
||||||
RandomLauncher::fillBernoulli(rng, z, f);
|
RandomLauncher::fillBernoulli(block.launchContext(), rng, z, f);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -53,7 +53,7 @@ namespace nd4j {
|
||||||
auto z = OUTPUT_VARIABLE(0);
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
auto lambda = T_ARG(0);
|
auto lambda = T_ARG(0);
|
||||||
|
|
||||||
RandomLauncher::fillExponential(rng, z, lambda);
|
RandomLauncher::fillExponential(block.launchContext(), rng, z, lambda);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -39,7 +39,7 @@ namespace nd4j {
|
||||||
functions::random::RandomFunction<T>::template execTransform<randomOps::GaussianDistribution<T>>(block.getRNG(), z->getBuffer(), z->getShapeInfo(), z->getBuffer(), z->getShapeInfo(), z->getBuffer(), z->getShapeInfo(), block.getTArguments()->data());
|
functions::random::RandomFunction<T>::template execTransform<randomOps::GaussianDistribution<T>>(block.getRNG(), z->getBuffer(), z->getShapeInfo(), z->getBuffer(), z->getShapeInfo(), z->getBuffer(), z->getShapeInfo(), block.getTArguments()->data());
|
||||||
*/
|
*/
|
||||||
|
|
||||||
RandomLauncher::fillGaussian(rng, OUTPUT_VARIABLE(0), T_ARG(0), T_ARG(1));
|
RandomLauncher::fillGaussian(block.launchContext(), rng, OUTPUT_VARIABLE(0), T_ARG(0), T_ARG(1));
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -53,7 +53,7 @@ namespace nd4j {
|
||||||
*/
|
*/
|
||||||
REQUIRE_TRUE(block.numT() > 1, 0, "RandomUniform: to/from must be set");
|
REQUIRE_TRUE(block.numT() > 1, 0, "RandomUniform: to/from must be set");
|
||||||
|
|
||||||
RandomLauncher::fillUniform(rng, OUTPUT_VARIABLE(0), T_ARG(0), T_ARG(1));
|
RandomLauncher::fillUniform(block.launchContext(), rng, OUTPUT_VARIABLE(0), T_ARG(0), T_ARG(1));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1203,71 +1203,7 @@ static void mirrorPad_(const NDArray& input, const NDArray& paddings, NDArray& o
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
template<typename T>
|
template<typename T>
|
||||||
static void concat_(const std::vector<NDArray*>& inArrs, NDArray& output, const int axis) {
|
static void concat_(const std::vector<NDArray*>& inArrs, NDArray& output, const int axis) {
|
||||||
|
nd4j::SpecialMethods<T>::concatCpuGeneric(inArrs, output, axis);
|
||||||
const uint numOfArrs = inArrs.size();
|
|
||||||
|
|
||||||
int outDim;
|
|
||||||
const bool isOutputVector = output.isCommonVector(outDim);
|
|
||||||
|
|
||||||
if(isOutputVector || (axis == 0 && output.ordering() == 'c')) {
|
|
||||||
|
|
||||||
bool allVectorsOrScalars = true;
|
|
||||||
const uint outEws = isOutputVector ? output.stridesOf()[outDim] : output.ews();
|
|
||||||
|
|
||||||
std::vector<int> nonUnityDim(numOfArrs);
|
|
||||||
std::vector<Nd4jLong> zOffset(numOfArrs);
|
|
||||||
|
|
||||||
for(int i = 0; i < numOfArrs; i++) {
|
|
||||||
allVectorsOrScalars &= (inArrs[i]->lengthOf() == 1 || inArrs[i]->isCommonVector(nonUnityDim[i]));
|
|
||||||
if(!allVectorsOrScalars)
|
|
||||||
break;
|
|
||||||
if(i == 0) zOffset[0] = 0;
|
|
||||||
else zOffset[i] = zOffset[i - 1] + outEws * inArrs[i - 1]->lengthOf();
|
|
||||||
}
|
|
||||||
|
|
||||||
if(allVectorsOrScalars) {
|
|
||||||
|
|
||||||
T* outBuff = output.bufferAsT<T>();
|
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
|
||||||
for (uint r = 0; r < numOfArrs; r++) {
|
|
||||||
|
|
||||||
const uint arrLen = inArrs[r]->lengthOf();
|
|
||||||
const uint xEws = (arrLen == 1) ? 1 : inArrs[r]->stridesOf()[nonUnityDim[r]];
|
|
||||||
|
|
||||||
T *z = outBuff + zOffset[r];
|
|
||||||
T *x = inArrs[r]->bufferAsT<T>();
|
|
||||||
|
|
||||||
if(outEws == 1 && xEws == 1)
|
|
||||||
for (uint e = 0; e < arrLen; e++)
|
|
||||||
z[e] = x[e];
|
|
||||||
else
|
|
||||||
for (uint e = 0; e < arrLen; e++)
|
|
||||||
z[e * outEws] = x[e * xEws];
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const int rank = inArrs[0]->rankOf();
|
|
||||||
const int rank2 = 2*rank;
|
|
||||||
std::vector<std::vector<Nd4jLong>> indices(numOfArrs, std::vector<Nd4jLong>(rank2,0));
|
|
||||||
|
|
||||||
// take into account indices for first array
|
|
||||||
indices[0][2 * axis + 1] = inArrs[0]->sizeAt(axis);
|
|
||||||
|
|
||||||
// loop through the rest of input arrays
|
|
||||||
for(int i = 1; i < numOfArrs; ++i) {
|
|
||||||
indices[i][2 * axis] = indices[i-1][2 * axis + 1]; // index start from
|
|
||||||
indices[i][2 * axis + 1] = indices[i-1][2 * axis + 1] + inArrs[i]->sizeAt(axis); // index end with (excluding)
|
|
||||||
}
|
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
|
||||||
for(int i = 0; i < numOfArrs; ++i) {
|
|
||||||
auto temp = output(indices[i], true);
|
|
||||||
nd4j::TransformLoops<T,T,T>::template loopTransform<simdOps::Assign<T,T>, false>(inArrs[i]->bufferAsT<T>(), inArrs[i]->getShapeInfo(), temp.bufferAsT<T>(), temp.getShapeInfo(), nullptr);
|
|
||||||
// temp.assign(inArrs[i]);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void concat(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output, const int axis) {
|
void concat(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output, const int axis) {
|
||||||
|
|
|
@ -81,7 +81,7 @@ namespace nd4j {
|
||||||
|
|
||||||
auto z = OUTPUT_VARIABLE(0); //NDArrayFactory::create_<T>('c', shape, block.getWorkspace());
|
auto z = OUTPUT_VARIABLE(0); //NDArrayFactory::create_<T>('c', shape, block.getWorkspace());
|
||||||
|
|
||||||
RandomLauncher::fillUniform(block.randomGenerator(), z, from, to);
|
RandomLauncher::fillUniform(block.launchContext(), block.randomGenerator(), z, from, to);
|
||||||
|
|
||||||
// FIXME:
|
// FIXME:
|
||||||
//OVERWRITE_RESULT(z);
|
//OVERWRITE_RESULT(z);
|
||||||
|
@ -105,7 +105,7 @@ namespace nd4j {
|
||||||
if (!block.isInplace())
|
if (!block.isInplace())
|
||||||
z->assign(input);
|
z->assign(input);
|
||||||
|
|
||||||
RandomLauncher::applyDropOut(block.randomGenerator(), z, prob);
|
RandomLauncher::applyDropOut(block.launchContext(), block.randomGenerator(), z, prob);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case nd4j::random::DropOutInverted: {
|
case nd4j::random::DropOutInverted: {
|
||||||
|
@ -140,7 +140,7 @@ namespace nd4j {
|
||||||
|
|
||||||
auto z = OUTPUT_VARIABLE(0);//NDArrayFactory::create_<T>('c', shape, block.getWorkspace());
|
auto z = OUTPUT_VARIABLE(0);//NDArrayFactory::create_<T>('c', shape, block.getWorkspace());
|
||||||
|
|
||||||
RandomLauncher::fillGaussian(block.randomGenerator(), z, mean, stdev);
|
RandomLauncher::fillGaussian(block.launchContext(), block.randomGenerator(), z, mean, stdev);
|
||||||
|
|
||||||
// FIXME: !!
|
// FIXME: !!
|
||||||
//OVERWRITE_RESULT(z);
|
//OVERWRITE_RESULT(z);
|
||||||
|
@ -168,7 +168,7 @@ namespace nd4j {
|
||||||
|
|
||||||
auto z = OUTPUT_VARIABLE(0); // NDArrayFactory::create_<T>('c', shape, block.getWorkspace());
|
auto z = OUTPUT_VARIABLE(0); // NDArrayFactory::create_<T>('c', shape, block.getWorkspace());
|
||||||
|
|
||||||
RandomLauncher::fillBernoulli(block.randomGenerator(), z, prob);
|
RandomLauncher::fillBernoulli(block.launchContext(), block.randomGenerator(), z, prob);
|
||||||
|
|
||||||
// FIXME:
|
// FIXME:
|
||||||
//OVERWRITE_RESULT(z);
|
//OVERWRITE_RESULT(z);
|
||||||
|
@ -201,7 +201,7 @@ namespace nd4j {
|
||||||
|
|
||||||
auto z = OUTPUT_VARIABLE(0);//NDArrayFactory::create_<T>('c', shape, block.getWorkspace());
|
auto z = OUTPUT_VARIABLE(0);//NDArrayFactory::create_<T>('c', shape, block.getWorkspace());
|
||||||
|
|
||||||
RandomLauncher::fillBinomial(block.randomGenerator(), z, trials, prob);
|
RandomLauncher::fillBinomial(block.launchContext(), block.randomGenerator(), z, trials, prob);
|
||||||
|
|
||||||
// FIXME: !!!
|
// FIXME: !!!
|
||||||
//OVERWRITE_RESULT(z);
|
//OVERWRITE_RESULT(z);
|
||||||
|
@ -233,7 +233,7 @@ namespace nd4j {
|
||||||
|
|
||||||
auto z = OUTPUT_VARIABLE(0);//NDArrayFactory::create_<T>('c', shape, block.getWorkspace());
|
auto z = OUTPUT_VARIABLE(0);//NDArrayFactory::create_<T>('c', shape, block.getWorkspace());
|
||||||
|
|
||||||
RandomLauncher::fillLogNormal(block.randomGenerator(), z, mean, stdev);
|
RandomLauncher::fillLogNormal(block.launchContext(), block.randomGenerator(), z, mean, stdev);
|
||||||
|
|
||||||
// FIXME: !!
|
// FIXME: !!
|
||||||
//OVERWRITE_RESULT(z);
|
//OVERWRITE_RESULT(z);
|
||||||
|
@ -265,7 +265,7 @@ namespace nd4j {
|
||||||
|
|
||||||
auto z = OUTPUT_VARIABLE(0); // NDArrayFactory::create_<T>('c', shape, block.getWorkspace());
|
auto z = OUTPUT_VARIABLE(0); // NDArrayFactory::create_<T>('c', shape, block.getWorkspace());
|
||||||
|
|
||||||
RandomLauncher::fillTruncatedNormal(block.randomGenerator(), z, mean, stdev);
|
RandomLauncher::fillTruncatedNormal(block.launchContext(), block.randomGenerator(), z, mean, stdev);
|
||||||
|
|
||||||
// FIXME: !!!
|
// FIXME: !!!
|
||||||
//OVERWRITE_RESULT(z);
|
//OVERWRITE_RESULT(z);
|
||||||
|
@ -301,7 +301,7 @@ namespace nd4j {
|
||||||
if (!block.isInplace())
|
if (!block.isInplace())
|
||||||
z->assign(input);
|
z->assign(input);
|
||||||
|
|
||||||
RandomLauncher::applyAlphaDropOut(block.randomGenerator(), z, prob, a, b, pa);
|
RandomLauncher::applyAlphaDropOut(block.launchContext(), block.randomGenerator(), z, prob, a, b, pa);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case nd4j::random::Linspace: {
|
case nd4j::random::Linspace: {
|
||||||
|
|
|
@ -28,9 +28,81 @@
|
||||||
#include <NDArray.h>
|
#include <NDArray.h>
|
||||||
#include <ops/declarable/CustomOperations.h>
|
#include <ops/declarable/CustomOperations.h>
|
||||||
#include <types/types.h>
|
#include <types/types.h>
|
||||||
|
#include <helpers/Loops.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Concatneate multi array of the same shape together
|
||||||
|
* along a particular dimension
|
||||||
|
*/
|
||||||
|
template <typename T>
|
||||||
|
void SpecialMethods<T>::concatCpuGeneric(const std::vector<NDArray*>& inArrs, NDArray& output, const int axis) {
|
||||||
|
const uint numOfArrs = inArrs.size();
|
||||||
|
|
||||||
|
int outDim;
|
||||||
|
const bool isOutputVector = output.isCommonVector(outDim);
|
||||||
|
|
||||||
|
if(isOutputVector || (axis == 0 && output.ordering() == 'c')) {
|
||||||
|
|
||||||
|
bool allVectorsOrScalars = true;
|
||||||
|
const uint outEws = isOutputVector ? output.stridesOf()[outDim] : output.ews();
|
||||||
|
|
||||||
|
std::vector<int> nonUnityDim(numOfArrs);
|
||||||
|
std::vector<Nd4jLong> zOffset(numOfArrs);
|
||||||
|
|
||||||
|
for(int i = 0; i < numOfArrs; i++) {
|
||||||
|
allVectorsOrScalars &= (inArrs[i]->lengthOf() == 1 || inArrs[i]->isCommonVector(nonUnityDim[i]));
|
||||||
|
if(!allVectorsOrScalars)
|
||||||
|
break;
|
||||||
|
if(i == 0) zOffset[0] = 0;
|
||||||
|
else zOffset[i] = zOffset[i - 1] + outEws * inArrs[i - 1]->lengthOf();
|
||||||
|
}
|
||||||
|
|
||||||
|
if(allVectorsOrScalars) {
|
||||||
|
|
||||||
|
T* outBuff = output.bufferAsT<T>();
|
||||||
|
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||||
|
for (uint r = 0; r < numOfArrs; r++) {
|
||||||
|
|
||||||
|
const uint arrLen = inArrs[r]->lengthOf();
|
||||||
|
const uint xEws = (arrLen == 1) ? 1 : inArrs[r]->stridesOf()[nonUnityDim[r]];
|
||||||
|
|
||||||
|
T *z = outBuff + zOffset[r];
|
||||||
|
T *x = inArrs[r]->bufferAsT<T>();
|
||||||
|
|
||||||
|
if(outEws == 1 && xEws == 1)
|
||||||
|
for (uint e = 0; e < arrLen; e++)
|
||||||
|
z[e] = x[e];
|
||||||
|
else
|
||||||
|
for (uint e = 0; e < arrLen; e++)
|
||||||
|
z[e * outEws] = x[e * xEws];
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const int rank = inArrs[0]->rankOf();
|
||||||
|
const int rank2 = 2*rank;
|
||||||
|
std::vector<std::vector<Nd4jLong>> indices(numOfArrs, std::vector<Nd4jLong>(rank2,0));
|
||||||
|
|
||||||
|
// take into account indices for first array
|
||||||
|
indices[0][2 * axis + 1] = inArrs[0]->sizeAt(axis);
|
||||||
|
|
||||||
|
// loop through the rest of input arrays
|
||||||
|
for(int i = 1; i < numOfArrs; ++i) {
|
||||||
|
indices[i][2 * axis] = indices[i-1][2 * axis + 1]; // index start from
|
||||||
|
indices[i][2 * axis + 1] = indices[i-1][2 * axis + 1] + inArrs[i]->sizeAt(axis); // index end with (excluding)
|
||||||
|
}
|
||||||
|
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||||
|
for(int i = 0; i < numOfArrs; ++i) {
|
||||||
|
auto temp = output(indices[i], true);
|
||||||
|
nd4j::TransformLoops<T,T,T>::template loopTransform<simdOps::Assign<T,T>, false>(inArrs[i]->bufferAsT<T>(), inArrs[i]->getShapeInfo(), temp.bufferAsT<T>(), temp.getShapeInfo(), nullptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Concatneate multi array of the same shape together
|
* Concatneate multi array of the same shape together
|
||||||
* along a particular dimension
|
* along a particular dimension
|
||||||
|
@ -38,24 +110,14 @@ namespace nd4j {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void SpecialMethods<T>::concatCpuGeneric(int dimension, int numArrays, Nd4jPointer *data, Nd4jPointer *inputShapeInfo, void *vresult, Nd4jLong *resultShapeInfo) {
|
void SpecialMethods<T>::concatCpuGeneric(int dimension, int numArrays, Nd4jPointer *data, Nd4jPointer *inputShapeInfo, void *vresult, Nd4jLong *resultShapeInfo) {
|
||||||
auto result = reinterpret_cast<T *>(vresult);
|
auto result = reinterpret_cast<T *>(vresult);
|
||||||
|
|
||||||
std::vector<Nd4jLong> iArgs = {dimension};
|
|
||||||
std::vector<double> tArgs;
|
|
||||||
std::vector<bool> bArgsEmpty;
|
|
||||||
std::vector<NDArray*> inputs(numArrays);
|
std::vector<NDArray*> inputs(numArrays);
|
||||||
std::vector<NDArray*> outputs(1);
|
|
||||||
|
|
||||||
outputs[0] = new NDArray(static_cast<void*>(result), static_cast<Nd4jLong*>(resultShapeInfo));
|
NDArray output(static_cast<void*>(result), static_cast<Nd4jLong*>(resultShapeInfo));
|
||||||
|
|
||||||
for(int i = 0; i < numArrays; ++i)
|
for(int i = 0; i < numArrays; ++i)
|
||||||
inputs[i] = new NDArray(static_cast<void *>(data[i]), static_cast<Nd4jLong*>(inputShapeInfo[i]));
|
inputs[i] = new NDArray(static_cast<void *>(data[i]), static_cast<Nd4jLong*>(inputShapeInfo[i]));
|
||||||
|
|
||||||
nd4j::ops::concat op;
|
nd4j::SpecialMethods<T>::concatCpuGeneric(inputs, output, dimension);
|
||||||
auto status = op.execute(inputs, outputs, tArgs, iArgs, bArgsEmpty);
|
|
||||||
if(status != Status::OK())
|
|
||||||
throw std::runtime_error("concatCpuGeneric fails to be executed !");
|
|
||||||
|
|
||||||
delete outputs[0];
|
|
||||||
|
|
||||||
for(int i = 0; i < numArrays; ++i)
|
for(int i = 0; i < numArrays; ++i)
|
||||||
delete inputs[i];
|
delete inputs[i];
|
||||||
|
|
|
@ -30,6 +30,8 @@
|
||||||
#include <pointercast.h>
|
#include <pointercast.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
|
class NDArray;
|
||||||
|
|
||||||
//FIXME: get rid of this redefinition
|
//FIXME: get rid of this redefinition
|
||||||
typedef union
|
typedef union
|
||||||
{
|
{
|
||||||
|
@ -47,6 +49,7 @@ namespace nd4j {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
class ND4J_EXPORT SpecialMethods {
|
class ND4J_EXPORT SpecialMethods {
|
||||||
public:
|
public:
|
||||||
|
static void concatCpuGeneric(const std::vector<NDArray*>& inArrs, NDArray& output, const int axis);
|
||||||
static void concatCpuGeneric(int dimension, int numArrays, Nd4jPointer *data, Nd4jPointer *inputShapeInfo, void *result, Nd4jLong *resultShapeInfo);
|
static void concatCpuGeneric(int dimension, int numArrays, Nd4jPointer *data, Nd4jPointer *inputShapeInfo, void *result, Nd4jLong *resultShapeInfo);
|
||||||
static void accumulateGeneric(void **x, void *z, Nd4jLong *zShapeInfo, int n, const Nd4jLong length);
|
static void accumulateGeneric(void **x, void *z, Nd4jLong *zShapeInfo, int n, const Nd4jLong length);
|
||||||
static void averageGeneric(void **x, void *z, Nd4jLong *zShapeInfo, int n, const Nd4jLong length, bool propagate);
|
static void averageGeneric(void **x, void *z, Nd4jLong *zShapeInfo, int n, const Nd4jLong length, bool propagate);
|
||||||
|
|
|
@ -541,8 +541,6 @@ public class ImportClassMapping {
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.strict.Log.class,
|
org.nd4j.linalg.api.ops.impl.transforms.strict.Log.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.strict.Log1p.class,
|
org.nd4j.linalg.api.ops.impl.transforms.strict.Log1p.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.strict.LogSigmoid.class,
|
org.nd4j.linalg.api.ops.impl.transforms.strict.LogSigmoid.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.strict.OldLogSoftMax.class,
|
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.strict.OldSoftMax.class,
|
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELU.class,
|
org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELU.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELUDerivative.class,
|
org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELUDerivative.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.strict.RationalTanh.class,
|
org.nd4j.linalg.api.ops.impl.transforms.strict.RationalTanh.class,
|
||||||
|
|
|
@ -23,6 +23,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.Op;
|
import org.nd4j.linalg.api.ops.Op;
|
||||||
import org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU;
|
import org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU;
|
||||||
import org.nd4j.linalg.api.ops.impl.scalar.ScalarSet;
|
import org.nd4j.linalg.api.ops.impl.scalar.ScalarSet;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.same.OldIdentity;
|
import org.nd4j.linalg.api.ops.impl.transforms.same.OldIdentity;
|
||||||
import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear;
|
import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear;
|
||||||
import org.nd4j.linalg.api.ops.impl.scalar.Step;
|
import org.nd4j.linalg.api.ops.impl.scalar.Step;
|
||||||
|
@ -161,109 +162,4 @@ public enum Activation {
|
||||||
throw new UnsupportedOperationException("Activation function not yet supported: " + this);
|
throw new UnsupportedOperationException("Activation function not yet supported: " + this);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the Activation function as an ND4J Transform, applied on either the input or a copy of the input
|
|
||||||
*
|
|
||||||
* @param in Input to apply the activation function op to
|
|
||||||
* @param dup If true: duplicate the array before applying the transform. If false: don't duplicate
|
|
||||||
* @return The transform op (execute using {@code Nd4j.getExecutioner().exec(op)}
|
|
||||||
*/
|
|
||||||
public Op asTransform(INDArray in, boolean dup) {
|
|
||||||
if (dup) {
|
|
||||||
in = in.dup();
|
|
||||||
}
|
|
||||||
switch (this) {
|
|
||||||
case CUBE:
|
|
||||||
return new Cube(in);
|
|
||||||
case ELU:
|
|
||||||
return new ELU(in);
|
|
||||||
case HARDSIGMOID:
|
|
||||||
return new HardSigmoid(in);
|
|
||||||
case HARDTANH:
|
|
||||||
return new HardTanh(in);
|
|
||||||
case IDENTITY:
|
|
||||||
return new OldIdentity(in);
|
|
||||||
case LEAKYRELU:
|
|
||||||
return new LeakyReLU(in);
|
|
||||||
case RATIONALTANH:
|
|
||||||
return new RationalTanh(in);
|
|
||||||
case RELU:
|
|
||||||
return new RectifiedLinear(in);
|
|
||||||
case SIGMOID:
|
|
||||||
return new Sigmoid(in);
|
|
||||||
case SOFTMAX:
|
|
||||||
return new OldSoftMax(in);
|
|
||||||
case SOFTPLUS:
|
|
||||||
return new SoftPlus(in);
|
|
||||||
case SOFTSIGN:
|
|
||||||
return new SoftSign(in);
|
|
||||||
case TANH:
|
|
||||||
return new Tanh(in);
|
|
||||||
case RECTIFIEDTANH:
|
|
||||||
return new RectifiedTanh(in);
|
|
||||||
case SELU:
|
|
||||||
return new SELU(in);
|
|
||||||
case SWISH:
|
|
||||||
return new Swish(in);
|
|
||||||
case GELU:
|
|
||||||
return new GELU(in);
|
|
||||||
case RRELU:
|
|
||||||
default:
|
|
||||||
throw new UnsupportedOperationException("Not supported via this method: " + this);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the Activation function <i>derivative</i> (i.e., dOut/dIn) as an ND4J Transform, applied on either the input
|
|
||||||
* or a copy of the input
|
|
||||||
*
|
|
||||||
* @param in Input to apply the activation function derivative op to
|
|
||||||
* @param dup If true: duplicate the array before applying the transform. If false: don't duplicate
|
|
||||||
* @return The op (execute using {@code Nd4j.getExecutioner().exec(op)}
|
|
||||||
*/
|
|
||||||
public Op asTransformDerivative(INDArray in, boolean dup) {
|
|
||||||
if (dup) {
|
|
||||||
in = in.dup();
|
|
||||||
}
|
|
||||||
switch (this) {
|
|
||||||
case CUBE:
|
|
||||||
return new CubeDerivative(in);
|
|
||||||
case ELU:
|
|
||||||
return new ELUDerivative(in);
|
|
||||||
case HARDSIGMOID:
|
|
||||||
return new HardSigmoidDerivative(in);
|
|
||||||
case HARDTANH:
|
|
||||||
return new HardTanhDerivative(in);
|
|
||||||
case LEAKYRELU:
|
|
||||||
return new LeakyReLUDerivative(in);
|
|
||||||
case RATIONALTANH:
|
|
||||||
return new RationalTanhDerivative(in);
|
|
||||||
case SIGMOID:
|
|
||||||
return new SigmoidDerivative(in);
|
|
||||||
case SOFTPLUS:
|
|
||||||
return new Sigmoid(in);
|
|
||||||
case SOFTSIGN:
|
|
||||||
return new SoftSignDerivative(in);
|
|
||||||
case TANH:
|
|
||||||
return new TanhDerivative(in);
|
|
||||||
case RECTIFIEDTANH:
|
|
||||||
return new RectifiedTanhDerivative(in);
|
|
||||||
case SELU:
|
|
||||||
return new SELUDerivative(in);
|
|
||||||
case SWISH:
|
|
||||||
return new SwishDerivative(in);
|
|
||||||
case SOFTMAX:
|
|
||||||
return new SoftMaxDerivative(in);
|
|
||||||
case IDENTITY:
|
|
||||||
return new ScalarSet(in, 1.0);
|
|
||||||
case RELU:
|
|
||||||
return new Step(in);
|
|
||||||
case GELU:
|
|
||||||
return new GELUDerivative(in);
|
|
||||||
case RRELU:
|
|
||||||
default:
|
|
||||||
throw new UnsupportedOperationException("Not supported via this method: " + this);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,7 +20,8 @@ import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.OldSoftMax;
|
import org.nd4j.linalg.api.ops.CustomOp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
|
@ -34,14 +35,14 @@ public class ActivationSoftmax extends BaseActivationFunction {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray getActivation(INDArray in, boolean training) {
|
public INDArray getActivation(INDArray in, boolean training) {
|
||||||
Nd4j.getExecutioner().execAndReturn(new OldSoftMax(in));
|
Nd4j.getExecutioner().execAndReturn((CustomOp) new SoftMax(in, in));
|
||||||
return in;
|
return in;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
||||||
assertShape(in, epsilon);
|
assertShape(in, epsilon);
|
||||||
INDArray out = Nd4j.getExecutioner().exec(new OldSoftMax(in));
|
INDArray out = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(in, in.ulike()))[0];
|
||||||
INDArray x = out.mul(epsilon).sum(1);
|
INDArray x = out.mul(epsilon).sum(1);
|
||||||
INDArray dLdz = out.mul(epsilon.subColumnVector(x));
|
INDArray dLdz = out.mul(epsilon.subColumnVector(x));
|
||||||
return new Pair<>(dLdz, null);
|
return new Pair<>(dLdz, null);
|
||||||
|
|
|
@ -21,14 +21,9 @@ import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformOp;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformStrictOp;
|
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.OldSoftMax;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
|
||||||
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
|
|
@ -19,10 +19,14 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.CustomOp;
|
||||||
|
import org.nd4j.linalg.api.ops.Op;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
|
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
|
||||||
|
|
||||||
|
import java.nio.Buffer;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
@ -53,10 +57,6 @@ public class SoftMax extends BaseDynamicTransformOp {
|
||||||
super(sameDiff, args, inPlace);
|
super(sameDiff, args, inPlace);
|
||||||
}
|
}
|
||||||
|
|
||||||
public SoftMax(INDArray input, INDArray result){
|
|
||||||
super(new INDArray[]{input}, new INDArray[]{result});
|
|
||||||
}
|
|
||||||
|
|
||||||
public SoftMax(SameDiff sameDiff, SDVariable[] args, int dimension) {
|
public SoftMax(SameDiff sameDiff, SDVariable[] args, int dimension) {
|
||||||
super(sameDiff, args, false);
|
super(sameDiff, args, false);
|
||||||
this.dimension = dimension;
|
this.dimension = dimension;
|
||||||
|
@ -75,13 +75,19 @@ public class SoftMax extends BaseDynamicTransformOp {
|
||||||
addIArgument(dimension);
|
addIArgument(dimension);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public SoftMax(INDArray input){
|
||||||
|
this(input, input);
|
||||||
|
}
|
||||||
|
|
||||||
|
public SoftMax(INDArray input, INDArray result){
|
||||||
|
this(input, result, -1);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "softmax";
|
return "softmax";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String onnxName() {
|
public String onnxName() {
|
||||||
return "Softmax";
|
return "Softmax";
|
||||||
|
|
|
@ -1,87 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.strict;
|
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
import org.nd4j.base.Preconditions;
|
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformStrictOp;
|
|
||||||
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Old LogSoftMax function
|
|
||||||
*
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
|
|
||||||
public class OldLogSoftMax extends BaseTransformStrictOp {
|
|
||||||
public OldLogSoftMax(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
|
|
||||||
super(sameDiff, i_v1, i_v2);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldLogSoftMax(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
|
|
||||||
super(sameDiff, i_v1, i_v2, inPlace);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldLogSoftMax(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
|
|
||||||
super(sameDiff, i_v, inPlace);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldLogSoftMax() {
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldLogSoftMax(INDArray x){
|
|
||||||
this(x,x);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldLogSoftMax(INDArray x, INDArray z) {
|
|
||||||
super(x, z);
|
|
||||||
Preconditions.checkArgument(x != null && x.rank() == 2, "OldSoftMax op supports rank 2 (2d) arrays only. Got x (source) array with shape: %ndShape", x);
|
|
||||||
Preconditions.checkArgument(z != null && z.rank() == 2, "OldSoftMax op supports rank 2 (2d) arrays only. Got z (result) array with shape: %ndShape", z);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int opNum() {
|
|
||||||
return 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String opName() {
|
|
||||||
return "old_logsoftmax";
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
return "old_LogSoftmax";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
|
||||||
SDVariable ret = f().logSoftmaxDerivative(arg(), i_v.get(0));
|
|
||||||
return Collections.singletonList(ret);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,94 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.strict;
|
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
import org.nd4j.base.Preconditions;
|
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformStrictOp;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Soft max function
|
|
||||||
* row_maxes is a row vector (max for each row)
|
|
||||||
* row_maxes = rowmaxes(input)
|
|
||||||
* diff = exp(input - max) / diff.rowSums()
|
|
||||||
* Outputs a probability distribution.
|
|
||||||
* Note that this is a parameterized model and requires
|
|
||||||
* the sum and max for the vector being calculated
|
|
||||||
*
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
|
|
||||||
public class OldSoftMax extends BaseTransformStrictOp {
|
|
||||||
public OldSoftMax(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
|
|
||||||
super(sameDiff, i_v1, i_v2);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldSoftMax(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
|
|
||||||
super(sameDiff, i_v1, i_v2, inPlace);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldSoftMax(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
|
|
||||||
super(sameDiff, i_v, inPlace);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldSoftMax() {
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldSoftMax(INDArray x){
|
|
||||||
this(x,x);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OldSoftMax(INDArray x, INDArray z) {
|
|
||||||
super(x, z);
|
|
||||||
Preconditions.checkArgument(x != null && x.rank() == 2, "OldSoftMax op supports rank 2 (2d) arrays only. Got x (source) array with shape: %ndShape", x);
|
|
||||||
Preconditions.checkArgument(z != null && z.rank() == 2, "OldSoftMax op supports rank 2 (2d) arrays only. Got z (result) array with shape: %ndShape", z);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int opNum() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String opName() {
|
|
||||||
return "old_softmax";
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
|
||||||
SDVariable ret = f().softmaxDerivative(arg(), i_v.get(0), 1);
|
|
||||||
return Collections.singletonList(ret);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -19,20 +19,27 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.CustomOp;
|
||||||
|
import org.nd4j.linalg.api.ops.Op;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
||||||
|
|
||||||
|
import java.nio.Buffer;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Softmax derivative
|
* Softmax derivative
|
||||||
*
|
*
|
||||||
* @author Adam Gibson
|
* @author Adam Gibson
|
||||||
*/
|
*/
|
||||||
public class SoftMaxDerivative extends OldSoftMax {
|
public class SoftMaxDerivative extends SoftMax {
|
||||||
public SoftMaxDerivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
|
public SoftMaxDerivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
|
||||||
super(sameDiff, i_v1, i_v2);
|
super(sameDiff, new SDVariable[]{i_v1, i_v2});
|
||||||
}
|
}
|
||||||
|
|
||||||
public SoftMaxDerivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
|
public SoftMaxDerivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
|
||||||
super(sameDiff, i_v1, i_v2, inPlace);
|
super(sameDiff, new SDVariable[]{ i_v1, i_v2}, inPlace);
|
||||||
}
|
}
|
||||||
|
|
||||||
public SoftMaxDerivative(INDArray x, INDArray z) {
|
public SoftMaxDerivative(INDArray x, INDArray z) {
|
||||||
|
@ -40,11 +47,13 @@ public class SoftMaxDerivative extends OldSoftMax {
|
||||||
}
|
}
|
||||||
|
|
||||||
public SoftMaxDerivative(INDArray x) {
|
public SoftMaxDerivative(INDArray x) {
|
||||||
super(x);
|
super(x, x);
|
||||||
}
|
}
|
||||||
|
|
||||||
public SoftMaxDerivative() {}
|
public SoftMaxDerivative() {}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int opNum() {
|
public int opNum() {
|
||||||
return 1;
|
return 1;
|
||||||
|
|
|
@ -1334,6 +1334,17 @@ public class Shape {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static boolean isVector(LongBuffer shapeInfo) {
|
||||||
|
int rank = Shape.rank(shapeInfo);
|
||||||
|
if (rank > 2 || rank < 1)
|
||||||
|
return false;
|
||||||
|
else {
|
||||||
|
long len = Shape.length(shapeInfo);
|
||||||
|
val shape = Shape.shapeOf(shapeInfo);
|
||||||
|
return shape.get(0) == len || shape.get(1) == len;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns whether the given shape is a vector
|
* Returns whether the given shape is a vector
|
||||||
*
|
*
|
||||||
|
@ -2498,6 +2509,14 @@ public class Shape {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static int length(LongBuffer buffer) {
|
||||||
|
int ret = 1;
|
||||||
|
val shape = Shape.shapeOf(buffer);
|
||||||
|
int rank = Shape.rank(buffer);
|
||||||
|
for (int i = 0; i < rank; i++)
|
||||||
|
ret *= shape.get(i);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Gets the rank given the shape info buffer
|
* Gets the rank given the shape info buffer
|
||||||
|
@ -2764,8 +2783,8 @@ public class Shape {
|
||||||
* @param rank the rank to get the length for
|
* @param rank the rank to get the length for
|
||||||
* @return rank * 2 + 4
|
* @return rank * 2 + 4
|
||||||
*/
|
*/
|
||||||
public static int shapeInfoLength(int rank) {
|
public static int shapeInfoLength(long rank) {
|
||||||
return rank * 2 + 4;
|
return (int) rank * 2 + 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static int shapeInfoLength(long[] shape) {
|
public static int shapeInfoLength(long[] shape) {
|
||||||
|
@ -3072,6 +3091,11 @@ public class Shape {
|
||||||
return buffer.get(length2 - 2);
|
return buffer.get(length2 - 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static long elementWiseStride(LongBuffer buffer) {
|
||||||
|
int length2 = shapeInfoLength(buffer.get(0));
|
||||||
|
return buffer.get(length2 - 2);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the element wise stride for the
|
* Get the element wise stride for the
|
||||||
* shape info buffer
|
* shape info buffer
|
||||||
|
@ -3179,40 +3203,6 @@ public class Shape {
|
||||||
throw new RuntimeException("setOrder called");
|
throw new RuntimeException("setOrder called");
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates the shape information buffer
|
|
||||||
* given the shape,stride
|
|
||||||
* @param shape the shape for the buffer
|
|
||||||
* @param stride the stride for the buffer
|
|
||||||
* @param offset the offset for the buffer
|
|
||||||
* @param elementWiseStride the element wise stride for the buffer
|
|
||||||
* @param order the order for the buffer
|
|
||||||
* @return the shape information buffer given the parameters
|
|
||||||
*/
|
|
||||||
public static DataBuffer createShapeInformation(int[] shape, int[] stride, long offset, int elementWiseStride, char order) {
|
|
||||||
if (shape.length != stride.length)
|
|
||||||
throw new IllegalStateException("Shape and stride must be the same length");
|
|
||||||
|
|
||||||
int rank = shape.length;
|
|
||||||
int shapeBuffer[] = new int[rank * 2 + 4];
|
|
||||||
shapeBuffer[0] = rank;
|
|
||||||
int count = 1;
|
|
||||||
for (int e = 0; e < shape.length; e++)
|
|
||||||
shapeBuffer[count++] = shape[e];
|
|
||||||
|
|
||||||
for (int e = 0; e < stride.length; e++)
|
|
||||||
shapeBuffer[count++] = stride[e];
|
|
||||||
|
|
||||||
shapeBuffer[count++] = (int) offset;
|
|
||||||
shapeBuffer[count++] = elementWiseStride;
|
|
||||||
shapeBuffer[count] = (int) order;
|
|
||||||
|
|
||||||
DataBuffer ret = Nd4j.createBufferDetached(shapeBuffer);
|
|
||||||
ret.setConstant(true);
|
|
||||||
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static DataBuffer createShapeInformation(long[] shape, long[] stride, long elementWiseStride, char order, DataType dataType, boolean empty) {
|
public static DataBuffer createShapeInformation(long[] shape, long[] stride, long elementWiseStride, char order, DataType dataType, boolean empty) {
|
||||||
boolean isEmpty = empty;
|
boolean isEmpty = empty;
|
||||||
if (!empty)
|
if (!empty)
|
||||||
|
@ -3438,9 +3428,20 @@ public class Shape {
|
||||||
|
|
||||||
public static boolean contentEquals(long[] arr, IntBuffer other) {
|
public static boolean contentEquals(long[] arr, IntBuffer other) {
|
||||||
for (int i = 0; i < arr.length; i++) {
|
for (int i = 0; i < arr.length; i++) {
|
||||||
Buffer buffer2 = (Buffer) other;
|
val t = arr[i];
|
||||||
buffer2.position(i);
|
val o = other.get(i);
|
||||||
if (arr[i] != other.get()) {
|
if (t != o) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static boolean contentEquals(long[] arr, LongBuffer other) {
|
||||||
|
for (int i = 0; i < arr.length; i++) {
|
||||||
|
val t = arr[i];
|
||||||
|
val o = other.get(i);
|
||||||
|
if (t != o) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,8 +25,8 @@ import org.nd4j.linalg.activations.impl.ActivationSoftmax;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.CustomOp;
|
import org.nd4j.linalg.api.ops.CustomOp;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.same.TimesOneMinus;
|
import org.nd4j.linalg.api.ops.impl.transforms.same.TimesOneMinus;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.OldSoftMax;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.lossfunctions.ILossFunction;
|
import org.nd4j.linalg.lossfunctions.ILossFunction;
|
||||||
import org.nd4j.linalg.lossfunctions.LossUtil;
|
import org.nd4j.linalg.lossfunctions.LossUtil;
|
||||||
|
@ -121,7 +121,7 @@ public class LossBinaryXENT implements ILossFunction {
|
||||||
INDArray scoreArr;
|
INDArray scoreArr;
|
||||||
if (activationFn instanceof ActivationSoftmax) {
|
if (activationFn instanceof ActivationSoftmax) {
|
||||||
//TODO Post GPU support for custom ops: Use LogSoftMax op to avoid numerical issues when calculating score
|
//TODO Post GPU support for custom ops: Use LogSoftMax op to avoid numerical issues when calculating score
|
||||||
INDArray logsoftmax = Nd4j.getExecutioner().exec(new OldSoftMax(preOutput.dup()));
|
INDArray logsoftmax = Nd4j.exec((CustomOp) new SoftMax(preOutput, preOutput.ulike(), -1))[0];
|
||||||
Transforms.log(logsoftmax, false);
|
Transforms.log(logsoftmax, false);
|
||||||
scoreArr = logsoftmax.muli(labels);
|
scoreArr = logsoftmax.muli(labels);
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,8 @@ import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import org.nd4j.linalg.activations.IActivation;
|
import org.nd4j.linalg.activations.IActivation;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.OldSoftMax;
|
import org.nd4j.linalg.api.ops.CustomOp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.indexing.INDArrayIndex;
|
import org.nd4j.linalg.indexing.INDArrayIndex;
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||||
|
@ -139,7 +140,7 @@ public class LossMixtureDensity implements ILossFunction {
|
||||||
|
|
||||||
// Alpha is a softmax because
|
// Alpha is a softmax because
|
||||||
// the alpha should all sum to 1 for a given gaussian mixture.
|
// the alpha should all sum to 1 for a given gaussian mixture.
|
||||||
mdc.alpha = Nd4j.getExecutioner().exec(new OldSoftMax(mdc.alpha));
|
mdc.alpha = Nd4j.exec((CustomOp) new SoftMax(mdc.alpha, mdc.alpha, -1))[0];
|
||||||
|
|
||||||
// Mu comes directly from the network as an unmolested value.
|
// Mu comes directly from the network as an unmolested value.
|
||||||
// Note that this effectively means that the output layer of
|
// Note that this effectively means that the output layer of
|
||||||
|
|
|
@ -21,6 +21,7 @@ import lombok.val;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.CustomOp;
|
||||||
import org.nd4j.linalg.api.ops.ScalarOp;
|
import org.nd4j.linalg.api.ops.ScalarOp;
|
||||||
import org.nd4j.linalg.api.ops.TransformOp;
|
import org.nd4j.linalg.api.ops.TransformOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce3.*;
|
import org.nd4j.linalg.api.ops.impl.reduce3.*;
|
||||||
|
@ -29,6 +30,7 @@ import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarNot;
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.Cross;
|
import org.nd4j.linalg.api.ops.impl.shape.Cross;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.bool.BooleanNot;
|
import org.nd4j.linalg.api.ops.impl.transforms.bool.BooleanNot;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
|
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.floating.*;
|
import org.nd4j.linalg.api.ops.impl.transforms.floating.*;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.comparison.*;
|
import org.nd4j.linalg.api.ops.impl.transforms.comparison.*;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.ELUDerivative;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.ELUDerivative;
|
||||||
|
@ -512,7 +514,7 @@ public class Transforms {
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public static INDArray softmax(INDArray in, boolean copy) {
|
public static INDArray softmax(INDArray in, boolean copy) {
|
||||||
return Nd4j.getExecutioner().exec(new OldSoftMax(in, (copy ? in.ulike() : in)));
|
return Nd4j.getExecutioner().exec((CustomOp) new SoftMax(in, (copy ? in.ulike() : in), -1))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -240,7 +240,7 @@ public class OpProfiler {
|
||||||
String opClass = getOpClass(op);
|
String opClass = getOpClass(op);
|
||||||
classCounter.incrementCount(opClass);
|
classCounter.incrementCount(opClass);
|
||||||
|
|
||||||
if(op.x() == null || (op.x() != null && op.x().data().address() == lastZ && op.z() == op.x() && op.y() == null)) {
|
if(op.x() == null || (op.x() != null && op.x().data().platformAddress() == lastZ && op.z() == op.x() && op.y() == null)) {
|
||||||
// we have possible shift here
|
// we have possible shift here
|
||||||
matchingCounter.incrementCount(prevOpMatching + " -> " + opClass);
|
matchingCounter.incrementCount(prevOpMatching + " -> " + opClass);
|
||||||
matchingCounterDetailed.incrementCount(prevOpMatchingDetailed + " -> " + opClass + " " + op.opName());
|
matchingCounterDetailed.incrementCount(prevOpMatchingDetailed + " -> " + opClass + " " + op.opName());
|
||||||
|
@ -254,7 +254,7 @@ public class OpProfiler {
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
lastZ = op.z() != null ? op.z().data().address() : 0L;
|
lastZ = op.z() != null ? op.z().data().platformAddress() : 0L;
|
||||||
prevOpMatching = opClass;
|
prevOpMatching = opClass;
|
||||||
prevOpMatchingDetailed = opClass + " " + op.opName();
|
prevOpMatchingDetailed = opClass + " " + op.opName();
|
||||||
prevOpMatchingInverted = opClass + " " + op.opName();
|
prevOpMatchingInverted = opClass + " " + op.opName();
|
||||||
|
|
|
@ -610,6 +610,7 @@ public abstract class BaseNativeNDArrayFactory extends BaseNDArrayFactory {
|
||||||
bb2.put((byte)((s >> 8) & 0xff));
|
bb2.put((byte)((s >> 8) & 0xff));
|
||||||
bb2.put((byte)(s & 0xff));
|
bb2.put((byte)(s & 0xff));
|
||||||
}
|
}
|
||||||
|
Nd4j.getAffinityManager().tagLocation(arr, AffinityManager.Location.HOST);
|
||||||
map.put(fName, arr.reshape(order, shape));
|
map.put(fName, arr.reshape(order, shape));
|
||||||
} else if(dt == DataType.LONG){
|
} else if(dt == DataType.LONG){
|
||||||
long[] d = new long[(int)size];
|
long[] d = new long[(int)size];
|
||||||
|
|
|
@ -72,7 +72,7 @@ public class SynchronousFlowController implements FlowController {
|
||||||
public void synchronizeToHost(AllocationPoint point) {
|
public void synchronizeToHost(AllocationPoint point) {
|
||||||
|
|
||||||
if (!point.isActualOnHostSide()) {
|
if (!point.isActualOnHostSide()) {
|
||||||
CudaContext context = (CudaContext) allocator.getDeviceContext().getContext();
|
val context = (CudaContext) allocator.getDeviceContext().getContext();
|
||||||
|
|
||||||
if (!point.isConstant())
|
if (!point.isConstant())
|
||||||
waitTillFinished(point);
|
waitTillFinished(point);
|
||||||
|
@ -102,7 +102,7 @@ public class SynchronousFlowController implements FlowController {
|
||||||
|
|
||||||
if (!point.isActualOnDeviceSide()) {
|
if (!point.isActualOnDeviceSide()) {
|
||||||
if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
|
if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
|
||||||
CudaContext context = (CudaContext) allocator.getDeviceContext().getContext();
|
val context = (CudaContext) allocator.getDeviceContext().getContext();
|
||||||
|
|
||||||
long perfD = PerformanceTracker.getInstance().helperStartTransaction();
|
long perfD = PerformanceTracker.getInstance().helperStartTransaction();
|
||||||
|
|
||||||
|
@ -135,17 +135,17 @@ public class SynchronousFlowController implements FlowController {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public CudaContext prepareActionAllWrite(INDArray... operands) {
|
public CudaContext prepareActionAllWrite(INDArray... operands) {
|
||||||
CudaContext context = (CudaContext) allocator.getDeviceContext().getContext();
|
val context = (CudaContext) allocator.getDeviceContext().getContext();
|
||||||
int cId = allocator.getDeviceId();
|
val cId = allocator.getDeviceId();
|
||||||
|
|
||||||
for (INDArray operand : operands) {
|
for (INDArray operand : operands) {
|
||||||
if (operand == null)
|
if (operand == null || operand.isEmpty())
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
Nd4j.getCompressor().autoDecompress(operand);
|
Nd4j.getCompressor().autoDecompress(operand);
|
||||||
|
|
||||||
AllocationPoint pointData = allocator.getAllocationPoint(operand);
|
val pointData = allocator.getAllocationPoint(operand);
|
||||||
AllocationPoint pointShape = allocator.getAllocationPoint(operand.shapeInfoDataBuffer());
|
val pointShape = allocator.getAllocationPoint(operand.shapeInfoDataBuffer());
|
||||||
|
|
||||||
pointData.acquireLock();
|
pointData.acquireLock();
|
||||||
|
|
||||||
|
@ -168,15 +168,15 @@ public class SynchronousFlowController implements FlowController {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public CudaContext prepareAction(INDArray result, INDArray... operands) {
|
public CudaContext prepareAction(INDArray result, INDArray... operands) {
|
||||||
CudaContext context = (CudaContext) allocator.getDeviceContext().getContext();
|
val context = (CudaContext) allocator.getDeviceContext().getContext();
|
||||||
int cId = allocator.getDeviceId();
|
val cId = allocator.getDeviceId();
|
||||||
|
|
||||||
|
|
||||||
if (result != null) {
|
if (result != null && !result.isEmpty()) {
|
||||||
Nd4j.getCompressor().autoDecompress(result);
|
Nd4j.getCompressor().autoDecompress(result);
|
||||||
prepareDelayedMemory(result);
|
prepareDelayedMemory(result);
|
||||||
AllocationPoint pointData = allocator.getAllocationPoint(result);
|
val pointData = allocator.getAllocationPoint(result);
|
||||||
AllocationPoint pointShape = allocator.getAllocationPoint(result.shapeInfoDataBuffer());
|
val pointShape = allocator.getAllocationPoint(result.shapeInfoDataBuffer());
|
||||||
|
|
||||||
pointData.acquireLock();
|
pointData.acquireLock();
|
||||||
|
|
||||||
|
@ -196,13 +196,13 @@ public class SynchronousFlowController implements FlowController {
|
||||||
}
|
}
|
||||||
|
|
||||||
for (INDArray operand : operands) {
|
for (INDArray operand : operands) {
|
||||||
if (operand == null)
|
if (operand == null || operand.isEmpty())
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
Nd4j.getCompressor().autoDecompress(operand);
|
Nd4j.getCompressor().autoDecompress(operand);
|
||||||
|
|
||||||
AllocationPoint pointData = allocator.getAllocationPoint(operand);
|
val pointData = allocator.getAllocationPoint(operand);
|
||||||
AllocationPoint pointShape = allocator.getAllocationPoint(operand.shapeInfoDataBuffer());
|
val pointShape = allocator.getAllocationPoint(operand.shapeInfoDataBuffer());
|
||||||
|
|
||||||
pointData.acquireLock();
|
pointData.acquireLock();
|
||||||
|
|
||||||
|
@ -256,7 +256,7 @@ public class SynchronousFlowController implements FlowController {
|
||||||
if (operand == null)
|
if (operand == null)
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
AllocationPoint pointOperand = allocator.getAllocationPoint(operand);
|
val pointOperand = allocator.getAllocationPoint(operand);
|
||||||
pointOperand.tickDeviceWrite();
|
pointOperand.tickDeviceWrite();
|
||||||
eventsProvider.storeEvent(pointOperand.getLastWriteEvent());
|
eventsProvider.storeEvent(pointOperand.getLastWriteEvent());
|
||||||
pointOperand.setLastWriteEvent(eventsProvider.getEvent());
|
pointOperand.setLastWriteEvent(eventsProvider.getEvent());
|
||||||
|
@ -266,9 +266,10 @@ public class SynchronousFlowController implements FlowController {
|
||||||
}
|
}
|
||||||
|
|
||||||
public void registerAction(CudaContext context, INDArray result, INDArray... operands) {
|
public void registerAction(CudaContext context, INDArray result, INDArray... operands) {
|
||||||
if (result == null)
|
if (result == null || result.isEmpty())
|
||||||
return;
|
return;
|
||||||
AllocationPoint point = allocator.getAllocationPoint(result);
|
|
||||||
|
val point = allocator.getAllocationPoint(result);
|
||||||
point.tickDeviceWrite();
|
point.tickDeviceWrite();
|
||||||
eventsProvider.storeEvent(point.getLastWriteEvent());
|
eventsProvider.storeEvent(point.getLastWriteEvent());
|
||||||
point.setLastWriteEvent(eventsProvider.getEvent());
|
point.setLastWriteEvent(eventsProvider.getEvent());
|
||||||
|
@ -276,10 +277,10 @@ public class SynchronousFlowController implements FlowController {
|
||||||
point.releaseLock();
|
point.releaseLock();
|
||||||
|
|
||||||
for (INDArray operand : operands) {
|
for (INDArray operand : operands) {
|
||||||
if (operand == null)
|
if (operand == null || operand.isEmpty())
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
AllocationPoint pointOperand = allocator.getAllocationPoint(operand);
|
val pointOperand = allocator.getAllocationPoint(operand);
|
||||||
pointOperand.releaseLock();
|
pointOperand.releaseLock();
|
||||||
eventsProvider.storeEvent(pointOperand.getLastReadEvent());
|
eventsProvider.storeEvent(pointOperand.getLastReadEvent());
|
||||||
pointOperand.setLastReadEvent(eventsProvider.getEvent());
|
pointOperand.setLastReadEvent(eventsProvider.getEvent());
|
||||||
|
@ -289,7 +290,7 @@ public class SynchronousFlowController implements FlowController {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public CudaContext prepareAction(AllocationPoint result, AllocationPoint... operands) {
|
public CudaContext prepareAction(AllocationPoint result, AllocationPoint... operands) {
|
||||||
CudaContext context = (CudaContext) allocator.getDeviceContext().getContext();
|
val context = (CudaContext) allocator.getDeviceContext().getContext();
|
||||||
|
|
||||||
if (result != null) {
|
if (result != null) {
|
||||||
result.acquireLock();
|
result.acquireLock();
|
||||||
|
@ -299,6 +300,7 @@ public class SynchronousFlowController implements FlowController {
|
||||||
for (AllocationPoint operand : operands) {
|
for (AllocationPoint operand : operands) {
|
||||||
if (operand == null)
|
if (operand == null)
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
operand.acquireLock();
|
operand.acquireLock();
|
||||||
operand.setCurrentContext(context);
|
operand.setCurrentContext(context);
|
||||||
}
|
}
|
||||||
|
@ -313,15 +315,16 @@ public class SynchronousFlowController implements FlowController {
|
||||||
|
|
||||||
protected void prepareDelayedMemory(INDArray array) {
|
protected void prepareDelayedMemory(INDArray array) {
|
||||||
if (configuration.getMemoryModel() == Configuration.MemoryModel.DELAYED) {
|
if (configuration.getMemoryModel() == Configuration.MemoryModel.DELAYED) {
|
||||||
AllocationPoint pointData = allocator.getAllocationPoint(array.shapeInfoDataBuffer());
|
val pointData = allocator.getAllocationPoint(array.shapeInfoDataBuffer());
|
||||||
AllocationPoint pointShape = allocator.getAllocationPoint(array.shapeInfoDataBuffer());
|
val pointShape = allocator.getAllocationPoint(array.shapeInfoDataBuffer());
|
||||||
|
|
||||||
if (pointData.getAllocationStatus() != AllocationStatus.DEVICE)
|
if (pointData.getAllocationStatus() != AllocationStatus.DEVICE)
|
||||||
prepareDelayedMemory(array.data());
|
prepareDelayedMemory(array.data());
|
||||||
|
|
||||||
if (pointShape.getAllocationStatus() == AllocationStatus.HOST) {
|
if (pointShape.getAllocationStatus() == AllocationStatus.HOST) {
|
||||||
DataBuffer oShape = array.shapeInfoDataBuffer();
|
val oShape = array.shapeInfoDataBuffer();
|
||||||
DataBuffer nShape = Nd4j.getConstantHandler().relocateConstantSpace(oShape);
|
val nShape = Nd4j.getConstantHandler().relocateConstantSpace(oShape);
|
||||||
|
|
||||||
if (nShape == oShape)
|
if (nShape == oShape)
|
||||||
Nd4j.getConstantHandler().moveToConstantSpace(nShape);
|
Nd4j.getConstantHandler().moveToConstantSpace(nShape);
|
||||||
((JCublasNDArray) array).setShapeInfoDataBuffer(nShape);
|
((JCublasNDArray) array).setShapeInfoDataBuffer(nShape);
|
||||||
|
|
|
@ -567,6 +567,11 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
|
||||||
return allocationPoint.getPointers().getHostPointer().address();
|
return allocationPoint.getPointers().getHostPointer().address();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long platformAddress() {
|
||||||
|
return allocationPoint.getPointers().getDevicePointer().address();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Pointer pointer() {
|
public Pointer pointer() {
|
||||||
// FIXME: very bad thing,
|
// FIXME: very bad thing,
|
||||||
|
|
|
@ -26,6 +26,7 @@ import org.bytedeco.javacpp.*;
|
||||||
import org.bytedeco.javacpp.indexer.LongIndexer;
|
import org.bytedeco.javacpp.indexer.LongIndexer;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
|
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.jita.allocator.impl.AllocationPoint;
|
import org.nd4j.jita.allocator.impl.AllocationPoint;
|
||||||
import org.nd4j.jita.allocator.impl.AtomicAllocator;
|
import org.nd4j.jita.allocator.impl.AtomicAllocator;
|
||||||
import org.nd4j.jita.allocator.pointers.CudaPointer;
|
import org.nd4j.jita.allocator.pointers.CudaPointer;
|
||||||
|
@ -515,7 +516,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
}
|
}
|
||||||
|
|
||||||
// in case of regular accumulation we don't care about array state before op
|
// in case of regular accumulation we don't care about array state before op
|
||||||
ret = Nd4j.createUninitialized(dtype, retShape);
|
ret = Nd4j.create(dtype, retShape);
|
||||||
}
|
}
|
||||||
op.setZ(ret);
|
op.setZ(ret);
|
||||||
} else {
|
} else {
|
||||||
|
@ -536,11 +537,16 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
@Override
|
@Override
|
||||||
public INDArray exec(IndexAccumulation op) {
|
public INDArray exec(IndexAccumulation op) {
|
||||||
val dimension = Shape.normalizeAxis(op.x().rank(), op.dimensions().toIntVector());
|
val dimension = Shape.normalizeAxis(op.x().rank(), op.dimensions().toIntVector());
|
||||||
if (op.z() == null) {
|
|
||||||
long[] retShape = Shape.reductionShape(op.x(), dimension, true, op.isKeepDims());
|
|
||||||
|
|
||||||
INDArray ret = Nd4j.createUninitialized(DataType.LONG, retShape);
|
if (op.x().isEmpty()) {
|
||||||
op.setZ(ret);
|
for (val d:dimension) {
|
||||||
|
Preconditions.checkArgument(op.x().shape()[d] != 0, "IndexReduce can't be issued along axis with 0 in shape");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (op.z() == null) {
|
||||||
|
val retShape = Shape.reductionShape(op.x(), dimension, true, op.isKeepDims());
|
||||||
|
op.setZ(Nd4j.createUninitialized(DataType.LONG, retShape));
|
||||||
}
|
}
|
||||||
|
|
||||||
long st = profilingConfigurableHookIn(op);
|
long st = profilingConfigurableHookIn(op);
|
||||||
|
@ -556,10 +562,13 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
return op.x();
|
return op.x();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (op.z().isEmpty())
|
||||||
|
return op.z();
|
||||||
|
|
||||||
if (CudaEnvironment.getInstance().getConfiguration().isDebug())
|
if (CudaEnvironment.getInstance().getConfiguration().isDebug())
|
||||||
lastOp.set(op.opName());
|
lastOp.set(op.opName());
|
||||||
|
|
||||||
CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
|
val context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
|
||||||
|
|
||||||
val hostXShapeInfo =
|
val hostXShapeInfo =
|
||||||
op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer());
|
op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer());
|
||||||
|
|
|
@ -619,9 +619,6 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory {
|
||||||
"Illegal concatenation at array " + i + " and shape element " + j);
|
"Illegal concatenation at array " + i + " and shape element " + j);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
//log.info("Shape[{}]: {}", i, Arrays.toString(toConcat[i].shapeInfoDataBuffer().asInt()));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (allScalars) {
|
if (allScalars) {
|
||||||
|
@ -630,8 +627,6 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory {
|
||||||
outputShape[dimension] = sumAlongDim;
|
outputShape[dimension] = sumAlongDim;
|
||||||
}
|
}
|
||||||
|
|
||||||
//PointerPointer dummy = new PointerPointer(new Pointer[] {null});
|
|
||||||
|
|
||||||
INDArray ret = Nd4j.createUninitialized(toConcat[0].dataType(), outputShape, Nd4j.order());
|
INDArray ret = Nd4j.createUninitialized(toConcat[0].dataType(), outputShape, Nd4j.order());
|
||||||
|
|
||||||
nativeOps.concat(null, dimension, toConcat.length,
|
nativeOps.concat(null, dimension, toConcat.length,
|
||||||
|
@ -639,11 +634,9 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory {
|
||||||
null, null,
|
null, null,
|
||||||
ret.data().addressPointer(), (LongPointer) ret.shapeInfoDataBuffer().addressPointer(),
|
ret.data().addressPointer(), (LongPointer) ret.shapeInfoDataBuffer().addressPointer(),
|
||||||
null, null,
|
null, null,
|
||||||
//new PointerPointer(new Pointer[] {null}), new PointerPointer(new Pointer[] {null}));
|
|
||||||
null, null);
|
null, null);
|
||||||
|
|
||||||
return ret;
|
return ret;
|
||||||
// return super.concat(dimension,toConcat);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -31,6 +31,7 @@ import org.nd4j.autodiff.validation.OpValidation;
|
||||||
import org.nd4j.autodiff.validation.TestCase;
|
import org.nd4j.autodiff.validation.TestCase;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.CustomOp;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.indexaccum.IAMax;
|
import org.nd4j.linalg.api.ops.impl.indexaccum.IAMax;
|
||||||
import org.nd4j.linalg.api.ops.impl.indexaccum.IAMin;
|
import org.nd4j.linalg.api.ops.impl.indexaccum.IAMin;
|
||||||
|
@ -1083,7 +1084,7 @@ public class ReductionOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
final INDArray exec = Nd4j.matmul(keys, query, true, false, false)
|
final INDArray exec = Nd4j.matmul(keys, query, true, false, false)
|
||||||
.divi(Math.sqrt(keys.size(1)));
|
.divi(Math.sqrt(keys.size(1)));
|
||||||
Nd4j.exec(new SoftMax(exec, exec, 1));
|
Nd4j.exec((CustomOp) new SoftMax(exec, exec, 1));
|
||||||
final INDArray finalOut = Nd4j.matmul(values, exec).norm1();
|
final INDArray finalOut = Nd4j.matmul(values, exec).norm1();
|
||||||
|
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
|
@ -1111,7 +1112,7 @@ public class ReductionOpValidation extends BaseOpValidation {
|
||||||
final INDArray exec = Nd4j.matmul(keys, query, true, false, false)
|
final INDArray exec = Nd4j.matmul(keys, query, true, false, false)
|
||||||
.divi(Math.sqrt(keys.size(1)));
|
.divi(Math.sqrt(keys.size(1)));
|
||||||
exec.addi(mask.reshape(10, 3, 1).sub(1).muli(1e9));
|
exec.addi(mask.reshape(10, 3, 1).sub(1).muli(1e9));
|
||||||
Nd4j.exec(new SoftMax(exec, exec, 1));
|
Nd4j.exec((CustomOp) new SoftMax(exec, exec, 1));
|
||||||
final INDArray finalOut = Nd4j.matmul(values, exec).norm1();
|
final INDArray finalOut = Nd4j.matmul(values, exec).norm1();
|
||||||
|
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
|
@ -1141,7 +1142,7 @@ public class ReductionOpValidation extends BaseOpValidation {
|
||||||
final INDArray exec = Nd4j.matmul(keys, query, true, false, false)
|
final INDArray exec = Nd4j.matmul(keys, query, true, false, false)
|
||||||
.divi(Math.sqrt(keys.size(-2)));
|
.divi(Math.sqrt(keys.size(-2)));
|
||||||
exec.addi(Nd4j.tile(mask.reshape(2, 1, 3, 1), 1, 5, 1, 2).sub(1).muli(1e9));
|
exec.addi(Nd4j.tile(mask.reshape(2, 1, 3, 1), 1, 5, 1, 2).sub(1).muli(1e9));
|
||||||
Nd4j.exec(new SoftMax(exec, exec, -2));
|
Nd4j.exec((CustomOp) new SoftMax(exec, exec, -2));
|
||||||
final INDArray finalOut = Nd4j.matmul(values, exec).norm1();
|
final INDArray finalOut = Nd4j.matmul(values, exec).norm1();
|
||||||
|
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
|
@ -1169,7 +1170,7 @@ public class ReductionOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
final INDArray exec = Nd4j.matmul(keys, query, true, false, false)
|
final INDArray exec = Nd4j.matmul(keys, query, true, false, false)
|
||||||
.divi(Math.sqrt(keys.size(-2)));
|
.divi(Math.sqrt(keys.size(-2)));
|
||||||
Nd4j.exec(new SoftMax(exec, exec, -2));
|
Nd4j.exec((CustomOp) new SoftMax(exec, exec, -2));
|
||||||
final INDArray finalOut = Nd4j.matmul(values, exec).norm1();
|
final INDArray finalOut = Nd4j.matmul(values, exec).norm1();
|
||||||
|
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
|
@ -1249,7 +1250,7 @@ public class ReductionOpValidation extends BaseOpValidation {
|
||||||
final INDArray exec = Nd4j.matmul(keys, query, true, false, false)
|
final INDArray exec = Nd4j.matmul(keys, query, true, false, false)
|
||||||
.divi(Math.sqrt(keys.size(1)));
|
.divi(Math.sqrt(keys.size(1)));
|
||||||
exec.addi(mask.reshape(10, 3, 1).sub(1).muli(1e9));
|
exec.addi(mask.reshape(10, 3, 1).sub(1).muli(1e9));
|
||||||
Nd4j.exec(new SoftMax(exec, exec, 1));
|
Nd4j.exec((CustomOp) new SoftMax(exec, exec, 1));
|
||||||
final INDArray finalOut = Nd4j.matmul(values, exec).norm1();
|
final INDArray finalOut = Nd4j.matmul(values, exec).norm1();
|
||||||
|
|
||||||
for (char queryOrder : new char[]{'f', 'c'}) {
|
for (char queryOrder : new char[]{'f', 'c'}) {
|
||||||
|
|
|
@ -42,6 +42,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMax;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMin;
|
import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldMin;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt;
|
import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.*;
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.*;
|
||||||
|
@ -671,7 +672,7 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
//TODO SHOULDN'T THIS HAVE A DIMENSION ARG???
|
//TODO SHOULDN'T THIS HAVE A DIMENSION ARG???
|
||||||
t = sd.nn().softmax(in);
|
t = sd.nn().softmax(in);
|
||||||
ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut);
|
ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut);
|
||||||
tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new OldSoftMax(ia.dup())));
|
tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new SoftMax(ia.dup()))[0]);
|
||||||
break;
|
break;
|
||||||
case 24:
|
case 24:
|
||||||
t = sd.math().sqrt(in);
|
t = sd.math().sqrt(in);
|
||||||
|
|
|
@ -25,7 +25,7 @@ import org.junit.runner.RunWith;
|
||||||
import org.junit.runners.Parameterized;
|
import org.junit.runners.Parameterized;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.OldSoftMax;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh;
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh;
|
||||||
import org.nd4j.linalg.api.shape.Shape;
|
import org.nd4j.linalg.api.shape.Shape;
|
||||||
import org.nd4j.linalg.checkutil.NDArrayCreationUtil;
|
import org.nd4j.linalg.checkutil.NDArrayCreationUtil;
|
||||||
|
@ -60,7 +60,7 @@ public class LoneTest extends BaseNd4jTest {
|
||||||
System.out.println("Input transpose " + Shape.shapeToString(input.shapeInfo()));
|
System.out.println("Input transpose " + Shape.shapeToString(input.shapeInfo()));
|
||||||
INDArray output = Nd4j.create(DataType.DOUBLE, 10, 1);
|
INDArray output = Nd4j.create(DataType.DOUBLE, 10, 1);
|
||||||
System.out.println("Element wise stride of output " + output.elementWiseStride());
|
System.out.println("Element wise stride of output " + output.elementWiseStride());
|
||||||
Nd4j.getExecutioner().exec(new OldSoftMax(input, output));
|
Nd4j.getExecutioner().exec(new SoftMax(input, output));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -41,9 +41,9 @@ import org.nd4j.linalg.api.memory.enums.LearningPolicy;
|
||||||
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
|
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.BroadcastOp;
|
import org.nd4j.linalg.api.ops.BroadcastOp;
|
||||||
|
import org.nd4j.linalg.api.ops.CustomOp;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.Op;
|
import org.nd4j.linalg.api.ops.Op;
|
||||||
import org.nd4j.linalg.api.ops.custom.Flatten;
|
|
||||||
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
|
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
|
||||||
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
||||||
import org.nd4j.linalg.api.ops.impl.broadcast.*;
|
import org.nd4j.linalg.api.ops.impl.broadcast.*;
|
||||||
|
@ -72,13 +72,13 @@ import org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet;
|
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.comparison.Eps;
|
import org.nd4j.linalg.api.ops.impl.transforms.comparison.Eps;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpace;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpace;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.BinaryRelativeError;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.BinaryRelativeError;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.Set;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.Set;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.Axpy;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.Axpy;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.same.OldReverse;
|
import org.nd4j.linalg.api.ops.impl.transforms.same.OldReverse;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.same.Sign;
|
import org.nd4j.linalg.api.ops.impl.transforms.same.Sign;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh;
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.OldSoftMax;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftMaxDerivative;
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftMaxDerivative;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh;
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh;
|
||||||
import org.nd4j.linalg.api.shape.Shape;
|
import org.nd4j.linalg.api.shape.Shape;
|
||||||
|
@ -94,8 +94,6 @@ import org.nd4j.linalg.indexing.conditions.Conditions;
|
||||||
import org.nd4j.linalg.io.ClassPathResource;
|
import org.nd4j.linalg.io.ClassPathResource;
|
||||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.nd4j.linalg.profiler.OpProfiler;
|
|
||||||
import org.nd4j.linalg.profiler.ProfilerConfig;
|
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
import org.nd4j.linalg.util.MathUtils;
|
import org.nd4j.linalg.util.MathUtils;
|
||||||
|
|
||||||
|
@ -2919,7 +2917,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
System.out.println("Input transpose " + Shape.shapeToString(input.shapeInfo()));
|
System.out.println("Input transpose " + Shape.shapeToString(input.shapeInfo()));
|
||||||
INDArray output = Nd4j.create(10, 1);
|
INDArray output = Nd4j.create(10, 1);
|
||||||
System.out.println("Element wise stride of output " + output.elementWiseStride());
|
System.out.println("Element wise stride of output " + output.elementWiseStride());
|
||||||
Nd4j.getExecutioner().exec(new OldSoftMax(input, output));
|
Nd4j.getExecutioner().exec(new SoftMax(input, output));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -3134,7 +3132,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
public void testSoftmaxRow() {
|
public void testSoftmaxRow() {
|
||||||
for (int i = 0; i < 20; i++) {
|
for (int i = 0; i < 20; i++) {
|
||||||
INDArray arr1 = Nd4j.zeros(1, 100);
|
INDArray arr1 = Nd4j.zeros(1, 100);
|
||||||
Nd4j.getExecutioner().execAndReturn(new OldSoftMax(arr1));
|
Nd4j.getExecutioner().execAndReturn(new SoftMax(arr1));
|
||||||
System.out.println(Arrays.toString(arr1.data().asFloat()));
|
System.out.println(Arrays.toString(arr1.data().asFloat()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3779,7 +3777,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
|
|
||||||
for (int i = 0; i < 3; i++) {
|
for (int i = 0; i < 3; i++) {
|
||||||
INDArray subset = result12.tensorAlongDimension(i, 1, 2);//result12.get(NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.all());
|
INDArray subset = result12.tensorAlongDimension(i, 1, 2);//result12.get(NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.all());
|
||||||
assertEquals("Failed for subset " + i, bc12, subset);
|
assertEquals("Failed for subset [" + i + "] orders [" + orderArr + "/" + orderbc + "]", bc12, subset);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -5725,9 +5723,9 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
val reference = original.dup(original.ordering());
|
val reference = original.dup(original.ordering());
|
||||||
val expected = original.dup(original.ordering());
|
val expected = original.dup(original.ordering());
|
||||||
|
|
||||||
Nd4j.getExecutioner().execAndReturn(new OldSoftMax(expected));
|
Nd4j.getExecutioner().execAndReturn((CustomOp) new SoftMax(expected, expected, -1));
|
||||||
|
|
||||||
val result = Nd4j.getExecutioner().exec(new OldSoftMax(original, original.dup(original.ordering())));
|
val result = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(original, original.dup(original.ordering())))[0];
|
||||||
|
|
||||||
assertEquals(reference, original);
|
assertEquals(reference, original);
|
||||||
assertEquals(expected, result);
|
assertEquals(expected, result);
|
||||||
|
|
|
@ -23,12 +23,12 @@ import org.junit.runner.RunWith;
|
||||||
import org.junit.runners.Parameterized;
|
import org.junit.runners.Parameterized;
|
||||||
import org.nd4j.linalg.BaseNd4jTest;
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.CustomOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.indexaccum.IMax;
|
import org.nd4j.linalg.api.ops.impl.indexaccum.IMax;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance;
|
import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt;
|
import org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.OldLogSoftMax;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.OldSoftMax;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftMaxDerivative;
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftMaxDerivative;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
|
@ -158,9 +158,9 @@ public class CrashTest extends BaseNd4jTest {
|
||||||
|
|
||||||
|
|
||||||
// logisoftmax, softmax & softmax derivative
|
// logisoftmax, softmax & softmax derivative
|
||||||
Nd4j.getExecutioner().exec(new OldSoftMax(x));
|
Nd4j.getExecutioner().exec((CustomOp) new SoftMax(x));
|
||||||
Nd4j.getExecutioner().exec(new SoftMaxDerivative(x));
|
Nd4j.getExecutioner().exec((CustomOp) new SoftMaxDerivative(x));
|
||||||
Nd4j.getExecutioner().exec(new OldLogSoftMax(x));
|
Nd4j.getExecutioner().exec((CustomOp) new LogSoftMax(x));
|
||||||
|
|
||||||
|
|
||||||
// BooleanIndexing
|
// BooleanIndexing
|
||||||
|
|
|
@ -30,12 +30,13 @@ import org.nd4j.linalg.api.memory.enums.LearningPolicy;
|
||||||
import org.nd4j.linalg.api.memory.enums.MirroringPolicy;
|
import org.nd4j.linalg.api.memory.enums.MirroringPolicy;
|
||||||
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
|
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.CustomOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.bool.IsInf;
|
import org.nd4j.linalg.api.ops.impl.reduce.bool.IsInf;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.bool.IsNaN;
|
import org.nd4j.linalg.api.ops.impl.reduce.bool.IsNaN;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.longer.CountNonZero;
|
import org.nd4j.linalg.api.ops.impl.reduce.longer.CountNonZero;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity;
|
import org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldEqualTo;
|
import org.nd4j.linalg.api.ops.impl.transforms.comparison.OldEqualTo;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.OldSoftMax;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
||||||
import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper;
|
import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
|
@ -334,7 +335,7 @@ public class MixedDataTypesTests extends BaseNd4jTest {
|
||||||
public void testTypesValidation_3() {
|
public void testTypesValidation_3() {
|
||||||
val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT);
|
val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT);
|
||||||
|
|
||||||
val result = Nd4j.getExecutioner().exec(new OldSoftMax(arrayX));
|
val result = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(arrayX, arrayX, -1));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testTypesValidation_4() {
|
public void testTypesValidation_4() {
|
||||||
|
|
|
@ -25,7 +25,9 @@ import org.junit.runners.Parameterized;
|
||||||
import org.nd4j.linalg.BaseNd4jTest;
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.CustomOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.scalar.Step;
|
import org.nd4j.linalg.api.ops.impl.scalar.Step;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.*;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.*;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.*;
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.*;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.SigmoidDerivative;
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.SigmoidDerivative;
|
||||||
|
@ -217,8 +219,8 @@ public class DerivativeTests extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
INDArray sm = Nd4j.getExecutioner().exec(new OldSoftMax(z.dup()));
|
INDArray sm = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(z.dup()))[0];
|
||||||
INDArray zPrime = Nd4j.getExecutioner().exec(new SoftMaxDerivative(z));
|
INDArray zPrime = Nd4j.getExecutioner().exec((CustomOp) new SoftMaxDerivative(z))[0];
|
||||||
System.out.println(Arrays.toString(sm.data().asDouble()));
|
System.out.println(Arrays.toString(sm.data().asDouble()));
|
||||||
System.out.println(Arrays.toString(zPrime.data().asDouble()));
|
System.out.println(Arrays.toString(zPrime.data().asDouble()));
|
||||||
assertNotEquals(sm, zPrime);
|
assertNotEquals(sm, zPrime);
|
||||||
|
@ -396,7 +398,7 @@ public class DerivativeTests extends BaseNd4jTest {
|
||||||
//random array represeting preout
|
//random array represeting preout
|
||||||
INDArray X = Nd4j.rand(1, 2);
|
INDArray X = Nd4j.rand(1, 2);
|
||||||
//preout transformed to y_hat with softmax
|
//preout transformed to y_hat with softmax
|
||||||
INDArray YHat = Nd4j.getExecutioner().exec(new OldSoftMax(X.dup()));
|
INDArray YHat = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(X.dup()))[0];
|
||||||
|
|
||||||
//hard coding something to construct a function with, using MSE
|
//hard coding something to construct a function with, using MSE
|
||||||
INDArray Y = Nd4j.create(new double[][] {{0.123, 1 - 0.123}});
|
INDArray Y = Nd4j.create(new double[][] {{0.123, 1 - 0.123}});
|
||||||
|
@ -404,7 +406,7 @@ public class DerivativeTests extends BaseNd4jTest {
|
||||||
//This is the MSE now
|
//This is the MSE now
|
||||||
double lossHere = Transforms.pow(Y.sub(YHat), 2).sumNumber().doubleValue();
|
double lossHere = Transforms.pow(Y.sub(YHat), 2).sumNumber().doubleValue();
|
||||||
|
|
||||||
INDArray softmaxDer = Nd4j.getExecutioner().exec(new SoftMaxDerivative(X.dup()));
|
INDArray softmaxDer = Nd4j.getExecutioner().exec((CustomOp) new SoftMaxDerivative(X.dup()))[0];
|
||||||
|
|
||||||
//the way we apply the chain rule now is 2*(y-yhat)*softmaxder
|
//the way we apply the chain rule now is 2*(y-yhat)*softmaxder
|
||||||
INDArray dLdY = Y.sub(YHat).mul(-2);
|
INDArray dLdY = Y.sub(YHat).mul(-2);
|
||||||
|
@ -444,13 +446,13 @@ public class DerivativeTests extends BaseNd4jTest {
|
||||||
double x = X.getDouble(0, i);
|
double x = X.getDouble(0, i);
|
||||||
Xiplus = X.dup();
|
Xiplus = X.dup();
|
||||||
Xiplus.put(0, i, x + epsilon);
|
Xiplus.put(0, i, x + epsilon);
|
||||||
YHatplus = Nd4j.getExecutioner().exec(new OldSoftMax(Xiplus.dup()));
|
YHatplus = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(Xiplus.dup()))[0];
|
||||||
lossplus = Transforms.pow(Y.sub(YHatplus), 2).sumNumber().doubleValue();
|
lossplus = Transforms.pow(Y.sub(YHatplus), 2).sumNumber().doubleValue();
|
||||||
|
|
||||||
// -epsilon
|
// -epsilon
|
||||||
Ximinus = X.dup();
|
Ximinus = X.dup();
|
||||||
Ximinus.put(0, i, x - epsilon);
|
Ximinus.put(0, i, x - epsilon);
|
||||||
YHatminus = Nd4j.getExecutioner().exec(new OldSoftMax(Ximinus.dup()));
|
YHatminus = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(Ximinus.dup()))[0];
|
||||||
lossminus = Transforms.pow(Y.sub(YHatminus), 2).sumNumber().doubleValue();
|
lossminus = Transforms.pow(Y.sub(YHatminus), 2).sumNumber().doubleValue();
|
||||||
|
|
||||||
double gradienti = (lossplus - lossminus) / (2 * epsilon);
|
double gradienti = (lossplus - lossminus) / (2 * epsilon);
|
||||||
|
@ -479,16 +481,16 @@ public class DerivativeTests extends BaseNd4jTest {
|
||||||
|
|
||||||
INDArray X = Nd4j.rand(1, someLength);
|
INDArray X = Nd4j.rand(1, someLength);
|
||||||
//preout transformed to y_hat with softmax
|
//preout transformed to y_hat with softmax
|
||||||
INDArray YHat = Nd4j.getExecutioner().exec(new OldSoftMax(X.dup()));
|
INDArray YHat = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(X.dup()))[0];
|
||||||
|
|
||||||
//hard coding something to construct a function with, using MSE
|
//hard coding something to construct a function with, using MSE
|
||||||
INDArray temp = Nd4j.rand(1, someLength);
|
INDArray temp = Nd4j.rand(1, someLength);
|
||||||
INDArray Y = Nd4j.getExecutioner().exec(new OldSoftMax(temp));
|
INDArray Y = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(temp))[0];
|
||||||
|
|
||||||
//This is the MSE now
|
//This is the MSE now
|
||||||
double lossHere = Transforms.pow(Y.sub(YHat), 2).sumNumber().doubleValue();
|
double lossHere = Transforms.pow(Y.sub(YHat), 2).sumNumber().doubleValue();
|
||||||
|
|
||||||
INDArray softmaxDer = Nd4j.getExecutioner().exec(new SoftMaxDerivative(X.dup()));
|
INDArray softmaxDer = Nd4j.getExecutioner().exec((CustomOp) new SoftMaxDerivative(X.dup()))[0];
|
||||||
|
|
||||||
//the way we apply the chain rule now is 2*(y-yhat)*softmaxder
|
//the way we apply the chain rule now is 2*(y-yhat)*softmaxder
|
||||||
INDArray dLdY = Y.sub(YHat).mul(-2);
|
INDArray dLdY = Y.sub(YHat).mul(-2);
|
||||||
|
@ -511,13 +513,13 @@ public class DerivativeTests extends BaseNd4jTest {
|
||||||
double x = X.getDouble(0, i);
|
double x = X.getDouble(0, i);
|
||||||
Xiplus = X.dup();
|
Xiplus = X.dup();
|
||||||
Xiplus.put(0, i, x + epsilon);
|
Xiplus.put(0, i, x + epsilon);
|
||||||
YHatplus = Nd4j.getExecutioner().exec(new OldSoftMax(Xiplus.dup()));
|
YHatplus = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(Xiplus.dup()))[0];
|
||||||
lossplus = Transforms.pow(Y.sub(YHatplus), 2).sumNumber().doubleValue();
|
lossplus = Transforms.pow(Y.sub(YHatplus), 2).sumNumber().doubleValue();
|
||||||
|
|
||||||
// -epsilon
|
// -epsilon
|
||||||
Ximinus = X.dup();
|
Ximinus = X.dup();
|
||||||
Ximinus.put(0, i, x - epsilon);
|
Ximinus.put(0, i, x - epsilon);
|
||||||
YHatminus = Nd4j.getExecutioner().exec(new OldSoftMax(Ximinus.dup()));
|
YHatminus = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(Ximinus.dup()))[0];
|
||||||
lossminus = Transforms.pow(Y.sub(YHatminus), 2).sumNumber().doubleValue();
|
lossminus = Transforms.pow(Y.sub(YHatminus), 2).sumNumber().doubleValue();
|
||||||
|
|
||||||
double gradienti = (lossplus - lossminus) / (2 * epsilon);
|
double gradienti = (lossplus - lossminus) / (2 * epsilon);
|
||||||
|
@ -538,14 +540,14 @@ public class DerivativeTests extends BaseNd4jTest {
|
||||||
// this is only for X a row vector
|
// this is only for X a row vector
|
||||||
// should return rank 2 matrix diagonal elements are pi*(1-pi)
|
// should return rank 2 matrix diagonal elements are pi*(1-pi)
|
||||||
//rest are -pi*pj
|
//rest are -pi*pj
|
||||||
INDArray p = Nd4j.getExecutioner().exec(new OldSoftMax(X.dup()));
|
INDArray p = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(X.dup()))[0];
|
||||||
INDArray pCol = p.dup().transpose();
|
INDArray pCol = p.dup().transpose();
|
||||||
INDArray pipj = pCol.mmul(p);
|
INDArray pipj = pCol.mmul(p);
|
||||||
pipj.muli(-1);
|
pipj.muli(-1);
|
||||||
|
|
||||||
//so now pipj is correct except for the diagonal elements
|
//so now pipj is correct except for the diagonal elements
|
||||||
// which by the way is what our current softmax der gives us
|
// which by the way is what our current softmax der gives us
|
||||||
INDArray diagp = Nd4j.getExecutioner().exec(new SoftMaxDerivative(X.dup()));
|
INDArray diagp = Nd4j.getExecutioner().exec((CustomOp) new SoftMaxDerivative(X.dup()))[0];
|
||||||
|
|
||||||
|
|
||||||
//ugly for loop to correct diag elements
|
//ugly for loop to correct diag elements
|
||||||
|
|
|
@ -24,6 +24,7 @@ import org.junit.runners.Parameterized;
|
||||||
import org.nd4j.linalg.BaseNd4jTest;
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.CustomOp;
|
||||||
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
||||||
import org.nd4j.linalg.api.ops.impl.indexaccum.IAMax;
|
import org.nd4j.linalg.api.ops.impl.indexaccum.IAMax;
|
||||||
import org.nd4j.linalg.api.ops.impl.indexaccum.IMax;
|
import org.nd4j.linalg.api.ops.impl.indexaccum.IMax;
|
||||||
|
@ -42,11 +43,11 @@ import org.nd4j.linalg.api.ops.impl.scalar.ScalarMax;
|
||||||
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThan;
|
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThan;
|
||||||
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThan;
|
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThan;
|
||||||
import org.nd4j.linalg.api.ops.impl.summarystats.Variance;
|
import org.nd4j.linalg.api.ops.impl.summarystats.Variance;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.Exp;
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.Exp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.Log;
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.Log;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.OldSoftMax;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.SetRange;
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.SetRange;
|
||||||
import org.nd4j.linalg.api.ops.random.impl.DropOut;
|
import org.nd4j.linalg.api.ops.random.impl.DropOut;
|
||||||
import org.nd4j.linalg.api.ops.random.impl.DropOutInverted;
|
import org.nd4j.linalg.api.ops.random.impl.DropOutInverted;
|
||||||
|
@ -304,11 +305,11 @@ public class OpExecutionerTests extends BaseNd4jTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRowSoftmax() {
|
public void testRowSoftmax() {
|
||||||
OpExecutioner opExecutioner = Nd4j.getExecutioner();
|
val opExecutioner = Nd4j.getExecutioner();
|
||||||
INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1);
|
val arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1);
|
||||||
OldSoftMax softMax = new OldSoftMax(arr);
|
val softMax = new SoftMax(arr);
|
||||||
opExecutioner.exec(softMax);
|
opExecutioner.exec((CustomOp) softMax);
|
||||||
assertEquals(getFailureMessage(), 1.0, softMax.z().sumNumber().doubleValue(), 1e-1);
|
assertEquals(getFailureMessage(), 1.0, softMax.outputArguments()[0].sumNumber().doubleValue(), 1e-1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -373,7 +374,7 @@ public class OpExecutionerTests extends BaseNd4jTest {
|
||||||
public void testSoftmax() {
|
public void testSoftmax() {
|
||||||
INDArray vec = Nd4j.linspace(1, 6, 6, DataType.DOUBLE);
|
INDArray vec = Nd4j.linspace(1, 6, 6, DataType.DOUBLE);
|
||||||
INDArray matrix = vec.dup().reshape('f', 2, 3);
|
INDArray matrix = vec.dup().reshape('f', 2, 3);
|
||||||
Nd4j.getExecutioner().exec(new OldSoftMax(matrix));
|
Nd4j.getExecutioner().exec((CustomOp) new SoftMax(matrix));
|
||||||
INDArray matrixAssertion = Nd4j.create(
|
INDArray matrixAssertion = Nd4j.create(
|
||||||
new double[] {0.015876241, 0.015876241, 0.11731043, 0.11731043, 0.86681336, 0.86681336},
|
new double[] {0.015876241, 0.015876241, 0.11731043, 0.11731043, 0.86681336, 0.86681336},
|
||||||
new int[] {2, 3}, 'f');
|
new int[] {2, 3}, 'f');
|
||||||
|
@ -384,7 +385,7 @@ public class OpExecutionerTests extends BaseNd4jTest {
|
||||||
public void testOtherSoftmax() {
|
public void testOtherSoftmax() {
|
||||||
INDArray vec = Nd4j.linspace(1, 18, 18, DataType.DOUBLE);
|
INDArray vec = Nd4j.linspace(1, 18, 18, DataType.DOUBLE);
|
||||||
INDArray matrix = vec.dup().reshape('f', 3, 6);
|
INDArray matrix = vec.dup().reshape('f', 3, 6);
|
||||||
Nd4j.getExecutioner().exec(new OldSoftMax(matrix));
|
Nd4j.getExecutioner().exec((CustomOp) new SoftMax(matrix));
|
||||||
INDArray assertion = Nd4j.create(new double[] {2.9067235E-7, 2.9067235E-7, 2.9067235E-7, 5.8383102E-6,
|
INDArray assertion = Nd4j.create(new double[] {2.9067235E-7, 2.9067235E-7, 2.9067235E-7, 5.8383102E-6,
|
||||||
5.8383102E-6, 5.8383102E-6, 1.1726559E-4, 1.1726559E-4, 1.1726559E-4, 0.0023553425,
|
5.8383102E-6, 5.8383102E-6, 1.1726559E-4, 1.1726559E-4, 1.1726559E-4, 0.0023553425,
|
||||||
0.0023553425, 0.0023553425, 0.047308315, 0.047308315, 0.047308315, 0.95021296, 0.95021296,
|
0.0023553425, 0.0023553425, 0.047308315, 0.047308315, 0.047308315, 0.95021296, 0.95021296,
|
||||||
|
@ -517,9 +518,9 @@ public class OpExecutionerTests extends BaseNd4jTest {
|
||||||
0.3049033, 0.29277474, 0.29136384, 0.30316526, 0.2807459}, new int[] {150, 3}, 'f');
|
0.3049033, 0.29277474, 0.29136384, 0.30316526, 0.2807459}, new int[] {150, 3}, 'f');
|
||||||
|
|
||||||
System.out.println("Data:" + input.data().length());
|
System.out.println("Data:" + input.data().length());
|
||||||
OldSoftMax softMax = new OldSoftMax(input);
|
val softMax = new SoftMax(input);
|
||||||
Nd4j.getExecutioner().exec(softMax);
|
Nd4j.getExecutioner().exec((CustomOp) softMax);
|
||||||
assertEquals(assertion, softMax.z());
|
assertEquals(assertion, softMax.outputArguments()[0]);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -557,9 +558,9 @@ public class OpExecutionerTests extends BaseNd4jTest {
|
||||||
public void testSoftMax() {
|
public void testSoftMax() {
|
||||||
OpExecutioner opExecutioner = Nd4j.getExecutioner();
|
OpExecutioner opExecutioner = Nd4j.getExecutioner();
|
||||||
INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1);
|
INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1);
|
||||||
OldSoftMax softMax = new OldSoftMax(arr);
|
val softMax = new SoftMax(arr);
|
||||||
opExecutioner.exec(softMax);
|
opExecutioner.exec((CustomOp) softMax);
|
||||||
assertEquals(getFailureMessage(), 1.0, softMax.z().sumNumber().doubleValue(), 1e-1);
|
assertEquals(getFailureMessage(), 1.0, softMax.outputArguments()[0].sumNumber().doubleValue(), 1e-1);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -29,6 +29,7 @@ import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
|
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
|
||||||
import org.nd4j.linalg.api.iter.NdIndexIterator;
|
import org.nd4j.linalg.api.iter.NdIndexIterator;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.CustomOp;
|
||||||
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
||||||
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
|
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.indexaccum.IMax;
|
import org.nd4j.linalg.api.ops.impl.indexaccum.IMax;
|
||||||
|
@ -50,12 +51,12 @@ import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThan;
|
||||||
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThan;
|
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThan;
|
||||||
import org.nd4j.linalg.api.ops.impl.summarystats.Variance;
|
import org.nd4j.linalg.api.ops.impl.summarystats.Variance;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.floating.Histogram;
|
import org.nd4j.linalg.api.ops.impl.transforms.floating.Histogram;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.Exp;
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.Exp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.Log;
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.Log;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.OldSoftMax;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.SetRange;
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.SetRange;
|
||||||
import org.nd4j.linalg.api.shape.Shape;
|
import org.nd4j.linalg.api.shape.Shape;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -96,9 +97,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest {
|
||||||
public void testSoftmaxReference() {
|
public void testSoftmaxReference() {
|
||||||
INDArray input = Nd4j.linspace(1,4,4, DataType.FLOAT).reshape(2,2);
|
INDArray input = Nd4j.linspace(1,4,4, DataType.FLOAT).reshape(2,2);
|
||||||
INDArray dup = input.dup();
|
INDArray dup = input.dup();
|
||||||
Nd4j.getExecutioner().exec(new OldSoftMax(dup));
|
Nd4j.getExecutioner().exec((CustomOp) new SoftMax(dup));
|
||||||
INDArray result = Nd4j.zeros(DataType.FLOAT, 2,2);
|
INDArray result = Nd4j.zeros(DataType.FLOAT, 2,2);
|
||||||
Nd4j.getExecutioner().exec(new OldSoftMax(input,result));
|
Nd4j.getExecutioner().exec((CustomOp) new SoftMax(input,result));
|
||||||
assertEquals(dup,result);
|
assertEquals(dup,result);
|
||||||
|
|
||||||
|
|
||||||
|
@ -322,9 +323,9 @@ public class OpExecutionerTestsC extends BaseNd4jTest {
|
||||||
public void testRowSoftmax() {
|
public void testRowSoftmax() {
|
||||||
OpExecutioner opExecutioner = Nd4j.getExecutioner();
|
OpExecutioner opExecutioner = Nd4j.getExecutioner();
|
||||||
INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1);
|
INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1);
|
||||||
OldSoftMax softMax = new OldSoftMax(arr);
|
val softMax = new SoftMax(arr);
|
||||||
opExecutioner.exec(softMax);
|
opExecutioner.exec((CustomOp) softMax);
|
||||||
assertEquals(getFailureMessage(), 1.0, softMax.z().sumNumber().doubleValue(), 1e-1);
|
assertEquals(getFailureMessage(), 1.0, softMax.outputArguments()[0].sumNumber().doubleValue(), 1e-1);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -422,23 +423,23 @@ public class OpExecutionerTestsC extends BaseNd4jTest {
|
||||||
public void testSoftMax() {
|
public void testSoftMax() {
|
||||||
OpExecutioner opExecutioner = Nd4j.getExecutioner();
|
OpExecutioner opExecutioner = Nd4j.getExecutioner();
|
||||||
INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1);
|
INDArray arr = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(1, -1);
|
||||||
OldSoftMax softMax = new OldSoftMax(arr);
|
val softMax = new SoftMax(arr);
|
||||||
opExecutioner.exec(softMax);
|
opExecutioner.exec((CustomOp) softMax);
|
||||||
assertEquals(getFailureMessage(), 1.0, softMax.z().sumNumber().doubleValue(), 1e-1);
|
assertEquals(getFailureMessage(), 1.0, softMax.outputArguments()[0].sumNumber().doubleValue(), 1e-1);
|
||||||
|
|
||||||
INDArray linspace = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3);
|
INDArray linspace = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3);
|
||||||
OldSoftMax softmax = new OldSoftMax(linspace.dup());
|
val softmax = new SoftMax(linspace.dup());
|
||||||
Nd4j.getExecutioner().exec(softmax);
|
Nd4j.getExecutioner().exec((CustomOp) softmax);
|
||||||
assertEquals(linspace.rows(), softmax.z().sumNumber().doubleValue(), 1e-1);
|
assertEquals(linspace.rows(), softmax.outputArguments()[0].sumNumber().doubleValue(), 1e-1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testDimensionSoftMax() {
|
public void testDimensionSoftMax() {
|
||||||
INDArray linspace = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3);
|
INDArray linspace = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3);
|
||||||
OldSoftMax max = new OldSoftMax(linspace);
|
val max = new SoftMax(linspace);
|
||||||
Nd4j.getExecutioner().exec(max);
|
Nd4j.getExecutioner().exec((CustomOp) max);
|
||||||
linspace.assign(max.z());
|
linspace.assign(max.outputArguments()[0]);
|
||||||
assertEquals(getFailureMessage(), linspace.getRow(0).sumNumber().doubleValue(), 1.0, 1e-1);
|
assertEquals(getFailureMessage(), linspace.getRow(0).sumNumber().doubleValue(), 1.0, 1e-1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -782,7 +783,7 @@ public class OpExecutionerTestsC extends BaseNd4jTest {
|
||||||
public void testSoftmax() {
|
public void testSoftmax() {
|
||||||
INDArray vec = Nd4j.linspace(1, 18, 18, DataType.DOUBLE);
|
INDArray vec = Nd4j.linspace(1, 18, 18, DataType.DOUBLE);
|
||||||
INDArray matrix = vec.dup().reshape(3, 6);
|
INDArray matrix = vec.dup().reshape(3, 6);
|
||||||
Nd4j.getExecutioner().exec(new OldSoftMax(matrix));
|
Nd4j.getExecutioner().exec((CustomOp) new SoftMax(matrix));
|
||||||
INDArray assertion = Nd4j.create(
|
INDArray assertion = Nd4j.create(
|
||||||
new double[] {0.0042697787, 0.011606461, 0.031549633, 0.085760795, 0.23312202, 0.6336913,
|
new double[] {0.0042697787, 0.011606461, 0.031549633, 0.085760795, 0.23312202, 0.6336913,
|
||||||
0.0042697787, 0.011606461, 0.031549633, 0.085760795, 0.23312202, 0.6336913,
|
0.0042697787, 0.011606461, 0.031549633, 0.085760795, 0.23312202, 0.6336913,
|
||||||
|
|
|
@ -381,30 +381,32 @@ public class OperationProfilerTests extends BaseNd4jTest {
|
||||||
INDArray x = Nd4j.create(1000, 1000).assign(1.0);
|
INDArray x = Nd4j.create(1000, 1000).assign(1.0);
|
||||||
INDArray y = Nd4j.create(1000, 1000).assign(1.0);
|
INDArray y = Nd4j.create(1000, 1000).assign(1.0);
|
||||||
|
|
||||||
for (int e = 0; e < 10000; e++) {
|
int iterations = 100;
|
||||||
|
|
||||||
|
for (int e = 0; e < iterations; e++) {
|
||||||
x.addi(y);
|
x.addi(y);
|
||||||
}
|
}
|
||||||
|
|
||||||
Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.SCOPE_PANIC);
|
Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.SCOPE_PANIC);
|
||||||
|
|
||||||
val nanosC = System.nanoTime();
|
val nanosC = System.nanoTime();
|
||||||
for (int e = 0; e < 10000; e++) {
|
for (int e = 0; e < iterations; e++) {
|
||||||
x.addi(y);
|
x.addi(y);
|
||||||
}
|
}
|
||||||
val nanosD = System.nanoTime();
|
val nanosD = System.nanoTime();
|
||||||
|
|
||||||
val avgB = (nanosD - nanosC) / 10000;
|
val avgB = (nanosD - nanosC) / iterations;
|
||||||
|
|
||||||
|
|
||||||
Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.DISABLED);
|
Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.DISABLED);
|
||||||
|
|
||||||
val nanosA = System.nanoTime();
|
val nanosA = System.nanoTime();
|
||||||
for (int e = 0; e < 10000; e++) {
|
for (int e = 0; e < iterations; e++) {
|
||||||
x.addi(y);
|
x.addi(y);
|
||||||
}
|
}
|
||||||
val nanosB = System.nanoTime();
|
val nanosB = System.nanoTime();
|
||||||
|
|
||||||
val avgA = (nanosB - nanosA) / 10000;
|
val avgA = (nanosB - nanosA) / iterations;
|
||||||
|
|
||||||
|
|
||||||
log.info("A: {}; B: {}", avgA, avgB);
|
log.info("A: {}; B: {}", avgA, avgB);
|
||||||
|
|
|
@ -1429,15 +1429,17 @@ public class RandomTests extends BaseNd4jTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRngRepeatabilityUniform(){
|
public void testRngRepeatabilityUniform(){
|
||||||
|
val nexp = Nd4j.create(DataType.FLOAT, 10);
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
INDArray out1 = Nd4j.create(DataType.FLOAT, 10);
|
val out1 = Nd4j.create(DataType.FLOAT, 10);
|
||||||
Nd4j.exec(new DistributionUniform(Nd4j.createFromArray(10L), out1, 0.0, 1.0));
|
Nd4j.exec(new DistributionUniform(Nd4j.createFromArray(10L), out1, 0.0, 1.0));
|
||||||
|
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
INDArray out2 = Nd4j.create(DataType.FLOAT, 10);
|
val out2 = Nd4j.create(DataType.FLOAT, 10);
|
||||||
Nd4j.exec(new DistributionUniform(Nd4j.createFromArray(10L), out2, 0.0, 1.0));
|
Nd4j.exec(new DistributionUniform(Nd4j.createFromArray(10L), out2, 0.0, 1.0));
|
||||||
|
|
||||||
assertEquals(out1, out2);
|
assertEquals(out1, out2);
|
||||||
|
assertNotEquals(nexp, out1);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.shape;
|
package org.nd4j.linalg.shape;
|
||||||
|
|
||||||
|
import lombok.val;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.junit.runner.RunWith;
|
import org.junit.runner.RunWith;
|
||||||
import org.junit.runners.Parameterized;
|
import org.junit.runners.Parameterized;
|
||||||
|
@ -48,18 +49,17 @@ public class ShapeBufferTests extends BaseNd4jTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRank() {
|
public void testRank() {
|
||||||
int[] shape = {2, 4};
|
long[] shape = {2, 4};
|
||||||
int[] stride = {1, 2};
|
long[] stride = {1, 2};
|
||||||
IntBuffer buff = Shape.createShapeInformation(shape, stride, 0, 1, 'c').asNioInt();
|
val shapeInfoBuffer = Shape.createShapeInformation(shape, stride, 1, 'c', DataType.DOUBLE, false);
|
||||||
int rank = 2;
|
val buff = shapeInfoBuffer.asNioLong();
|
||||||
assertEquals(rank, Shape.rank(buff));
|
assertEquals(2, Shape.rank(buff));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testArrCreationShape() {
|
public void testArrCreationShape() {
|
||||||
INDArray arr = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
val arr = Nd4j.linspace(1, 4, 4, DataType.DOUBLE).reshape(2, 2);
|
||||||
for (int i = 0; i < 2; i++)
|
for (int i = 0; i < 2; i++)
|
||||||
assertEquals(2, arr.size(i));
|
assertEquals(2, arr.size(i));
|
||||||
int[] stride = ArrayUtil.calcStrides(new int[] {2, 2});
|
int[] stride = ArrayUtil.calcStrides(new int[] {2, 2});
|
||||||
|
@ -70,12 +70,13 @@ public class ShapeBufferTests extends BaseNd4jTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testShape() {
|
public void testShape() {
|
||||||
int[] shape = {2, 4};
|
long[] shape = {2, 4};
|
||||||
int[] stride = {1, 2};
|
long[] stride = {1, 2};
|
||||||
IntBuffer buff = Shape.createShapeInformation(shape, stride, 0, 1, 'c').asNioInt();
|
val shapeInfoBuffer = Shape.createShapeInformation(shape, stride, 1, 'c', DataType.DOUBLE, false);
|
||||||
IntBuffer shapeView = Shape.shapeOf(buff);
|
val buff = shapeInfoBuffer.asNioLong();
|
||||||
|
val shapeView = Shape.shapeOf(buff);
|
||||||
assertTrue(Shape.contentEquals(shape, shapeView));
|
assertTrue(Shape.contentEquals(shape, shapeView));
|
||||||
IntBuffer strideView = Shape.stride(buff);
|
val strideView = Shape.stride(buff);
|
||||||
assertTrue(Shape.contentEquals(stride, strideView));
|
assertTrue(Shape.contentEquals(stride, strideView));
|
||||||
assertEquals('c', Shape.order(buff));
|
assertEquals('c', Shape.order(buff));
|
||||||
assertEquals(1, Shape.elementWiseStride(buff));
|
assertEquals(1, Shape.elementWiseStride(buff));
|
||||||
|
@ -86,9 +87,9 @@ public class ShapeBufferTests extends BaseNd4jTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBuff() {
|
public void testBuff() {
|
||||||
int[] shape = {1, 2};
|
long[] shape = {1, 2};
|
||||||
int[] stride = {1, 2};
|
long[] stride = {1, 2};
|
||||||
IntBuffer buff = Shape.createShapeInformation(shape, stride, 0, 1, 'c').asNioInt();
|
val buff = Shape.createShapeInformation(shape, stride, 1, 'c', DataType.DOUBLE, false).asNioLong();
|
||||||
assertTrue(Shape.isVector(buff));
|
assertTrue(Shape.isVector(buff));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2661,4 +2661,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
||||||
this.indexer = null;
|
this.indexer = null;
|
||||||
this.pointer = null;
|
this.pointer = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long platformAddress() {
|
||||||
|
return address();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -84,6 +84,14 @@ public interface DataBuffer extends Serializable, AutoCloseable {
|
||||||
*/
|
*/
|
||||||
long address();
|
long address();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the address of platform-specific pointer:
|
||||||
|
* - for native backend that'll be host pointer
|
||||||
|
* - for cuda backend that'll be device pointer
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
long platformAddress();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns true if the underlying data source
|
* Returns true if the underlying data source
|
||||||
* is the same for both buffers (referential equals)
|
* is the same for both buffers (referential equals)
|
||||||
|
|
Loading…
Reference in New Issue