[WIP] More CUDA fixes/updates (#62)
* CUDA reallocation update Signed-off-by: raver119 <raver119@gmail.com> * Legacy SoftMax/LogSoftMax/SoftMaxDerivative removed from cpp Signed-off-by: raver119 <raver119@gmail.com> * SoftMaxDerivative op removed Signed-off-by: raver119 <raver119@gmail.com> * few tests updates Signed-off-by: raver119 <raver119@gmail.com> * RNG fixes Signed-off-by: raver119 <raver119@gmail.com> * few more tests updates Signed-off-by: raver119 <raver119@gmail.com> * legacy Histogram/Pooling2D removed Signed-off-by: raver119 <raver119@gmail.com> * legacy Histogram removed Signed-off-by: raver119 <raver119@gmail.com> * histogram moved Signed-off-by: raver119 <raver119@gmail.com> * histogram moved cuda Signed-off-by: raver119 <raver119@gmail.com> * Histogram custom op Signed-off-by: raver119 <raver119@gmail.com>master
parent
5cf6859fc4
commit
fd6c0df024
|
@ -827,6 +827,9 @@ void NativeOpExecutioner::execRandom(nd4j::LaunchContext *lc,
|
|||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||
|
||||
BUILD_SINGLE_SELECTOR(zType, functions::random::RandomFunction, ::execTransform(opNum, state, hZ, hZShapeInfo, extraArguments), FLOAT_TYPES);
|
||||
|
||||
auto rng = reinterpret_cast<nd4j::graph::RandomGenerator*>(state);
|
||||
rng->rewindH(shape::length(hZShapeInfo));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
@ -842,6 +845,9 @@ void NativeOpExecutioner::execRandom(nd4j::LaunchContext *lc,
|
|||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||
|
||||
BUILD_SINGLE_SELECTOR(zType, functions::random::RandomFunction, ::execTransform(opNum, state, hX, hXShapeInfo, hZ, hZShapeInfo, extraArguments), FLOAT_TYPES);
|
||||
|
||||
auto rng = reinterpret_cast<nd4j::graph::RandomGenerator*>(state);
|
||||
rng->rewindH(shape::length(hZShapeInfo));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
@ -859,6 +865,9 @@ void NativeOpExecutioner::execRandom(nd4j::LaunchContext *lc,
|
|||
auto xType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||
|
||||
BUILD_SINGLE_SELECTOR(xType, functions::random::RandomFunction, ::execTransform(opNum, state, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraArguments), FLOAT_TYPES);
|
||||
|
||||
auto rng = reinterpret_cast<nd4j::graph::RandomGenerator*>(state);
|
||||
rng->rewindH(shape::length(hZShapeInfo));
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -774,78 +774,7 @@ void NativeOpExecutioner::execTransformStrict(nd4j::LaunchContext *lc,
|
|||
if (xType != zType || !DataTypeUtils::isR(xType))
|
||||
throw datatype_exception::build("NativeOpExecutioner::execTransformStrict requires X & Z to have same floating point type", xType, zType);
|
||||
|
||||
switch (opNum) {
|
||||
case transform::SoftMax:
|
||||
case transform::SoftMaxDerivative:
|
||||
case transform::LogSoftMax: {
|
||||
if (shape::isVector(hXShapeInfo)) {
|
||||
int length = shape::length(hXShapeInfo);
|
||||
int block = nd4j::math::nd4j_min<int>(length, 256);
|
||||
|
||||
launchDims.x = 1;
|
||||
launchDims.y = block;
|
||||
launchDims.z += (block * sizeof(double) * 4);
|
||||
|
||||
BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformStrict, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, lc->getAllocationPointer(), lc->getReductionPointer(), nullptr, nullptr), FLOAT_TYPES);
|
||||
} else {
|
||||
auto shape = shape::shapeOf(hXShapeInfo);
|
||||
auto reductionPointer = lc->getReductionPointer();
|
||||
auto allocationPointer = lc->getAllocationPointer();
|
||||
auto specialPointer = reinterpret_cast<double *>(allocationPointer);
|
||||
|
||||
// special pointer for special buffer for special ops
|
||||
auto dimension = reinterpret_cast<int *>(specialPointer);
|
||||
auto maxDimension = dimension + 1;
|
||||
auto maxShapeBuffer = reinterpret_cast<Nd4jLong *>(maxDimension + 1);
|
||||
auto special = reinterpret_cast<double *> (maxShapeBuffer + (MAX_RANK * 2 + 4));
|
||||
|
||||
|
||||
Nd4jLong maxShape[2] = {shape::shapeOf(hXShapeInfo)[0], 1};
|
||||
auto hostMaxShapeBuffer = shape::shapeBuffer(2, xType, maxShape);
|
||||
|
||||
prepareShapeBuffer<<<1, 1, 128, *stream>>>(dimension, maxDimension, maxShapeBuffer, shape[0], xType);
|
||||
|
||||
DEBUG_KERNEL(stream, opNum);
|
||||
|
||||
// max 3
|
||||
execReduceSame(lc, reduce::Max, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, nullptr, hostMaxShapeBuffer, special, maxShapeBuffer, maxDimension, 1, tadShapeInfo, tadOffsets);
|
||||
|
||||
DEBUG_KERNEL(stream, opNum);
|
||||
|
||||
// sub 1
|
||||
execBroadcast(lc, broadcast::Subtract, hX, hXShapeInfo, dX, dXShapeInfo, nullptr, hostMaxShapeBuffer, special, maxShapeBuffer, nullptr, hZShapeInfo, dZ, dZShapeInfo, dimension, 1, tadShapeInfo, tadOffsets, nullptr, nullptr);
|
||||
|
||||
DEBUG_KERNEL(stream, opNum);
|
||||
|
||||
// exp 3
|
||||
execTransformFloat(lc, transform::Exp, hZ, hZShapeInfo, dZ, dZShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraParams, tadShapeInfo, tadOffsets);
|
||||
|
||||
DEBUG_KERNEL(stream, opNum);
|
||||
|
||||
//sum 1
|
||||
execReduceSame(lc, reduce::Sum, hZ, hZShapeInfo, dZ, dZShapeInfo, extraParams, nullptr, hostMaxShapeBuffer, special, maxShapeBuffer, maxDimension, 1, tadShapeInfo, tadOffsets);
|
||||
|
||||
// divide 3
|
||||
execBroadcast(lc, broadcast::Divide, hZ, hZShapeInfo, dZ, dZShapeInfo, nullptr, hostMaxShapeBuffer, special, maxShapeBuffer, nullptr, hZShapeInfo, dZ, dZShapeInfo, dimension, 1, tadShapeInfo, tadOffsets, nullptr, nullptr);
|
||||
|
||||
DEBUG_KERNEL(stream, opNum);
|
||||
|
||||
// log 3
|
||||
if (opNum == transform::LogSoftMax)
|
||||
execTransformFloat(lc, transform::Log, nullptr, hZShapeInfo, dZ, dZShapeInfo, nullptr, hZShapeInfo, dZ, dZShapeInfo, extraParams, tadShapeInfo, tadOffsets);
|
||||
else if (opNum == transform::SoftMaxDerivative)
|
||||
execTransformStrict(lc, transform::SpecialDerivative, nullptr, hZShapeInfo, dZ, dZShapeInfo, nullptr, hZShapeInfo, dZ, dZShapeInfo, extraParams, tadShapeInfo, tadOffsets);
|
||||
|
||||
nd4j::DebugHelper::checkErrorCode(stream, "SoftMax(...) failed");
|
||||
|
||||
delete hostMaxShapeBuffer;
|
||||
}
|
||||
}
|
||||
break;
|
||||
default: {
|
||||
BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformStrict, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), FLOAT_TYPES);
|
||||
}
|
||||
}
|
||||
BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformStrict, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
@ -869,22 +798,8 @@ void NativeOpExecutioner::execTransformFloat(nd4j::LaunchContext *lc,
|
|||
if (!DataTypeUtils::isR(zType))
|
||||
throw datatype_exception::build("NativeOpExecutioner::execTransformFloat requires Z to have floating point type", zType);
|
||||
|
||||
if (opNum == transform::Histogram) {
|
||||
dim3 launchDims(256, 256, 32768);
|
||||
|
||||
Nd4jPointer maskedallocationPointer;
|
||||
auto length = shape::length(hZShapeInfo);
|
||||
cudaMalloc(reinterpret_cast<void **>(&maskedallocationPointer), length * launchDims.x * DataTypeUtils::sizeOf(nd4j::DataType::INT64));
|
||||
auto imaskedallocationPointer = reinterpret_cast<int *>(maskedallocationPointer);
|
||||
|
||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformFloat, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, imaskedallocationPointer, reductionPointer, nullptr, nullptr), LIBND4J_TYPES, FLOAT_TYPES);
|
||||
|
||||
checkCudaErrors(cudaStreamSynchronize(*stream));
|
||||
cudaFree(maskedallocationPointer);
|
||||
} else {
|
||||
dim3 launchDims(512, 512, 16384);
|
||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformFloat, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, FLOAT_TYPES);
|
||||
}
|
||||
dim3 launchDims(512, 512, 16384);
|
||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformFloat, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, FLOAT_TYPES);
|
||||
}
|
||||
|
||||
|
||||
|
@ -1197,12 +1112,15 @@ void NativeOpExecutioner::execRandom(nd4j::LaunchContext *lc,
|
|||
dim3 launchDims = dim3(512, 512, 32768);
|
||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||
|
||||
auto rng = reinterpret_cast<nd4j::graph::RandomGenerator*>(stateHost);
|
||||
|
||||
// functions::random::RandomFunction<float>::executeCudaSingle(launchDims, extraPointers, opNum, stateHost, dZ, dZShapeInfo, extraArguments),
|
||||
BUILD_SINGLE_SELECTOR(zType, functions::random::RandomFunction, ::executeCudaSingle(launchDims, stream, opNum, stateDevice, dZ, dZShapeInfo, extraArguments), FLOAT_TYPES);
|
||||
|
||||
checkCudaErrors(cudaMemcpyAsync(stateHost, stateDevice, sizeOf, cudaMemcpyDeviceToHost, *stream));
|
||||
checkCudaErrors(cudaStreamSynchronize(*stream));
|
||||
cudaFree(stateDevice);
|
||||
|
||||
rng->rewindH(shape::length(hZShapeInfo));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1224,14 +1142,17 @@ void NativeOpExecutioner::execRandom(nd4j::LaunchContext *lc,
|
|||
checkCudaErrors(cudaStreamSynchronize(*stream));
|
||||
checkCudaErrors(cudaMemcpyAsync(stateDevice, stateHost, sizeOf, cudaMemcpyHostToDevice, *stream));
|
||||
|
||||
auto rng = reinterpret_cast<nd4j::graph::RandomGenerator*>(stateHost);
|
||||
|
||||
dim3 launchDims = dim3(512, 512, 32768);
|
||||
auto xType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||
// functions::random::RandomFunction<float>::executeCudaDouble(launchDims, extraPointers, opNum, stateHost, dX, dXShapeInfo, dZ, dZShapeInfo, extraArguments);
|
||||
BUILD_SINGLE_SELECTOR(xType, functions::random::RandomFunction, ::executeCudaDouble(launchDims, stream, opNum, stateDevice, dX, dXShapeInfo, dZ, dZShapeInfo, extraArguments), FLOAT_TYPES);
|
||||
|
||||
checkCudaErrors(cudaMemcpyAsync(stateHost, stateDevice, sizeOf, cudaMemcpyDeviceToHost, *stream));
|
||||
checkCudaErrors(cudaStreamSynchronize(*stream));
|
||||
cudaFree(stateDevice);
|
||||
|
||||
rng->rewindH(shape::length(hZShapeInfo));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1254,14 +1175,17 @@ void NativeOpExecutioner::execRandom(nd4j::LaunchContext *lc,
|
|||
checkCudaErrors(cudaStreamSynchronize(*stream));
|
||||
checkCudaErrors(cudaMemcpyAsync(stateDevice, stateHost, sizeOf, cudaMemcpyHostToDevice, *stream));
|
||||
|
||||
auto rng = reinterpret_cast<nd4j::graph::RandomGenerator*>(stateHost);
|
||||
|
||||
dim3 launchDims = dim3(512, 512, 32768);
|
||||
auto xType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||
// functions::random::RandomFunction<float>::executeCudaTriple(launchDims, extraPointers, opNum, stateHost, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraArguments);
|
||||
BUILD_SINGLE_SELECTOR(xType, functions::random::RandomFunction, ::executeCudaTriple(launchDims, stream, opNum, stateDevice, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraArguments), FLOAT_TYPES);
|
||||
|
||||
checkCudaErrors(cudaMemcpyAsync(stateHost, stateDevice, sizeOf, cudaMemcpyDeviceToHost, *stream));
|
||||
checkCudaErrors(cudaStreamSynchronize(*stream));
|
||||
cudaFree(stateDevice);
|
||||
|
||||
rng->rewindH(shape::length(hZShapeInfo));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -1958,7 +1958,7 @@ void NativeOps::execRandom(Nd4jPointer *extraPointers,
|
|||
void *extraArguments) {
|
||||
|
||||
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
|
||||
NativeOpExecutioner::execRandom(&lc, opNum, extraPointers, hZ, hZShapeInfo, dZ, dZShapeInfo, extraArguments);
|
||||
NativeOpExecutioner::execRandom(&lc, opNum, stateHost, hZ, hZShapeInfo, dZ, dZShapeInfo, extraArguments);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1970,7 +1970,7 @@ void NativeOps::execRandom(Nd4jPointer *extraPointers, int opNum, Nd4jPointer st
|
|||
void *extraArguments) {
|
||||
|
||||
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
|
||||
NativeOpExecutioner::execRandom(&lc, opNum, extraPointers, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraArguments);
|
||||
NativeOpExecutioner::execRandom(&lc, opNum, stateHost, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraArguments);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1984,7 +1984,7 @@ void NativeOps::execRandom(Nd4jPointer *extraPointers, int opNum, Nd4jPointer st
|
|||
void *extraArguments) {
|
||||
|
||||
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
|
||||
NativeOpExecutioner::execRandom(&lc, opNum, extraPointers, hX, hXShapeInfo, dX, dXShapeInfo, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraArguments);
|
||||
NativeOpExecutioner::execRandom(&lc, opNum, stateHost, hX, hXShapeInfo, dX, dXShapeInfo, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraArguments);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -152,9 +152,9 @@ namespace functions {
|
|||
__syncthreads();
|
||||
|
||||
// using this loop instead of memcpy
|
||||
for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) {
|
||||
for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x)
|
||||
cB[e] = dB[e];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
|
||||
|
@ -213,9 +213,9 @@ namespace functions {
|
|||
__syncthreads();
|
||||
|
||||
// using this loop instead of memcpy
|
||||
for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) {
|
||||
for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x)
|
||||
cB[e] = dB[e];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
|
||||
|
@ -262,9 +262,9 @@ namespace functions {
|
|||
__syncthreads();
|
||||
|
||||
// using this loop instead of memcpy
|
||||
for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) {
|
||||
for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x)
|
||||
cB[e] = dB[e];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
|
|
@ -107,9 +107,6 @@
|
|||
|
||||
|
||||
#define TRANSFORM_STRICT_OPS \
|
||||
(0, SoftMax), \
|
||||
(1, SoftMaxDerivative), \
|
||||
(2, LogSoftMax) ,\
|
||||
(3, ELUDerivative), \
|
||||
(4, TanhDerivative), \
|
||||
(5, HardTanhDerivative), \
|
||||
|
@ -167,9 +164,7 @@
|
|||
|
||||
// these ops return one of FLOAT data types
|
||||
#define TRANSFORM_FLOAT_OPS \
|
||||
(0, Histogram), \
|
||||
(1, Sqrt), \
|
||||
(2, Pooling2D) ,\
|
||||
(3, RSqrt)
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include <op_boilerplate.h>
|
||||
#if NOT_EXCLUDED(OP_histogram)
|
||||
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
#include <ops/declarable/helpers/transforms.h>
|
||||
#include <ops/declarable/helpers/histogram.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
CUSTOM_OP_IMPL(histogram, 1, 1, false, 0, 1) {
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto numBins = INT_ARG(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
REQUIRE_TRUE(numBins == output->lengthOf(), 0, "Histogram: numBins must match output length")
|
||||
|
||||
helpers::histogramHelper(block.launchContext(), *input, *output);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_SHAPE_FN(histogram) {
|
||||
auto numBins = INT_ARG(0);
|
||||
|
||||
return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(numBins, nd4j::DataType::INT64));
|
||||
}
|
||||
|
||||
|
||||
DECLARE_TYPES(histogram) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS})
|
||||
->setAllowedOutputTypes({ALL_INTS});
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
|
@ -206,6 +206,13 @@ namespace nd4j {
|
|||
#if NOT_EXCLUDED(OP_hashcode)
|
||||
DECLARE_REDUCTION_OP(hashcode, 1, 1, false, 0, 0);
|
||||
#endif
|
||||
|
||||
/**
|
||||
* This operation calculates number of entries per bin
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_histogram)
|
||||
DECLARE_CUSTOM_OP(histogram, 1, 1, false, 0, 1);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,83 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include <ops/declarable/helpers/histogram.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
template <typename X, typename Z>
|
||||
static void histogram_(void *xBuffer, Nd4jLong *xShapeInfo, void *zBuffer, Nd4jLong *zShapeInfo, Nd4jLong numBins, double min_val, double max_val) {
|
||||
auto dx = reinterpret_cast<X*>(xBuffer);
|
||||
auto result = reinterpret_cast<Z*>(zBuffer);
|
||||
|
||||
int length = shape::length(xShapeInfo);
|
||||
// FIXME: 2???
|
||||
int _threads = 2;
|
||||
|
||||
int span = (length / _threads) + 8;
|
||||
|
||||
X binSize = (max_val - min_val) / (numBins);
|
||||
|
||||
PRAGMA_OMP_PARALLEL_THREADS(_threads)
|
||||
{
|
||||
int tid, start, end;
|
||||
|
||||
int *bins = new int[numBins];
|
||||
std::memset(bins, 0, sizeof(int) * numBins);
|
||||
tid = omp_get_thread_num();
|
||||
start = span * tid;
|
||||
end = span * (tid + 1);
|
||||
if (end > length) end = length;
|
||||
|
||||
PRAGMA_OMP_SIMD
|
||||
for (int x = start; x < end; x++) {
|
||||
int idx = (int) ((dx[x] - min_val) / binSize);
|
||||
if (idx < 0)
|
||||
idx = 0;
|
||||
else if (idx >= numBins)
|
||||
idx = numBins - 1;
|
||||
|
||||
bins[idx]++;
|
||||
}
|
||||
|
||||
PRAGMA_OMP_CRITICAL
|
||||
{
|
||||
PRAGMA_OMP_SIMD
|
||||
for (int x = 0; x < numBins; x++) {
|
||||
result[x] += bins[x];
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
delete[] bins;
|
||||
}
|
||||
}
|
||||
|
||||
void histogramHelper(nd4j::LaunchContext *context, NDArray &input, NDArray &output) {
|
||||
Nd4jLong numBins = output.lengthOf();
|
||||
double min_val = input.reduceNumber(reduce::SameOps::Min).e<double>(0);
|
||||
double max_val = input.reduceNumber(reduce::SameOps::Max).e<double>(0);
|
||||
|
||||
BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), histogram_, (input.buffer(), input.shapeInfo(), output.getBuffer(), output.getShapeInfo(), numBins, min_val, max_val), LIBND4J_TYPES, INTEGER_TYPES);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,134 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include <ops/declarable/helpers/histogram.h>
|
||||
#include <NDArrayFactory.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
template <typename X, typename Z>
|
||||
void _CUDA_G histogramKernel(void *xBuffer, Nd4jLong *xShapeInfo, void *zBuffer, Nd4jLong *zShapeInfo, void *allocationPointer, void *reductionPointer, Nd4jLong numBins, double min_val, double max_val) {
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
auto dx = reinterpret_cast<X*>(xBuffer);
|
||||
auto result = reinterpret_cast<Z*>(zBuffer);
|
||||
|
||||
__shared__ Z *bins;
|
||||
__shared__ int length;
|
||||
__shared__ Z *reductor;
|
||||
if (threadIdx.x == 0) {
|
||||
extern __shared__ unsigned char shmem[];
|
||||
bins = (Z *) shmem;
|
||||
reductor = ((Z *) allocationPointer) + (numBins * blockIdx.x);
|
||||
|
||||
length = shape::length(xShapeInfo);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
Z binSize = (max_val - min_val) / (numBins);
|
||||
|
||||
for (int e = threadIdx.x; e < numBins; e += blockDim.x) {
|
||||
bins[e] = (Z) 0.0f;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int e = tid; e < length; e+= blockDim.x * gridDim.x) {
|
||||
int idx = (int) ((dx[e] - min_val) / binSize);
|
||||
if (idx < 0) idx = 0;
|
||||
else if (idx >= numBins) idx = numBins - 1;
|
||||
|
||||
nd4j::math::atomics::nd4j_atomicAdd(&bins[idx], (Z) 1.0f);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// transfer shared memory to reduction memory
|
||||
|
||||
|
||||
if (gridDim.x > 1) {
|
||||
unsigned int *tc = (unsigned int *)reductionPointer;
|
||||
__shared__ bool amLast;
|
||||
|
||||
for (int e = threadIdx.x; e < numBins; e += blockDim.x) {
|
||||
reductor[e] = bins[e];
|
||||
}
|
||||
__threadfence();
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
unsigned int ticket = atomicInc(&tc[16384], gridDim.x);
|
||||
amLast = (ticket == gridDim.x - 1);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (amLast) {
|
||||
tc[16384] = 0;
|
||||
|
||||
// nullify shared memory for future accumulation
|
||||
for (int e = threadIdx.x; e < numBins; e += blockDim.x) {
|
||||
bins[e] = (Z) 0.0f;
|
||||
}
|
||||
|
||||
// accumulate reduced bins
|
||||
for (int r = 0; r < gridDim.x; r++) {
|
||||
Z *ptrBuf = ((Z *)allocationPointer) + (r * numBins);
|
||||
|
||||
for (int e = threadIdx.x; e < numBins; e += blockDim.x) {
|
||||
bins[e] += ptrBuf[e];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// write them out to Z
|
||||
for (int e = threadIdx.x; e < numBins; e += blockDim.x) {
|
||||
result[e] = bins[e];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// if there's only 1 block - just write away data
|
||||
for (int e = threadIdx.x; e < numBins; e += blockDim.x) {
|
||||
result[e] = bins[e];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename X, typename Z>
|
||||
static void histogram_(nd4j::LaunchContext *context, void *xBuffer, Nd4jLong *xShapeInfo, void *zBuffer, Nd4jLong *zShapeInfo, Nd4jLong numBins, double min_val, double max_val) {
|
||||
int numThreads = 256;
|
||||
int numBlocks = nd4j::math::nd4j_max<int>(256, nd4j::math::nd4j_min<int>(1, shape::length(xShapeInfo) / numThreads));
|
||||
int workspaceSize = numBlocks * numBins;
|
||||
auto tmp = NDArrayFactory::create<Z>('c',{workspaceSize});
|
||||
|
||||
histogramKernel<X, Z><<<numBlocks, numThreads, 32768, *context->getCudaStream()>>>(xBuffer, xShapeInfo, zBuffer, zShapeInfo, tmp.getSpecialBuffer(), context->getReductionPointer(), numBins, min_val, max_val);
|
||||
|
||||
cudaStreamSynchronize(*context->getCudaStream());
|
||||
}
|
||||
|
||||
void histogramHelper(nd4j::LaunchContext *context, NDArray &input, NDArray &output) {
|
||||
Nd4jLong numBins = output.lengthOf();
|
||||
double min_val = input.reduceNumber(reduce::SameOps::Min).e<double>(0);
|
||||
double max_val = input.reduceNumber(reduce::SameOps::Max).e<double>(0);
|
||||
|
||||
BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), histogram_, (context, input.specialBuffer(), input.specialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), numBins, min_val, max_val), LIBND4J_TYPES, INTEGER_TYPES);
|
||||
|
||||
NDArray::registerSpecialUse({&output}, {&input});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,34 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#ifndef LIBND4J_HISTOGRAM_H
|
||||
#define LIBND4J_HISTOGRAM_H
|
||||
|
||||
#include <NDArray.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
void histogramHelper(nd4j::LaunchContext *context, NDArray &input, NDArray &output);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif //DEV_TESTS_HISTOGRAM_H
|
|
@ -747,88 +747,7 @@ static void execSpecial(T *in, Nd4jLong *inShapeBuffer, Z *out, Nd4jLong *outSha
|
|||
int *allocationPointer, Z *reductionPointer,
|
||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||||
|
||||
int numBins = (int) extraParams[0];
|
||||
Z min_val = extraParams[1];
|
||||
Z max_val = extraParams[2];
|
||||
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
__shared__ Z *bins;
|
||||
__shared__ int length;
|
||||
__shared__ Z *reductor;
|
||||
if (threadIdx.x == 0) {
|
||||
extern __shared__ unsigned char shmem[];
|
||||
bins = (Z *) shmem;
|
||||
reductor = ((Z *) allocationPointer) + (numBins * blockIdx.x);
|
||||
|
||||
length = shape::length(xShapeBuffer);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
Z binSize = (max_val - min_val) / (numBins);
|
||||
|
||||
for (int e = threadIdx.x; e < numBins; e += blockDim.x) {
|
||||
bins[e] = (Z) 0.0f;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int e = tid; e < length; e+= blockDim.x * gridDim.x) {
|
||||
int idx = (int) ((dx[e] - min_val) / binSize);
|
||||
if (idx < 0) idx = 0;
|
||||
else if (idx >= numBins) idx = numBins - 1;
|
||||
|
||||
nd4j::math::atomics::nd4j_atomicAdd(&bins[idx], (Z) 1.0f);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// transfer shared memory to reduction memory
|
||||
|
||||
|
||||
if (gridDim.x > 1) {
|
||||
unsigned int *tc = (unsigned int *)reductionPointer;
|
||||
__shared__ bool amLast;
|
||||
|
||||
for (int e = threadIdx.x; e < numBins; e += blockDim.x) {
|
||||
reductor[e] = bins[e];
|
||||
}
|
||||
__threadfence();
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
unsigned int ticket = atomicInc(&tc[16384], gridDim.x);
|
||||
amLast = (ticket == gridDim.x - 1);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (amLast) {
|
||||
tc[16384] = 0;
|
||||
|
||||
// nullify shared memory for future accumulation
|
||||
for (int e = threadIdx.x; e < numBins; e += blockDim.x) {
|
||||
bins[e] = (Z) 0.0f;
|
||||
}
|
||||
|
||||
// accumulate reduced bins
|
||||
for (int r = 0; r < gridDim.x; r++) {
|
||||
Z *ptrBuf = ((Z *)allocationPointer) + (r * numBins);
|
||||
|
||||
for (int e = threadIdx.x; e < numBins; e += blockDim.x) {
|
||||
bins[e] += ptrBuf[e];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// write them out to Z
|
||||
for (int e = threadIdx.x; e < numBins; e += blockDim.x) {
|
||||
result[e] = bins[e];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// if there's only 1 block - just write away data
|
||||
for (int e = threadIdx.x; e < numBins; e += blockDim.x) {
|
||||
result[e] = bins[e];
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
#endif
|
||||
|
@ -840,52 +759,7 @@ static void execSpecial(T *in, Nd4jLong *inShapeBuffer, Z *out, Nd4jLong *outSha
|
|||
Nd4jLong *zShapeBuffer,
|
||||
Z *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||||
|
||||
int length = shape::length(xShapeBuffer);
|
||||
int _threads = 2;
|
||||
|
||||
int numBins = (int) extraParams[0];
|
||||
int span = (length / _threads) + 8;
|
||||
|
||||
// get min over input
|
||||
T min_val = extraParams[1];
|
||||
T max_val = extraParams[2];
|
||||
|
||||
T binSize = (max_val - min_val) / (numBins);
|
||||
|
||||
|
||||
PRAGMA_OMP_PARALLEL_THREADS(_threads)
|
||||
{
|
||||
int tid, start, end;
|
||||
|
||||
int *bins = new int[numBins];
|
||||
std::memset(bins, 0, sizeof(int) * numBins);
|
||||
tid = omp_get_thread_num();
|
||||
start = span * tid;
|
||||
end = span * (tid + 1);
|
||||
if (end > length) end = length;
|
||||
|
||||
PRAGMA_OMP_SIMD
|
||||
for (int x = start; x < end; x++) {
|
||||
int idx = (int) ((dx[x] - min_val) / binSize);
|
||||
if (idx < 0)
|
||||
idx = 0;
|
||||
else if (idx >= numBins)
|
||||
idx = numBins - 1;
|
||||
|
||||
bins[idx]++;
|
||||
}
|
||||
|
||||
PRAGMA_OMP_CRITICAL
|
||||
{
|
||||
PRAGMA_OMP_SIMD
|
||||
for (int x = 0; x < numBins; x++) {
|
||||
result[x] += bins[x];
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
delete[] bins;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -88,10 +88,10 @@ namespace randomOps {
|
|||
__syncthreads();
|
||||
|
||||
// using this loop instead of memcpy
|
||||
for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) {
|
||||
for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x)
|
||||
cB[e] = dB[e];
|
||||
}
|
||||
// __syncthreads(); // Eliminated due RTX20xx specific
|
||||
|
||||
__syncthreads();
|
||||
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
|
@ -137,10 +137,6 @@ namespace randomOps {
|
|||
// __syncthreads(); // Eliminated due RTX20xx specific
|
||||
}
|
||||
}
|
||||
|
||||
// __syncthreads(); // Eliminated due RTX20xx specific
|
||||
if (threadIdx.x == 0 && blockIdx.x == 0)
|
||||
devRng->rewindH(zLength);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
@ -208,9 +204,6 @@ namespace randomOps {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// update rng state
|
||||
rng->rewindH(zLength);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -277,7 +270,7 @@ namespace randomOps {
|
|||
for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x)
|
||||
cB[e] = dB[e];
|
||||
|
||||
// __syncthreads(); // Eliminated due RTX20xx specific
|
||||
__syncthreads();
|
||||
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
|
@ -299,10 +292,6 @@ namespace randomOps {
|
|||
z[epm * zEWS] = (nd4j::math::nd4j_sqrt<T,T>(t * nd4j::math::nd4j_log<T,T>(r0)) * nd4j::math::nd4j_sin<T,T>(two_pi * r1)) * stddev + realMean1;
|
||||
}
|
||||
}
|
||||
|
||||
// __syncthreads(); // Eliminated due RTX20xx specific
|
||||
if (threadIdx.x == 0 && blockIdx.x == 0)
|
||||
devRng->rewindH(zLength);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
@ -352,9 +341,6 @@ namespace randomOps {
|
|||
z[epm * zEWS] = z1;
|
||||
}
|
||||
}
|
||||
|
||||
// update rng state
|
||||
rng->rewindH(zLength);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -401,10 +387,10 @@ namespace randomOps {
|
|||
__syncthreads();
|
||||
|
||||
// using this loop instead of memcpy
|
||||
for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) {
|
||||
for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x)
|
||||
cB[e] = dB[e];
|
||||
}
|
||||
// __syncthreads(); // Eliminated due RTX20xx specific
|
||||
|
||||
__syncthreads();
|
||||
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
|
@ -423,10 +409,6 @@ namespace randomOps {
|
|||
// if trials is set to 0, effectively we just have successful memset
|
||||
z[e * zEWS] = static_cast<T>(success);
|
||||
}
|
||||
|
||||
// __syncthreads(); // Eliminated due RTX20xx specific
|
||||
if (trials > 0 && threadIdx.x == 0 && blockIdx.x == 0)
|
||||
devRng->rewindH(zLength * trials);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
@ -472,10 +454,6 @@ namespace randomOps {
|
|||
z[e * zEWS] = static_cast<T>(success);
|
||||
}
|
||||
}
|
||||
|
||||
// update rng state
|
||||
if (trials > 0)
|
||||
rng->rewindH(zLength * trials);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -522,10 +500,10 @@ namespace randomOps {
|
|||
__syncthreads();
|
||||
|
||||
// using this loop instead of memcpy
|
||||
for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) {
|
||||
for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x)
|
||||
cB[e] = dB[e];
|
||||
}
|
||||
// __syncthreads(); // Eliminated due RTX20xx specific
|
||||
|
||||
__syncthreads();
|
||||
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
|
@ -591,10 +569,6 @@ namespace randomOps {
|
|||
z[e * zEWS] = static_cast<T>(success);
|
||||
}
|
||||
}
|
||||
|
||||
// update rng state
|
||||
if (trials > 0)
|
||||
rng->rewindH(zLength * trials);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -676,10 +650,10 @@ namespace randomOps {
|
|||
__syncthreads();
|
||||
|
||||
// using this loop instead of memcpy
|
||||
for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) {
|
||||
for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x)
|
||||
cB[e] = dB[e];
|
||||
}
|
||||
// __syncthreads(); // Eliminated due RTX20xx specific
|
||||
|
||||
__syncthreads();
|
||||
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
|
@ -724,9 +698,6 @@ namespace randomOps {
|
|||
z[e] = mean + nd4j::DataTypeUtils::min<T>();
|
||||
}
|
||||
}
|
||||
|
||||
// update rng state
|
||||
rng->rewindH(zLength);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -788,10 +759,10 @@ namespace randomOps {
|
|||
__syncthreads();
|
||||
|
||||
// using this loop instead of memcpy
|
||||
for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) {
|
||||
for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x)
|
||||
cB[e] = dB[e];
|
||||
}
|
||||
// __syncthreads(); // Eliminated due RTX20xx specific
|
||||
|
||||
__syncthreads();
|
||||
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
|
@ -868,9 +839,6 @@ namespace randomOps {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// update rng state
|
||||
rng->rewindH(zLength);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -1724,9 +1724,9 @@ namespace nd4j {
|
|||
}
|
||||
};
|
||||
|
||||
TransformBenchmark tbSoftmax(transform::StrictOps::SoftMax, "softmax");
|
||||
//TransformBenchmark tbSoftmax(transform::StrictOps::SoftMax, "softmax");
|
||||
|
||||
output += helper.runOperationSuit(&tbSoftmax, generator2, batch2, "Softmax");
|
||||
//output += helper.runOperationSuit(&tbSoftmax, generator2, batch2, "Softmax");
|
||||
|
||||
return output;
|
||||
}
|
||||
|
|
|
@ -60,11 +60,11 @@ namespace nd4j {
|
|||
sbRelu.setY(NDArrayFactory::create_<T>(0.0));
|
||||
|
||||
TransformBenchmark tbSigmoid(transform::StrictOps::Sigmoid, "sigmoid");
|
||||
TransformBenchmark tbSoftmax(transform::StrictOps::SoftMax, "softmax");
|
||||
//TransformBenchmark tbSoftmax(transform::StrictOps::SoftMax, "softmax");
|
||||
|
||||
output += helper.runOperationSuit(&sbRelu, generator, batch, "RELU");
|
||||
output += helper.runOperationSuit(&tbSigmoid, generator, batch, "Sigmoid");
|
||||
output += helper.runOperationSuit(&tbSigmoid, generator, batch, "Softmax");
|
||||
//output += helper.runOperationSuit(&tbSigmoid, generator, batch, "Softmax");
|
||||
|
||||
return output;
|
||||
}
|
||||
|
|
|
@ -49,7 +49,6 @@ import org.nd4j.linalg.api.ops.impl.summarystats.Variance;
|
|||
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.comparison.*;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.*;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.floating.Histogram;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.*;
|
||||
|
@ -748,8 +747,6 @@ public class LegacyOpMapper {
|
|||
|
||||
public static Class<?> transformFloatingOpClass(int opNum){
|
||||
switch (opNum){
|
||||
case 0:
|
||||
return Histogram.class;
|
||||
case 1:
|
||||
return Sqrt.class;
|
||||
case 3:
|
||||
|
|
|
@ -70,10 +70,10 @@ import org.nd4j.linalg.api.ops.impl.shape.bp.SliceBp;
|
|||
import org.nd4j.linalg.api.ops.impl.shape.bp.StridedSliceBp;
|
||||
import org.nd4j.linalg.api.ops.impl.shape.bp.TileBp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.Assert;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.Histogram;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.bool.BooleanNot;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.*;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.floating.Histogram;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.BinaryMinimalRelativeError;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.*;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.*;
|
||||
|
@ -888,7 +888,6 @@ public class OpValidation {
|
|||
SELUDerivative.class,
|
||||
SigmoidDerivative.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.strict.SigmoidDerivative.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.strict.SoftMaxDerivative.class,
|
||||
SoftSignDerivative.class,
|
||||
TanhDerivative.class,
|
||||
SwishDerivative.class,
|
||||
|
@ -987,7 +986,6 @@ public class OpValidation {
|
|||
OneHot.class,
|
||||
BinaryMinimalRelativeError.class,
|
||||
BinaryMinimalRelativeError.class,
|
||||
Histogram.class,
|
||||
InvertPermutation.class, //Uses integer indices
|
||||
ConfusionMatrix.class, //Integer indices
|
||||
Linspace.class, //No input array
|
||||
|
@ -1056,7 +1054,9 @@ public class OpValidation {
|
|||
LogicalAnd.class,
|
||||
LogicalNot.class,
|
||||
LogicalOr.class,
|
||||
LogicalXor.class
|
||||
LogicalXor.class,
|
||||
|
||||
Histogram.class
|
||||
);
|
||||
|
||||
return new HashSet<>(list);
|
||||
|
|
|
@ -424,7 +424,6 @@ public class ImportClassMapping {
|
|||
org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentProd.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentSum.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.dtype.Cast.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.floating.Histogram.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative.class,
|
||||
|
@ -552,7 +551,6 @@ public class ImportClassMapping {
|
|||
org.nd4j.linalg.api.ops.impl.transforms.strict.SigmoidDerivative.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.strict.Sin.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.strict.Sinh.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.strict.SoftMaxDerivative.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.strict.SoftPlus.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.strict.SoftSign.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.strict.Stabilize.class,
|
||||
|
|
|
@ -19,20 +19,6 @@ package org.nd4j.linalg.activations;
|
|||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.activations.impl.*;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.Op;
|
||||
import org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU;
|
||||
import org.nd4j.linalg.api.ops.impl.scalar.ScalarSet;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.same.OldIdentity;
|
||||
import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear;
|
||||
import org.nd4j.linalg.api.ops.impl.scalar.Step;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.*;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.*;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.same.Cube;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.SigmoidDerivative;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftMaxDerivative;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.TanhDerivative;
|
||||
|
||||
/**
|
||||
* This enum is the factory for the activation function.
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms;
|
||||
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
/**
|
||||
* Histogram op wrapper
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
public class Histogram extends DynamicCustomOp {
|
||||
private long numBins;
|
||||
|
||||
public Histogram() {
|
||||
|
||||
}
|
||||
|
||||
public Histogram(INDArray input, INDArray output) {
|
||||
Preconditions.checkArgument(output.isZ(), "Histogram op output should have integer data type");
|
||||
|
||||
numBins = output.length();
|
||||
inputArguments.add(input);
|
||||
outputArguments.add(output);
|
||||
iArguments.add(numBins);
|
||||
}
|
||||
|
||||
public Histogram(INDArray input, long numBins) {
|
||||
this.numBins = numBins;
|
||||
inputArguments.add(input);
|
||||
iArguments.add(numBins);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "histogram";
|
||||
}
|
||||
}
|
|
@ -1,114 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.floating;
|
||||
|
||||
import lombok.val;
|
||||
import onnx.OnnxProto3;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.BaseTransformFloatOp;
|
||||
import org.nd4j.linalg.api.ops.BaseTransformOp;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
import org.tensorflow.framework.NodeDef;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
public class Histogram extends BaseTransformFloatOp {
|
||||
public Histogram(SameDiff sameDiff, SDVariable i_v, boolean inPlace, int numBins) {
|
||||
super(sameDiff, i_v, inPlace);
|
||||
this.numBins = numBins;
|
||||
}
|
||||
|
||||
public Histogram(SameDiff sameDiff, SDVariable i_v, int[] shape, boolean inPlace, Object[] extraArgs, int numBins) {
|
||||
super(sameDiff, i_v, shape, inPlace, extraArgs);
|
||||
this.numBins = numBins;
|
||||
}
|
||||
|
||||
public Histogram(SameDiff sameDiff, SDVariable i_v, Object[] extraArgs, int numBins) {
|
||||
super(sameDiff, i_v, extraArgs);
|
||||
this.numBins = numBins;
|
||||
}
|
||||
|
||||
private int numBins = 0;
|
||||
|
||||
public Histogram() {
|
||||
//no-op
|
||||
}
|
||||
|
||||
public Histogram(INDArray x, INDArray z) {
|
||||
setX(x);
|
||||
setZ(z);
|
||||
|
||||
//FIXME: int cast
|
||||
numBins = (int) z.length();
|
||||
|
||||
double max = x.maxNumber().doubleValue();
|
||||
double min = x.minNumber().doubleValue();
|
||||
|
||||
this.extraArgs = new Object[] {(double) numBins, min, max};
|
||||
}
|
||||
|
||||
public Histogram(INDArray x, int numberOfBins) {
|
||||
this(x, Nd4j.create(x.dataType(), numberOfBins));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, Object> propertiesForFunction() {
|
||||
Map<String,Object> ret = new LinkedHashMap<>();
|
||||
ret.put("numBins",numBins);
|
||||
return ret;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int opNum() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "histogram";
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public String onnxName() {
|
||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||
throw new UnsupportedOperationException("Not supported");
|
||||
}
|
||||
}
|
|
@ -1,76 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.strict;
|
||||
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.CustomOp;
|
||||
import org.nd4j.linalg.api.ops.Op;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
||||
|
||||
import java.nio.Buffer;
|
||||
|
||||
/**
|
||||
* Softmax derivative
|
||||
*
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
public class SoftMaxDerivative extends SoftMax {
|
||||
public SoftMaxDerivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
|
||||
super(sameDiff, new SDVariable[]{i_v1, i_v2});
|
||||
}
|
||||
|
||||
public SoftMaxDerivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) {
|
||||
super(sameDiff, new SDVariable[]{ i_v1, i_v2}, inPlace);
|
||||
}
|
||||
|
||||
public SoftMaxDerivative(INDArray x, INDArray z) {
|
||||
super(x, z);
|
||||
}
|
||||
|
||||
public SoftMaxDerivative(INDArray x) {
|
||||
super(x, x);
|
||||
}
|
||||
|
||||
public SoftMaxDerivative() {}
|
||||
|
||||
|
||||
|
||||
@Override
|
||||
public int opNum() {
|
||||
return 1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String onnxName() {
|
||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "_softmaxderivative";
|
||||
}
|
||||
}
|
|
@ -1529,85 +1529,40 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
|
|||
Nd4j.getExecutioner().commit();
|
||||
|
||||
|
||||
AllocationPoint old = allocationPoint;
|
||||
allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, new AllocationShape(length, elementSize, dataType()), false);
|
||||
AllocationPoint old = allocationPoint;
|
||||
allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, new AllocationShape(length, elementSize, dataType()), false);
|
||||
|
||||
Nd4j.getDeallocatorService().pickObject(this);
|
||||
trackingPoint = allocationPoint.getObjectId();
|
||||
Nd4j.getDeallocatorService().pickObject(this);
|
||||
trackingPoint = allocationPoint.getObjectId();
|
||||
val oldLength = this.length;
|
||||
this.length = length;
|
||||
|
||||
switch(dataType()){
|
||||
case DOUBLE:
|
||||
this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asDoublePointer();
|
||||
indexer = DoubleIndexer.create((DoublePointer) pointer);
|
||||
break;
|
||||
case FLOAT:
|
||||
this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asFloatPointer();
|
||||
indexer = FloatIndexer.create((FloatPointer) pointer);
|
||||
break;
|
||||
case BFLOAT16:
|
||||
this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asShortPointer();
|
||||
indexer = Bfloat16Indexer.create((ShortPointer) pointer);
|
||||
break;
|
||||
case HALF:
|
||||
this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asShortPointer();
|
||||
indexer = ShortIndexer.create((ShortPointer) pointer);
|
||||
break;
|
||||
case LONG:
|
||||
this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asLongPointer();
|
||||
indexer = LongIndexer.create((LongPointer) pointer);
|
||||
break;
|
||||
case UINT64:
|
||||
this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asLongPointer();
|
||||
indexer = LongIndexer.create((LongPointer) pointer);
|
||||
break;
|
||||
case INT:
|
||||
this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asIntPointer();
|
||||
indexer = IntIndexer.create((IntPointer) pointer);
|
||||
break;
|
||||
case UINT32:
|
||||
this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asIntPointer();
|
||||
indexer = IntIndexer.create((IntPointer) pointer);
|
||||
break;
|
||||
case SHORT:
|
||||
this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asShortPointer();
|
||||
indexer = ShortIndexer.create((ShortPointer) pointer);
|
||||
break;
|
||||
case UINT16:
|
||||
this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asShortPointer();
|
||||
indexer = UShortIndexer.create((ShortPointer) pointer);
|
||||
break;
|
||||
case BYTE:
|
||||
this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asBytePointer();
|
||||
indexer = ByteIndexer.create((BytePointer) pointer);
|
||||
break;
|
||||
default:
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
// if original buffer had host pointer allocated, we'll reallocate host buffer as well
|
||||
if (old.getHostPointer() != null) {
|
||||
lazyAllocateHostPointer();
|
||||
}
|
||||
|
||||
CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
|
||||
NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(allocationPoint.getDevicePointer(), 0, length * elementSize, 0, context.getSpecialStream());
|
||||
val context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
|
||||
NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(allocationPoint.getDevicePointer(), 0, length * elementSize, 0, context.getSpecialStream());
|
||||
|
||||
MemcpyDirection direction = MemcpyDirection.DEVICE_TO_DEVICE;
|
||||
val perfD = PerformanceTracker.getInstance().helperStartTransaction();
|
||||
MemcpyDirection direction = MemcpyDirection.DEVICE_TO_DEVICE;
|
||||
val perfD = PerformanceTracker.getInstance().helperStartTransaction();
|
||||
|
||||
if (old.isActualOnDeviceSide()) {
|
||||
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), old.getDevicePointer(), this.length * elementSize, CudaConstants.cudaMemcpyDeviceToDevice, context.getSpecialStream());
|
||||
} else if (old.isActualOnHostSide()) {
|
||||
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), old.getHostPointer(), this.length * elementSize, CudaConstants.cudaMemcpyHostToDevice, context.getSpecialStream());
|
||||
direction = MemcpyDirection.HOST_TO_DEVICE;
|
||||
}
|
||||
if (old.isActualOnDeviceSide()) {
|
||||
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), old.getDevicePointer(), oldLength * elementSize, CudaConstants.cudaMemcpyDeviceToDevice, context.getSpecialStream());
|
||||
} else if (old.isActualOnHostSide()) {
|
||||
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), old.getHostPointer(), oldLength * elementSize, CudaConstants.cudaMemcpyHostToDevice, context.getSpecialStream());
|
||||
direction = MemcpyDirection.HOST_TO_DEVICE;
|
||||
}
|
||||
|
||||
context.getSpecialStream().synchronize();
|
||||
context.getSpecialStream().synchronize();
|
||||
|
||||
PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint.getDeviceId(), perfD, allocationPoint.getNumberOfBytes(), direction);
|
||||
PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint.getDeviceId(), perfD, allocationPoint.getNumberOfBytes(), direction);
|
||||
|
||||
allocationPoint.tickDeviceWrite();
|
||||
// we're keeping pointer reference for JVM
|
||||
pointer.address();
|
||||
allocationPoint.tickDeviceWrite();
|
||||
|
||||
|
||||
// we need to update length with new value now
|
||||
//this.length = length;
|
||||
// we need to update length with new value now
|
||||
//this.length = length;
|
||||
if(isAttached()){
|
||||
// do nothing here, that's workspaces
|
||||
} else{
|
||||
|
@ -1619,7 +1574,10 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
|
|||
|
||||
@Override
|
||||
public long capacity() {
|
||||
return pointer.capacity();
|
||||
if (allocationPoint.getHostPointer() != null)
|
||||
return pointer.capacity();
|
||||
else
|
||||
return length;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -16755,6 +16755,27 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
}
|
||||
// #endif
|
||||
|
||||
/**
|
||||
* This operation calculates number of entries per bin
|
||||
*/
|
||||
// #if NOT_EXCLUDED(OP_histogram)
|
||||
@Namespace("nd4j::ops") public static class histogram extends DeclarableCustomOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public histogram(Pointer p) { super(p); }
|
||||
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||
public histogram(long size) { super((Pointer)null); allocateArray(size); }
|
||||
private native void allocateArray(long size);
|
||||
@Override public histogram position(long position) {
|
||||
return (histogram)super.position(position);
|
||||
}
|
||||
|
||||
public histogram() { super((Pointer)null); allocate(); }
|
||||
private native void allocate();
|
||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||
}
|
||||
// #endif
|
||||
|
||||
|
||||
|
||||
// #endif
|
||||
|
|
|
@ -79,7 +79,6 @@ import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.Axpy;
|
|||
import org.nd4j.linalg.api.ops.impl.transforms.same.OldReverse;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.same.Sign;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftMaxDerivative;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh;
|
||||
import org.nd4j.linalg.api.shape.Shape;
|
||||
import org.nd4j.linalg.checkutil.NDArrayCreationUtil;
|
||||
|
@ -931,13 +930,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
|||
System.out.println("Second: OK");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSoftmaxDerivative() {
|
||||
INDArray input = Nd4j.create(new double[] {-1.07, -0.01, 0.45, 0.95, 0.45, 0.16, 0.20, 0.80, 0.89, 0.25}).reshape(1, -1).transpose();
|
||||
INDArray output = Nd4j.create(10, 1);
|
||||
Nd4j.getExecutioner().exec(new SoftMaxDerivative(input, output));
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testVStackDifferentOrders() {
|
||||
|
@ -7245,19 +7237,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testInvalidTransformsSoftmax(){
|
||||
INDArray arr = Nd4j.zeros(2,3,4);
|
||||
try{
|
||||
Transforms.softmax(arr);
|
||||
fail("Expected exception");
|
||||
} catch (IllegalArgumentException e){
|
||||
//OK
|
||||
assertTrue(e.getMessage().contains("rank 2"));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEmptyCasting(){
|
||||
for(val from : DataType.values()) {
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.nd4j.linalg.api.buffer;
|
||||
|
||||
|
||||
import lombok.val;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||
|
@ -28,6 +29,7 @@ import org.nd4j.linalg.factory.Nd4j;
|
|||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
|
||||
import java.io.*;
|
||||
import java.util.Arrays;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
@ -85,10 +87,12 @@ public class IntDataBufferTests extends BaseNd4jTest {
|
|||
public void testReallocation() {
|
||||
DataBuffer buffer = Nd4j.createBuffer(new int[] {1, 2, 3, 4});
|
||||
assertEquals(4, buffer.capacity());
|
||||
int[] old = buffer.asInt();
|
||||
buffer.reallocate(6);
|
||||
val old = buffer.asInt();
|
||||
assertEquals(6, buffer.capacity());
|
||||
assertArrayEquals(old, buffer.asInt());
|
||||
val newContent = buffer.asInt();
|
||||
assertEquals(6, newContent.length);
|
||||
assertArrayEquals(old, Arrays.copyOf(newContent, old.length));
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -98,12 +102,14 @@ public class IntDataBufferTests extends BaseNd4jTest {
|
|||
MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "SOME_ID");
|
||||
|
||||
DataBuffer buffer = Nd4j.createBuffer(new int[] {1, 2, 3, 4});
|
||||
int[] old = buffer.asInt();
|
||||
val old = buffer.asInt();
|
||||
assertTrue(buffer.isAttached());
|
||||
assertEquals(4, buffer.capacity());
|
||||
buffer.reallocate(6);
|
||||
assertEquals(6, buffer.capacity());
|
||||
assertArrayEquals(old, buffer.asInt());
|
||||
val newContent = buffer.asInt();
|
||||
assertEquals(6, newContent.length);
|
||||
assertArrayEquals(old, Arrays.copyOf(newContent, old.length));
|
||||
workspace.close();
|
||||
}
|
||||
|
||||
|
|
|
@ -82,12 +82,12 @@ public class RngTests extends BaseNd4jTest {
|
|||
INDArray narr = Nd4j.randn('c', rows, cols);
|
||||
assertArrayEquals(new long[] {rows, cols}, narr.shape());
|
||||
assertEquals('c', narr.ordering());
|
||||
assertEquals(narr.meanNumber().doubleValue(), 0.0, 0.05);
|
||||
assertEquals(0.0, narr.meanNumber().doubleValue(), 0.05);
|
||||
|
||||
INDArray narr2 = Nd4j.randn('f', rows, cols);
|
||||
assertArrayEquals(new long[] {rows, cols}, narr2.shape());
|
||||
assertEquals('f', narr2.ordering());
|
||||
assertEquals(narr2.meanNumber().doubleValue(), 0.0, 0.05);
|
||||
assertEquals(0.0, narr2.meanNumber().doubleValue(), 0.05);
|
||||
|
||||
INDArray narr3 = Nd4j.randn('c', new int[] {rows, cols, dim2});
|
||||
assertArrayEquals(new long[] {rows, cols, dim2}, narr3.shape());
|
||||
|
@ -97,7 +97,7 @@ public class RngTests extends BaseNd4jTest {
|
|||
INDArray narr4 = Nd4j.randn('f', new int[] {rows, cols, dim2});
|
||||
assertArrayEquals(new long[] {rows, cols, dim2}, narr4.shape());
|
||||
assertEquals('f', narr4.ordering());
|
||||
assertEquals(narr4.meanNumber().doubleValue(), 0.0, 0.05);
|
||||
assertEquals(0.0, narr4.meanNumber().doubleValue(), 0.05);
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -29,7 +29,6 @@ import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance;
|
|||
import org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftMaxDerivative;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
import org.nd4j.linalg.indexing.BooleanIndexing;
|
||||
|
@ -159,7 +158,6 @@ public class CrashTest extends BaseNd4jTest {
|
|||
|
||||
// logisoftmax, softmax & softmax derivative
|
||||
Nd4j.getExecutioner().exec((CustomOp) new SoftMax(x));
|
||||
Nd4j.getExecutioner().exec((CustomOp) new SoftMaxDerivative(x));
|
||||
Nd4j.getExecutioner().exec((CustomOp) new LogSoftMax(x));
|
||||
|
||||
|
||||
|
|
|
@ -441,8 +441,8 @@ public class BooleanIndexingTest extends BaseNd4jTest {
|
|||
Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.ANY_PANIC);
|
||||
NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(true);
|
||||
INDArray arr = Nd4j.linspace(1,4,4, Nd4j.dataType()).reshape(2,2);
|
||||
INDArray filtered = BooleanIndexing.chooseFrom(new INDArray[]{arr},Arrays.asList(2.0), Collections.emptyList(),new GreaterThan());
|
||||
assertEquals(4,filtered.length());
|
||||
INDArray filtered = BooleanIndexing.chooseFrom(new INDArray[]{arr}, Arrays.asList(2.0), Collections.emptyList(),new GreaterThan());
|
||||
assertEquals(2, filtered.length());
|
||||
}
|
||||
|
||||
|
||||
|
@ -450,7 +450,7 @@ public class BooleanIndexingTest extends BaseNd4jTest {
|
|||
public void testChooseGreaterThanZero() {
|
||||
INDArray zero = Nd4j.linspace(0,4,4, Nd4j.dataType());
|
||||
INDArray filtered = BooleanIndexing.chooseFrom(new INDArray[]{zero},Arrays.asList(0.0), Collections.emptyList(),new GreaterThan());
|
||||
assertEquals(3,filtered.length());
|
||||
assertEquals(3, filtered.length());
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
|
@ -331,7 +331,7 @@ public class MixedDataTypesTests extends BaseNd4jTest {
|
|||
assertArrayEquals(exp, arr);
|
||||
}
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
@Test(expected = RuntimeException.class)
|
||||
public void testTypesValidation_3() {
|
||||
val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT);
|
||||
|
||||
|
|
|
@ -195,47 +195,6 @@ public class DerivativeTests extends BaseNd4jTest {
|
|||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSoftMaxDerivative() {
|
||||
Random r = new Random(12345L);
|
||||
|
||||
int[] mb = new int[] {10, 2, 1};
|
||||
for (int minibatch : mb) {
|
||||
System.out.println("Minibatch size: " + minibatch);
|
||||
INDArray z = Nd4j.zeros(minibatch, 5);
|
||||
double[][] in = new double[minibatch][5];
|
||||
double[][] softmax = new double[minibatch][5];
|
||||
double[][] expOut = new double[minibatch][5];
|
||||
for (int i = 0; i < minibatch; i++) {
|
||||
double rowSumExp = 0.0;
|
||||
for (int j = 0; j < 5; j++) {
|
||||
in[i][j] = 10 * r.nextDouble();
|
||||
z.putScalar(new int[] {i, j}, in[i][j]);
|
||||
rowSumExp += FastMath.exp(in[i][j]);
|
||||
}
|
||||
for (int j = 0; j < 5; j++) {
|
||||
softmax[i][j] = FastMath.exp(in[i][j]) / rowSumExp;
|
||||
expOut[i][j] = softmax[i][j] * (1.0 - softmax[i][j]);
|
||||
}
|
||||
}
|
||||
|
||||
INDArray sm = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(z.dup()))[0];
|
||||
INDArray zPrime = Nd4j.getExecutioner().exec((CustomOp) new SoftMaxDerivative(z))[0];
|
||||
System.out.println(Arrays.toString(sm.data().asDouble()));
|
||||
System.out.println(Arrays.toString(zPrime.data().asDouble()));
|
||||
assertNotEquals(sm, zPrime);
|
||||
|
||||
for (int i = 0; i < minibatch; i++) {
|
||||
for (int j = 0; j < 5; j++) {
|
||||
double relError = Math.abs(expOut[i][j] - zPrime.getDouble(i, j))
|
||||
/ (Math.abs(expOut[i][j]) + Math.abs(zPrime.getDouble(i, j)));
|
||||
// System.out.println("Error: " + relError);
|
||||
assertTrue(relError < REL_ERROR_TOLERANCE);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testSoftPlusDerivative() {
|
||||
|
@ -384,180 +343,6 @@ public class DerivativeTests extends BaseNd4jTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void softmaxsimpleLossTest() {
|
||||
/*
|
||||
Softmax derivative is correct if it is standalone
|
||||
But when we are applying it in the chain rule the current derivative function is incomplete.
|
||||
For this test, I am assuming that the function off interest is just MSE
|
||||
What the fix is:
|
||||
We need the derivative of a softmax needs to return a rank 2 matrix.
|
||||
Right now we get only the diagonal elements of this matrix
|
||||
http://stats.stackexchange.com/questions/79454/softmax-layer-in-a-neural-network
|
||||
*/
|
||||
//random array represeting preout
|
||||
INDArray X = Nd4j.rand(1, 2);
|
||||
//preout transformed to y_hat with softmax
|
||||
INDArray YHat = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(X.dup()))[0];
|
||||
|
||||
//hard coding something to construct a function with, using MSE
|
||||
INDArray Y = Nd4j.create(new double[][] {{0.123, 1 - 0.123}});
|
||||
|
||||
//This is the MSE now
|
||||
double lossHere = Transforms.pow(Y.sub(YHat), 2).sumNumber().doubleValue();
|
||||
|
||||
INDArray softmaxDer = Nd4j.getExecutioner().exec((CustomOp) new SoftMaxDerivative(X.dup()))[0];
|
||||
|
||||
//the way we apply the chain rule now is 2*(y-yhat)*softmaxder
|
||||
INDArray dLdY = Y.sub(YHat).mul(-2);
|
||||
INDArray currentGradient = dLdY.mul(softmaxDer);
|
||||
|
||||
//what I think we should be doing
|
||||
// we have x0, x1 -> y0,y1
|
||||
//we need the derivatives of the output of the softmax wrt every input (x0,x1)
|
||||
// we only have dy0/dx0 and dy1/dx1
|
||||
// we also need dy0/dx1 and dy1/dx0
|
||||
// the below is the chain rule in calc applied when L is a function of y0,y1; y0 and y1 are in turn functions of BOTH (x0 and x1)
|
||||
// dL/dx0 = (dl/dy0) * (dy0/dx0) + (dL/dy1) * (dy1/dx0)
|
||||
// dL/dx1 = (dl/dy0) * (dy0/dx1) + (dL/dy1) * (dy1/dx1)
|
||||
// worked it out on paper and googled it (should have googled first, gave formula from link above)
|
||||
// dy0/dx0 = y0*(1-y0) = y0*y1
|
||||
// dy1/dx0 = -y1*(1-y1) = -y0*y1
|
||||
// dy0/dx1 = -y0*(1-y0) = -y0*y1
|
||||
// dy1/dx1 = y1*(1-y1) = y0*y1
|
||||
//[ dL/dy0 dL/dy1] [[dy0/dx0 dy1/dx0] [dy0/dx1 dy1/dx1]]
|
||||
double y0y1 = softmaxDer.getDouble(0, 0);
|
||||
//hack but this is what we need to implement, straightforward here but complicated for >2
|
||||
//INDArray mysoftmaxDer = Nd4j.create(new double[][] {{y0y1,y0y1*-1},{-1*y0y1,y0y1}});
|
||||
INDArray mysoftmaxDer = correctSoftmax(X);
|
||||
INDArray myGradient = mysoftmaxDer.mulRowVector(dLdY).sum(1);
|
||||
|
||||
double epsilon = 0.0001;
|
||||
INDArray Xiplus, Ximinus;
|
||||
INDArray YHatplus, YHatminus;
|
||||
double lossplus, lossminus;
|
||||
|
||||
INDArray numGradient = Nd4j.zeros(1, 2);
|
||||
|
||||
for (int i = 0; i < 2; i++) {
|
||||
/* change X one value one at a time */
|
||||
|
||||
// +epsilon
|
||||
double x = X.getDouble(0, i);
|
||||
Xiplus = X.dup();
|
||||
Xiplus.put(0, i, x + epsilon);
|
||||
YHatplus = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(Xiplus.dup()))[0];
|
||||
lossplus = Transforms.pow(Y.sub(YHatplus), 2).sumNumber().doubleValue();
|
||||
|
||||
// -epsilon
|
||||
Ximinus = X.dup();
|
||||
Ximinus.put(0, i, x - epsilon);
|
||||
YHatminus = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(Ximinus.dup()))[0];
|
||||
lossminus = Transforms.pow(Y.sub(YHatminus), 2).sumNumber().doubleValue();
|
||||
|
||||
double gradienti = (lossplus - lossminus) / (2 * epsilon);
|
||||
numGradient.put(0, i, gradienti);
|
||||
}
|
||||
System.out.println("=========================");
|
||||
System.out.println("NUMERICAL:");
|
||||
System.out.println(numGradient);
|
||||
System.out.println("\nCURRENTLY:");
|
||||
System.out.println(currentGradient);
|
||||
System.out.println("\nMY GRADIENT:");
|
||||
System.out.println(myGradient + "\n");
|
||||
System.out.println(
|
||||
"Because of the nature of the derivative of the softmax for length = 2, our current method will make it off by a factor of 2");
|
||||
System.out.println("=========================");
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void softmaxsimplelongerlengthLossTest() {
|
||||
/*
|
||||
Read comments in earlier test for length = 2
|
||||
*/
|
||||
//random array represeting preout
|
||||
int someLength = 7;
|
||||
|
||||
INDArray X = Nd4j.rand(1, someLength);
|
||||
//preout transformed to y_hat with softmax
|
||||
INDArray YHat = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(X.dup()))[0];
|
||||
|
||||
//hard coding something to construct a function with, using MSE
|
||||
INDArray temp = Nd4j.rand(1, someLength);
|
||||
INDArray Y = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(temp))[0];
|
||||
|
||||
//This is the MSE now
|
||||
double lossHere = Transforms.pow(Y.sub(YHat), 2).sumNumber().doubleValue();
|
||||
|
||||
INDArray softmaxDer = Nd4j.getExecutioner().exec((CustomOp) new SoftMaxDerivative(X.dup()))[0];
|
||||
|
||||
//the way we apply the chain rule now is 2*(y-yhat)*softmaxder
|
||||
INDArray dLdY = Y.sub(YHat).mul(-2);
|
||||
INDArray currentGradient = dLdY.mul(softmaxDer);
|
||||
|
||||
INDArray mysoftmaxDer = correctSoftmax(X);
|
||||
INDArray myGradient = mysoftmaxDer.mulRowVector(dLdY).sum(1);
|
||||
|
||||
double epsilon = 0.0001;
|
||||
INDArray Xiplus, Ximinus;
|
||||
INDArray YHatplus, YHatminus;
|
||||
double lossplus, lossminus;
|
||||
|
||||
INDArray numGradient = Nd4j.zeros(1, someLength);
|
||||
|
||||
for (int i = 0; i < someLength; i++) {
|
||||
/* change X one value one at a time */
|
||||
|
||||
// +epsilon
|
||||
double x = X.getDouble(0, i);
|
||||
Xiplus = X.dup();
|
||||
Xiplus.put(0, i, x + epsilon);
|
||||
YHatplus = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(Xiplus.dup()))[0];
|
||||
lossplus = Transforms.pow(Y.sub(YHatplus), 2).sumNumber().doubleValue();
|
||||
|
||||
// -epsilon
|
||||
Ximinus = X.dup();
|
||||
Ximinus.put(0, i, x - epsilon);
|
||||
YHatminus = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(Ximinus.dup()))[0];
|
||||
lossminus = Transforms.pow(Y.sub(YHatminus), 2).sumNumber().doubleValue();
|
||||
|
||||
double gradienti = (lossplus - lossminus) / (2 * epsilon);
|
||||
numGradient.put(0, i, gradienti);
|
||||
}
|
||||
System.out.println("=========================");
|
||||
System.out.println("NUMERICAL GRADIENT:");
|
||||
System.out.println(new NDArrayStrings(6).format(numGradient).toString());
|
||||
System.out.println("\nANALYTIC USING EXISTING SOFTMAX DER:");
|
||||
System.out.println(new NDArrayStrings(6).format(currentGradient).toString());
|
||||
System.out.println("\nGRADIENT USING MY VERSION OF SOFTMAX DER:");
|
||||
System.out.println(new NDArrayStrings(6).format(myGradient).toString());
|
||||
System.out.println("=========================");
|
||||
}
|
||||
|
||||
|
||||
public static INDArray correctSoftmax(INDArray X) {
|
||||
// this is only for X a row vector
|
||||
// should return rank 2 matrix diagonal elements are pi*(1-pi)
|
||||
//rest are -pi*pj
|
||||
INDArray p = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(X.dup()))[0];
|
||||
INDArray pCol = p.dup().transpose();
|
||||
INDArray pipj = pCol.mmul(p);
|
||||
pipj.muli(-1);
|
||||
|
||||
//so now pipj is correct except for the diagonal elements
|
||||
// which by the way is what our current softmax der gives us
|
||||
INDArray diagp = Nd4j.getExecutioner().exec((CustomOp) new SoftMaxDerivative(X.dup()))[0];
|
||||
|
||||
|
||||
//ugly for loop to correct diag elements
|
||||
for (int i = 0; i < X.length(); i++) {
|
||||
pipj.put(i, i, diagp.getDouble(0, i));
|
||||
}
|
||||
|
||||
return pipj;
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
return 'f';
|
||||
|
|
|
@ -50,9 +50,10 @@ import org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseSubtraction;
|
|||
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThan;
|
||||
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThan;
|
||||
import org.nd4j.linalg.api.ops.impl.summarystats.Variance;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.Histogram;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.HistogramFixedWidth;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.floating.Histogram;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.Exp;
|
||||
|
@ -876,14 +877,14 @@ public class OpExecutionerTestsC extends BaseNd4jTest {
|
|||
@Test
|
||||
public void testHistogram1() {
|
||||
INDArray x = Nd4j.linspace(1, 1000, 100000, DataType.DOUBLE);
|
||||
INDArray z = Nd4j.zeros(DataType.DOUBLE,new long[]{20});
|
||||
INDArray z = Nd4j.zeros(DataType.LONG,new long[]{20});
|
||||
|
||||
INDArray xDup = x.dup();
|
||||
INDArray zDup = z.dup();
|
||||
|
||||
INDArray zExp = Nd4j.create(DataType.DOUBLE, 20).assign(5000);
|
||||
INDArray zExp = Nd4j.create(DataType.LONG, 20).assign(5000);
|
||||
|
||||
Histogram histogram = new Histogram(x, z);
|
||||
val histogram = new Histogram(x, z);
|
||||
|
||||
Nd4j.getExecutioner().exec(histogram);
|
||||
|
||||
|
@ -899,16 +900,13 @@ public class OpExecutionerTestsC extends BaseNd4jTest {
|
|||
public void testHistogram2() {
|
||||
INDArray x = Nd4j.create(new float[] {0f, 0f, 0f, 5f, 5f, 5f, 10f, 10f, 10f});
|
||||
|
||||
|
||||
INDArray xDup = x.dup();
|
||||
|
||||
INDArray zExp = Nd4j.zeros(DataType.FLOAT, 10).putScalar(0, 3f).putScalar(5, 3f).putScalar(9, 3f);
|
||||
INDArray zExp = Nd4j.zeros(DataType.LONG, 10).putScalar(0, 3).putScalar(5, 3).putScalar(9, 3);
|
||||
|
||||
Histogram histogram = new Histogram(x, 10);
|
||||
val histogram = new Histogram(x, 10);
|
||||
|
||||
Nd4j.getExecutioner().exec(histogram);
|
||||
|
||||
INDArray z = histogram.z();
|
||||
val z = Nd4j.getExecutioner().exec(histogram)[0];
|
||||
|
||||
assertEquals(xDup, x);
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.nd4j.linalg.serde;
|
|||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.val;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
|
@ -58,7 +59,7 @@ public class JsonSerdeTests extends BaseNd4jTest {
|
|||
Nd4j.getRandom().setSeed(12345);
|
||||
INDArray in = Nd4j.rand(DataType.DOUBLE, 3, 4).muli(20).subi(10);
|
||||
|
||||
ObjectMapper om = new ObjectMapper();
|
||||
val om = new ObjectMapper();
|
||||
|
||||
for (DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.LONG, DataType.INT, DataType.SHORT,
|
||||
DataType.BYTE, DataType.UBYTE, DataType.BOOL, DataType.UTF8}) {
|
||||
|
|
|
@ -673,7 +673,7 @@ public class BasicWorkspaceTests extends BaseNd4jTest {
|
|||
|
||||
workspace.initializeWorkspace();
|
||||
long reqMemory = 12 * Nd4j.sizeOfDataType(arrayCold.dataType());
|
||||
assertEquals(reqMemory + reqMemory % 8 + Nd4j.sizeOfDataType(DOUBLE), workspace.getCurrentSize());
|
||||
assertEquals(reqMemory + reqMemory % 8, workspace.getCurrentSize());
|
||||
|
||||
|
||||
log.info("-----------------------");
|
||||
|
@ -692,7 +692,7 @@ public class BasicWorkspaceTests extends BaseNd4jTest {
|
|||
|
||||
array.addi(1.0);
|
||||
|
||||
assertEquals(reqMem + reqMem % 8 + Nd4j.sizeOfDataType(DOUBLE), workspace.getPrimaryOffset());
|
||||
assertEquals(reqMem + reqMem % 8, workspace.getPrimaryOffset());
|
||||
|
||||
assertEquals("Failed on iteration " + x, 10, array.sumNumber().doubleValue(), 0.01);
|
||||
|
||||
|
@ -746,7 +746,7 @@ public class BasicWorkspaceTests extends BaseNd4jTest {
|
|||
|
||||
INDArray dup = array.dup();
|
||||
|
||||
assertEquals((reqMemory + reqMemory % 8) * 2 + Nd4j.sizeOfDataType(DOUBLE), workspace.getPrimaryOffset());
|
||||
assertEquals((reqMemory + reqMemory % 8) * 2, workspace.getPrimaryOffset());
|
||||
|
||||
assertEquals(5, dup.sumNumber().doubleValue(), 0.01);
|
||||
|
||||
|
|
|
@ -91,7 +91,7 @@ public class DebugModeTests extends BaseNd4jTest {
|
|||
assertEquals(0, ws.getDeviceOffset());
|
||||
|
||||
// array buffer should be spilled now
|
||||
assertEquals(10 * 10 * Nd4j.sizeOfDataType(DataType.DOUBLE) + Nd4j.sizeOfDataType(DataType.DOUBLE), ws.getSpilledSize());
|
||||
assertEquals(10 * 10 * Nd4j.sizeOfDataType(DataType.DOUBLE), ws.getSpilledSize());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -118,7 +118,7 @@ public class DebugModeTests extends BaseNd4jTest {
|
|||
assertEquals(0, ws.getDeviceOffset());
|
||||
|
||||
// array buffer should be spilled now
|
||||
assertEquals(10 * 10 * Nd4j.sizeOfDataType(DataType.DOUBLE) + Nd4j.sizeOfDataType(DataType.DOUBLE), ws.getSpilledSize());
|
||||
assertEquals(10 * 10 * Nd4j.sizeOfDataType(DataType.DOUBLE), ws.getSpilledSize());
|
||||
}
|
||||
|
||||
try (val ws = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfig, "R_119_1992")) {
|
||||
|
|
Loading…
Reference in New Issue