From 243bf866c42978ed9f52ca8b47bedb70235272fb Mon Sep 17 00:00:00 2001 From: raver119 Date: Fri, 23 Aug 2019 09:00:10 +0300 Subject: [PATCH] [WIP] Few fixes (#153) * throw exception if op execution failed Signed-off-by: raver119 * expected for test Signed-off-by: raver119 * one more ismax test Signed-off-by: raver119 * ismax view fix Signed-off-by: raver119 --- .../ops/declarable/helpers/cpu/ismax.cpp | 8 ++-- .../ops/executioner/CudaExecutioner.java | 4 +- .../nativecpu/ops/NativeOpExecutioner.java | 4 +- .../nd4j/linalg/custom/CustomOpsTests.java | 37 +++++++++++++++++++ 4 files changed, 47 insertions(+), 6 deletions(-) diff --git a/libnd4j/include/ops/declarable/helpers/cpu/ismax.cpp b/libnd4j/include/ops/declarable/helpers/cpu/ismax.cpp index 0a4ee2fd7..45ec30b58 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/ismax.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/ismax.cpp @@ -84,8 +84,8 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector if (length < ELEMENT_THRESHOLD) { for (int i = 0; i < length; i++) { - if (currMax < input->e(i*eleStride)) { - currMax = input->e(i*eleStride); + if (currMax < input->e(i)) { + currMax = input->e(i); maxIdx = i; } output->p(i, 0.f); @@ -97,8 +97,8 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector int maxIdxLocal = maxIdx; auto currMaxLocal = currMax; for (int i = 0; i < length; i++) { - if (currMaxLocal < input->e(i*eleStride)) { - currMaxLocal = input->e(i*eleStride); + if (currMaxLocal < input->e(i)) { + currMaxLocal = input->e(i); maxIdxLocal = i; } output->p(i, 0.f); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java index 7d41ba986..789f0f1a3 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java @@ -2489,7 +2489,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { val ctx = AtomicAllocator.getInstance().getDeviceContext(); ((CudaOpContext) context).setCudaStream(ctx.getOldStream(), ctx.getBufferReduction(), ctx.getBufferAllocation()); - nativeOps.execCustomOp2(null, op.opHash(), context.contextPointer()); + val status = nativeOps.execCustomOp2(null, op.opHash(), context.contextPointer()); + if (status != 0) + throw new RuntimeException("Op [" + op.opName() + "] execution failed"); for (val arr:op.outputArguments()) AtomicAllocator.getInstance().registerAction(ctx, arr); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java index 23cedba09..11373c440 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java @@ -2077,7 +2077,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { } - loop.execCustomOp2(null, op.opHash(), context.contextPointer()); + val status = loop.execCustomOp2(null, op.opHash(), context.contextPointer()); + if (status != 0) + throw new RuntimeException("Op [" + op.opName() + "] execution failed"); if (context.getOutputArrays().isEmpty()) return new INDArray[0]; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java index 675f60007..e325adea8 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java @@ -32,6 +32,7 @@ import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.api.ops.executioner.OpStatus; import org.nd4j.linalg.api.ops.impl.reduce.Mmul; import org.nd4j.linalg.api.ops.impl.reduce.MmulBp; +import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp; import org.nd4j.linalg.api.ops.random.compat.RandomStandardNormal; @@ -40,6 +41,7 @@ import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; +import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.nativeblas.NativeOpsHolder; import java.util.List; @@ -604,4 +606,39 @@ public class CustomOpsTests extends BaseNd4jTest { assertTrue(Shape.shapeEquals(e.shape(), z.shape())); assertEquals(e, z); } + + @Test(expected = RuntimeException.class) + public void testInputValidationMergeMax(){ + INDArray[] inputs = new INDArray[]{ + Nd4j.createFromArray(0.0f, 1.0f, 2.0f).reshape('c', 1, 3), + Nd4j.createFromArray(1.0f).reshape('c', 1, 1)}; + + INDArray out = Nd4j.create(DataType.FLOAT, 1, 3).assign(Double.NaN); + CustomOp op = DynamicCustomOp.builder("mergemax") + .addInputs(inputs) + .addOutputs(out) + .callInplace(false) + .build(); + + Nd4j.exec(op); + System.out.println(out); + } + + + @Test + public void testIsMaxView(){ + INDArray predictions = Nd4j.rand(DataType.FLOAT, 3, 4, 3, 2); + + INDArray row = predictions.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(0)); + row = row.reshape(1, row.length()); + assertArrayEquals(new long[]{1, 4}, row.shape()); + + val result1 = row.ulike(); + val result2 = row.ulike(); + + Nd4j.exec(new IsMax(row.dup(), result1, 1)); //OK + Nd4j.exec(new IsMax(row, result2, 1)); //C++ exception + + assertEquals(result1, result2); + } }