[WIP] More CUDA fixes/updates (#62)

* CUDA reallocation update

Signed-off-by: raver119 <raver119@gmail.com>

* Legacy SoftMax/LogSoftMax/SoftMaxDerivative removed from cpp

Signed-off-by: raver119 <raver119@gmail.com>

* SoftMaxDerivative op removed

Signed-off-by: raver119 <raver119@gmail.com>

* few tests updates

Signed-off-by: raver119 <raver119@gmail.com>

* RNG fixes

Signed-off-by: raver119 <raver119@gmail.com>

* few more tests updates

Signed-off-by: raver119 <raver119@gmail.com>

* legacy Histogram/Pooling2D removed

Signed-off-by: raver119 <raver119@gmail.com>

* legacy Histogram removed

Signed-off-by: raver119 <raver119@gmail.com>

* histogram moved

Signed-off-by: raver119 <raver119@gmail.com>

* histogram moved cuda

Signed-off-by: raver119 <raver119@gmail.com>

* Histogram custom op

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-07-16 18:48:40 +03:00 committed by AlexDBlack
parent 5cf6859fc4
commit fd6c0df024
34 changed files with 509 additions and 832 deletions

View File

@ -827,6 +827,9 @@ void NativeOpExecutioner::execRandom(nd4j::LaunchContext *lc,
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
BUILD_SINGLE_SELECTOR(zType, functions::random::RandomFunction, ::execTransform(opNum, state, hZ, hZShapeInfo, extraArguments), FLOAT_TYPES);
auto rng = reinterpret_cast<nd4j::graph::RandomGenerator*>(state);
rng->rewindH(shape::length(hZShapeInfo));
}
////////////////////////////////////////////////////////////////////////
@ -842,6 +845,9 @@ void NativeOpExecutioner::execRandom(nd4j::LaunchContext *lc,
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
BUILD_SINGLE_SELECTOR(zType, functions::random::RandomFunction, ::execTransform(opNum, state, hX, hXShapeInfo, hZ, hZShapeInfo, extraArguments), FLOAT_TYPES);
auto rng = reinterpret_cast<nd4j::graph::RandomGenerator*>(state);
rng->rewindH(shape::length(hZShapeInfo));
}
////////////////////////////////////////////////////////////////////////
@ -859,6 +865,9 @@ void NativeOpExecutioner::execRandom(nd4j::LaunchContext *lc,
auto xType = nd4j::ArrayOptions::dataType(hZShapeInfo);
BUILD_SINGLE_SELECTOR(xType, functions::random::RandomFunction, ::execTransform(opNum, state, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraArguments), FLOAT_TYPES);
auto rng = reinterpret_cast<nd4j::graph::RandomGenerator*>(state);
rng->rewindH(shape::length(hZShapeInfo));
}

View File

@ -774,78 +774,7 @@ void NativeOpExecutioner::execTransformStrict(nd4j::LaunchContext *lc,
if (xType != zType || !DataTypeUtils::isR(xType))
throw datatype_exception::build("NativeOpExecutioner::execTransformStrict requires X & Z to have same floating point type", xType, zType);
switch (opNum) {
case transform::SoftMax:
case transform::SoftMaxDerivative:
case transform::LogSoftMax: {
if (shape::isVector(hXShapeInfo)) {
int length = shape::length(hXShapeInfo);
int block = nd4j::math::nd4j_min<int>(length, 256);
launchDims.x = 1;
launchDims.y = block;
launchDims.z += (block * sizeof(double) * 4);
BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformStrict, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, lc->getAllocationPointer(), lc->getReductionPointer(), nullptr, nullptr), FLOAT_TYPES);
} else {
auto shape = shape::shapeOf(hXShapeInfo);
auto reductionPointer = lc->getReductionPointer();
auto allocationPointer = lc->getAllocationPointer();
auto specialPointer = reinterpret_cast<double *>(allocationPointer);
// special pointer for special buffer for special ops
auto dimension = reinterpret_cast<int *>(specialPointer);
auto maxDimension = dimension + 1;
auto maxShapeBuffer = reinterpret_cast<Nd4jLong *>(maxDimension + 1);
auto special = reinterpret_cast<double *> (maxShapeBuffer + (MAX_RANK * 2 + 4));
Nd4jLong maxShape[2] = {shape::shapeOf(hXShapeInfo)[0], 1};
auto hostMaxShapeBuffer = shape::shapeBuffer(2, xType, maxShape);
prepareShapeBuffer<<<1, 1, 128, *stream>>>(dimension, maxDimension, maxShapeBuffer, shape[0], xType);
DEBUG_KERNEL(stream, opNum);
// max 3
execReduceSame(lc, reduce::Max, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, nullptr, hostMaxShapeBuffer, special, maxShapeBuffer, maxDimension, 1, tadShapeInfo, tadOffsets);
DEBUG_KERNEL(stream, opNum);
// sub 1
execBroadcast(lc, broadcast::Subtract, hX, hXShapeInfo, dX, dXShapeInfo, nullptr, hostMaxShapeBuffer, special, maxShapeBuffer, nullptr, hZShapeInfo, dZ, dZShapeInfo, dimension, 1, tadShapeInfo, tadOffsets, nullptr, nullptr);
DEBUG_KERNEL(stream, opNum);
// exp 3
execTransformFloat(lc, transform::Exp, hZ, hZShapeInfo, dZ, dZShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraParams, tadShapeInfo, tadOffsets);
DEBUG_KERNEL(stream, opNum);
//sum 1
execReduceSame(lc, reduce::Sum, hZ, hZShapeInfo, dZ, dZShapeInfo, extraParams, nullptr, hostMaxShapeBuffer, special, maxShapeBuffer, maxDimension, 1, tadShapeInfo, tadOffsets);
// divide 3
execBroadcast(lc, broadcast::Divide, hZ, hZShapeInfo, dZ, dZShapeInfo, nullptr, hostMaxShapeBuffer, special, maxShapeBuffer, nullptr, hZShapeInfo, dZ, dZShapeInfo, dimension, 1, tadShapeInfo, tadOffsets, nullptr, nullptr);
DEBUG_KERNEL(stream, opNum);
// log 3
if (opNum == transform::LogSoftMax)
execTransformFloat(lc, transform::Log, nullptr, hZShapeInfo, dZ, dZShapeInfo, nullptr, hZShapeInfo, dZ, dZShapeInfo, extraParams, tadShapeInfo, tadOffsets);
else if (opNum == transform::SoftMaxDerivative)
execTransformStrict(lc, transform::SpecialDerivative, nullptr, hZShapeInfo, dZ, dZShapeInfo, nullptr, hZShapeInfo, dZ, dZShapeInfo, extraParams, tadShapeInfo, tadOffsets);
nd4j::DebugHelper::checkErrorCode(stream, "SoftMax(...) failed");
delete hostMaxShapeBuffer;
}
}
break;
default: {
BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformStrict, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), FLOAT_TYPES);
}
}
BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformStrict, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), FLOAT_TYPES);
}
////////////////////////////////////////////////////////////////////////
@ -869,22 +798,8 @@ void NativeOpExecutioner::execTransformFloat(nd4j::LaunchContext *lc,
if (!DataTypeUtils::isR(zType))
throw datatype_exception::build("NativeOpExecutioner::execTransformFloat requires Z to have floating point type", zType);
if (opNum == transform::Histogram) {
dim3 launchDims(256, 256, 32768);
Nd4jPointer maskedallocationPointer;
auto length = shape::length(hZShapeInfo);
cudaMalloc(reinterpret_cast<void **>(&maskedallocationPointer), length * launchDims.x * DataTypeUtils::sizeOf(nd4j::DataType::INT64));
auto imaskedallocationPointer = reinterpret_cast<int *>(maskedallocationPointer);
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformFloat, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, imaskedallocationPointer, reductionPointer, nullptr, nullptr), LIBND4J_TYPES, FLOAT_TYPES);
checkCudaErrors(cudaStreamSynchronize(*stream));
cudaFree(maskedallocationPointer);
} else {
dim3 launchDims(512, 512, 16384);
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformFloat, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, FLOAT_TYPES);
}
dim3 launchDims(512, 512, 16384);
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformFloat, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, FLOAT_TYPES);
}
@ -1197,12 +1112,15 @@ void NativeOpExecutioner::execRandom(nd4j::LaunchContext *lc,
dim3 launchDims = dim3(512, 512, 32768);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
auto rng = reinterpret_cast<nd4j::graph::RandomGenerator*>(stateHost);
// functions::random::RandomFunction<float>::executeCudaSingle(launchDims, extraPointers, opNum, stateHost, dZ, dZShapeInfo, extraArguments),
BUILD_SINGLE_SELECTOR(zType, functions::random::RandomFunction, ::executeCudaSingle(launchDims, stream, opNum, stateDevice, dZ, dZShapeInfo, extraArguments), FLOAT_TYPES);
checkCudaErrors(cudaMemcpyAsync(stateHost, stateDevice, sizeOf, cudaMemcpyDeviceToHost, *stream));
checkCudaErrors(cudaStreamSynchronize(*stream));
cudaFree(stateDevice);
rng->rewindH(shape::length(hZShapeInfo));
}
////////////////////////////////////////////////////////////////////////
@ -1224,14 +1142,17 @@ void NativeOpExecutioner::execRandom(nd4j::LaunchContext *lc,
checkCudaErrors(cudaStreamSynchronize(*stream));
checkCudaErrors(cudaMemcpyAsync(stateDevice, stateHost, sizeOf, cudaMemcpyHostToDevice, *stream));
auto rng = reinterpret_cast<nd4j::graph::RandomGenerator*>(stateHost);
dim3 launchDims = dim3(512, 512, 32768);
auto xType = nd4j::ArrayOptions::dataType(hZShapeInfo);
// functions::random::RandomFunction<float>::executeCudaDouble(launchDims, extraPointers, opNum, stateHost, dX, dXShapeInfo, dZ, dZShapeInfo, extraArguments);
BUILD_SINGLE_SELECTOR(xType, functions::random::RandomFunction, ::executeCudaDouble(launchDims, stream, opNum, stateDevice, dX, dXShapeInfo, dZ, dZShapeInfo, extraArguments), FLOAT_TYPES);
checkCudaErrors(cudaMemcpyAsync(stateHost, stateDevice, sizeOf, cudaMemcpyDeviceToHost, *stream));
checkCudaErrors(cudaStreamSynchronize(*stream));
cudaFree(stateDevice);
rng->rewindH(shape::length(hZShapeInfo));
}
////////////////////////////////////////////////////////////////////////
@ -1254,14 +1175,17 @@ void NativeOpExecutioner::execRandom(nd4j::LaunchContext *lc,
checkCudaErrors(cudaStreamSynchronize(*stream));
checkCudaErrors(cudaMemcpyAsync(stateDevice, stateHost, sizeOf, cudaMemcpyHostToDevice, *stream));
auto rng = reinterpret_cast<nd4j::graph::RandomGenerator*>(stateHost);
dim3 launchDims = dim3(512, 512, 32768);
auto xType = nd4j::ArrayOptions::dataType(hZShapeInfo);
// functions::random::RandomFunction<float>::executeCudaTriple(launchDims, extraPointers, opNum, stateHost, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraArguments);
BUILD_SINGLE_SELECTOR(xType, functions::random::RandomFunction, ::executeCudaTriple(launchDims, stream, opNum, stateDevice, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraArguments), FLOAT_TYPES);
checkCudaErrors(cudaMemcpyAsync(stateHost, stateDevice, sizeOf, cudaMemcpyDeviceToHost, *stream));
checkCudaErrors(cudaStreamSynchronize(*stream));
cudaFree(stateDevice);
rng->rewindH(shape::length(hZShapeInfo));
}
////////////////////////////////////////////////////////////////////////

View File

@ -1958,7 +1958,7 @@ void NativeOps::execRandom(Nd4jPointer *extraPointers,
void *extraArguments) {
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
NativeOpExecutioner::execRandom(&lc, opNum, extraPointers, hZ, hZShapeInfo, dZ, dZShapeInfo, extraArguments);
NativeOpExecutioner::execRandom(&lc, opNum, stateHost, hZ, hZShapeInfo, dZ, dZShapeInfo, extraArguments);
}
////////////////////////////////////////////////////////////////////////
@ -1970,7 +1970,7 @@ void NativeOps::execRandom(Nd4jPointer *extraPointers, int opNum, Nd4jPointer st
void *extraArguments) {
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
NativeOpExecutioner::execRandom(&lc, opNum, extraPointers, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraArguments);
NativeOpExecutioner::execRandom(&lc, opNum, stateHost, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraArguments);
}
////////////////////////////////////////////////////////////////////////
@ -1984,7 +1984,7 @@ void NativeOps::execRandom(Nd4jPointer *extraPointers, int opNum, Nd4jPointer st
void *extraArguments) {
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
NativeOpExecutioner::execRandom(&lc, opNum, extraPointers, hX, hXShapeInfo, dX, dXShapeInfo, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraArguments);
NativeOpExecutioner::execRandom(&lc, opNum, stateHost, hX, hXShapeInfo, dX, dXShapeInfo, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraArguments);
}

View File

@ -152,9 +152,9 @@ namespace functions {
__syncthreads();
// using this loop instead of memcpy
for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) {
for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x)
cB[e] = dB[e];
}
__syncthreads();
@ -213,9 +213,9 @@ namespace functions {
__syncthreads();
// using this loop instead of memcpy
for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) {
for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x)
cB[e] = dB[e];
}
__syncthreads();
@ -262,9 +262,9 @@ namespace functions {
__syncthreads();
// using this loop instead of memcpy
for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) {
for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x)
cB[e] = dB[e];
}
__syncthreads();
int tid = blockIdx.x * blockDim.x + threadIdx.x;

View File

@ -107,9 +107,6 @@
#define TRANSFORM_STRICT_OPS \
(0, SoftMax), \
(1, SoftMaxDerivative), \
(2, LogSoftMax) ,\
(3, ELUDerivative), \
(4, TanhDerivative), \
(5, HardTanhDerivative), \
@ -167,9 +164,7 @@
// these ops return one of FLOAT data types
#define TRANSFORM_FLOAT_OPS \
(0, Histogram), \
(1, Sqrt), \
(2, Pooling2D) ,\
(3, RSqrt)

View File

@ -0,0 +1,57 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_histogram)
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/transforms.h>
#include <ops/declarable/helpers/histogram.h>
namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(histogram, 1, 1, false, 0, 1) {
auto input = INPUT_VARIABLE(0);
auto numBins = INT_ARG(0);
auto output = OUTPUT_VARIABLE(0);
REQUIRE_TRUE(numBins == output->lengthOf(), 0, "Histogram: numBins must match output length")
helpers::histogramHelper(block.launchContext(), *input, *output);
return Status::OK();
}
DECLARE_SHAPE_FN(histogram) {
auto numBins = INT_ARG(0);
return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(numBins, nd4j::DataType::INT64));
}
DECLARE_TYPES(histogram) {
getOpDescriptor()
->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS})
->setAllowedOutputTypes({ALL_INTS});
};
}
}
#endif

View File

@ -206,6 +206,13 @@ namespace nd4j {
#if NOT_EXCLUDED(OP_hashcode)
DECLARE_REDUCTION_OP(hashcode, 1, 1, false, 0, 0);
#endif
/**
* This operation calculates number of entries per bin
*/
#if NOT_EXCLUDED(OP_histogram)
DECLARE_CUSTOM_OP(histogram, 1, 1, false, 0, 1);
#endif
}
}

View File

@ -0,0 +1,83 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <ops/declarable/helpers/histogram.h>
namespace nd4j {
namespace ops {
namespace helpers {
template <typename X, typename Z>
static void histogram_(void *xBuffer, Nd4jLong *xShapeInfo, void *zBuffer, Nd4jLong *zShapeInfo, Nd4jLong numBins, double min_val, double max_val) {
auto dx = reinterpret_cast<X*>(xBuffer);
auto result = reinterpret_cast<Z*>(zBuffer);
int length = shape::length(xShapeInfo);
// FIXME: 2???
int _threads = 2;
int span = (length / _threads) + 8;
X binSize = (max_val - min_val) / (numBins);
PRAGMA_OMP_PARALLEL_THREADS(_threads)
{
int tid, start, end;
int *bins = new int[numBins];
std::memset(bins, 0, sizeof(int) * numBins);
tid = omp_get_thread_num();
start = span * tid;
end = span * (tid + 1);
if (end > length) end = length;
PRAGMA_OMP_SIMD
for (int x = start; x < end; x++) {
int idx = (int) ((dx[x] - min_val) / binSize);
if (idx < 0)
idx = 0;
else if (idx >= numBins)
idx = numBins - 1;
bins[idx]++;
}
PRAGMA_OMP_CRITICAL
{
PRAGMA_OMP_SIMD
for (int x = 0; x < numBins; x++) {
result[x] += bins[x];
}
}
delete[] bins;
}
}
void histogramHelper(nd4j::LaunchContext *context, NDArray &input, NDArray &output) {
Nd4jLong numBins = output.lengthOf();
double min_val = input.reduceNumber(reduce::SameOps::Min).e<double>(0);
double max_val = input.reduceNumber(reduce::SameOps::Max).e<double>(0);
BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), histogram_, (input.buffer(), input.shapeInfo(), output.getBuffer(), output.getShapeInfo(), numBins, min_val, max_val), LIBND4J_TYPES, INTEGER_TYPES);
}
}
}
}

View File

@ -0,0 +1,134 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <ops/declarable/helpers/histogram.h>
#include <NDArrayFactory.h>
namespace nd4j {
namespace ops {
namespace helpers {
template <typename X, typename Z>
void _CUDA_G histogramKernel(void *xBuffer, Nd4jLong *xShapeInfo, void *zBuffer, Nd4jLong *zShapeInfo, void *allocationPointer, void *reductionPointer, Nd4jLong numBins, double min_val, double max_val) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
auto dx = reinterpret_cast<X*>(xBuffer);
auto result = reinterpret_cast<Z*>(zBuffer);
__shared__ Z *bins;
__shared__ int length;
__shared__ Z *reductor;
if (threadIdx.x == 0) {
extern __shared__ unsigned char shmem[];
bins = (Z *) shmem;
reductor = ((Z *) allocationPointer) + (numBins * blockIdx.x);
length = shape::length(xShapeInfo);
}
__syncthreads();
Z binSize = (max_val - min_val) / (numBins);
for (int e = threadIdx.x; e < numBins; e += blockDim.x) {
bins[e] = (Z) 0.0f;
}
__syncthreads();
for (int e = tid; e < length; e+= blockDim.x * gridDim.x) {
int idx = (int) ((dx[e] - min_val) / binSize);
if (idx < 0) idx = 0;
else if (idx >= numBins) idx = numBins - 1;
nd4j::math::atomics::nd4j_atomicAdd(&bins[idx], (Z) 1.0f);
}
__syncthreads();
// transfer shared memory to reduction memory
if (gridDim.x > 1) {
unsigned int *tc = (unsigned int *)reductionPointer;
__shared__ bool amLast;
for (int e = threadIdx.x; e < numBins; e += blockDim.x) {
reductor[e] = bins[e];
}
__threadfence();
__syncthreads();
if (threadIdx.x == 0) {
unsigned int ticket = atomicInc(&tc[16384], gridDim.x);
amLast = (ticket == gridDim.x - 1);
}
__syncthreads();
if (amLast) {
tc[16384] = 0;
// nullify shared memory for future accumulation
for (int e = threadIdx.x; e < numBins; e += blockDim.x) {
bins[e] = (Z) 0.0f;
}
// accumulate reduced bins
for (int r = 0; r < gridDim.x; r++) {
Z *ptrBuf = ((Z *)allocationPointer) + (r * numBins);
for (int e = threadIdx.x; e < numBins; e += blockDim.x) {
bins[e] += ptrBuf[e];
}
}
__syncthreads();
// write them out to Z
for (int e = threadIdx.x; e < numBins; e += blockDim.x) {
result[e] = bins[e];
}
}
} else {
// if there's only 1 block - just write away data
for (int e = threadIdx.x; e < numBins; e += blockDim.x) {
result[e] = bins[e];
}
}
}
template <typename X, typename Z>
static void histogram_(nd4j::LaunchContext *context, void *xBuffer, Nd4jLong *xShapeInfo, void *zBuffer, Nd4jLong *zShapeInfo, Nd4jLong numBins, double min_val, double max_val) {
int numThreads = 256;
int numBlocks = nd4j::math::nd4j_max<int>(256, nd4j::math::nd4j_min<int>(1, shape::length(xShapeInfo) / numThreads));
int workspaceSize = numBlocks * numBins;
auto tmp = NDArrayFactory::create<Z>('c',{workspaceSize});
histogramKernel<X, Z><<<numBlocks, numThreads, 32768, *context->getCudaStream()>>>(xBuffer, xShapeInfo, zBuffer, zShapeInfo, tmp.getSpecialBuffer(), context->getReductionPointer(), numBins, min_val, max_val);
cudaStreamSynchronize(*context->getCudaStream());
}
void histogramHelper(nd4j::LaunchContext *context, NDArray &input, NDArray &output) {
Nd4jLong numBins = output.lengthOf();
double min_val = input.reduceNumber(reduce::SameOps::Min).e<double>(0);
double max_val = input.reduceNumber(reduce::SameOps::Max).e<double>(0);
BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), histogram_, (context, input.specialBuffer(), input.specialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), numBins, min_val, max_val), LIBND4J_TYPES, INTEGER_TYPES);
NDArray::registerSpecialUse({&output}, {&input});
}
}
}
}

View File

@ -0,0 +1,34 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#ifndef LIBND4J_HISTOGRAM_H
#define LIBND4J_HISTOGRAM_H
#include <NDArray.h>
namespace nd4j {
namespace ops {
namespace helpers {
void histogramHelper(nd4j::LaunchContext *context, NDArray &input, NDArray &output);
}
}
}
#endif //DEV_TESTS_HISTOGRAM_H

View File

@ -747,88 +747,7 @@ static void execSpecial(T *in, Nd4jLong *inShapeBuffer, Z *out, Nd4jLong *outSha
int *allocationPointer, Z *reductionPointer,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
int numBins = (int) extraParams[0];
Z min_val = extraParams[1];
Z max_val = extraParams[2];
int tid = blockIdx.x * blockDim.x + threadIdx.x;
__shared__ Z *bins;
__shared__ int length;
__shared__ Z *reductor;
if (threadIdx.x == 0) {
extern __shared__ unsigned char shmem[];
bins = (Z *) shmem;
reductor = ((Z *) allocationPointer) + (numBins * blockIdx.x);
length = shape::length(xShapeBuffer);
}
__syncthreads();
Z binSize = (max_val - min_val) / (numBins);
for (int e = threadIdx.x; e < numBins; e += blockDim.x) {
bins[e] = (Z) 0.0f;
}
__syncthreads();
for (int e = tid; e < length; e+= blockDim.x * gridDim.x) {
int idx = (int) ((dx[e] - min_val) / binSize);
if (idx < 0) idx = 0;
else if (idx >= numBins) idx = numBins - 1;
nd4j::math::atomics::nd4j_atomicAdd(&bins[idx], (Z) 1.0f);
}
__syncthreads();
// transfer shared memory to reduction memory
if (gridDim.x > 1) {
unsigned int *tc = (unsigned int *)reductionPointer;
__shared__ bool amLast;
for (int e = threadIdx.x; e < numBins; e += blockDim.x) {
reductor[e] = bins[e];
}
__threadfence();
__syncthreads();
if (threadIdx.x == 0) {
unsigned int ticket = atomicInc(&tc[16384], gridDim.x);
amLast = (ticket == gridDim.x - 1);
}
__syncthreads();
if (amLast) {
tc[16384] = 0;
// nullify shared memory for future accumulation
for (int e = threadIdx.x; e < numBins; e += blockDim.x) {
bins[e] = (Z) 0.0f;
}
// accumulate reduced bins
for (int r = 0; r < gridDim.x; r++) {
Z *ptrBuf = ((Z *)allocationPointer) + (r * numBins);
for (int e = threadIdx.x; e < numBins; e += blockDim.x) {
bins[e] += ptrBuf[e];
}
}
__syncthreads();
// write them out to Z
for (int e = threadIdx.x; e < numBins; e += blockDim.x) {
result[e] = bins[e];
}
}
} else {
// if there's only 1 block - just write away data
for (int e = threadIdx.x; e < numBins; e += blockDim.x) {
result[e] = bins[e];
}
}
};
#endif
@ -840,52 +759,7 @@ static void execSpecial(T *in, Nd4jLong *inShapeBuffer, Z *out, Nd4jLong *outSha
Nd4jLong *zShapeBuffer,
Z *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
int length = shape::length(xShapeBuffer);
int _threads = 2;
int numBins = (int) extraParams[0];
int span = (length / _threads) + 8;
// get min over input
T min_val = extraParams[1];
T max_val = extraParams[2];
T binSize = (max_val - min_val) / (numBins);
PRAGMA_OMP_PARALLEL_THREADS(_threads)
{
int tid, start, end;
int *bins = new int[numBins];
std::memset(bins, 0, sizeof(int) * numBins);
tid = omp_get_thread_num();
start = span * tid;
end = span * (tid + 1);
if (end > length) end = length;
PRAGMA_OMP_SIMD
for (int x = start; x < end; x++) {
int idx = (int) ((dx[x] - min_val) / binSize);
if (idx < 0)
idx = 0;
else if (idx >= numBins)
idx = numBins - 1;
bins[idx]++;
}
PRAGMA_OMP_CRITICAL
{
PRAGMA_OMP_SIMD
for (int x = 0; x < numBins; x++) {
result[x] += bins[x];
}
}
delete[] bins;
}
}

View File

@ -88,10 +88,10 @@ namespace randomOps {
__syncthreads();
// using this loop instead of memcpy
for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) {
for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x)
cB[e] = dB[e];
}
// __syncthreads(); // Eliminated due RTX20xx specific
__syncthreads();
int tid = blockIdx.x * blockDim.x + threadIdx.x;
@ -137,10 +137,6 @@ namespace randomOps {
// __syncthreads(); // Eliminated due RTX20xx specific
}
}
// __syncthreads(); // Eliminated due RTX20xx specific
if (threadIdx.x == 0 && blockIdx.x == 0)
devRng->rewindH(zLength);
}
#endif
@ -208,9 +204,6 @@ namespace randomOps {
}
}
}
// update rng state
rng->rewindH(zLength);
}
};
@ -277,7 +270,7 @@ namespace randomOps {
for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x)
cB[e] = dB[e];
// __syncthreads(); // Eliminated due RTX20xx specific
__syncthreads();
int tid = blockIdx.x * blockDim.x + threadIdx.x;
@ -299,10 +292,6 @@ namespace randomOps {
z[epm * zEWS] = (nd4j::math::nd4j_sqrt<T,T>(t * nd4j::math::nd4j_log<T,T>(r0)) * nd4j::math::nd4j_sin<T,T>(two_pi * r1)) * stddev + realMean1;
}
}
// __syncthreads(); // Eliminated due RTX20xx specific
if (threadIdx.x == 0 && blockIdx.x == 0)
devRng->rewindH(zLength);
}
#endif
@ -352,9 +341,6 @@ namespace randomOps {
z[epm * zEWS] = z1;
}
}
// update rng state
rng->rewindH(zLength);
}
};
@ -401,10 +387,10 @@ namespace randomOps {
__syncthreads();
// using this loop instead of memcpy
for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) {
for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x)
cB[e] = dB[e];
}
// __syncthreads(); // Eliminated due RTX20xx specific
__syncthreads();
int tid = blockIdx.x * blockDim.x + threadIdx.x;
@ -423,10 +409,6 @@ namespace randomOps {
// if trials is set to 0, effectively we just have successful memset
z[e * zEWS] = static_cast<T>(success);
}
// __syncthreads(); // Eliminated due RTX20xx specific
if (trials > 0 && threadIdx.x == 0 && blockIdx.x == 0)
devRng->rewindH(zLength * trials);
}
#endif
@ -472,10 +454,6 @@ namespace randomOps {
z[e * zEWS] = static_cast<T>(success);
}
}
// update rng state
if (trials > 0)
rng->rewindH(zLength * trials);
}
};
@ -522,10 +500,10 @@ namespace randomOps {
__syncthreads();
// using this loop instead of memcpy
for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) {
for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x)
cB[e] = dB[e];
}
// __syncthreads(); // Eliminated due RTX20xx specific
__syncthreads();
int tid = blockIdx.x * blockDim.x + threadIdx.x;
@ -591,10 +569,6 @@ namespace randomOps {
z[e * zEWS] = static_cast<T>(success);
}
}
// update rng state
if (trials > 0)
rng->rewindH(zLength * trials);
}
};
@ -676,10 +650,10 @@ namespace randomOps {
__syncthreads();
// using this loop instead of memcpy
for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) {
for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x)
cB[e] = dB[e];
}
// __syncthreads(); // Eliminated due RTX20xx specific
__syncthreads();
int tid = blockIdx.x * blockDim.x + threadIdx.x;
@ -724,9 +698,6 @@ namespace randomOps {
z[e] = mean + nd4j::DataTypeUtils::min<T>();
}
}
// update rng state
rng->rewindH(zLength);
}
};
@ -788,10 +759,10 @@ namespace randomOps {
__syncthreads();
// using this loop instead of memcpy
for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) {
for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x)
cB[e] = dB[e];
}
// __syncthreads(); // Eliminated due RTX20xx specific
__syncthreads();
int tid = blockIdx.x * blockDim.x + threadIdx.x;
@ -868,9 +839,6 @@ namespace randomOps {
}
}
}
// update rng state
rng->rewindH(zLength);
}
};

View File

@ -1724,9 +1724,9 @@ namespace nd4j {
}
};
TransformBenchmark tbSoftmax(transform::StrictOps::SoftMax, "softmax");
//TransformBenchmark tbSoftmax(transform::StrictOps::SoftMax, "softmax");
output += helper.runOperationSuit(&tbSoftmax, generator2, batch2, "Softmax");
//output += helper.runOperationSuit(&tbSoftmax, generator2, batch2, "Softmax");
return output;
}

View File

@ -60,11 +60,11 @@ namespace nd4j {
sbRelu.setY(NDArrayFactory::create_<T>(0.0));
TransformBenchmark tbSigmoid(transform::StrictOps::Sigmoid, "sigmoid");
TransformBenchmark tbSoftmax(transform::StrictOps::SoftMax, "softmax");
//TransformBenchmark tbSoftmax(transform::StrictOps::SoftMax, "softmax");
output += helper.runOperationSuit(&sbRelu, generator, batch, "RELU");
output += helper.runOperationSuit(&tbSigmoid, generator, batch, "Sigmoid");
output += helper.runOperationSuit(&tbSigmoid, generator, batch, "Softmax");
//output += helper.runOperationSuit(&tbSigmoid, generator, batch, "Softmax");
return output;
}

View File

@ -49,7 +49,6 @@ import org.nd4j.linalg.api.ops.impl.summarystats.Variance;
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.*;
import org.nd4j.linalg.api.ops.impl.transforms.custom.*;
import org.nd4j.linalg.api.ops.impl.transforms.floating.Histogram;
import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt;
import org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.*;
@ -748,8 +747,6 @@ public class LegacyOpMapper {
public static Class<?> transformFloatingOpClass(int opNum){
switch (opNum){
case 0:
return Histogram.class;
case 1:
return Sqrt.class;
case 3:

View File

@ -70,10 +70,10 @@ import org.nd4j.linalg.api.ops.impl.shape.bp.SliceBp;
import org.nd4j.linalg.api.ops.impl.shape.bp.StridedSliceBp;
import org.nd4j.linalg.api.ops.impl.shape.bp.TileBp;
import org.nd4j.linalg.api.ops.impl.transforms.Assert;
import org.nd4j.linalg.api.ops.impl.transforms.Histogram;
import org.nd4j.linalg.api.ops.impl.transforms.bool.BooleanNot;
import org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform;
import org.nd4j.linalg.api.ops.impl.transforms.custom.*;
import org.nd4j.linalg.api.ops.impl.transforms.floating.Histogram;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.BinaryMinimalRelativeError;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.*;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.*;
@ -888,7 +888,6 @@ public class OpValidation {
SELUDerivative.class,
SigmoidDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.strict.SigmoidDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.strict.SoftMaxDerivative.class,
SoftSignDerivative.class,
TanhDerivative.class,
SwishDerivative.class,
@ -987,7 +986,6 @@ public class OpValidation {
OneHot.class,
BinaryMinimalRelativeError.class,
BinaryMinimalRelativeError.class,
Histogram.class,
InvertPermutation.class, //Uses integer indices
ConfusionMatrix.class, //Integer indices
Linspace.class, //No input array
@ -1056,7 +1054,9 @@ public class OpValidation {
LogicalAnd.class,
LogicalNot.class,
LogicalOr.class,
LogicalXor.class
LogicalXor.class,
Histogram.class
);
return new HashSet<>(list);

View File

@ -424,7 +424,6 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentProd.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentSum.class,
org.nd4j.linalg.api.ops.impl.transforms.dtype.Cast.class,
org.nd4j.linalg.api.ops.impl.transforms.floating.Histogram.class,
org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt.class,
org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative.class,
@ -552,7 +551,6 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.transforms.strict.SigmoidDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.strict.Sin.class,
org.nd4j.linalg.api.ops.impl.transforms.strict.Sinh.class,
org.nd4j.linalg.api.ops.impl.transforms.strict.SoftMaxDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.strict.SoftPlus.class,
org.nd4j.linalg.api.ops.impl.transforms.strict.SoftSign.class,
org.nd4j.linalg.api.ops.impl.transforms.strict.Stabilize.class,

View File

@ -19,20 +19,6 @@ package org.nd4j.linalg.activations;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.activations.impl.*;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarSet;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.same.OldIdentity;
import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear;
import org.nd4j.linalg.api.ops.impl.scalar.Step;
import org.nd4j.linalg.api.ops.impl.transforms.strict.*;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.*;
import org.nd4j.linalg.api.ops.impl.transforms.same.Cube;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SigmoidDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftMaxDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.strict.TanhDerivative;
/**
* This enum is the factory for the activation function.

View File

@ -0,0 +1,55 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
/**
* Histogram op wrapper
*
* @author raver119@gmail.com
*/
public class Histogram extends DynamicCustomOp {
private long numBins;
public Histogram() {
}
public Histogram(INDArray input, INDArray output) {
Preconditions.checkArgument(output.isZ(), "Histogram op output should have integer data type");
numBins = output.length();
inputArguments.add(input);
outputArguments.add(output);
iArguments.add(numBins);
}
public Histogram(INDArray input, long numBins) {
this.numBins = numBins;
inputArguments.add(input);
iArguments.add(numBins);
}
@Override
public String opName() {
return "histogram";
}
}

View File

@ -1,114 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.floating;
import lombok.val;
import onnx.OnnxProto3;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseTransformFloatOp;
import org.nd4j.linalg.api.ops.BaseTransformOp;
import org.nd4j.linalg.factory.Nd4j;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
/**
* @author raver119@gmail.com
*/
public class Histogram extends BaseTransformFloatOp {
public Histogram(SameDiff sameDiff, SDVariable i_v, boolean inPlace, int numBins) {
super(sameDiff, i_v, inPlace);
this.numBins = numBins;
}
public Histogram(SameDiff sameDiff, SDVariable i_v, int[] shape, boolean inPlace, Object[] extraArgs, int numBins) {
super(sameDiff, i_v, shape, inPlace, extraArgs);
this.numBins = numBins;
}
public Histogram(SameDiff sameDiff, SDVariable i_v, Object[] extraArgs, int numBins) {
super(sameDiff, i_v, extraArgs);
this.numBins = numBins;
}
private int numBins = 0;
public Histogram() {
//no-op
}
public Histogram(INDArray x, INDArray z) {
setX(x);
setZ(z);
//FIXME: int cast
numBins = (int) z.length();
double max = x.maxNumber().doubleValue();
double min = x.minNumber().doubleValue();
this.extraArgs = new Object[] {(double) numBins, min, max};
}
public Histogram(INDArray x, int numberOfBins) {
this(x, Nd4j.create(x.dataType(), numberOfBins));
}
@Override
public Map<String, Object> propertiesForFunction() {
Map<String,Object> ret = new LinkedHashMap<>();
ret.put("numBins",numBins);
return ret;
}
@Override
public int opNum() {
return 0;
}
@Override
public String opName() {
return "histogram";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
throw new UnsupportedOperationException("Not supported");
}
}

View File

@ -1,76 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.strict;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import java.nio.Buffer;
/**
* Softmax derivative
*
* @author Adam Gibson
*/
public class SoftMaxDerivative extends SoftMax {
public SoftMaxDerivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
super(sameDiff, new SDVariable[]{i_v1, i_v2});
}
public SoftMaxDerivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
super(sameDiff, new SDVariable[]{ i_v1, i_v2}, inPlace);
}
public SoftMaxDerivative(INDArray x, INDArray z) {
super(x, z);
}
public SoftMaxDerivative(INDArray x) {
super(x, x);
}
public SoftMaxDerivative() {}
@Override
public int opNum() {
return 1;
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
}
@Override
public String opName() {
return "_softmaxderivative";
}
}

View File

@ -1529,85 +1529,40 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
Nd4j.getExecutioner().commit();
AllocationPoint old = allocationPoint;
allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, new AllocationShape(length, elementSize, dataType()), false);
AllocationPoint old = allocationPoint;
allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, new AllocationShape(length, elementSize, dataType()), false);
Nd4j.getDeallocatorService().pickObject(this);
trackingPoint = allocationPoint.getObjectId();
Nd4j.getDeallocatorService().pickObject(this);
trackingPoint = allocationPoint.getObjectId();
val oldLength = this.length;
this.length = length;
switch(dataType()){
case DOUBLE:
this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asDoublePointer();
indexer = DoubleIndexer.create((DoublePointer) pointer);
break;
case FLOAT:
this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asFloatPointer();
indexer = FloatIndexer.create((FloatPointer) pointer);
break;
case BFLOAT16:
this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asShortPointer();
indexer = Bfloat16Indexer.create((ShortPointer) pointer);
break;
case HALF:
this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asShortPointer();
indexer = ShortIndexer.create((ShortPointer) pointer);
break;
case LONG:
this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asLongPointer();
indexer = LongIndexer.create((LongPointer) pointer);
break;
case UINT64:
this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asLongPointer();
indexer = LongIndexer.create((LongPointer) pointer);
break;
case INT:
this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asIntPointer();
indexer = IntIndexer.create((IntPointer) pointer);
break;
case UINT32:
this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asIntPointer();
indexer = IntIndexer.create((IntPointer) pointer);
break;
case SHORT:
this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asShortPointer();
indexer = ShortIndexer.create((ShortPointer) pointer);
break;
case UINT16:
this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asShortPointer();
indexer = UShortIndexer.create((ShortPointer) pointer);
break;
case BYTE:
this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asBytePointer();
indexer = ByteIndexer.create((BytePointer) pointer);
break;
default:
throw new UnsupportedOperationException();
}
// if original buffer had host pointer allocated, we'll reallocate host buffer as well
if (old.getHostPointer() != null) {
lazyAllocateHostPointer();
}
CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(allocationPoint.getDevicePointer(), 0, length * elementSize, 0, context.getSpecialStream());
val context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(allocationPoint.getDevicePointer(), 0, length * elementSize, 0, context.getSpecialStream());
MemcpyDirection direction = MemcpyDirection.DEVICE_TO_DEVICE;
val perfD = PerformanceTracker.getInstance().helperStartTransaction();
MemcpyDirection direction = MemcpyDirection.DEVICE_TO_DEVICE;
val perfD = PerformanceTracker.getInstance().helperStartTransaction();
if (old.isActualOnDeviceSide()) {
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), old.getDevicePointer(), this.length * elementSize, CudaConstants.cudaMemcpyDeviceToDevice, context.getSpecialStream());
} else if (old.isActualOnHostSide()) {
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), old.getHostPointer(), this.length * elementSize, CudaConstants.cudaMemcpyHostToDevice, context.getSpecialStream());
direction = MemcpyDirection.HOST_TO_DEVICE;
}
if (old.isActualOnDeviceSide()) {
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), old.getDevicePointer(), oldLength * elementSize, CudaConstants.cudaMemcpyDeviceToDevice, context.getSpecialStream());
} else if (old.isActualOnHostSide()) {
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), old.getHostPointer(), oldLength * elementSize, CudaConstants.cudaMemcpyHostToDevice, context.getSpecialStream());
direction = MemcpyDirection.HOST_TO_DEVICE;
}
context.getSpecialStream().synchronize();
context.getSpecialStream().synchronize();
PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint.getDeviceId(), perfD, allocationPoint.getNumberOfBytes(), direction);
PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint.getDeviceId(), perfD, allocationPoint.getNumberOfBytes(), direction);
allocationPoint.tickDeviceWrite();
// we're keeping pointer reference for JVM
pointer.address();
allocationPoint.tickDeviceWrite();
// we need to update length with new value now
//this.length = length;
// we need to update length with new value now
//this.length = length;
if(isAttached()){
// do nothing here, that's workspaces
} else{
@ -1619,7 +1574,10 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
@Override
public long capacity() {
return pointer.capacity();
if (allocationPoint.getHostPointer() != null)
return pointer.capacity();
else
return length;
}
@Override

View File

@ -16754,6 +16754,27 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
private native void allocate();
}
// #endif
/**
* This operation calculates number of entries per bin
*/
// #if NOT_EXCLUDED(OP_histogram)
@Namespace("nd4j::ops") public static class histogram extends DeclarableCustomOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public histogram(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public histogram(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public histogram position(long position) {
return (histogram)super.position(position);
}
public histogram() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif

View File

@ -79,7 +79,6 @@ import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.Axpy;
import org.nd4j.linalg.api.ops.impl.transforms.same.OldReverse;
import org.nd4j.linalg.api.ops.impl.transforms.same.Sign;
import org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftMaxDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.checkutil.NDArrayCreationUtil;
@ -931,13 +930,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
System.out.println("Second: OK");
}
@Test
public void testSoftmaxDerivative() {
INDArray input = Nd4j.create(new double[] {-1.07, -0.01, 0.45, 0.95, 0.45, 0.16, 0.20, 0.80, 0.89, 0.25}).reshape(1, -1).transpose();
INDArray output = Nd4j.create(10, 1);
Nd4j.getExecutioner().exec(new SoftMaxDerivative(input, output));
}
@Test
public void testVStackDifferentOrders() {
@ -7245,19 +7237,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
}
}
@Test
public void testInvalidTransformsSoftmax(){
INDArray arr = Nd4j.zeros(2,3,4);
try{
Transforms.softmax(arr);
fail("Expected exception");
} catch (IllegalArgumentException e){
//OK
assertTrue(e.getMessage().contains("rank 2"));
}
}
@Test
public void testEmptyCasting(){
for(val from : DataType.values()) {

View File

@ -17,6 +17,7 @@
package org.nd4j.linalg.api.buffer;
import lombok.val;
import org.junit.Test;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
@ -28,6 +29,7 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import java.io.*;
import java.util.Arrays;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
@ -85,10 +87,12 @@ public class IntDataBufferTests extends BaseNd4jTest {
public void testReallocation() {
DataBuffer buffer = Nd4j.createBuffer(new int[] {1, 2, 3, 4});
assertEquals(4, buffer.capacity());
int[] old = buffer.asInt();
buffer.reallocate(6);
val old = buffer.asInt();
assertEquals(6, buffer.capacity());
assertArrayEquals(old, buffer.asInt());
val newContent = buffer.asInt();
assertEquals(6, newContent.length);
assertArrayEquals(old, Arrays.copyOf(newContent, old.length));
}
@Test
@ -98,12 +102,14 @@ public class IntDataBufferTests extends BaseNd4jTest {
MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "SOME_ID");
DataBuffer buffer = Nd4j.createBuffer(new int[] {1, 2, 3, 4});
int[] old = buffer.asInt();
val old = buffer.asInt();
assertTrue(buffer.isAttached());
assertEquals(4, buffer.capacity());
buffer.reallocate(6);
assertEquals(6, buffer.capacity());
assertArrayEquals(old, buffer.asInt());
val newContent = buffer.asInt();
assertEquals(6, newContent.length);
assertArrayEquals(old, Arrays.copyOf(newContent, old.length));
workspace.close();
}

View File

@ -82,12 +82,12 @@ public class RngTests extends BaseNd4jTest {
INDArray narr = Nd4j.randn('c', rows, cols);
assertArrayEquals(new long[] {rows, cols}, narr.shape());
assertEquals('c', narr.ordering());
assertEquals(narr.meanNumber().doubleValue(), 0.0, 0.05);
assertEquals(0.0, narr.meanNumber().doubleValue(), 0.05);
INDArray narr2 = Nd4j.randn('f', rows, cols);
assertArrayEquals(new long[] {rows, cols}, narr2.shape());
assertEquals('f', narr2.ordering());
assertEquals(narr2.meanNumber().doubleValue(), 0.0, 0.05);
assertEquals(0.0, narr2.meanNumber().doubleValue(), 0.05);
INDArray narr3 = Nd4j.randn('c', new int[] {rows, cols, dim2});
assertArrayEquals(new long[] {rows, cols, dim2}, narr3.shape());
@ -97,7 +97,7 @@ public class RngTests extends BaseNd4jTest {
INDArray narr4 = Nd4j.randn('f', new int[] {rows, cols, dim2});
assertArrayEquals(new long[] {rows, cols, dim2}, narr4.shape());
assertEquals('f', narr4.ordering());
assertEquals(narr4.meanNumber().doubleValue(), 0.0, 0.05);
assertEquals(0.0, narr4.meanNumber().doubleValue(), 0.05);
}

View File

@ -29,7 +29,6 @@ import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftMaxDerivative;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.BooleanIndexing;
@ -159,7 +158,6 @@ public class CrashTest extends BaseNd4jTest {
// logisoftmax, softmax & softmax derivative
Nd4j.getExecutioner().exec((CustomOp) new SoftMax(x));
Nd4j.getExecutioner().exec((CustomOp) new SoftMaxDerivative(x));
Nd4j.getExecutioner().exec((CustomOp) new LogSoftMax(x));

View File

@ -441,8 +441,8 @@ public class BooleanIndexingTest extends BaseNd4jTest {
Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.ANY_PANIC);
NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(true);
INDArray arr = Nd4j.linspace(1,4,4, Nd4j.dataType()).reshape(2,2);
INDArray filtered = BooleanIndexing.chooseFrom(new INDArray[]{arr},Arrays.asList(2.0), Collections.emptyList(),new GreaterThan());
assertEquals(4,filtered.length());
INDArray filtered = BooleanIndexing.chooseFrom(new INDArray[]{arr}, Arrays.asList(2.0), Collections.emptyList(),new GreaterThan());
assertEquals(2, filtered.length());
}
@ -450,7 +450,7 @@ public class BooleanIndexingTest extends BaseNd4jTest {
public void testChooseGreaterThanZero() {
INDArray zero = Nd4j.linspace(0,4,4, Nd4j.dataType());
INDArray filtered = BooleanIndexing.chooseFrom(new INDArray[]{zero},Arrays.asList(0.0), Collections.emptyList(),new GreaterThan());
assertEquals(3,filtered.length());
assertEquals(3, filtered.length());
}
@Test

View File

@ -331,7 +331,7 @@ public class MixedDataTypesTests extends BaseNd4jTest {
assertArrayEquals(exp, arr);
}
@Test(expected = IllegalArgumentException.class)
@Test(expected = RuntimeException.class)
public void testTypesValidation_3() {
val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT);

View File

@ -195,47 +195,6 @@ public class DerivativeTests extends BaseNd4jTest {
}
@Test
public void testSoftMaxDerivative() {
Random r = new Random(12345L);
int[] mb = new int[] {10, 2, 1};
for (int minibatch : mb) {
System.out.println("Minibatch size: " + minibatch);
INDArray z = Nd4j.zeros(minibatch, 5);
double[][] in = new double[minibatch][5];
double[][] softmax = new double[minibatch][5];
double[][] expOut = new double[minibatch][5];
for (int i = 0; i < minibatch; i++) {
double rowSumExp = 0.0;
for (int j = 0; j < 5; j++) {
in[i][j] = 10 * r.nextDouble();
z.putScalar(new int[] {i, j}, in[i][j]);
rowSumExp += FastMath.exp(in[i][j]);
}
for (int j = 0; j < 5; j++) {
softmax[i][j] = FastMath.exp(in[i][j]) / rowSumExp;
expOut[i][j] = softmax[i][j] * (1.0 - softmax[i][j]);
}
}
INDArray sm = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(z.dup()))[0];
INDArray zPrime = Nd4j.getExecutioner().exec((CustomOp) new SoftMaxDerivative(z))[0];
System.out.println(Arrays.toString(sm.data().asDouble()));
System.out.println(Arrays.toString(zPrime.data().asDouble()));
assertNotEquals(sm, zPrime);
for (int i = 0; i < minibatch; i++) {
for (int j = 0; j < 5; j++) {
double relError = Math.abs(expOut[i][j] - zPrime.getDouble(i, j))
/ (Math.abs(expOut[i][j]) + Math.abs(zPrime.getDouble(i, j)));
// System.out.println("Error: " + relError);
assertTrue(relError < REL_ERROR_TOLERANCE);
}
}
}
}
@Test
public void testSoftPlusDerivative() {
@ -384,180 +343,6 @@ public class DerivativeTests extends BaseNd4jTest {
}
}
@Test
public void softmaxsimpleLossTest() {
/*
Softmax derivative is correct if it is standalone
But when we are applying it in the chain rule the current derivative function is incomplete.
For this test, I am assuming that the function off interest is just MSE
What the fix is:
We need the derivative of a softmax needs to return a rank 2 matrix.
Right now we get only the diagonal elements of this matrix
http://stats.stackexchange.com/questions/79454/softmax-layer-in-a-neural-network
*/
//random array represeting preout
INDArray X = Nd4j.rand(1, 2);
//preout transformed to y_hat with softmax
INDArray YHat = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(X.dup()))[0];
//hard coding something to construct a function with, using MSE
INDArray Y = Nd4j.create(new double[][] {{0.123, 1 - 0.123}});
//This is the MSE now
double lossHere = Transforms.pow(Y.sub(YHat), 2).sumNumber().doubleValue();
INDArray softmaxDer = Nd4j.getExecutioner().exec((CustomOp) new SoftMaxDerivative(X.dup()))[0];
//the way we apply the chain rule now is 2*(y-yhat)*softmaxder
INDArray dLdY = Y.sub(YHat).mul(-2);
INDArray currentGradient = dLdY.mul(softmaxDer);
//what I think we should be doing
// we have x0, x1 -> y0,y1
//we need the derivatives of the output of the softmax wrt every input (x0,x1)
// we only have dy0/dx0 and dy1/dx1
// we also need dy0/dx1 and dy1/dx0
// the below is the chain rule in calc applied when L is a function of y0,y1; y0 and y1 are in turn functions of BOTH (x0 and x1)
// dL/dx0 = (dl/dy0) * (dy0/dx0) + (dL/dy1) * (dy1/dx0)
// dL/dx1 = (dl/dy0) * (dy0/dx1) + (dL/dy1) * (dy1/dx1)
// worked it out on paper and googled it (should have googled first, gave formula from link above)
// dy0/dx0 = y0*(1-y0) = y0*y1
// dy1/dx0 = -y1*(1-y1) = -y0*y1
// dy0/dx1 = -y0*(1-y0) = -y0*y1
// dy1/dx1 = y1*(1-y1) = y0*y1
//[ dL/dy0 dL/dy1] [[dy0/dx0 dy1/dx0] [dy0/dx1 dy1/dx1]]
double y0y1 = softmaxDer.getDouble(0, 0);
//hack but this is what we need to implement, straightforward here but complicated for >2
//INDArray mysoftmaxDer = Nd4j.create(new double[][] {{y0y1,y0y1*-1},{-1*y0y1,y0y1}});
INDArray mysoftmaxDer = correctSoftmax(X);
INDArray myGradient = mysoftmaxDer.mulRowVector(dLdY).sum(1);
double epsilon = 0.0001;
INDArray Xiplus, Ximinus;
INDArray YHatplus, YHatminus;
double lossplus, lossminus;
INDArray numGradient = Nd4j.zeros(1, 2);
for (int i = 0; i < 2; i++) {
/* change X one value one at a time */
// +epsilon
double x = X.getDouble(0, i);
Xiplus = X.dup();
Xiplus.put(0, i, x + epsilon);
YHatplus = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(Xiplus.dup()))[0];
lossplus = Transforms.pow(Y.sub(YHatplus), 2).sumNumber().doubleValue();
// -epsilon
Ximinus = X.dup();
Ximinus.put(0, i, x - epsilon);
YHatminus = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(Ximinus.dup()))[0];
lossminus = Transforms.pow(Y.sub(YHatminus), 2).sumNumber().doubleValue();
double gradienti = (lossplus - lossminus) / (2 * epsilon);
numGradient.put(0, i, gradienti);
}
System.out.println("=========================");
System.out.println("NUMERICAL:");
System.out.println(numGradient);
System.out.println("\nCURRENTLY:");
System.out.println(currentGradient);
System.out.println("\nMY GRADIENT:");
System.out.println(myGradient + "\n");
System.out.println(
"Because of the nature of the derivative of the softmax for length = 2, our current method will make it off by a factor of 2");
System.out.println("=========================");
}
@Test
public void softmaxsimplelongerlengthLossTest() {
/*
Read comments in earlier test for length = 2
*/
//random array represeting preout
int someLength = 7;
INDArray X = Nd4j.rand(1, someLength);
//preout transformed to y_hat with softmax
INDArray YHat = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(X.dup()))[0];
//hard coding something to construct a function with, using MSE
INDArray temp = Nd4j.rand(1, someLength);
INDArray Y = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(temp))[0];
//This is the MSE now
double lossHere = Transforms.pow(Y.sub(YHat), 2).sumNumber().doubleValue();
INDArray softmaxDer = Nd4j.getExecutioner().exec((CustomOp) new SoftMaxDerivative(X.dup()))[0];
//the way we apply the chain rule now is 2*(y-yhat)*softmaxder
INDArray dLdY = Y.sub(YHat).mul(-2);
INDArray currentGradient = dLdY.mul(softmaxDer);
INDArray mysoftmaxDer = correctSoftmax(X);
INDArray myGradient = mysoftmaxDer.mulRowVector(dLdY).sum(1);
double epsilon = 0.0001;
INDArray Xiplus, Ximinus;
INDArray YHatplus, YHatminus;
double lossplus, lossminus;
INDArray numGradient = Nd4j.zeros(1, someLength);
for (int i = 0; i < someLength; i++) {
/* change X one value one at a time */
// +epsilon
double x = X.getDouble(0, i);
Xiplus = X.dup();
Xiplus.put(0, i, x + epsilon);
YHatplus = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(Xiplus.dup()))[0];
lossplus = Transforms.pow(Y.sub(YHatplus), 2).sumNumber().doubleValue();
// -epsilon
Ximinus = X.dup();
Ximinus.put(0, i, x - epsilon);
YHatminus = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(Ximinus.dup()))[0];
lossminus = Transforms.pow(Y.sub(YHatminus), 2).sumNumber().doubleValue();
double gradienti = (lossplus - lossminus) / (2 * epsilon);
numGradient.put(0, i, gradienti);
}
System.out.println("=========================");
System.out.println("NUMERICAL GRADIENT:");
System.out.println(new NDArrayStrings(6).format(numGradient).toString());
System.out.println("\nANALYTIC USING EXISTING SOFTMAX DER:");
System.out.println(new NDArrayStrings(6).format(currentGradient).toString());
System.out.println("\nGRADIENT USING MY VERSION OF SOFTMAX DER:");
System.out.println(new NDArrayStrings(6).format(myGradient).toString());
System.out.println("=========================");
}
public static INDArray correctSoftmax(INDArray X) {
// this is only for X a row vector
// should return rank 2 matrix diagonal elements are pi*(1-pi)
//rest are -pi*pj
INDArray p = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(X.dup()))[0];
INDArray pCol = p.dup().transpose();
INDArray pipj = pCol.mmul(p);
pipj.muli(-1);
//so now pipj is correct except for the diagonal elements
// which by the way is what our current softmax der gives us
INDArray diagp = Nd4j.getExecutioner().exec((CustomOp) new SoftMaxDerivative(X.dup()))[0];
//ugly for loop to correct diag elements
for (int i = 0; i < X.length(); i++) {
pipj.put(i, i, diagp.getDouble(0, i));
}
return pipj;
}
@Override
public char ordering() {
return 'f';

View File

@ -50,9 +50,10 @@ import org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseSubtraction;
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThan;
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThan;
import org.nd4j.linalg.api.ops.impl.summarystats.Variance;
import org.nd4j.linalg.api.ops.impl.transforms.Histogram;
import org.nd4j.linalg.api.ops.impl.transforms.HistogramFixedWidth;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.floating.Histogram;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Exp;
@ -876,14 +877,14 @@ public class OpExecutionerTestsC extends BaseNd4jTest {
@Test
public void testHistogram1() {
INDArray x = Nd4j.linspace(1, 1000, 100000, DataType.DOUBLE);
INDArray z = Nd4j.zeros(DataType.DOUBLE,new long[]{20});
INDArray z = Nd4j.zeros(DataType.LONG,new long[]{20});
INDArray xDup = x.dup();
INDArray zDup = z.dup();
INDArray zExp = Nd4j.create(DataType.DOUBLE, 20).assign(5000);
INDArray zExp = Nd4j.create(DataType.LONG, 20).assign(5000);
Histogram histogram = new Histogram(x, z);
val histogram = new Histogram(x, z);
Nd4j.getExecutioner().exec(histogram);
@ -899,16 +900,13 @@ public class OpExecutionerTestsC extends BaseNd4jTest {
public void testHistogram2() {
INDArray x = Nd4j.create(new float[] {0f, 0f, 0f, 5f, 5f, 5f, 10f, 10f, 10f});
INDArray xDup = x.dup();
INDArray zExp = Nd4j.zeros(DataType.FLOAT, 10).putScalar(0, 3f).putScalar(5, 3f).putScalar(9, 3f);
INDArray zExp = Nd4j.zeros(DataType.LONG, 10).putScalar(0, 3).putScalar(5, 3).putScalar(9, 3);
Histogram histogram = new Histogram(x, 10);
val histogram = new Histogram(x, 10);
Nd4j.getExecutioner().exec(histogram);
INDArray z = histogram.z();
val z = Nd4j.getExecutioner().exec(histogram)[0];
assertEquals(xDup, x);

View File

@ -19,6 +19,7 @@ package org.nd4j.linalg.serde;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import lombok.val;
import org.junit.Test;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType;
@ -58,7 +59,7 @@ public class JsonSerdeTests extends BaseNd4jTest {
Nd4j.getRandom().setSeed(12345);
INDArray in = Nd4j.rand(DataType.DOUBLE, 3, 4).muli(20).subi(10);
ObjectMapper om = new ObjectMapper();
val om = new ObjectMapper();
for (DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.LONG, DataType.INT, DataType.SHORT,
DataType.BYTE, DataType.UBYTE, DataType.BOOL, DataType.UTF8}) {

View File

@ -673,7 +673,7 @@ public class BasicWorkspaceTests extends BaseNd4jTest {
workspace.initializeWorkspace();
long reqMemory = 12 * Nd4j.sizeOfDataType(arrayCold.dataType());
assertEquals(reqMemory + reqMemory % 8 + Nd4j.sizeOfDataType(DOUBLE), workspace.getCurrentSize());
assertEquals(reqMemory + reqMemory % 8, workspace.getCurrentSize());
log.info("-----------------------");
@ -692,7 +692,7 @@ public class BasicWorkspaceTests extends BaseNd4jTest {
array.addi(1.0);
assertEquals(reqMem + reqMem % 8 + Nd4j.sizeOfDataType(DOUBLE), workspace.getPrimaryOffset());
assertEquals(reqMem + reqMem % 8, workspace.getPrimaryOffset());
assertEquals("Failed on iteration " + x, 10, array.sumNumber().doubleValue(), 0.01);
@ -746,7 +746,7 @@ public class BasicWorkspaceTests extends BaseNd4jTest {
INDArray dup = array.dup();
assertEquals((reqMemory + reqMemory % 8) * 2 + Nd4j.sizeOfDataType(DOUBLE), workspace.getPrimaryOffset());
assertEquals((reqMemory + reqMemory % 8) * 2, workspace.getPrimaryOffset());
assertEquals(5, dup.sumNumber().doubleValue(), 0.01);

View File

@ -91,7 +91,7 @@ public class DebugModeTests extends BaseNd4jTest {
assertEquals(0, ws.getDeviceOffset());
// array buffer should be spilled now
assertEquals(10 * 10 * Nd4j.sizeOfDataType(DataType.DOUBLE) + Nd4j.sizeOfDataType(DataType.DOUBLE), ws.getSpilledSize());
assertEquals(10 * 10 * Nd4j.sizeOfDataType(DataType.DOUBLE), ws.getSpilledSize());
}
}
@ -118,7 +118,7 @@ public class DebugModeTests extends BaseNd4jTest {
assertEquals(0, ws.getDeviceOffset());
// array buffer should be spilled now
assertEquals(10 * 10 * Nd4j.sizeOfDataType(DataType.DOUBLE) + Nd4j.sizeOfDataType(DataType.DOUBLE), ws.getSpilledSize());
assertEquals(10 * 10 * Nd4j.sizeOfDataType(DataType.DOUBLE), ws.getSpilledSize());
}
try (val ws = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfig, "R_119_1992")) {