From a5f5ac72b10c2a3bcc312a28cceed989c60a4f3e Mon Sep 17 00:00:00 2001 From: raver119 Date: Mon, 9 Dec 2019 20:08:59 +0300 Subject: [PATCH] reduce bool changes (#118) * reduce bool changes Signed-off-by: raver119 * reduce bool tweaks Signed-off-by: raver119 --- .../include/loops/cuda/reduce/reduce_bool.cu | 8 ++++- .../loops/cuda/reduce/reduce_float.chpp | 2 +- .../nd4j/linalg/api/ops/BaseReduceBoolOp.java | 3 ++ .../linalg/api/ops/impl/reduce/bool/All.java | 9 ++++++ .../linalg/api/ops/impl/reduce/bool/Any.java | 5 ++++ .../api/ops/impl/reduce/bool/IsInf.java | 4 +++ .../api/ops/impl/reduce/bool/IsNaN.java | 4 +++ .../ops/executioner/CudaExecutioner.java | 18 +++++++++-- .../nativecpu/ops/NativeOpExecutioner.java | 11 +++++++ .../test/java/org/nd4j/linalg/Nd4jTestsC.java | 30 +++++++++++++++++++ 10 files changed, 89 insertions(+), 5 deletions(-) diff --git a/libnd4j/include/loops/cuda/reduce/reduce_bool.cu b/libnd4j/include/loops/cuda/reduce/reduce_bool.cu index a785094f1..52ca3decc 100644 --- a/libnd4j/include/loops/cuda/reduce/reduce_bool.cu +++ b/libnd4j/include/loops/cuda/reduce/reduce_bool.cu @@ -237,6 +237,8 @@ template template __host__ void ReduceBoolFunction::intermediateXD(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, Nd4jLong *hXShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { + nd4j_printf("Step A%i\n", -1); + if(shape::isEmpty(hXShapeInfo)) { if(shape::isEmpty(hZShapeInfo)) @@ -251,7 +253,8 @@ __host__ void ReduceBoolFunction::intermediateXD(dim3 launchDims, cudaStrea auto ptr = nd4j::LaunchContext::defaultContext()->getScalarPointer(); // scalar assign - functions::scalar::ScalarTransform::executeCudaShaped(launchDims, stream, 14, z, zShapeInfo, hXShapeInfo, z, zShapeInfo, hZShapeInfo, ptr, nullptr); + functions::scalar::ScalarTransform::executeCudaShaped(launchDims, stream, 14, z, zShapeInfo, hZShapeInfo, z, zShapeInfo, hZShapeInfo, ptr, nullptr); + nd4j::DebugHelper::checkErrorCode(stream, "reduceBoolDim empty(...) failed"); } else { simpleReduce<<>>(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets); @@ -274,6 +277,9 @@ __host__ void ReduceBoolFunction::intermediateScalar(dim3 launchDims, cudaS auto res = cudaMemcpyAsync(z, &startingVal, sizeof(Z), cudaMemcpyHostToDevice, *stream); if (res != 0) throw nd4j::cuda_exception::build("ReduceBoolFunction::intermediateScalar: failed to copy resulting scalar", res); + + nd4j::DebugHelper::checkErrorCode(stream, "reduceBoolScalar empty(...) failed"); + } else { simpleScalar<<>>(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo); diff --git a/libnd4j/include/loops/cuda/reduce/reduce_float.chpp b/libnd4j/include/loops/cuda/reduce/reduce_float.chpp index ef366caf7..110cc0f68 100644 --- a/libnd4j/include/loops/cuda/reduce/reduce_float.chpp +++ b/libnd4j/include/loops/cuda/reduce/reduce_float.chpp @@ -249,7 +249,7 @@ __host__ void ReduceFloatFunction::intermediateXD(dim3 launchDims, cudaStre auto ptr = nd4j::LaunchContext::defaultContext()->getScalarPointer(); // scalar assign - functions::scalar::ScalarTransform::executeCudaShaped(launchDims, stream, 14, z, zShape, hXShapeInfo, z, zShape, hZShapeInfo, ptr, nullptr); + functions::scalar::ScalarTransform::executeCudaShaped(launchDims, stream, 14, z, zShape, hZShapeInfo, z, zShape, hZShapeInfo, ptr, nullptr); } else { simpleReduce<<>>(x, xShape, extraParams, z, zShape, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceBoolOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceBoolOp.java index 6e2801c67..dd2072758 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceBoolOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceBoolOp.java @@ -102,4 +102,7 @@ public abstract class BaseReduceBoolOp extends BaseReduceOp implements ReduceBoo "with 2 inputs, second input (axis) must be an integer datatype for %s, got %s", getClass(), dataTypes); return Collections.singletonList(DataType.BOOL); } + + + public abstract boolean emptyValue(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/All.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/All.java index 60b835135..a465728d1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/All.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/All.java @@ -41,6 +41,10 @@ public class All extends BaseReduceBoolOp { super(x); } + public All(INDArray x, int... axis) { + super(x, axis); + } + @Override public int opNum() { return 1; @@ -65,4 +69,9 @@ public class All extends BaseReduceBoolOp { public String tensorflowName() { return "All"; } + + @Override + public boolean emptyValue() { + return true; + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/Any.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/Any.java index 1cd31d19d..7daebd4cf 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/Any.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/Any.java @@ -65,4 +65,9 @@ public class Any extends BaseReduceBoolOp { public String tensorflowName() { return "Any"; } + + @Override + public boolean emptyValue() { + return false; + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsInf.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsInf.java index c74acf734..cb93a832e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsInf.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsInf.java @@ -71,4 +71,8 @@ public class IsInf extends BaseReduceBoolOp { return Collections.singletonList(f().zerosLike(arg())); } + @Override + public boolean emptyValue() { + return false; + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsNaN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsNaN.java index 611219d3e..c8cd72f2c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsNaN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsNaN.java @@ -71,4 +71,8 @@ public class IsNaN extends BaseReduceBoolOp { return Collections.singletonList(f().zerosLike(arg())); } + @Override + public boolean emptyValue() { + return false; + } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java index 904e1305e..16568fbf4 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java @@ -935,6 +935,18 @@ public class CudaExecutioner extends DefaultOpExecutioner { } } + // FIXME: this should be moved down to C++ on per-op basis + // reduce to scalar case, ReduceBool ops require special treatment + if (op instanceof BaseReduceBoolOp && op.x().isEmpty() && (dimension == null || (dimension.length == 1 && dimension[0] == Integer.MAX_VALUE))) { + if (op.z() == null) { + op.setZ(Nd4j.scalar(((BaseReduceBoolOp) op).emptyValue())); + } else { + op.z().assign(((BaseReduceBoolOp) op).emptyValue()); + } + + return context; + } + long st = profilingConfigurableHookIn(op); checkForCompression(op); @@ -994,9 +1006,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { } } - if (op.x().isVector() && op.x().length() == ArrayUtil.prod(retShape)) { - return null; - } + //if (op.x().isVector() && op.x().length() == ArrayUtil.prod(retShape)) { + // return null; + //} val dataType = op.resultType(); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java index b6af2e5f2..d12efba59 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java @@ -265,7 +265,18 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { } } + // FIXME: this should be moved down to C++ on per-op basis val dimension = Shape.normalizeAxis(op.x().rank(), op.dimensions().toIntVector()); + // reduce to scalar case, ReduceBool ops require special treatment + if (op instanceof BaseReduceBoolOp && op.x().isEmpty() && (dimension == null || (dimension.length == 1 && dimension[0] == Integer.MAX_VALUE))) { + if (op.z() == null) { + op.setZ(Nd4j.scalar(((BaseReduceBoolOp) op).emptyValue())); + } else { + op.z().assign(((BaseReduceBoolOp) op).emptyValue()); + } + + return op.z(); + } //validateDataType(Nd4j.dataType(), op); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index bac06b981..68551e53d 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -8134,6 +8134,36 @@ public class Nd4jTestsC extends BaseNd4jTest { assertEquals(Nd4j.createFromArray(1.0,2,3,4,5,6), hStack); } + + @Test + public void testReduceAll_1() { + val x = Nd4j.empty(DataType.FLOAT); + val e = Nd4j.scalar(true); + val z = Nd4j.exec(new All(x)); + + assertEquals(e, z); + } + + @Test + public void testReduceAll_2() { + val x = Nd4j.ones(DataType.FLOAT, 0); + val e = Nd4j.scalar(true); + val z = Nd4j.exec(new All(x)); + + assertEquals(e, z); + } + + @Test + public void testReduceAll_3() { + val x = Nd4j.create(DataType.FLOAT, 0); + assertEquals(1, x.rank()); + + val e = Nd4j.scalar(true); + val z = Nd4j.exec(new All(x, 0)); + + assertEquals(e, z); + } + @Override public char ordering() { return 'c';