From fd6c0df024dab13f9d794415993cfb6eabbac400 Mon Sep 17 00:00:00 2001 From: raver119 Date: Tue, 16 Jul 2019 18:48:40 +0300 Subject: [PATCH] [WIP] More CUDA fixes/updates (#62) * CUDA reallocation update Signed-off-by: raver119 * Legacy SoftMax/LogSoftMax/SoftMaxDerivative removed from cpp Signed-off-by: raver119 * SoftMaxDerivative op removed Signed-off-by: raver119 * few tests updates Signed-off-by: raver119 * RNG fixes Signed-off-by: raver119 * few more tests updates Signed-off-by: raver119 * legacy Histogram/Pooling2D removed Signed-off-by: raver119 * legacy Histogram removed Signed-off-by: raver119 * histogram moved Signed-off-by: raver119 * histogram moved cuda Signed-off-by: raver119 * Histogram custom op Signed-off-by: raver119 --- libnd4j/blas/cpu/NativeOpExecutioner.cpp | 9 + libnd4j/blas/cuda/NativeOpExecutioner.cu | 106 ++------- libnd4j/blas/cuda/NativeOps.cu | 6 +- libnd4j/include/loops/cuda/random.cu | 12 +- libnd4j/include/loops/legacy_ops.h | 5 - .../generic/transforms/histogram.cpp | 57 +++++ .../ops/declarable/headers/transforms.h | 7 + .../ops/declarable/helpers/cpu/histogram.cpp | 83 +++++++ .../ops/declarable/helpers/cuda/histogram.cu | 134 +++++++++++ .../ops/declarable/helpers/histogram.h | 34 +++ libnd4j/include/ops/special_ops.h | 126 ---------- libnd4j/include/ops/special_random_ops.h | 64 ++---- .../benchmarking/impl/FullBenchmarkSuit.cpp | 4 +- .../benchmarking/impl/LightBenchmarkSuit.cpp | 4 +- .../samediff/serde/LegacyOpMapper.java | 3 - .../autodiff/validation/OpValidation.java | 8 +- .../converters/ImportClassMapping.java | 2 - .../nd4j/linalg/activations/Activation.java | 14 -- .../api/ops/impl/transforms/Histogram.java | 55 +++++ .../impl/transforms/floating/Histogram.java | 114 ---------- .../transforms/strict/SoftMaxDerivative.java | 76 ------- .../jcublas/buffer/BaseCudaDataBuffer.java | 100 +++----- .../java/org/nd4j/nativeblas/Nd4jCpu.java | 21 ++ .../test/java/org/nd4j/linalg/Nd4jTestsC.java | 21 -- .../linalg/api/buffer/IntDataBufferTests.java | 14 +- .../org/nd4j/linalg/api/rng/RngTests.java | 6 +- .../java/org/nd4j/linalg/crash/CrashTest.java | 2 - .../linalg/indexing/BooleanIndexingTest.java | 6 +- .../linalg/mixed/MixedDataTypesTests.java | 2 +- .../org/nd4j/linalg/ops/DerivativeTests.java | 215 ------------------ .../nd4j/linalg/ops/OpExecutionerTestsC.java | 18 +- .../org/nd4j/linalg/serde/JsonSerdeTests.java | 3 +- .../linalg/workspace/BasicWorkspaceTests.java | 6 +- .../nd4j/linalg/workspace/DebugModeTests.java | 4 +- 34 files changed, 509 insertions(+), 832 deletions(-) create mode 100644 libnd4j/include/ops/declarable/generic/transforms/histogram.cpp create mode 100644 libnd4j/include/ops/declarable/helpers/cpu/histogram.cpp create mode 100644 libnd4j/include/ops/declarable/helpers/cuda/histogram.cu create mode 100644 libnd4j/include/ops/declarable/helpers/histogram.h create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Histogram.java delete mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/floating/Histogram.java delete mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftMaxDerivative.java diff --git a/libnd4j/blas/cpu/NativeOpExecutioner.cpp b/libnd4j/blas/cpu/NativeOpExecutioner.cpp index 986f76469..a87ca4e30 100644 --- a/libnd4j/blas/cpu/NativeOpExecutioner.cpp +++ b/libnd4j/blas/cpu/NativeOpExecutioner.cpp @@ -827,6 +827,9 @@ void NativeOpExecutioner::execRandom(nd4j::LaunchContext *lc, auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); BUILD_SINGLE_SELECTOR(zType, functions::random::RandomFunction, ::execTransform(opNum, state, hZ, hZShapeInfo, extraArguments), FLOAT_TYPES); + + auto rng = reinterpret_cast(state); + rng->rewindH(shape::length(hZShapeInfo)); } //////////////////////////////////////////////////////////////////////// @@ -842,6 +845,9 @@ void NativeOpExecutioner::execRandom(nd4j::LaunchContext *lc, auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); BUILD_SINGLE_SELECTOR(zType, functions::random::RandomFunction, ::execTransform(opNum, state, hX, hXShapeInfo, hZ, hZShapeInfo, extraArguments), FLOAT_TYPES); + + auto rng = reinterpret_cast(state); + rng->rewindH(shape::length(hZShapeInfo)); } //////////////////////////////////////////////////////////////////////// @@ -859,6 +865,9 @@ void NativeOpExecutioner::execRandom(nd4j::LaunchContext *lc, auto xType = nd4j::ArrayOptions::dataType(hZShapeInfo); BUILD_SINGLE_SELECTOR(xType, functions::random::RandomFunction, ::execTransform(opNum, state, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraArguments), FLOAT_TYPES); + + auto rng = reinterpret_cast(state); + rng->rewindH(shape::length(hZShapeInfo)); } diff --git a/libnd4j/blas/cuda/NativeOpExecutioner.cu b/libnd4j/blas/cuda/NativeOpExecutioner.cu index 0a6104be4..98f11a275 100644 --- a/libnd4j/blas/cuda/NativeOpExecutioner.cu +++ b/libnd4j/blas/cuda/NativeOpExecutioner.cu @@ -774,78 +774,7 @@ void NativeOpExecutioner::execTransformStrict(nd4j::LaunchContext *lc, if (xType != zType || !DataTypeUtils::isR(xType)) throw datatype_exception::build("NativeOpExecutioner::execTransformStrict requires X & Z to have same floating point type", xType, zType); - switch (opNum) { - case transform::SoftMax: - case transform::SoftMaxDerivative: - case transform::LogSoftMax: { - if (shape::isVector(hXShapeInfo)) { - int length = shape::length(hXShapeInfo); - int block = nd4j::math::nd4j_min(length, 256); - - launchDims.x = 1; - launchDims.y = block; - launchDims.z += (block * sizeof(double) * 4); - - BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformStrict, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, lc->getAllocationPointer(), lc->getReductionPointer(), nullptr, nullptr), FLOAT_TYPES); - } else { - auto shape = shape::shapeOf(hXShapeInfo); - auto reductionPointer = lc->getReductionPointer(); - auto allocationPointer = lc->getAllocationPointer(); - auto specialPointer = reinterpret_cast(allocationPointer); - - // special pointer for special buffer for special ops - auto dimension = reinterpret_cast(specialPointer); - auto maxDimension = dimension + 1; - auto maxShapeBuffer = reinterpret_cast(maxDimension + 1); - auto special = reinterpret_cast (maxShapeBuffer + (MAX_RANK * 2 + 4)); - - - Nd4jLong maxShape[2] = {shape::shapeOf(hXShapeInfo)[0], 1}; - auto hostMaxShapeBuffer = shape::shapeBuffer(2, xType, maxShape); - - prepareShapeBuffer<<<1, 1, 128, *stream>>>(dimension, maxDimension, maxShapeBuffer, shape[0], xType); - - DEBUG_KERNEL(stream, opNum); - - // max 3 - execReduceSame(lc, reduce::Max, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, nullptr, hostMaxShapeBuffer, special, maxShapeBuffer, maxDimension, 1, tadShapeInfo, tadOffsets); - - DEBUG_KERNEL(stream, opNum); - - // sub 1 - execBroadcast(lc, broadcast::Subtract, hX, hXShapeInfo, dX, dXShapeInfo, nullptr, hostMaxShapeBuffer, special, maxShapeBuffer, nullptr, hZShapeInfo, dZ, dZShapeInfo, dimension, 1, tadShapeInfo, tadOffsets, nullptr, nullptr); - - DEBUG_KERNEL(stream, opNum); - - // exp 3 - execTransformFloat(lc, transform::Exp, hZ, hZShapeInfo, dZ, dZShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraParams, tadShapeInfo, tadOffsets); - - DEBUG_KERNEL(stream, opNum); - - //sum 1 - execReduceSame(lc, reduce::Sum, hZ, hZShapeInfo, dZ, dZShapeInfo, extraParams, nullptr, hostMaxShapeBuffer, special, maxShapeBuffer, maxDimension, 1, tadShapeInfo, tadOffsets); - - // divide 3 - execBroadcast(lc, broadcast::Divide, hZ, hZShapeInfo, dZ, dZShapeInfo, nullptr, hostMaxShapeBuffer, special, maxShapeBuffer, nullptr, hZShapeInfo, dZ, dZShapeInfo, dimension, 1, tadShapeInfo, tadOffsets, nullptr, nullptr); - - DEBUG_KERNEL(stream, opNum); - - // log 3 - if (opNum == transform::LogSoftMax) - execTransformFloat(lc, transform::Log, nullptr, hZShapeInfo, dZ, dZShapeInfo, nullptr, hZShapeInfo, dZ, dZShapeInfo, extraParams, tadShapeInfo, tadOffsets); - else if (opNum == transform::SoftMaxDerivative) - execTransformStrict(lc, transform::SpecialDerivative, nullptr, hZShapeInfo, dZ, dZShapeInfo, nullptr, hZShapeInfo, dZ, dZShapeInfo, extraParams, tadShapeInfo, tadOffsets); - - nd4j::DebugHelper::checkErrorCode(stream, "SoftMax(...) failed"); - - delete hostMaxShapeBuffer; - } - } - break; - default: { - BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformStrict, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), FLOAT_TYPES); - } - } + BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformStrict, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), FLOAT_TYPES); } //////////////////////////////////////////////////////////////////////// @@ -869,22 +798,8 @@ void NativeOpExecutioner::execTransformFloat(nd4j::LaunchContext *lc, if (!DataTypeUtils::isR(zType)) throw datatype_exception::build("NativeOpExecutioner::execTransformFloat requires Z to have floating point type", zType); - if (opNum == transform::Histogram) { - dim3 launchDims(256, 256, 32768); - - Nd4jPointer maskedallocationPointer; - auto length = shape::length(hZShapeInfo); - cudaMalloc(reinterpret_cast(&maskedallocationPointer), length * launchDims.x * DataTypeUtils::sizeOf(nd4j::DataType::INT64)); - auto imaskedallocationPointer = reinterpret_cast(maskedallocationPointer); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformFloat, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, imaskedallocationPointer, reductionPointer, nullptr, nullptr), LIBND4J_TYPES, FLOAT_TYPES); - - checkCudaErrors(cudaStreamSynchronize(*stream)); - cudaFree(maskedallocationPointer); - } else { - dim3 launchDims(512, 512, 16384); - BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformFloat, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, FLOAT_TYPES); - } + dim3 launchDims(512, 512, 16384); + BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformFloat, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, FLOAT_TYPES); } @@ -1197,12 +1112,15 @@ void NativeOpExecutioner::execRandom(nd4j::LaunchContext *lc, dim3 launchDims = dim3(512, 512, 32768); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + auto rng = reinterpret_cast(stateHost); + // functions::random::RandomFunction::executeCudaSingle(launchDims, extraPointers, opNum, stateHost, dZ, dZShapeInfo, extraArguments), BUILD_SINGLE_SELECTOR(zType, functions::random::RandomFunction, ::executeCudaSingle(launchDims, stream, opNum, stateDevice, dZ, dZShapeInfo, extraArguments), FLOAT_TYPES); - checkCudaErrors(cudaMemcpyAsync(stateHost, stateDevice, sizeOf, cudaMemcpyDeviceToHost, *stream)); checkCudaErrors(cudaStreamSynchronize(*stream)); cudaFree(stateDevice); + + rng->rewindH(shape::length(hZShapeInfo)); } //////////////////////////////////////////////////////////////////////// @@ -1224,14 +1142,17 @@ void NativeOpExecutioner::execRandom(nd4j::LaunchContext *lc, checkCudaErrors(cudaStreamSynchronize(*stream)); checkCudaErrors(cudaMemcpyAsync(stateDevice, stateHost, sizeOf, cudaMemcpyHostToDevice, *stream)); + auto rng = reinterpret_cast(stateHost); + dim3 launchDims = dim3(512, 512, 32768); auto xType = nd4j::ArrayOptions::dataType(hZShapeInfo); // functions::random::RandomFunction::executeCudaDouble(launchDims, extraPointers, opNum, stateHost, dX, dXShapeInfo, dZ, dZShapeInfo, extraArguments); BUILD_SINGLE_SELECTOR(xType, functions::random::RandomFunction, ::executeCudaDouble(launchDims, stream, opNum, stateDevice, dX, dXShapeInfo, dZ, dZShapeInfo, extraArguments), FLOAT_TYPES); - checkCudaErrors(cudaMemcpyAsync(stateHost, stateDevice, sizeOf, cudaMemcpyDeviceToHost, *stream)); checkCudaErrors(cudaStreamSynchronize(*stream)); cudaFree(stateDevice); + + rng->rewindH(shape::length(hZShapeInfo)); } //////////////////////////////////////////////////////////////////////// @@ -1254,14 +1175,17 @@ void NativeOpExecutioner::execRandom(nd4j::LaunchContext *lc, checkCudaErrors(cudaStreamSynchronize(*stream)); checkCudaErrors(cudaMemcpyAsync(stateDevice, stateHost, sizeOf, cudaMemcpyHostToDevice, *stream)); + auto rng = reinterpret_cast(stateHost); + dim3 launchDims = dim3(512, 512, 32768); auto xType = nd4j::ArrayOptions::dataType(hZShapeInfo); // functions::random::RandomFunction::executeCudaTriple(launchDims, extraPointers, opNum, stateHost, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraArguments); BUILD_SINGLE_SELECTOR(xType, functions::random::RandomFunction, ::executeCudaTriple(launchDims, stream, opNum, stateDevice, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraArguments), FLOAT_TYPES); - checkCudaErrors(cudaMemcpyAsync(stateHost, stateDevice, sizeOf, cudaMemcpyDeviceToHost, *stream)); checkCudaErrors(cudaStreamSynchronize(*stream)); cudaFree(stateDevice); + + rng->rewindH(shape::length(hZShapeInfo)); } //////////////////////////////////////////////////////////////////////// diff --git a/libnd4j/blas/cuda/NativeOps.cu b/libnd4j/blas/cuda/NativeOps.cu index e404863d5..2c292556c 100755 --- a/libnd4j/blas/cuda/NativeOps.cu +++ b/libnd4j/blas/cuda/NativeOps.cu @@ -1958,7 +1958,7 @@ void NativeOps::execRandom(Nd4jPointer *extraPointers, void *extraArguments) { LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execRandom(&lc, opNum, extraPointers, hZ, hZShapeInfo, dZ, dZShapeInfo, extraArguments); + NativeOpExecutioner::execRandom(&lc, opNum, stateHost, hZ, hZShapeInfo, dZ, dZShapeInfo, extraArguments); } //////////////////////////////////////////////////////////////////////// @@ -1970,7 +1970,7 @@ void NativeOps::execRandom(Nd4jPointer *extraPointers, int opNum, Nd4jPointer st void *extraArguments) { LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execRandom(&lc, opNum, extraPointers, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraArguments); + NativeOpExecutioner::execRandom(&lc, opNum, stateHost, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraArguments); } //////////////////////////////////////////////////////////////////////// @@ -1984,7 +1984,7 @@ void NativeOps::execRandom(Nd4jPointer *extraPointers, int opNum, Nd4jPointer st void *extraArguments) { LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execRandom(&lc, opNum, extraPointers, hX, hXShapeInfo, dX, dXShapeInfo, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraArguments); + NativeOpExecutioner::execRandom(&lc, opNum, stateHost, hX, hXShapeInfo, dX, dXShapeInfo, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraArguments); } diff --git a/libnd4j/include/loops/cuda/random.cu b/libnd4j/include/loops/cuda/random.cu index a00445d2c..4cc1c6565 100644 --- a/libnd4j/include/loops/cuda/random.cu +++ b/libnd4j/include/loops/cuda/random.cu @@ -152,9 +152,9 @@ namespace functions { __syncthreads(); // using this loop instead of memcpy - for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) { + for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) cB[e] = dB[e]; - } + __syncthreads(); @@ -213,9 +213,9 @@ namespace functions { __syncthreads(); // using this loop instead of memcpy - for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) { + for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) cB[e] = dB[e]; - } + __syncthreads(); @@ -262,9 +262,9 @@ namespace functions { __syncthreads(); // using this loop instead of memcpy - for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) { + for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) cB[e] = dB[e]; - } + __syncthreads(); int tid = blockIdx.x * blockDim.x + threadIdx.x; diff --git a/libnd4j/include/loops/legacy_ops.h b/libnd4j/include/loops/legacy_ops.h index 4a960e244..760b30bbd 100644 --- a/libnd4j/include/loops/legacy_ops.h +++ b/libnd4j/include/loops/legacy_ops.h @@ -107,9 +107,6 @@ #define TRANSFORM_STRICT_OPS \ - (0, SoftMax), \ - (1, SoftMaxDerivative), \ - (2, LogSoftMax) ,\ (3, ELUDerivative), \ (4, TanhDerivative), \ (5, HardTanhDerivative), \ @@ -167,9 +164,7 @@ // these ops return one of FLOAT data types #define TRANSFORM_FLOAT_OPS \ - (0, Histogram), \ (1, Sqrt), \ - (2, Pooling2D) ,\ (3, RSqrt) diff --git a/libnd4j/include/ops/declarable/generic/transforms/histogram.cpp b/libnd4j/include/ops/declarable/generic/transforms/histogram.cpp new file mode 100644 index 000000000..3581dcd9a --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/transforms/histogram.cpp @@ -0,0 +1,57 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#if NOT_EXCLUDED(OP_histogram) + +#include +#include +#include + +namespace nd4j { + namespace ops { + CUSTOM_OP_IMPL(histogram, 1, 1, false, 0, 1) { + auto input = INPUT_VARIABLE(0); + auto numBins = INT_ARG(0); + auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(numBins == output->lengthOf(), 0, "Histogram: numBins must match output length") + + helpers::histogramHelper(block.launchContext(), *input, *output); + + return Status::OK(); + } + + DECLARE_SHAPE_FN(histogram) { + auto numBins = INT_ARG(0); + + return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(numBins, nd4j::DataType::INT64)); + } + + + DECLARE_TYPES(histogram) { + getOpDescriptor() + ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_INTS}); + }; + } +} + +#endif diff --git a/libnd4j/include/ops/declarable/headers/transforms.h b/libnd4j/include/ops/declarable/headers/transforms.h index 6451b4f4d..b24fad482 100644 --- a/libnd4j/include/ops/declarable/headers/transforms.h +++ b/libnd4j/include/ops/declarable/headers/transforms.h @@ -206,6 +206,13 @@ namespace nd4j { #if NOT_EXCLUDED(OP_hashcode) DECLARE_REDUCTION_OP(hashcode, 1, 1, false, 0, 0); #endif + + /** + * This operation calculates number of entries per bin + */ + #if NOT_EXCLUDED(OP_histogram) + DECLARE_CUSTOM_OP(histogram, 1, 1, false, 0, 1); + #endif } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/histogram.cpp b/libnd4j/include/ops/declarable/helpers/cpu/histogram.cpp new file mode 100644 index 000000000..49626168c --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/histogram.cpp @@ -0,0 +1,83 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include + +namespace nd4j { + namespace ops { + namespace helpers { + template + static void histogram_(void *xBuffer, Nd4jLong *xShapeInfo, void *zBuffer, Nd4jLong *zShapeInfo, Nd4jLong numBins, double min_val, double max_val) { + auto dx = reinterpret_cast(xBuffer); + auto result = reinterpret_cast(zBuffer); + + int length = shape::length(xShapeInfo); + // FIXME: 2??? + int _threads = 2; + + int span = (length / _threads) + 8; + + X binSize = (max_val - min_val) / (numBins); + + PRAGMA_OMP_PARALLEL_THREADS(_threads) + { + int tid, start, end; + + int *bins = new int[numBins]; + std::memset(bins, 0, sizeof(int) * numBins); + tid = omp_get_thread_num(); + start = span * tid; + end = span * (tid + 1); + if (end > length) end = length; + + PRAGMA_OMP_SIMD + for (int x = start; x < end; x++) { + int idx = (int) ((dx[x] - min_val) / binSize); + if (idx < 0) + idx = 0; + else if (idx >= numBins) + idx = numBins - 1; + + bins[idx]++; + } + + PRAGMA_OMP_CRITICAL + { + PRAGMA_OMP_SIMD + for (int x = 0; x < numBins; x++) { + result[x] += bins[x]; + } + + } + + delete[] bins; + } + } + + void histogramHelper(nd4j::LaunchContext *context, NDArray &input, NDArray &output) { + Nd4jLong numBins = output.lengthOf(); + double min_val = input.reduceNumber(reduce::SameOps::Min).e(0); + double max_val = input.reduceNumber(reduce::SameOps::Max).e(0); + + BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), histogram_, (input.buffer(), input.shapeInfo(), output.getBuffer(), output.getShapeInfo(), numBins, min_val, max_val), LIBND4J_TYPES, INTEGER_TYPES); + } + } + } +} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/histogram.cu b/libnd4j/include/ops/declarable/helpers/cuda/histogram.cu new file mode 100644 index 000000000..e04b1b57a --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/histogram.cu @@ -0,0 +1,134 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#include + +namespace nd4j { + namespace ops { + namespace helpers { + template + void _CUDA_G histogramKernel(void *xBuffer, Nd4jLong *xShapeInfo, void *zBuffer, Nd4jLong *zShapeInfo, void *allocationPointer, void *reductionPointer, Nd4jLong numBins, double min_val, double max_val) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + auto dx = reinterpret_cast(xBuffer); + auto result = reinterpret_cast(zBuffer); + + __shared__ Z *bins; + __shared__ int length; + __shared__ Z *reductor; + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + bins = (Z *) shmem; + reductor = ((Z *) allocationPointer) + (numBins * blockIdx.x); + + length = shape::length(xShapeInfo); + } + __syncthreads(); + + Z binSize = (max_val - min_val) / (numBins); + + for (int e = threadIdx.x; e < numBins; e += blockDim.x) { + bins[e] = (Z) 0.0f; + } + __syncthreads(); + + for (int e = tid; e < length; e+= blockDim.x * gridDim.x) { + int idx = (int) ((dx[e] - min_val) / binSize); + if (idx < 0) idx = 0; + else if (idx >= numBins) idx = numBins - 1; + + nd4j::math::atomics::nd4j_atomicAdd(&bins[idx], (Z) 1.0f); + } + __syncthreads(); + + // transfer shared memory to reduction memory + + + if (gridDim.x > 1) { + unsigned int *tc = (unsigned int *)reductionPointer; + __shared__ bool amLast; + + for (int e = threadIdx.x; e < numBins; e += blockDim.x) { + reductor[e] = bins[e]; + } + __threadfence(); + __syncthreads(); + + if (threadIdx.x == 0) { + unsigned int ticket = atomicInc(&tc[16384], gridDim.x); + amLast = (ticket == gridDim.x - 1); + } + __syncthreads(); + + if (amLast) { + tc[16384] = 0; + + // nullify shared memory for future accumulation + for (int e = threadIdx.x; e < numBins; e += blockDim.x) { + bins[e] = (Z) 0.0f; + } + + // accumulate reduced bins + for (int r = 0; r < gridDim.x; r++) { + Z *ptrBuf = ((Z *)allocationPointer) + (r * numBins); + + for (int e = threadIdx.x; e < numBins; e += blockDim.x) { + bins[e] += ptrBuf[e]; + } + } + __syncthreads(); + + // write them out to Z + for (int e = threadIdx.x; e < numBins; e += blockDim.x) { + result[e] = bins[e]; + } + } + } else { + // if there's only 1 block - just write away data + for (int e = threadIdx.x; e < numBins; e += blockDim.x) { + result[e] = bins[e]; + } + } + } + + template + static void histogram_(nd4j::LaunchContext *context, void *xBuffer, Nd4jLong *xShapeInfo, void *zBuffer, Nd4jLong *zShapeInfo, Nd4jLong numBins, double min_val, double max_val) { + int numThreads = 256; + int numBlocks = nd4j::math::nd4j_max(256, nd4j::math::nd4j_min(1, shape::length(xShapeInfo) / numThreads)); + int workspaceSize = numBlocks * numBins; + auto tmp = NDArrayFactory::create('c',{workspaceSize}); + + histogramKernel<<getCudaStream()>>>(xBuffer, xShapeInfo, zBuffer, zShapeInfo, tmp.getSpecialBuffer(), context->getReductionPointer(), numBins, min_val, max_val); + + cudaStreamSynchronize(*context->getCudaStream()); + } + + void histogramHelper(nd4j::LaunchContext *context, NDArray &input, NDArray &output) { + Nd4jLong numBins = output.lengthOf(); + double min_val = input.reduceNumber(reduce::SameOps::Min).e(0); + double max_val = input.reduceNumber(reduce::SameOps::Max).e(0); + + BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), histogram_, (context, input.specialBuffer(), input.specialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), numBins, min_val, max_val), LIBND4J_TYPES, INTEGER_TYPES); + + NDArray::registerSpecialUse({&output}, {&input}); + } + } + } +} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/histogram.h b/libnd4j/include/ops/declarable/helpers/histogram.h new file mode 100644 index 000000000..b6556599f --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/histogram.h @@ -0,0 +1,34 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef LIBND4J_HISTOGRAM_H +#define LIBND4J_HISTOGRAM_H + +#include + +namespace nd4j { + namespace ops { + namespace helpers { + void histogramHelper(nd4j::LaunchContext *context, NDArray &input, NDArray &output); + } + } +} + +#endif //DEV_TESTS_HISTOGRAM_H diff --git a/libnd4j/include/ops/special_ops.h b/libnd4j/include/ops/special_ops.h index 141da26cc..33cce53c6 100644 --- a/libnd4j/include/ops/special_ops.h +++ b/libnd4j/include/ops/special_ops.h @@ -747,88 +747,7 @@ static void execSpecial(T *in, Nd4jLong *inShapeBuffer, Z *out, Nd4jLong *outSha int *allocationPointer, Z *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { - int numBins = (int) extraParams[0]; - Z min_val = extraParams[1]; - Z max_val = extraParams[2]; - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - __shared__ Z *bins; - __shared__ int length; - __shared__ Z *reductor; - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - bins = (Z *) shmem; - reductor = ((Z *) allocationPointer) + (numBins * blockIdx.x); - - length = shape::length(xShapeBuffer); - } - __syncthreads(); - - Z binSize = (max_val - min_val) / (numBins); - - for (int e = threadIdx.x; e < numBins; e += blockDim.x) { - bins[e] = (Z) 0.0f; - } - __syncthreads(); - - for (int e = tid; e < length; e+= blockDim.x * gridDim.x) { - int idx = (int) ((dx[e] - min_val) / binSize); - if (idx < 0) idx = 0; - else if (idx >= numBins) idx = numBins - 1; - - nd4j::math::atomics::nd4j_atomicAdd(&bins[idx], (Z) 1.0f); - } - __syncthreads(); - - // transfer shared memory to reduction memory - - - if (gridDim.x > 1) { - unsigned int *tc = (unsigned int *)reductionPointer; - __shared__ bool amLast; - - for (int e = threadIdx.x; e < numBins; e += blockDim.x) { - reductor[e] = bins[e]; - } - __threadfence(); - __syncthreads(); - - if (threadIdx.x == 0) { - unsigned int ticket = atomicInc(&tc[16384], gridDim.x); - amLast = (ticket == gridDim.x - 1); - } - __syncthreads(); - - if (amLast) { - tc[16384] = 0; - - // nullify shared memory for future accumulation - for (int e = threadIdx.x; e < numBins; e += blockDim.x) { - bins[e] = (Z) 0.0f; - } - - // accumulate reduced bins - for (int r = 0; r < gridDim.x; r++) { - Z *ptrBuf = ((Z *)allocationPointer) + (r * numBins); - - for (int e = threadIdx.x; e < numBins; e += blockDim.x) { - bins[e] += ptrBuf[e]; - } - } - __syncthreads(); - - // write them out to Z - for (int e = threadIdx.x; e < numBins; e += blockDim.x) { - result[e] = bins[e]; - } - } - } else { - // if there's only 1 block - just write away data - for (int e = threadIdx.x; e < numBins; e += blockDim.x) { - result[e] = bins[e]; - } - } }; #endif @@ -840,52 +759,7 @@ static void execSpecial(T *in, Nd4jLong *inShapeBuffer, Z *out, Nd4jLong *outSha Nd4jLong *zShapeBuffer, Z *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { - int length = shape::length(xShapeBuffer); - int _threads = 2; - int numBins = (int) extraParams[0]; - int span = (length / _threads) + 8; - - // get min over input - T min_val = extraParams[1]; - T max_val = extraParams[2]; - - T binSize = (max_val - min_val) / (numBins); - - - PRAGMA_OMP_PARALLEL_THREADS(_threads) - { - int tid, start, end; - - int *bins = new int[numBins]; - std::memset(bins, 0, sizeof(int) * numBins); - tid = omp_get_thread_num(); - start = span * tid; - end = span * (tid + 1); - if (end > length) end = length; - - PRAGMA_OMP_SIMD - for (int x = start; x < end; x++) { - int idx = (int) ((dx[x] - min_val) / binSize); - if (idx < 0) - idx = 0; - else if (idx >= numBins) - idx = numBins - 1; - - bins[idx]++; - } - - PRAGMA_OMP_CRITICAL - { - PRAGMA_OMP_SIMD - for (int x = 0; x < numBins; x++) { - result[x] += bins[x]; - } - - } - - delete[] bins; - } } diff --git a/libnd4j/include/ops/special_random_ops.h b/libnd4j/include/ops/special_random_ops.h index 7f6b01aa0..0d90c212a 100644 --- a/libnd4j/include/ops/special_random_ops.h +++ b/libnd4j/include/ops/special_random_ops.h @@ -88,10 +88,10 @@ namespace randomOps { __syncthreads(); // using this loop instead of memcpy - for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) { + for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) cB[e] = dB[e]; - } -// __syncthreads(); // Eliminated due RTX20xx specific + + __syncthreads(); int tid = blockIdx.x * blockDim.x + threadIdx.x; @@ -137,10 +137,6 @@ namespace randomOps { // __syncthreads(); // Eliminated due RTX20xx specific } } - -// __syncthreads(); // Eliminated due RTX20xx specific - if (threadIdx.x == 0 && blockIdx.x == 0) - devRng->rewindH(zLength); } #endif @@ -208,9 +204,6 @@ namespace randomOps { } } } - - // update rng state - rng->rewindH(zLength); } }; @@ -277,7 +270,7 @@ namespace randomOps { for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) cB[e] = dB[e]; -// __syncthreads(); // Eliminated due RTX20xx specific + __syncthreads(); int tid = blockIdx.x * blockDim.x + threadIdx.x; @@ -299,10 +292,6 @@ namespace randomOps { z[epm * zEWS] = (nd4j::math::nd4j_sqrt(t * nd4j::math::nd4j_log(r0)) * nd4j::math::nd4j_sin(two_pi * r1)) * stddev + realMean1; } } - -// __syncthreads(); // Eliminated due RTX20xx specific - if (threadIdx.x == 0 && blockIdx.x == 0) - devRng->rewindH(zLength); } #endif @@ -352,9 +341,6 @@ namespace randomOps { z[epm * zEWS] = z1; } } - - // update rng state - rng->rewindH(zLength); } }; @@ -401,10 +387,10 @@ namespace randomOps { __syncthreads(); // using this loop instead of memcpy - for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) { + for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) cB[e] = dB[e]; - } -// __syncthreads(); // Eliminated due RTX20xx specific + + __syncthreads(); int tid = blockIdx.x * blockDim.x + threadIdx.x; @@ -423,10 +409,6 @@ namespace randomOps { // if trials is set to 0, effectively we just have successful memset z[e * zEWS] = static_cast(success); } - -// __syncthreads(); // Eliminated due RTX20xx specific - if (trials > 0 && threadIdx.x == 0 && blockIdx.x == 0) - devRng->rewindH(zLength * trials); } #endif @@ -472,10 +454,6 @@ namespace randomOps { z[e * zEWS] = static_cast(success); } } - - // update rng state - if (trials > 0) - rng->rewindH(zLength * trials); } }; @@ -522,10 +500,10 @@ namespace randomOps { __syncthreads(); // using this loop instead of memcpy - for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) { + for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) cB[e] = dB[e]; - } -// __syncthreads(); // Eliminated due RTX20xx specific + + __syncthreads(); int tid = blockIdx.x * blockDim.x + threadIdx.x; @@ -591,10 +569,6 @@ namespace randomOps { z[e * zEWS] = static_cast(success); } } - - // update rng state - if (trials > 0) - rng->rewindH(zLength * trials); } }; @@ -676,10 +650,10 @@ namespace randomOps { __syncthreads(); // using this loop instead of memcpy - for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) { + for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) cB[e] = dB[e]; - } -// __syncthreads(); // Eliminated due RTX20xx specific + + __syncthreads(); int tid = blockIdx.x * blockDim.x + threadIdx.x; @@ -724,9 +698,6 @@ namespace randomOps { z[e] = mean + nd4j::DataTypeUtils::min(); } } - - // update rng state - rng->rewindH(zLength); } }; @@ -788,10 +759,10 @@ namespace randomOps { __syncthreads(); // using this loop instead of memcpy - for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) { + for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) cB[e] = dB[e]; - } -// __syncthreads(); // Eliminated due RTX20xx specific + + __syncthreads(); int tid = blockIdx.x * blockDim.x + threadIdx.x; @@ -868,9 +839,6 @@ namespace randomOps { } } } - - // update rng state - rng->rewindH(zLength); } }; diff --git a/libnd4j/include/performance/benchmarking/impl/FullBenchmarkSuit.cpp b/libnd4j/include/performance/benchmarking/impl/FullBenchmarkSuit.cpp index 40ecb6214..2c6de814a 100644 --- a/libnd4j/include/performance/benchmarking/impl/FullBenchmarkSuit.cpp +++ b/libnd4j/include/performance/benchmarking/impl/FullBenchmarkSuit.cpp @@ -1724,9 +1724,9 @@ namespace nd4j { } }; - TransformBenchmark tbSoftmax(transform::StrictOps::SoftMax, "softmax"); + //TransformBenchmark tbSoftmax(transform::StrictOps::SoftMax, "softmax"); - output += helper.runOperationSuit(&tbSoftmax, generator2, batch2, "Softmax"); + //output += helper.runOperationSuit(&tbSoftmax, generator2, batch2, "Softmax"); return output; } diff --git a/libnd4j/include/performance/benchmarking/impl/LightBenchmarkSuit.cpp b/libnd4j/include/performance/benchmarking/impl/LightBenchmarkSuit.cpp index ae9db9b6c..caad37867 100644 --- a/libnd4j/include/performance/benchmarking/impl/LightBenchmarkSuit.cpp +++ b/libnd4j/include/performance/benchmarking/impl/LightBenchmarkSuit.cpp @@ -60,11 +60,11 @@ namespace nd4j { sbRelu.setY(NDArrayFactory::create_(0.0)); TransformBenchmark tbSigmoid(transform::StrictOps::Sigmoid, "sigmoid"); - TransformBenchmark tbSoftmax(transform::StrictOps::SoftMax, "softmax"); + //TransformBenchmark tbSoftmax(transform::StrictOps::SoftMax, "softmax"); output += helper.runOperationSuit(&sbRelu, generator, batch, "RELU"); output += helper.runOperationSuit(&tbSigmoid, generator, batch, "Sigmoid"); - output += helper.runOperationSuit(&tbSigmoid, generator, batch, "Softmax"); + //output += helper.runOperationSuit(&tbSigmoid, generator, batch, "Softmax"); return output; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/LegacyOpMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/LegacyOpMapper.java index 3e1772bd9..a33a26f67 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/LegacyOpMapper.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/LegacyOpMapper.java @@ -49,7 +49,6 @@ import org.nd4j.linalg.api.ops.impl.summarystats.Variance; import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax; import org.nd4j.linalg.api.ops.impl.transforms.comparison.*; import org.nd4j.linalg.api.ops.impl.transforms.custom.*; -import org.nd4j.linalg.api.ops.impl.transforms.floating.Histogram; import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt; import org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt; import org.nd4j.linalg.api.ops.impl.transforms.gradient.*; @@ -748,8 +747,6 @@ public class LegacyOpMapper { public static Class transformFloatingOpClass(int opNum){ switch (opNum){ - case 0: - return Histogram.class; case 1: return Sqrt.class; case 3: diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java index 71d7dfba3..33d03e490 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java @@ -70,10 +70,10 @@ import org.nd4j.linalg.api.ops.impl.shape.bp.SliceBp; import org.nd4j.linalg.api.ops.impl.shape.bp.StridedSliceBp; import org.nd4j.linalg.api.ops.impl.shape.bp.TileBp; import org.nd4j.linalg.api.ops.impl.transforms.Assert; +import org.nd4j.linalg.api.ops.impl.transforms.Histogram; import org.nd4j.linalg.api.ops.impl.transforms.bool.BooleanNot; import org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform; import org.nd4j.linalg.api.ops.impl.transforms.custom.*; -import org.nd4j.linalg.api.ops.impl.transforms.floating.Histogram; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.BinaryMinimalRelativeError; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.*; import org.nd4j.linalg.api.ops.impl.transforms.gradient.*; @@ -888,7 +888,6 @@ public class OpValidation { SELUDerivative.class, SigmoidDerivative.class, org.nd4j.linalg.api.ops.impl.transforms.strict.SigmoidDerivative.class, - org.nd4j.linalg.api.ops.impl.transforms.strict.SoftMaxDerivative.class, SoftSignDerivative.class, TanhDerivative.class, SwishDerivative.class, @@ -987,7 +986,6 @@ public class OpValidation { OneHot.class, BinaryMinimalRelativeError.class, BinaryMinimalRelativeError.class, - Histogram.class, InvertPermutation.class, //Uses integer indices ConfusionMatrix.class, //Integer indices Linspace.class, //No input array @@ -1056,7 +1054,9 @@ public class OpValidation { LogicalAnd.class, LogicalNot.class, LogicalOr.class, - LogicalXor.class + LogicalXor.class, + + Histogram.class ); return new HashSet<>(list); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index 5f45a4a8a..702660d74 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -424,7 +424,6 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentProd.class, org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentSum.class, org.nd4j.linalg.api.ops.impl.transforms.dtype.Cast.class, - org.nd4j.linalg.api.ops.impl.transforms.floating.Histogram.class, org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt.class, org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt.class, org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative.class, @@ -552,7 +551,6 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.transforms.strict.SigmoidDerivative.class, org.nd4j.linalg.api.ops.impl.transforms.strict.Sin.class, org.nd4j.linalg.api.ops.impl.transforms.strict.Sinh.class, - org.nd4j.linalg.api.ops.impl.transforms.strict.SoftMaxDerivative.class, org.nd4j.linalg.api.ops.impl.transforms.strict.SoftPlus.class, org.nd4j.linalg.api.ops.impl.transforms.strict.SoftSign.class, org.nd4j.linalg.api.ops.impl.transforms.strict.Stabilize.class, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/Activation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/Activation.java index 145baf6e0..29876606d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/Activation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/Activation.java @@ -19,20 +19,6 @@ package org.nd4j.linalg.activations; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.activations.impl.*; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.Op; -import org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU; -import org.nd4j.linalg.api.ops.impl.scalar.ScalarSet; -import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; -import org.nd4j.linalg.api.ops.impl.transforms.same.OldIdentity; -import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear; -import org.nd4j.linalg.api.ops.impl.scalar.Step; -import org.nd4j.linalg.api.ops.impl.transforms.strict.*; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.*; -import org.nd4j.linalg.api.ops.impl.transforms.same.Cube; -import org.nd4j.linalg.api.ops.impl.transforms.strict.SigmoidDerivative; -import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftMaxDerivative; -import org.nd4j.linalg.api.ops.impl.transforms.strict.TanhDerivative; /** * This enum is the factory for the activation function. diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Histogram.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Histogram.java new file mode 100644 index 000000000..31937ded7 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Histogram.java @@ -0,0 +1,55 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.transforms; + +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.factory.Nd4j; + +/** + * Histogram op wrapper + * + * @author raver119@gmail.com + */ +public class Histogram extends DynamicCustomOp { + private long numBins; + + public Histogram() { + + } + + public Histogram(INDArray input, INDArray output) { + Preconditions.checkArgument(output.isZ(), "Histogram op output should have integer data type"); + + numBins = output.length(); + inputArguments.add(input); + outputArguments.add(output); + iArguments.add(numBins); + } + + public Histogram(INDArray input, long numBins) { + this.numBins = numBins; + inputArguments.add(input); + iArguments.add(numBins); + } + + @Override + public String opName() { + return "histogram"; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/floating/Histogram.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/floating/Histogram.java deleted file mode 100644 index 84cabd6ef..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/floating/Histogram.java +++ /dev/null @@ -1,114 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.transforms.floating; - -import lombok.val; -import onnx.OnnxProto3; -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.imports.descriptors.properties.PropertyMapping; -import org.nd4j.imports.graphmapper.tf.TFGraphMapper; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformFloatOp; -import org.nd4j.linalg.api.ops.BaseTransformOp; -import org.nd4j.linalg.factory.Nd4j; -import org.tensorflow.framework.AttrValue; -import org.tensorflow.framework.GraphDef; -import org.tensorflow.framework.NodeDef; - -import java.util.HashMap; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; - -/** - * @author raver119@gmail.com - */ -public class Histogram extends BaseTransformFloatOp { - public Histogram(SameDiff sameDiff, SDVariable i_v, boolean inPlace, int numBins) { - super(sameDiff, i_v, inPlace); - this.numBins = numBins; - } - - public Histogram(SameDiff sameDiff, SDVariable i_v, int[] shape, boolean inPlace, Object[] extraArgs, int numBins) { - super(sameDiff, i_v, shape, inPlace, extraArgs); - this.numBins = numBins; - } - - public Histogram(SameDiff sameDiff, SDVariable i_v, Object[] extraArgs, int numBins) { - super(sameDiff, i_v, extraArgs); - this.numBins = numBins; - } - - private int numBins = 0; - - public Histogram() { - //no-op - } - - public Histogram(INDArray x, INDArray z) { - setX(x); - setZ(z); - - //FIXME: int cast - numBins = (int) z.length(); - - double max = x.maxNumber().doubleValue(); - double min = x.minNumber().doubleValue(); - - this.extraArgs = new Object[] {(double) numBins, min, max}; - } - - public Histogram(INDArray x, int numberOfBins) { - this(x, Nd4j.create(x.dataType(), numberOfBins)); - } - - @Override - public Map propertiesForFunction() { - Map ret = new LinkedHashMap<>(); - ret.put("numBins",numBins); - return ret; - } - - @Override - public int opNum() { - return 0; - } - - @Override - public String opName() { - return "histogram"; - } - - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); - } - - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("No tensorflow op opName found for " + opName()); - } - - - @Override - public List doDiff(List f1) { - throw new UnsupportedOperationException("Not supported"); - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftMaxDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftMaxDerivative.java deleted file mode 100644 index b9bf8d2fe..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftMaxDerivative.java +++ /dev/null @@ -1,76 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.transforms.strict; - -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.linalg.api.buffer.DataBuffer; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.CustomOp; -import org.nd4j.linalg.api.ops.Op; -import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; - -import java.nio.Buffer; - -/** - * Softmax derivative - * - * @author Adam Gibson - */ -public class SoftMaxDerivative extends SoftMax { - public SoftMaxDerivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) { - super(sameDiff, new SDVariable[]{i_v1, i_v2}); - } - - public SoftMaxDerivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) { - super(sameDiff, new SDVariable[]{ i_v1, i_v2}, inPlace); - } - - public SoftMaxDerivative(INDArray x, INDArray z) { - super(x, z); - } - - public SoftMaxDerivative(INDArray x) { - super(x, x); - } - - public SoftMaxDerivative() {} - - - - @Override - public int opNum() { - return 1; - } - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); - } - - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("No tensorflow op opName found for " + opName()); - } - - @Override - public String opName() { - return "_softmaxderivative"; - } -} diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java index 2c5b7afa8..c85202e35 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java @@ -1529,85 +1529,40 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda Nd4j.getExecutioner().commit(); - AllocationPoint old = allocationPoint; - allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, new AllocationShape(length, elementSize, dataType()), false); + AllocationPoint old = allocationPoint; + allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, new AllocationShape(length, elementSize, dataType()), false); - Nd4j.getDeallocatorService().pickObject(this); - trackingPoint = allocationPoint.getObjectId(); + Nd4j.getDeallocatorService().pickObject(this); + trackingPoint = allocationPoint.getObjectId(); + val oldLength = this.length; + this.length = length; - switch(dataType()){ - case DOUBLE: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asDoublePointer(); - indexer = DoubleIndexer.create((DoublePointer) pointer); - break; - case FLOAT: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asFloatPointer(); - indexer = FloatIndexer.create((FloatPointer) pointer); - break; - case BFLOAT16: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asShortPointer(); - indexer = Bfloat16Indexer.create((ShortPointer) pointer); - break; - case HALF: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asShortPointer(); - indexer = ShortIndexer.create((ShortPointer) pointer); - break; - case LONG: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asLongPointer(); - indexer = LongIndexer.create((LongPointer) pointer); - break; - case UINT64: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asLongPointer(); - indexer = LongIndexer.create((LongPointer) pointer); - break; - case INT: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asIntPointer(); - indexer = IntIndexer.create((IntPointer) pointer); - break; - case UINT32: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asIntPointer(); - indexer = IntIndexer.create((IntPointer) pointer); - break; - case SHORT: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asShortPointer(); - indexer = ShortIndexer.create((ShortPointer) pointer); - break; - case UINT16: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asShortPointer(); - indexer = UShortIndexer.create((ShortPointer) pointer); - break; - case BYTE: - this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asBytePointer(); - indexer = ByteIndexer.create((BytePointer) pointer); - break; - default: - throw new UnsupportedOperationException(); - } + // if original buffer had host pointer allocated, we'll reallocate host buffer as well + if (old.getHostPointer() != null) { + lazyAllocateHostPointer(); + } - CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); - NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(allocationPoint.getDevicePointer(), 0, length * elementSize, 0, context.getSpecialStream()); + val context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(allocationPoint.getDevicePointer(), 0, length * elementSize, 0, context.getSpecialStream()); - MemcpyDirection direction = MemcpyDirection.DEVICE_TO_DEVICE; - val perfD = PerformanceTracker.getInstance().helperStartTransaction(); + MemcpyDirection direction = MemcpyDirection.DEVICE_TO_DEVICE; + val perfD = PerformanceTracker.getInstance().helperStartTransaction(); - if (old.isActualOnDeviceSide()) { - NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), old.getDevicePointer(), this.length * elementSize, CudaConstants.cudaMemcpyDeviceToDevice, context.getSpecialStream()); - } else if (old.isActualOnHostSide()) { - NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), old.getHostPointer(), this.length * elementSize, CudaConstants.cudaMemcpyHostToDevice, context.getSpecialStream()); - direction = MemcpyDirection.HOST_TO_DEVICE; - } + if (old.isActualOnDeviceSide()) { + NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), old.getDevicePointer(), oldLength * elementSize, CudaConstants.cudaMemcpyDeviceToDevice, context.getSpecialStream()); + } else if (old.isActualOnHostSide()) { + NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), old.getHostPointer(), oldLength * elementSize, CudaConstants.cudaMemcpyHostToDevice, context.getSpecialStream()); + direction = MemcpyDirection.HOST_TO_DEVICE; + } - context.getSpecialStream().synchronize(); + context.getSpecialStream().synchronize(); - PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint.getDeviceId(), perfD, allocationPoint.getNumberOfBytes(), direction); + PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint.getDeviceId(), perfD, allocationPoint.getNumberOfBytes(), direction); - allocationPoint.tickDeviceWrite(); - // we're keeping pointer reference for JVM - pointer.address(); + allocationPoint.tickDeviceWrite(); - - // we need to update length with new value now - //this.length = length; + // we need to update length with new value now + //this.length = length; if(isAttached()){ // do nothing here, that's workspaces } else{ @@ -1619,7 +1574,10 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda @Override public long capacity() { - return pointer.capacity(); + if (allocationPoint.getHostPointer() != null) + return pointer.capacity(); + else + return length; } @Override diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index c5c14d034..dce16b44d 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -16754,6 +16754,27 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); private native void allocate(); } // #endif + + /** + * This operation calculates number of entries per bin + */ +// #if NOT_EXCLUDED(OP_histogram) + @Namespace("nd4j::ops") public static class histogram extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public histogram(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public histogram(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public histogram position(long position) { + return (histogram)super.position(position); + } + + public histogram() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index 339e3e6b4..ab2685be2 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -79,7 +79,6 @@ import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.Axpy; import org.nd4j.linalg.api.ops.impl.transforms.same.OldReverse; import org.nd4j.linalg.api.ops.impl.transforms.same.Sign; import org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh; -import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftMaxDerivative; import org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.checkutil.NDArrayCreationUtil; @@ -931,13 +930,6 @@ public class Nd4jTestsC extends BaseNd4jTest { System.out.println("Second: OK"); } - @Test - public void testSoftmaxDerivative() { - INDArray input = Nd4j.create(new double[] {-1.07, -0.01, 0.45, 0.95, 0.45, 0.16, 0.20, 0.80, 0.89, 0.25}).reshape(1, -1).transpose(); - INDArray output = Nd4j.create(10, 1); - Nd4j.getExecutioner().exec(new SoftMaxDerivative(input, output)); - } - @Test public void testVStackDifferentOrders() { @@ -7245,19 +7237,6 @@ public class Nd4jTestsC extends BaseNd4jTest { } } - - @Test - public void testInvalidTransformsSoftmax(){ - INDArray arr = Nd4j.zeros(2,3,4); - try{ - Transforms.softmax(arr); - fail("Expected exception"); - } catch (IllegalArgumentException e){ - //OK - assertTrue(e.getMessage().contains("rank 2")); - } - } - @Test public void testEmptyCasting(){ for(val from : DataType.values()) { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/IntDataBufferTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/IntDataBufferTests.java index 1fb9e1154..4121cf1ea 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/IntDataBufferTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/buffer/IntDataBufferTests.java @@ -17,6 +17,7 @@ package org.nd4j.linalg.api.buffer; +import lombok.val; import org.junit.Test; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.memory.MemoryWorkspace; @@ -28,6 +29,7 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import java.io.*; +import java.util.Arrays; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -85,10 +87,12 @@ public class IntDataBufferTests extends BaseNd4jTest { public void testReallocation() { DataBuffer buffer = Nd4j.createBuffer(new int[] {1, 2, 3, 4}); assertEquals(4, buffer.capacity()); - int[] old = buffer.asInt(); buffer.reallocate(6); + val old = buffer.asInt(); assertEquals(6, buffer.capacity()); - assertArrayEquals(old, buffer.asInt()); + val newContent = buffer.asInt(); + assertEquals(6, newContent.length); + assertArrayEquals(old, Arrays.copyOf(newContent, old.length)); } @Test @@ -98,12 +102,14 @@ public class IntDataBufferTests extends BaseNd4jTest { MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "SOME_ID"); DataBuffer buffer = Nd4j.createBuffer(new int[] {1, 2, 3, 4}); - int[] old = buffer.asInt(); + val old = buffer.asInt(); assertTrue(buffer.isAttached()); assertEquals(4, buffer.capacity()); buffer.reallocate(6); assertEquals(6, buffer.capacity()); - assertArrayEquals(old, buffer.asInt()); + val newContent = buffer.asInt(); + assertEquals(6, newContent.length); + assertArrayEquals(old, Arrays.copyOf(newContent, old.length)); workspace.close(); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/rng/RngTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/rng/RngTests.java index 82f504d95..b068a1a65 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/rng/RngTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/rng/RngTests.java @@ -82,12 +82,12 @@ public class RngTests extends BaseNd4jTest { INDArray narr = Nd4j.randn('c', rows, cols); assertArrayEquals(new long[] {rows, cols}, narr.shape()); assertEquals('c', narr.ordering()); - assertEquals(narr.meanNumber().doubleValue(), 0.0, 0.05); + assertEquals(0.0, narr.meanNumber().doubleValue(), 0.05); INDArray narr2 = Nd4j.randn('f', rows, cols); assertArrayEquals(new long[] {rows, cols}, narr2.shape()); assertEquals('f', narr2.ordering()); - assertEquals(narr2.meanNumber().doubleValue(), 0.0, 0.05); + assertEquals(0.0, narr2.meanNumber().doubleValue(), 0.05); INDArray narr3 = Nd4j.randn('c', new int[] {rows, cols, dim2}); assertArrayEquals(new long[] {rows, cols, dim2}, narr3.shape()); @@ -97,7 +97,7 @@ public class RngTests extends BaseNd4jTest { INDArray narr4 = Nd4j.randn('f', new int[] {rows, cols, dim2}); assertArrayEquals(new long[] {rows, cols, dim2}, narr4.shape()); assertEquals('f', narr4.ordering()); - assertEquals(narr4.meanNumber().doubleValue(), 0.0, 0.05); + assertEquals(0.0, narr4.meanNumber().doubleValue(), 0.05); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/crash/CrashTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/crash/CrashTest.java index fe02070fd..c3a66200b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/crash/CrashTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/crash/CrashTest.java @@ -29,7 +29,6 @@ import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance; import org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax; import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; import org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt; -import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftMaxDerivative; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.indexing.BooleanIndexing; @@ -159,7 +158,6 @@ public class CrashTest extends BaseNd4jTest { // logisoftmax, softmax & softmax derivative Nd4j.getExecutioner().exec((CustomOp) new SoftMax(x)); - Nd4j.getExecutioner().exec((CustomOp) new SoftMaxDerivative(x)); Nd4j.getExecutioner().exec((CustomOp) new LogSoftMax(x)); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/indexing/BooleanIndexingTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/indexing/BooleanIndexingTest.java index bd4bedf66..324efed0e 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/indexing/BooleanIndexingTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/indexing/BooleanIndexingTest.java @@ -441,8 +441,8 @@ public class BooleanIndexingTest extends BaseNd4jTest { Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.ANY_PANIC); NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(true); INDArray arr = Nd4j.linspace(1,4,4, Nd4j.dataType()).reshape(2,2); - INDArray filtered = BooleanIndexing.chooseFrom(new INDArray[]{arr},Arrays.asList(2.0), Collections.emptyList(),new GreaterThan()); - assertEquals(4,filtered.length()); + INDArray filtered = BooleanIndexing.chooseFrom(new INDArray[]{arr}, Arrays.asList(2.0), Collections.emptyList(),new GreaterThan()); + assertEquals(2, filtered.length()); } @@ -450,7 +450,7 @@ public class BooleanIndexingTest extends BaseNd4jTest { public void testChooseGreaterThanZero() { INDArray zero = Nd4j.linspace(0,4,4, Nd4j.dataType()); INDArray filtered = BooleanIndexing.chooseFrom(new INDArray[]{zero},Arrays.asList(0.0), Collections.emptyList(),new GreaterThan()); - assertEquals(3,filtered.length()); + assertEquals(3, filtered.length()); } @Test diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/mixed/MixedDataTypesTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/mixed/MixedDataTypesTests.java index 3dc773981..1e86d4611 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/mixed/MixedDataTypesTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/mixed/MixedDataTypesTests.java @@ -331,7 +331,7 @@ public class MixedDataTypesTests extends BaseNd4jTest { assertArrayEquals(exp, arr); } - @Test(expected = IllegalArgumentException.class) + @Test(expected = RuntimeException.class) public void testTypesValidation_3() { val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new long[]{4}, DataType.INT); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/DerivativeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/DerivativeTests.java index c89ee8dfd..ff9582378 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/DerivativeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/DerivativeTests.java @@ -195,47 +195,6 @@ public class DerivativeTests extends BaseNd4jTest { } - @Test - public void testSoftMaxDerivative() { - Random r = new Random(12345L); - - int[] mb = new int[] {10, 2, 1}; - for (int minibatch : mb) { - System.out.println("Minibatch size: " + minibatch); - INDArray z = Nd4j.zeros(minibatch, 5); - double[][] in = new double[minibatch][5]; - double[][] softmax = new double[minibatch][5]; - double[][] expOut = new double[minibatch][5]; - for (int i = 0; i < minibatch; i++) { - double rowSumExp = 0.0; - for (int j = 0; j < 5; j++) { - in[i][j] = 10 * r.nextDouble(); - z.putScalar(new int[] {i, j}, in[i][j]); - rowSumExp += FastMath.exp(in[i][j]); - } - for (int j = 0; j < 5; j++) { - softmax[i][j] = FastMath.exp(in[i][j]) / rowSumExp; - expOut[i][j] = softmax[i][j] * (1.0 - softmax[i][j]); - } - } - - INDArray sm = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(z.dup()))[0]; - INDArray zPrime = Nd4j.getExecutioner().exec((CustomOp) new SoftMaxDerivative(z))[0]; - System.out.println(Arrays.toString(sm.data().asDouble())); - System.out.println(Arrays.toString(zPrime.data().asDouble())); - assertNotEquals(sm, zPrime); - - for (int i = 0; i < minibatch; i++) { - for (int j = 0; j < 5; j++) { - double relError = Math.abs(expOut[i][j] - zPrime.getDouble(i, j)) - / (Math.abs(expOut[i][j]) + Math.abs(zPrime.getDouble(i, j))); - // System.out.println("Error: " + relError); - assertTrue(relError < REL_ERROR_TOLERANCE); - } - } - } - } - @Test public void testSoftPlusDerivative() { @@ -384,180 +343,6 @@ public class DerivativeTests extends BaseNd4jTest { } } - @Test - public void softmaxsimpleLossTest() { - /* - Softmax derivative is correct if it is standalone - But when we are applying it in the chain rule the current derivative function is incomplete. - For this test, I am assuming that the function off interest is just MSE - What the fix is: - We need the derivative of a softmax needs to return a rank 2 matrix. - Right now we get only the diagonal elements of this matrix - http://stats.stackexchange.com/questions/79454/softmax-layer-in-a-neural-network - */ - //random array represeting preout - INDArray X = Nd4j.rand(1, 2); - //preout transformed to y_hat with softmax - INDArray YHat = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(X.dup()))[0]; - - //hard coding something to construct a function with, using MSE - INDArray Y = Nd4j.create(new double[][] {{0.123, 1 - 0.123}}); - - //This is the MSE now - double lossHere = Transforms.pow(Y.sub(YHat), 2).sumNumber().doubleValue(); - - INDArray softmaxDer = Nd4j.getExecutioner().exec((CustomOp) new SoftMaxDerivative(X.dup()))[0]; - - //the way we apply the chain rule now is 2*(y-yhat)*softmaxder - INDArray dLdY = Y.sub(YHat).mul(-2); - INDArray currentGradient = dLdY.mul(softmaxDer); - - //what I think we should be doing - // we have x0, x1 -> y0,y1 - //we need the derivatives of the output of the softmax wrt every input (x0,x1) - // we only have dy0/dx0 and dy1/dx1 - // we also need dy0/dx1 and dy1/dx0 - // the below is the chain rule in calc applied when L is a function of y0,y1; y0 and y1 are in turn functions of BOTH (x0 and x1) - // dL/dx0 = (dl/dy0) * (dy0/dx0) + (dL/dy1) * (dy1/dx0) - // dL/dx1 = (dl/dy0) * (dy0/dx1) + (dL/dy1) * (dy1/dx1) - // worked it out on paper and googled it (should have googled first, gave formula from link above) - // dy0/dx0 = y0*(1-y0) = y0*y1 - // dy1/dx0 = -y1*(1-y1) = -y0*y1 - // dy0/dx1 = -y0*(1-y0) = -y0*y1 - // dy1/dx1 = y1*(1-y1) = y0*y1 - //[ dL/dy0 dL/dy1] [[dy0/dx0 dy1/dx0] [dy0/dx1 dy1/dx1]] - double y0y1 = softmaxDer.getDouble(0, 0); - //hack but this is what we need to implement, straightforward here but complicated for >2 - //INDArray mysoftmaxDer = Nd4j.create(new double[][] {{y0y1,y0y1*-1},{-1*y0y1,y0y1}}); - INDArray mysoftmaxDer = correctSoftmax(X); - INDArray myGradient = mysoftmaxDer.mulRowVector(dLdY).sum(1); - - double epsilon = 0.0001; - INDArray Xiplus, Ximinus; - INDArray YHatplus, YHatminus; - double lossplus, lossminus; - - INDArray numGradient = Nd4j.zeros(1, 2); - - for (int i = 0; i < 2; i++) { - /* change X one value one at a time */ - - // +epsilon - double x = X.getDouble(0, i); - Xiplus = X.dup(); - Xiplus.put(0, i, x + epsilon); - YHatplus = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(Xiplus.dup()))[0]; - lossplus = Transforms.pow(Y.sub(YHatplus), 2).sumNumber().doubleValue(); - - // -epsilon - Ximinus = X.dup(); - Ximinus.put(0, i, x - epsilon); - YHatminus = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(Ximinus.dup()))[0]; - lossminus = Transforms.pow(Y.sub(YHatminus), 2).sumNumber().doubleValue(); - - double gradienti = (lossplus - lossminus) / (2 * epsilon); - numGradient.put(0, i, gradienti); - } - System.out.println("========================="); - System.out.println("NUMERICAL:"); - System.out.println(numGradient); - System.out.println("\nCURRENTLY:"); - System.out.println(currentGradient); - System.out.println("\nMY GRADIENT:"); - System.out.println(myGradient + "\n"); - System.out.println( - "Because of the nature of the derivative of the softmax for length = 2, our current method will make it off by a factor of 2"); - System.out.println("========================="); - } - - - @Test - public void softmaxsimplelongerlengthLossTest() { - /* - Read comments in earlier test for length = 2 - */ - //random array represeting preout - int someLength = 7; - - INDArray X = Nd4j.rand(1, someLength); - //preout transformed to y_hat with softmax - INDArray YHat = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(X.dup()))[0]; - - //hard coding something to construct a function with, using MSE - INDArray temp = Nd4j.rand(1, someLength); - INDArray Y = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(temp))[0]; - - //This is the MSE now - double lossHere = Transforms.pow(Y.sub(YHat), 2).sumNumber().doubleValue(); - - INDArray softmaxDer = Nd4j.getExecutioner().exec((CustomOp) new SoftMaxDerivative(X.dup()))[0]; - - //the way we apply the chain rule now is 2*(y-yhat)*softmaxder - INDArray dLdY = Y.sub(YHat).mul(-2); - INDArray currentGradient = dLdY.mul(softmaxDer); - - INDArray mysoftmaxDer = correctSoftmax(X); - INDArray myGradient = mysoftmaxDer.mulRowVector(dLdY).sum(1); - - double epsilon = 0.0001; - INDArray Xiplus, Ximinus; - INDArray YHatplus, YHatminus; - double lossplus, lossminus; - - INDArray numGradient = Nd4j.zeros(1, someLength); - - for (int i = 0; i < someLength; i++) { - /* change X one value one at a time */ - - // +epsilon - double x = X.getDouble(0, i); - Xiplus = X.dup(); - Xiplus.put(0, i, x + epsilon); - YHatplus = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(Xiplus.dup()))[0]; - lossplus = Transforms.pow(Y.sub(YHatplus), 2).sumNumber().doubleValue(); - - // -epsilon - Ximinus = X.dup(); - Ximinus.put(0, i, x - epsilon); - YHatminus = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(Ximinus.dup()))[0]; - lossminus = Transforms.pow(Y.sub(YHatminus), 2).sumNumber().doubleValue(); - - double gradienti = (lossplus - lossminus) / (2 * epsilon); - numGradient.put(0, i, gradienti); - } - System.out.println("========================="); - System.out.println("NUMERICAL GRADIENT:"); - System.out.println(new NDArrayStrings(6).format(numGradient).toString()); - System.out.println("\nANALYTIC USING EXISTING SOFTMAX DER:"); - System.out.println(new NDArrayStrings(6).format(currentGradient).toString()); - System.out.println("\nGRADIENT USING MY VERSION OF SOFTMAX DER:"); - System.out.println(new NDArrayStrings(6).format(myGradient).toString()); - System.out.println("========================="); - } - - - public static INDArray correctSoftmax(INDArray X) { - // this is only for X a row vector - // should return rank 2 matrix diagonal elements are pi*(1-pi) - //rest are -pi*pj - INDArray p = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(X.dup()))[0]; - INDArray pCol = p.dup().transpose(); - INDArray pipj = pCol.mmul(p); - pipj.muli(-1); - - //so now pipj is correct except for the diagonal elements - // which by the way is what our current softmax der gives us - INDArray diagp = Nd4j.getExecutioner().exec((CustomOp) new SoftMaxDerivative(X.dup()))[0]; - - - //ugly for loop to correct diag elements - for (int i = 0; i < X.length(); i++) { - pipj.put(i, i, diagp.getDouble(0, i)); - } - - return pipj; - } - @Override public char ordering() { return 'f'; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java index 614044a26..dfcf5dc79 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java @@ -50,9 +50,10 @@ import org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseSubtraction; import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThan; import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThan; import org.nd4j.linalg.api.ops.impl.summarystats.Variance; +import org.nd4j.linalg.api.ops.impl.transforms.Histogram; +import org.nd4j.linalg.api.ops.impl.transforms.HistogramFixedWidth; import org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax; import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; -import org.nd4j.linalg.api.ops.impl.transforms.floating.Histogram; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp; import org.nd4j.linalg.api.ops.impl.transforms.strict.Exp; @@ -876,14 +877,14 @@ public class OpExecutionerTestsC extends BaseNd4jTest { @Test public void testHistogram1() { INDArray x = Nd4j.linspace(1, 1000, 100000, DataType.DOUBLE); - INDArray z = Nd4j.zeros(DataType.DOUBLE,new long[]{20}); + INDArray z = Nd4j.zeros(DataType.LONG,new long[]{20}); INDArray xDup = x.dup(); INDArray zDup = z.dup(); - INDArray zExp = Nd4j.create(DataType.DOUBLE, 20).assign(5000); + INDArray zExp = Nd4j.create(DataType.LONG, 20).assign(5000); - Histogram histogram = new Histogram(x, z); + val histogram = new Histogram(x, z); Nd4j.getExecutioner().exec(histogram); @@ -899,16 +900,13 @@ public class OpExecutionerTestsC extends BaseNd4jTest { public void testHistogram2() { INDArray x = Nd4j.create(new float[] {0f, 0f, 0f, 5f, 5f, 5f, 10f, 10f, 10f}); - INDArray xDup = x.dup(); - INDArray zExp = Nd4j.zeros(DataType.FLOAT, 10).putScalar(0, 3f).putScalar(5, 3f).putScalar(9, 3f); + INDArray zExp = Nd4j.zeros(DataType.LONG, 10).putScalar(0, 3).putScalar(5, 3).putScalar(9, 3); - Histogram histogram = new Histogram(x, 10); + val histogram = new Histogram(x, 10); - Nd4j.getExecutioner().exec(histogram); - - INDArray z = histogram.z(); + val z = Nd4j.getExecutioner().exec(histogram)[0]; assertEquals(xDup, x); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/JsonSerdeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/JsonSerdeTests.java index 68d704661..ac19b87f2 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/JsonSerdeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/JsonSerdeTests.java @@ -19,6 +19,7 @@ package org.nd4j.linalg.serde; import lombok.Data; import lombok.EqualsAndHashCode; import lombok.NoArgsConstructor; +import lombok.val; import org.junit.Test; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; @@ -58,7 +59,7 @@ public class JsonSerdeTests extends BaseNd4jTest { Nd4j.getRandom().setSeed(12345); INDArray in = Nd4j.rand(DataType.DOUBLE, 3, 4).muli(20).subi(10); - ObjectMapper om = new ObjectMapper(); + val om = new ObjectMapper(); for (DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.LONG, DataType.INT, DataType.SHORT, DataType.BYTE, DataType.UBYTE, DataType.BOOL, DataType.UTF8}) { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java index 5f8f42759..ff26d7b1b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java @@ -673,7 +673,7 @@ public class BasicWorkspaceTests extends BaseNd4jTest { workspace.initializeWorkspace(); long reqMemory = 12 * Nd4j.sizeOfDataType(arrayCold.dataType()); - assertEquals(reqMemory + reqMemory % 8 + Nd4j.sizeOfDataType(DOUBLE), workspace.getCurrentSize()); + assertEquals(reqMemory + reqMemory % 8, workspace.getCurrentSize()); log.info("-----------------------"); @@ -692,7 +692,7 @@ public class BasicWorkspaceTests extends BaseNd4jTest { array.addi(1.0); - assertEquals(reqMem + reqMem % 8 + Nd4j.sizeOfDataType(DOUBLE), workspace.getPrimaryOffset()); + assertEquals(reqMem + reqMem % 8, workspace.getPrimaryOffset()); assertEquals("Failed on iteration " + x, 10, array.sumNumber().doubleValue(), 0.01); @@ -746,7 +746,7 @@ public class BasicWorkspaceTests extends BaseNd4jTest { INDArray dup = array.dup(); - assertEquals((reqMemory + reqMemory % 8) * 2 + Nd4j.sizeOfDataType(DOUBLE), workspace.getPrimaryOffset()); + assertEquals((reqMemory + reqMemory % 8) * 2, workspace.getPrimaryOffset()); assertEquals(5, dup.sumNumber().doubleValue(), 0.01); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/DebugModeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/DebugModeTests.java index 263a8c5da..0d46cb4ec 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/DebugModeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/DebugModeTests.java @@ -91,7 +91,7 @@ public class DebugModeTests extends BaseNd4jTest { assertEquals(0, ws.getDeviceOffset()); // array buffer should be spilled now - assertEquals(10 * 10 * Nd4j.sizeOfDataType(DataType.DOUBLE) + Nd4j.sizeOfDataType(DataType.DOUBLE), ws.getSpilledSize()); + assertEquals(10 * 10 * Nd4j.sizeOfDataType(DataType.DOUBLE), ws.getSpilledSize()); } } @@ -118,7 +118,7 @@ public class DebugModeTests extends BaseNd4jTest { assertEquals(0, ws.getDeviceOffset()); // array buffer should be spilled now - assertEquals(10 * 10 * Nd4j.sizeOfDataType(DataType.DOUBLE) + Nd4j.sizeOfDataType(DataType.DOUBLE), ws.getSpilledSize()); + assertEquals(10 * 10 * Nd4j.sizeOfDataType(DataType.DOUBLE), ws.getSpilledSize()); } try (val ws = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfig, "R_119_1992")) {