From df84bc7255b53c359ab1e2284eab73de4d390a79 Mon Sep 17 00:00:00 2001 From: raver119 Date: Tue, 27 Aug 2019 10:37:10 +0300 Subject: [PATCH] [WIP] More tweaks (#173) * CUDA empty reduction Signed-off-by: raver119 * - listdiff synchronization fix for CUDA - listdiff test Signed-off-by: raver119 * - IndexReduce ops now allow INDEXING_TYPES output - topK op accepts only INDEXING_TYPES as output Signed-off-by: raver119 --- libnd4j/blas/NDArray.hpp | 4 +- libnd4j/blas/cpu/NativeOpExecutioner.cpp | 6 +- libnd4j/blas/cuda/NativeOpExecutioner.cu | 14 +-- libnd4j/include/helpers/Loops.h | 6 +- .../helpers/cpu/loops/IndexReductionLoops.cpp | 35 +++--- libnd4j/include/loops/cpu/indexreduce.cpp | 32 +++--- libnd4j/include/loops/cuda/indexreduce.cu | 105 +++++++++--------- libnd4j/include/loops/indexreduce.h | 16 +-- .../declarable/generic/parity_ops/top_k.cpp | 2 +- .../ops/declarable/helpers/impl/listdiff.cpp | 25 +---- libnd4j/include/ops/ops.h | 12 +- .../ops/executioner/CudaExecutioner.java | 47 +++++++- .../nd4j/linalg/custom/CustomOpsTests.java | 40 +++++++ .../org/nd4j/linalg/shape/EmptyTests.java | 10 ++ 14 files changed, 217 insertions(+), 137 deletions(-) diff --git a/libnd4j/blas/NDArray.hpp b/libnd4j/blas/NDArray.hpp index fdbcae49f..72b029c0b 100644 --- a/libnd4j/blas/NDArray.hpp +++ b/libnd4j/blas/NDArray.hpp @@ -3590,8 +3590,8 @@ void NDArray::applyIndexReduce(nd4j::indexreduce::Ops op, NDArray* target, const if (isS()) throw std::runtime_error("NDArray::applyIndexReduce: you can't use this method on String array!"); - if (target->dataType() != nd4j::DataType::INT64) - throw std::runtime_error("NDArray::applyIndexReduce operations return INT64"); + if (target->dataType() != nd4j::DataType::INT64 && target->dataType() != nd4j::DataType::INT32) + throw std::runtime_error("NDArray::applyIndexReduce operations return INT32/INT64"); void* params = extraParams != nullptr ? const_cast(extraParams)->argumentsAsT(this->dataType()) : nullptr; diff --git a/libnd4j/blas/cpu/NativeOpExecutioner.cpp b/libnd4j/blas/cpu/NativeOpExecutioner.cpp index e320b4f57..b2ce7846a 100644 --- a/libnd4j/blas/cpu/NativeOpExecutioner.cpp +++ b/libnd4j/blas/cpu/NativeOpExecutioner.cpp @@ -79,9 +79,10 @@ void NativeOpExecutioner::execIndexReduceScalar(nd4j::LaunchContext *lc, int op #endif auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); + auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto hz = reinterpret_cast(hZ); - BUILD_SINGLE_SELECTOR(xType, hz[0] = functions::indexreduce::IndexReduce, ::execScalar(opNum,hX,hXShapeInfo,extraParams), LIBND4J_TYPES); + BUILD_DOUBLE_SELECTOR(xType, zType, hz[0] = functions::indexreduce::IndexReduce, ::execScalar(opNum,hX,hXShapeInfo,extraParams), LIBND4J_TYPES, INDEXING_TYPES); } //////////////////////////////////////////////////////////////////////// @@ -111,9 +112,10 @@ void NativeOpExecutioner::execIndexReduce(nd4j::LaunchContext *lc, #endif auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); + auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); Nd4jLong* hz = reinterpret_cast(hZ); - BUILD_SINGLE_SELECTOR(xType, functions::indexreduce::IndexReduce, ::exec(opNum, hX, hXShapeInfo, extraParams, hz, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets), LIBND4J_TYPES); + BUILD_DOUBLE_SELECTOR(xType, zType, functions::indexreduce::IndexReduce, ::exec(opNum, hX, hXShapeInfo, extraParams, hz, hZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets), LIBND4J_TYPES, INDEXING_TYPES); // BUILD_SINGLE_SELECTOR(xType, functions::indexreduce::IndexReduce, ::exec(opNum, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffsets), LIBND4J_TYPES); } diff --git a/libnd4j/blas/cuda/NativeOpExecutioner.cu b/libnd4j/blas/cuda/NativeOpExecutioner.cu index b3573c7ab..8c4f1d3fa 100644 --- a/libnd4j/blas/cuda/NativeOpExecutioner.cu +++ b/libnd4j/blas/cuda/NativeOpExecutioner.cu @@ -475,12 +475,12 @@ void NativeOpExecutioner::execIndexReduce(nd4j::LaunchContext *lc, auto numBlocks = shape::length(hZShapeInfo); dim3 launchDims(numBlocks, 256, 32768); - if (zType != nd4j::DataType::INT64) - throw datatype_exception::build("NativeOpExecutioner::execIndexReduce requires Z operand to have INT64 type", zType); + if (zType != nd4j::DataType::INT64 && zType != nd4j::DataType::INT32) + throw datatype_exception::build("NativeOpExecutioner::execIndexReduce requires Z operand to have INT32/INT64 type", zType); auto dz = reinterpret_cast(dZ); - BUILD_SINGLE_SELECTOR(xType, functions::indexreduce::IndexReduce, ::executeIndexReduce(launchDims, stream, opNum, dX, dXShapeInfo, shape::rank(hXShapeInfo), extraParams, dz, dZShapeInfo, shape::rank(hZShapeInfo), dimension, dimensionLength, 1, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets), LIBND4J_TYPES); + BUILD_DOUBLE_SELECTOR(xType, zType, functions::indexreduce::IndexReduce, ::executeIndexReduce(launchDims, stream, opNum, dX, dXShapeInfo, shape::rank(hXShapeInfo), extraParams, dz, dZShapeInfo, shape::rank(hZShapeInfo), dimension, dimensionLength, 1, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets), LIBND4J_TYPES, INDEXING_TYPES); // TODO: remove after the release auto res = cudaStreamSynchronize(*stream); @@ -567,12 +567,12 @@ void NativeOpExecutioner::execIndexReduceScalar(nd4j::LaunchContext *lc, // FIXME: we want Z to be one of integer types //if (!DataTypeUtils::isZ(zType)) // throw nd4j::datatype_exception("NativeOpExecutioner::execIndexReduceScalar requires Z operand to have one of integer types") - if (zType != nd4j::DataType::INT64) - throw nd4j::datatype_exception::build("NativeOpExecutioner::execIndexReduceScalar requires Z operand to have INT64 data type", zType); + if (zType != nd4j::DataType::INT64 && zType != nd4j::DataType::INT32) + throw nd4j::datatype_exception::build("NativeOpExecutioner::execIndexReduceScalar requires Z operand to have INT32/INT64 data type", zType); auto dz = reinterpret_cast(dZ); - BUILD_SINGLE_SELECTOR(xType, functions::indexreduce::IndexReduce, ::executeIndexReduceScalar(launchDims, stream, + BUILD_DOUBLE_SELECTOR(xType, zType, functions::indexreduce::IndexReduce, ::executeIndexReduceScalar(launchDims, stream, opNum, dX, dXShapeInfo, shape::rank(hXShapeInfo), extraParams, @@ -580,7 +580,7 @@ void NativeOpExecutioner::execIndexReduceScalar(nd4j::LaunchContext *lc, nullptr, 0, 1, allocationPointer, reductionPointer, - nullptr, nullptr), LIBND4J_TYPES); + nullptr, nullptr), LIBND4J_TYPES, INDEXING_TYPES); // TODO: remove after the release auto res = cudaStreamSynchronize(*stream); if (res != 0) diff --git a/libnd4j/include/helpers/Loops.h b/libnd4j/include/helpers/Loops.h index bda04414f..d04d3315d 100644 --- a/libnd4j/include/helpers/Loops.h +++ b/libnd4j/include/helpers/Loops.h @@ -80,14 +80,14 @@ namespace nd4j { }; - template + template class ND4J_EXPORT IndexReductionLoops { private: public: - static void wrapIndexReduce(const int opNum, void* x, Nd4jLong* xShapeInfo, Nd4jLong* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* extraParams); + static void wrapIndexReduce(const int opNum, void* x, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* extraParams); template - static void loopIndexReduce(X* x, Nd4jLong* xShapeInfo, Nd4jLong* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, X* extraParams); + static void loopIndexReduce(X* x, Nd4jLong* xShapeInfo, Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, X* extraParams); }; diff --git a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops.cpp b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops.cpp index 33e230bd5..0a096b65f 100644 --- a/libnd4j/include/helpers/cpu/loops/IndexReductionLoops.cpp +++ b/libnd4j/include/helpers/cpu/loops/IndexReductionLoops.cpp @@ -24,10 +24,10 @@ using namespace simdOps; ////////////////////////////////////////////////////////////////////////////// -template +template template -void nd4j::IndexReductionLoops::loopIndexReduce(X* x, Nd4jLong* xShapeInfo, - Nd4jLong* z, Nd4jLong* zShapeInfo, +void nd4j::IndexReductionLoops::loopIndexReduce(X* x, Nd4jLong* xShapeInfo, + Z* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, X* extraParams) { @@ -62,7 +62,7 @@ void nd4j::IndexReductionLoops::loopIndexReduce(X* x, Nd4jLong* xShapeInfo, indexValue = OpType::update(indexValue, comp, extraParams); } - z[i] = indexValue.index; + z[i] = (Z) indexValue.index; } } break; @@ -80,7 +80,7 @@ void nd4j::IndexReductionLoops::loopIndexReduce(X* x, Nd4jLong* xShapeInfo, indexValue = OpType::update(indexValue, comp, extraParams); } - z[i * zEws] = indexValue.index; + z[i * zEws] = (Z) indexValue.index; } } break; @@ -98,7 +98,7 @@ void nd4j::IndexReductionLoops::loopIndexReduce(X* x, Nd4jLong* xShapeInfo, indexValue = OpType::update(indexValue, comp, extraParams); } - z[i] = indexValue.index; + z[i] = (Z) indexValue.index; } } break; @@ -122,7 +122,7 @@ void nd4j::IndexReductionLoops::loopIndexReduce(X* x, Nd4jLong* xShapeInfo, } } - z[i] = indexValue.index; + z[i] = (Z) indexValue.index; } } break; @@ -148,7 +148,7 @@ void nd4j::IndexReductionLoops::loopIndexReduce(X* x, Nd4jLong* xShapeInfo, } } - z[i] = indexValue.index; + z[i] = (Z) indexValue.index; } } break; @@ -176,7 +176,7 @@ void nd4j::IndexReductionLoops::loopIndexReduce(X* x, Nd4jLong* xShapeInfo, } } - z[i] = indexValue.index; + z[i] = (Z) indexValue.index; } } break; @@ -206,7 +206,7 @@ void nd4j::IndexReductionLoops::loopIndexReduce(X* x, Nd4jLong* xShapeInfo, } } - z[i] = indexValue.index; + z[i] = (Z) indexValue.index; } } break; @@ -227,7 +227,7 @@ void nd4j::IndexReductionLoops::loopIndexReduce(X* x, Nd4jLong* xShapeInfo, } auto zOffset = shape::indexOffset(i, zShapeInfo, castZShapeInfo, zLen, canCastZ); - z[zOffset] = indexValue.index; + z[zOffset] = (Z) indexValue.index; } } break; @@ -248,7 +248,7 @@ void nd4j::IndexReductionLoops::loopIndexReduce(X* x, Nd4jLong* xShapeInfo, indexValue = OpType::update(indexValue, comp, extraParams); } - z[i * zEws] = indexValue.index; + z[i * zEws] = (Z) indexValue.index; } } break; @@ -272,18 +272,19 @@ void nd4j::IndexReductionLoops::loopIndexReduce(X* x, Nd4jLong* xShapeInfo, } auto zOffset = shape::indexOffset(i, zShapeInfo, castZShapeInfo, zLen, canCastZ); - z[zOffset] = indexValue.index; + z[zOffset] = (Z) indexValue.index; } } } } -template -void nd4j::IndexReductionLoops::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, Nd4jLong* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams) { +template +void nd4j::IndexReductionLoops::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* vz, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams) { auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); auto extraParams = reinterpret_cast(vextraParams); - DISPATCH_BY_OPNUM_T(loopIndexReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams), INDEX_REDUCE_OPS); + DISPATCH_BY_OPNUM_TT(loopIndexReduce, PARAMS(x, xShapeInfo, z, zShapeInfo, tadShapeInfo, tadOffsets, extraParams), INDEX_REDUCE_OPS); } -BUILD_SINGLE_TEMPLATE(template void nd4j::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, Nd4jLong* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES); \ No newline at end of file +BUILD_DOUBLE_TEMPLATE(template void nd4j::IndexReductionLoops, ::wrapIndexReduce(const int opNum, void* vx, Nd4jLong* xShapeInfo, void* z, Nd4jLong* zShapeInfo, Nd4jLong* tadShapeInfo, Nd4jLong* tadOffsets, void* vextraParams), LIBND4J_TYPES, INDEXING_TYPES); \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/indexreduce.cpp b/libnd4j/include/loops/cpu/indexreduce.cpp index 951ac287b..5a7beee24 100644 --- a/libnd4j/include/loops/cpu/indexreduce.cpp +++ b/libnd4j/include/loops/cpu/indexreduce.cpp @@ -31,26 +31,27 @@ namespace functions { namespace indexreduce { //////////////////////////////////////////////////////////////////////// -template Nd4jLong IndexReduce::execScalar( const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams) { - RETURNING_DISPATCH_BY_OPNUM_T(execScalar, PARAMS(x, xShapeInfo, extraParams), INDEX_REDUCE_OPS); +template +Nd4jLong IndexReduce::execScalar( const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams) { + RETURNING_DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(x, xShapeInfo, extraParams), INDEX_REDUCE_OPS); } //////////////////////////////////////////////////////////////////////// -template -void IndexReduce::exec(const int opNum, +template +void IndexReduce::exec(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, - Nd4jLong *z, Nd4jLong *zShapeInfo, + void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset) { -DISPATCH_BY_OPNUM_T(exec, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffset), INDEX_REDUCE_OPS); +DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffset), INDEX_REDUCE_OPS); } //////////////////////////////////////////////////////////////////////// -template +template template -Nd4jLong IndexReduce::execScalar(void *vx, Nd4jLong *xShapeInfo, void *vextraParams) { +Nd4jLong IndexReduce::execScalar(void *vx, Nd4jLong *xShapeInfo, void *vextraParams) { auto x = reinterpret_cast(vx); auto extraParams = reinterpret_cast(vextraParams); @@ -105,15 +106,16 @@ Nd4jLong IndexReduce::execScalar(void *vx, Nd4jLong *xShapeInfo, void *vextra //////////////////////////////////////////////////////////////////////// -template +template template -void IndexReduce::exec(void *vx, Nd4jLong *xShapeInfo, +void IndexReduce::exec(void *vx, Nd4jLong *xShapeInfo, void *vextraParams, - Nd4jLong *z, Nd4jLong *zShapeInfo, + void *vz, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset) { auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); auto extraParams = reinterpret_cast(vextraParams); const Nd4jLong zLen = shape::length(zShapeInfo); @@ -124,12 +126,12 @@ void IndexReduce::exec(void *vx, Nd4jLong *xShapeInfo, const auto indexValue = OpType::startingIndexValue(x); PRAGMA_OMP_PARALLEL_FOR_IF(zLen > nd4j::Environment::getInstance()->elementwiseThreshold()) for (uint i = 0; i < zLen; i++) - z[i] = indexValue.index;; + z[i] = (Z) indexValue.index;; return; } if(shape::isScalar(zShapeInfo)) { - z[0] = execScalar(x,xShapeInfo,extraParams); + z[0] = (Z) execScalar(x,xShapeInfo,extraParams); return; } @@ -146,11 +148,11 @@ void IndexReduce::exec(void *vx, Nd4jLong *xShapeInfo, tadOffsets = tadPack.primaryOffsets(); } - nd4j::IndexReductionLoops::template loopIndexReduce(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams); + nd4j::IndexReductionLoops::template loopIndexReduce(x, xShapeInfo, z, zShapeInfo, tadOnlyShapeInfo, tadOffsets, extraParams); } -BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES); +BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES, INDEXING_TYPES); } } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/indexreduce.cu b/libnd4j/include/loops/cuda/indexreduce.cu index 18e5b1432..5f0cf07ae 100644 --- a/libnd4j/include/loops/cuda/indexreduce.cu +++ b/libnd4j/include/loops/cuda/indexreduce.cu @@ -29,37 +29,37 @@ using namespace simdOps; -template +template static __global__ void simpleIndexReduceGeneric(const int op, void *dx, Nd4jLong *xShapeInfo, int xRank, void *extraParams, - Nd4jLong *result, + void *result, Nd4jLong *resultShapeInfo, int zRank, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets) { - functions::indexreduce::IndexReduce::transform(op,dx,xShapeInfo,extraParams,result,resultShapeInfo,dimension,dimensionLength,postProcessOrNot,allocationBuffer,reductionBuffer,tadOnlyShapeInfo,tadOffsets); + functions::indexreduce::IndexReduce::transform(op,dx,xShapeInfo,extraParams,result,resultShapeInfo,dimension,dimensionLength,postProcessOrNot,allocationBuffer,reductionBuffer,tadOnlyShapeInfo,tadOffsets); } namespace functions { namespace indexreduce { - template - _CUDA_H void IndexReduce::executeIndexReduceScalar(dim3 launchDims, cudaStream_t *stream, + template + _CUDA_H void IndexReduce::executeIndexReduceScalar(dim3 launchDims, cudaStream_t *stream, const int opNum, void *dx, Nd4jLong *xShapeInfo, int xRank, void *extraParams, - Nd4jLong *result, Nd4jLong *resultShapeInfo, + void *result, Nd4jLong *resultShapeInfo, int zRank, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets) { - simpleIndexReduceGeneric<<>>(opNum, + simpleIndexReduceGeneric<<>>(opNum, dx, xShapeInfo, xRank, extraParams, result, resultShapeInfo, 0, @@ -67,13 +67,11 @@ namespace functions { 1, allocationBuffer, reductionBuffer, tadOnlyShapeInfo, tadOffsets); - - nd4j::DebugHelper::checkErrorCode(stream, "execIndexReduceScalar(...) failed"); } - template - _CUDA_H void IndexReduce::executeIndexReduce(dim3 launchDims, cudaStream_t *stream, const int opNum, void *dx, Nd4jLong *xShapeInfo, int xRank, void *extraParams, Nd4jLong *result, Nd4jLong *resultShapeInfo, int zRank, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets) { - simpleIndexReduceGeneric<<>>( + template + _CUDA_H void IndexReduce::executeIndexReduce(dim3 launchDims, cudaStream_t *stream, const int opNum, void *dx, Nd4jLong *xShapeInfo, int xRank, void *extraParams, void *result, Nd4jLong *resultShapeInfo, int zRank, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets) { + simpleIndexReduceGeneric<<>>( opNum, dx, xShapeInfo, xRank, @@ -83,8 +81,6 @@ namespace functions { dimension, dimensionLength, 1, allocationBuffer, reductionBuffer, tadOnlyShapeInfo, tadOffsets); - - DEBUG_KERNEL(stream, opNum); } // This is the un-specialized struct. Note that we prevent instantiation of this @@ -122,14 +118,14 @@ namespace functions { } }; - template + template template - __device__ void IndexReduce::aggregatePartials(IndexValue **sPartialsRef, Nd4jLong tid, Nd4jLong numElements, void *vextraParams) { + __device__ void IndexReduce::aggregatePartials(IndexValue **sPartialsRef, Nd4jLong tid, Nd4jLong numElements, void *vextraParams) { // start the shared memory loop on the next power of 2 less // than the block size. If block size is not a power of 2, // accumulate the intermediate sums in the remainder range. - auto extraParams = static_cast(vextraParams); - IndexValue *sPartials = *sPartialsRef; + auto extraParams = static_cast(vextraParams); + IndexValue *sPartials = *sPartialsRef; Nd4jLong floorPow2 = blockDim.x; if (floorPow2 & (floorPow2 - 1)) { @@ -138,8 +134,8 @@ namespace functions { } if (tid >= floorPow2) { - IndexValue prev = sPartials[tid - floorPow2]; - IndexValue curr = sPartials[tid]; + IndexValue prev = sPartials[tid - floorPow2]; + IndexValue curr = sPartials[tid]; sPartials[tid - floorPow2] = OpType::update(prev,curr,extraParams); } __syncthreads(); @@ -147,21 +143,21 @@ namespace functions { for (int activeThreads = floorPow2 >> 1;activeThreads; activeThreads >>= 1) { if (tid < activeThreads && tid + activeThreads < numElements) { - IndexValue curr = sPartials[tid]; - IndexValue next = sPartials[tid + activeThreads]; + IndexValue curr = sPartials[tid]; + IndexValue next = sPartials[tid + activeThreads]; sPartials[tid] = OpType::update(curr,next,extraParams); } __syncthreads(); } } - template - __device__ void IndexReduce::transform( + template + __device__ void IndexReduce::transform( const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, - Nd4jLong *result, + void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, @@ -170,15 +166,15 @@ namespace functions { void *reductionBuffer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset) { - DISPATCH_BY_OPNUM_T(transform, PARAMS(x, xShapeInfo, extraParams, result, resultShapeInfo, dimension, dimensionLength, postProcessOrNot, allocationBuffer, reductionBuffer, tadShapeInfo, tadOffset), INDEX_REDUCE_OPS); + DISPATCH_BY_OPNUM_TT(transform, PARAMS(x, xShapeInfo, extraParams, result, resultShapeInfo, dimension, dimensionLength, postProcessOrNot, allocationBuffer, reductionBuffer, tadShapeInfo, tadOffset), INDEX_REDUCE_OPS); } - template + template template - __device__ void IndexReduce::transform(void *vdx, Nd4jLong *xShapeInfo, + __device__ void IndexReduce::transform(void *vdx, Nd4jLong *xShapeInfo, void *vextraParams, - Nd4jLong *result, Nd4jLong *resultShapeInfo, + void *vresult, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *vreductionBuffer, @@ -186,18 +182,19 @@ namespace functions { /**int * Gpu information for the problem */ - auto dx = static_cast(vdx); - auto extraParams = static_cast(vextraParams); - auto reductionBuffer = static_cast(vreductionBuffer); + auto dx = reinterpret_cast(vdx); + auto result = reinterpret_cast(vresult); + auto extraParams = static_cast(vextraParams); + auto reductionBuffer = static_cast(vreductionBuffer); auto order = shape::order(xShapeInfo); int tid = blockIdx.x * blockDim.x + threadIdx.x; __shared__ volatile int resultScalar; //shared memory space for storing intermediate results - __shared__ IndexValue* sPartials; + __shared__ IndexValue* sPartials; if(threadIdx.x == 0) { extern __shared__ unsigned char shmem[]; - sPartials = reinterpret_cast*>(shmem); + sPartials = reinterpret_cast*>(shmem); } __syncthreads(); @@ -210,7 +207,7 @@ namespace functions { //only compute the tad indexes once - IndexValue reduction = OpType::startingIndexValue(dx); + IndexValue reduction = OpType::startingIndexValue(dx); if (threadIdx.x == 0) { if (resultShapeInfo != nullptr) @@ -255,7 +252,7 @@ namespace functions { for(int i = threadIdx.x;i < tadLength; i += blockDim.x) { auto xOffset = tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo, tadLength); - IndexValue comp {dx[xOffset], i}; + IndexValue comp {dx[xOffset], i}; sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], comp, extraParams); } @@ -264,7 +261,7 @@ namespace functions { __syncthreads(); if (threadIdx.x == 0) { - result[r] = sPartials[threadIdx.x].index; + result[r] = (Z) sPartials[threadIdx.x].index; } __syncthreads(); } @@ -276,7 +273,7 @@ namespace functions { sPartials[threadIdx.x] = OpType::startingIndexValue(dx); for (int x = threadIdx.x; x < tadLength; x+= blockDim.x) { - IndexValue comp {dx[tadOffsetForBlock + x * tadEWS], x}; + IndexValue comp {dx[tadOffsetForBlock + x * tadEWS], x}; sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], comp, extraParams); } @@ -285,7 +282,7 @@ namespace functions { __syncthreads(); if (threadIdx.x == 0) { - result[i] = sPartials[threadIdx.x].index; //postProcess(sPartials[0],tadLength ,extraParams); + result[i] = (Z) sPartials[threadIdx.x].index; //postProcess(sPartials[0],tadLength ,extraParams); } __syncthreads(); } @@ -296,14 +293,14 @@ namespace functions { if(xElementWiseStride >= 1 && order == 'c') { for(Nd4jLong i = tid;i < n; i += (blockDim.x * gridDim.x)) { - IndexValue indexVal = {dx[i * xElementWiseStride], i}; + IndexValue indexVal = {dx[i * xElementWiseStride], i}; reduction = OpType::update(reduction, indexVal, extraParams); } } else { for(Nd4jLong i = tid;i < n; i += blockDim.x * gridDim.x) { auto offset = shape::getIndexOffset(i, xShapeInfo, n); - IndexValue indexVal = {dx[offset], i}; + IndexValue indexVal = {dx[offset], i}; reduction = OpType::update(reduction, indexVal, extraParams); } } @@ -320,7 +317,7 @@ namespace functions { unsigned int *tc = (unsigned int *) reductionBuffer; tid = threadIdx.x; if (threadIdx.x == 0) { - auto pBuffer = reinterpret_cast *>(reductionBuffer); + auto pBuffer = reinterpret_cast *>(reductionBuffer); pBuffer[blockIdx.x] = {sPartials[0].value, sPartials[0].index}; } __threadfence(); @@ -335,7 +332,7 @@ namespace functions { if (amLast) { tc[16384] = 0; - IndexValue *pBuffer = (IndexValue *) reductionBuffer; + IndexValue *pBuffer = (IndexValue *) reductionBuffer; sPartials[threadIdx.x] = OpType::startingIndexValue(dx); @@ -348,14 +345,14 @@ namespace functions { __syncthreads(); if (tid == 0) { - result[0] = sPartials[0].index; + result[0] = (Z) sPartials[0].index; } } } else { if (tid == 0) { auto tc = reinterpret_cast(reductionBuffer); tc[16384] = 0; - result[0] = sPartials[0].index; + result[0] = (Z) sPartials[0].index; } } @@ -365,30 +362,30 @@ namespace functions { - template - Nd4jLong IndexReduce::execScalar(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams) { + template + Nd4jLong IndexReduce::execScalar(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams) { return 0; } - template - void IndexReduce::exec(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, Nd4jLong *result, Nd4jLong *resultShapeInfoBuffer, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset) { + template + void IndexReduce::exec(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, void *result, Nd4jLong *resultShapeInfoBuffer, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset) { } - template + template template - Nd4jLong IndexReduce:: execScalar(void *x, Nd4jLong *xShapeInfo, void *extraParams) { + Nd4jLong IndexReduce:: execScalar(void *x, Nd4jLong *xShapeInfo, void *extraParams) { return 0; } - template + template template - _CUDA_H void IndexReduce::exec(void *x, Nd4jLong *xShapeInfo, void *extraParams, Nd4jLong *result, Nd4jLong *resultShapeInfoBuffer, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset) { + _CUDA_H void IndexReduce::exec(void *x, Nd4jLong *xShapeInfo, void *extraParams, void *result, Nd4jLong *resultShapeInfoBuffer, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset) { } - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES); + BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT IndexReduce, , LIBND4J_TYPES, INDEXING_TYPES); } } diff --git a/libnd4j/include/loops/indexreduce.h b/libnd4j/include/loops/indexreduce.h index 40f98c692..792ed16a9 100755 --- a/libnd4j/include/loops/indexreduce.h +++ b/libnd4j/include/loops/indexreduce.h @@ -52,35 +52,35 @@ namespace functions { namespace indexreduce { - template + template class IndexReduce { public: #ifdef __CUDACC__ - static __device__ void transform(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, Nd4jLong *result, Nd4jLong *resultShapeInfo, int *dimension,int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset); + static __device__ void transform(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, void *result, Nd4jLong *resultShapeInfo, int *dimension,int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset); template - static __device__ void aggregatePartials(IndexValue **sPartialsRef, Nd4jLong tid, Nd4jLong numElements,void *extraParams); + static __device__ void aggregatePartials(IndexValue **sPartialsRef, Nd4jLong tid, Nd4jLong numElements,void *extraParams); template - static __device__ void transform(void *dx, Nd4jLong *xShapeInfo, void *extraParams, Nd4jLong *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets); + static __device__ void transform(void *dx, Nd4jLong *xShapeInfo, void *extraParams, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets); - static _CUDA_H void executeIndexReduceScalar(dim3 launchDims, cudaStream_t *stream, const int op, void *dx, Nd4jLong *xShapeInfo, int xRank, void *extraParams, Nd4jLong *result, Nd4jLong *resultShapeInfo, int zRank, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets); + static _CUDA_H void executeIndexReduceScalar(dim3 launchDims, cudaStream_t *stream, const int op, void *dx, Nd4jLong *xShapeInfo, int xRank, void *extraParams, void *result, Nd4jLong *resultShapeInfo, int zRank, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets); - static _CUDA_H void executeIndexReduce(dim3 launchDims, cudaStream_t *stream, const int op, void *dx, Nd4jLong *xShapeInfo, int xRank, void *extraParams, Nd4jLong *result, Nd4jLong *resultShapeInfo, int zRank, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets); + static _CUDA_H void executeIndexReduce(dim3 launchDims, cudaStream_t *stream, const int op, void *dx, Nd4jLong *xShapeInfo, int xRank, void *extraParams, void *result, Nd4jLong *resultShapeInfo, int zRank, int *dimension, int dimensionLength, int postProcessOrNot, int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets); #endif static Nd4jLong execScalar(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams); - static void exec(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, Nd4jLong *result, Nd4jLong *resultShapeInfoBuffer, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset); + static void exec(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, void *result, Nd4jLong *resultShapeInfoBuffer, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset); template static _CUDA_H Nd4jLong execScalar(void *x, Nd4jLong *xShapeInfo, void *extraParams); template - static _CUDA_H void exec(void *x, Nd4jLong *xShapeInfo, void *extraParams, Nd4jLong *result, Nd4jLong *resultShapeInfoBuffer, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset); + static _CUDA_H void exec(void *x, Nd4jLong *xShapeInfo, void *extraParams, void *result, Nd4jLong *resultShapeInfoBuffer, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset); }; } } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/top_k.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/top_k.cpp index ea2e3330a..bd16cdd79 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/top_k.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/top_k.cpp @@ -87,7 +87,7 @@ namespace nd4j { getOpDescriptor() ->setAllowedInputTypes(nd4j::DataType::ANY) ->setAllowedOutputTypes(0, nd4j::DataType::ANY) - ->setAllowedOutputTypes(1, {ALL_INTS}); + ->setAllowedOutputTypes(1, {ALL_INDICES}); } } } diff --git a/libnd4j/include/ops/declarable/helpers/impl/listdiff.cpp b/libnd4j/include/ops/declarable/helpers/impl/listdiff.cpp index baa08dad9..c840f6960 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/listdiff.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/listdiff.cpp @@ -42,10 +42,11 @@ namespace helpers { Nd4jLong listDiffCount(nd4j::LaunchContext * context, NDArray* values, NDArray* keep) { auto xType = values->dataType(); - values->syncToHost(); - keep->syncToHost(); + NDArray::preparePrimaryUse({},{values, keep}); BUILD_SINGLE_SELECTOR(xType, return listDiffCount_, (values, keep), LIBND4J_TYPES); + + NDArray::registerPrimaryUse({},{values, keep}); } BUILD_SINGLE_TEMPLATE(template Nd4jLong listDiffCount_, (NDArray* values, NDArray* keep);, LIBND4J_TYPES); @@ -97,16 +98,7 @@ namespace helpers { int listDiffFunctor(nd4j::LaunchContext * context, NDArray* values, NDArray* keep, NDArray* output1, NDArray* output2) { auto xType = values->dataType(); - values->syncToHost(); - - if (keep != nullptr) - keep->syncToHost(); - - if (output1 != nullptr) - output1->syncToHost(); - - if (output2 != nullptr) - output2->syncToHost(); + NDArray::preparePrimaryUse({output1, output2}, {values, keep}); int result = 0; @@ -118,14 +110,7 @@ namespace helpers { throw std::runtime_error("ListDiff: Only integer and floating point data types are supported"); } - if (keep != nullptr) - keep->syncToDevice(); - - if (output1 != nullptr) - output1->syncToDevice(); - - if (output2 != nullptr) - output2->syncToDevice(); + NDArray::registerPrimaryUse({output1, output2}, {values, keep}); return result; } diff --git a/libnd4j/include/ops/ops.h b/libnd4j/include/ops/ops.h index 38122f985..fe6bfae81 100644 --- a/libnd4j/include/ops/ops.h +++ b/libnd4j/include/ops/ops.h @@ -3746,7 +3746,7 @@ namespace simdOps { }; - template + template class IndexAbsoluteMax { public: static _CUDA_HD inline functions::indexreduce::IndexValue op(functions::indexreduce::IndexValue val, X *extraParams) { @@ -3799,7 +3799,7 @@ namespace simdOps { } }; - template + template class FirstIndex { public: static _CUDA_HD inline functions::indexreduce::IndexValue op(functions::indexreduce::IndexValue val, X *extraParams) { @@ -3861,7 +3861,7 @@ namespace simdOps { }; - template + template class LastIndex { public: static _CUDA_HD inline functions::indexreduce::IndexValue op(functions::indexreduce::IndexValue val, X *extraParams) { @@ -3920,7 +3920,7 @@ namespace simdOps { }; - template + template class IndexMax { public: @@ -3974,7 +3974,7 @@ namespace simdOps { }; - template + template class IndexAbsoluteMin { public: static _CUDA_HD inline functions::indexreduce::IndexValue op( @@ -4030,7 +4030,7 @@ namespace simdOps { }; - template + template class IndexMin { public: static _CUDA_HD inline functions::indexreduce::IndexValue op( diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java index 38a1ba382..54649692c 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java @@ -226,6 +226,21 @@ public class CudaExecutioner extends DefaultOpExecutioner { */ protected INDArray naiveExec(ReduceOp op, int... dimension) { long st = profilingConfigurableHookIn(op); + + if(op instanceof BaseReduceOp && ((BaseReduceOp)op).isEmptyReduce()){ + //Edge case for TF import compatibility: [x,y].reduce(empty) = [x,y] + //Note that "empty" axis is NOT the same as length 0, as in INDArray.sum(new int[0]), which means "all dimensions" + if(op.z() != null){ + Preconditions.checkState(op.x().equalShapes(op.z()), "For empty reductions, result (z) array must have same shape as x shape." + + " Got: x=%ndShape, z=%ndShape", op.x(), op.z()); + op.z().assign(op.x()); + return op.z(); + } else { + op.setZ(op.x().dup()); + return op.z(); + } + } + INDArray ret = op.z(); checkForCompression(op); @@ -482,6 +497,20 @@ public class CudaExecutioner extends DefaultOpExecutioner { public INDArray exec(ReduceOp op) { checkForCompression(op); + if(op instanceof BaseReduceOp && ((BaseReduceOp)op).isEmptyReduce()){ + //Edge case for TF import compatibility: [x,y].reduce(empty) = [x,y] + //Note that "empty" axis is NOT the same as length 0, as in INDArray.sum(new int[0]), which means "all dimensions" + if(op.z() != null){ + Preconditions.checkState(op.x().equalShapes(op.z()), "For empty reductions, result (z) array must have same shape as x shape." + + " Got: x=%ndShape, z=%ndShape", op.x(), op.z()); + op.z().assign(op.x()); + return op.z(); + } else { + op.setZ(op.x().dup()); + return op.z(); + } + } + val dimension = op.dimensions().toIntVector(); if (extraz.get() == null) @@ -890,6 +919,22 @@ public class CudaExecutioner extends DefaultOpExecutioner { protected CudaContext invoke(ReduceOp op, int[] dimension) { + CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y()); + + if(op instanceof BaseReduceOp && ((BaseReduceOp)op).isEmptyReduce()){ + //Edge case for TF import compatibility: [x,y].reduce(empty) = [x,y] + //Note that "empty" axis is NOT the same as length 0, as in INDArray.sum(new int[0]), which means "all dimensions" + if(op.z() != null){ + Preconditions.checkState(op.x().equalShapes(op.z()), "For empty reductions, result (z) array must have same shape as x shape." + + " Got: x=%ndShape, z=%ndShape", op.x(), op.z()); + op.z().assign(op.x()); + return context; + } else { + op.setZ(op.x().dup()); + return context; + } + } + long st = profilingConfigurableHookIn(op); checkForCompression(op); @@ -913,8 +958,6 @@ public class CudaExecutioner extends DefaultOpExecutioner { throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(dimension) + " contains element that higher then rank of op.X: [" + op.x().rank() + "]"); - CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y()); - if (CudaEnvironment.getInstance().getConfiguration().isDebug()) lastOp.set(op.opName()); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java index 6c4595a0c..f325348fb 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java @@ -733,6 +733,46 @@ public class CustomOpsTests extends BaseNd4jTest { } } + @Test + public void testListDiff(){ + INDArray x = Nd4j.createFromArray(0, 1, 2, 3); + INDArray y = Nd4j.createFromArray(3, 1); + + INDArray out = Nd4j.create(DataType.INT, 2); + INDArray outIdx = Nd4j.create(DataType.INT, 2); + + Nd4j.exec(DynamicCustomOp.builder("listdiff") + .addInputs(x, y) + .addOutputs(out, outIdx) + .build()); + + INDArray exp = Nd4j.createFromArray(0, 2); + + assertEquals(exp, out); //Values in x not in y + assertEquals(exp, outIdx); //Indices of the values in x not in y + } + + @Test + public void testTopK1(){ + INDArray x = Nd4j.createFromArray(0.0, 0.0, 0.0, 10.0, 0.0); + INDArray k = Nd4j.scalar(1); + INDArray outValue = Nd4j.create(DataType.DOUBLE, 1); + INDArray outIdx = Nd4j.create(DataType.INT, 1); + + Nd4j.exec(DynamicCustomOp.builder("top_k") + .addInputs(x, k) + .addOutputs(outValue, outIdx) + .addBooleanArguments(false) //not sorted + .addIntegerArguments(1) + .build()); + + INDArray expValue = Nd4j.createFromArray(10.0); + INDArray expIdx = Nd4j.createFromArray(3); + + assertEquals(expValue, outValue); + assertEquals(expIdx, outIdx); + } + @Test public void testMaxPool2Dbp_1() { val x = Nd4j.create(DataType.HALF, 2,3,16,16).assign(Double.NaN); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/EmptyTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/EmptyTests.java index 261e1e300..e7e8f8288 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/EmptyTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/EmptyTests.java @@ -25,6 +25,7 @@ import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.reduce.bool.All; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; @@ -299,6 +300,15 @@ public class EmptyTests extends BaseNd4jTest { assertNotNull(result[0].shapeInfoDataBuffer().asLong()); } + @Test + public void testAllEmptyReduce(){ + INDArray x = Nd4j.createFromArray(true, true, true); + val all = new All(x); + all.setEmptyReduce(true); //For TF compatibility - empty array for axis (which means no-op - and NOT all array reduction) + INDArray out = Nd4j.exec(all); + assertEquals(x, out); + } + @Override public char ordering() { return 'c';