[WIP] Few fixes (#153)

* throw exception if op execution failed

Signed-off-by: raver119 <raver119@gmail.com>

* expected for test

Signed-off-by: raver119 <raver119@gmail.com>

* one more ismax test

Signed-off-by: raver119 <raver119@gmail.com>

* ismax view fix

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-23 09:00:10 +03:00 committed by GitHub
parent 80d35377d4
commit 243bf866c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 47 additions and 6 deletions

View File

@ -84,8 +84,8 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>
if (length < ELEMENT_THRESHOLD) { if (length < ELEMENT_THRESHOLD) {
for (int i = 0; i < length; i++) { for (int i = 0; i < length; i++) {
if (currMax < input->e<X>(i*eleStride)) { if (currMax < input->e<X>(i)) {
currMax = input->e<X>(i*eleStride); currMax = input->e<X>(i);
maxIdx = i; maxIdx = i;
} }
output->p<Z>(i, 0.f); output->p<Z>(i, 0.f);
@ -97,8 +97,8 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>
int maxIdxLocal = maxIdx; int maxIdxLocal = maxIdx;
auto currMaxLocal = currMax; auto currMaxLocal = currMax;
for (int i = 0; i < length; i++) { for (int i = 0; i < length; i++) {
if (currMaxLocal < input->e<X>(i*eleStride)) { if (currMaxLocal < input->e<X>(i)) {
currMaxLocal = input->e<X>(i*eleStride); currMaxLocal = input->e<X>(i);
maxIdxLocal = i; maxIdxLocal = i;
} }
output->p<Z>(i, 0.f); output->p<Z>(i, 0.f);

View File

@ -2489,7 +2489,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
val ctx = AtomicAllocator.getInstance().getDeviceContext(); val ctx = AtomicAllocator.getInstance().getDeviceContext();
((CudaOpContext) context).setCudaStream(ctx.getOldStream(), ctx.getBufferReduction(), ctx.getBufferAllocation()); ((CudaOpContext) context).setCudaStream(ctx.getOldStream(), ctx.getBufferReduction(), ctx.getBufferAllocation());
nativeOps.execCustomOp2(null, op.opHash(), context.contextPointer()); val status = nativeOps.execCustomOp2(null, op.opHash(), context.contextPointer());
if (status != 0)
throw new RuntimeException("Op [" + op.opName() + "] execution failed");
for (val arr:op.outputArguments()) for (val arr:op.outputArguments())
AtomicAllocator.getInstance().registerAction(ctx, arr); AtomicAllocator.getInstance().registerAction(ctx, arr);

View File

@ -2077,7 +2077,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
} }
loop.execCustomOp2(null, op.opHash(), context.contextPointer()); val status = loop.execCustomOp2(null, op.opHash(), context.contextPointer());
if (status != 0)
throw new RuntimeException("Op [" + op.opName() + "] execution failed");
if (context.getOutputArrays().isEmpty()) if (context.getOutputArrays().isEmpty())
return new INDArray[0]; return new INDArray[0];

View File

@ -32,6 +32,7 @@ import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.executioner.OpStatus; import org.nd4j.linalg.api.ops.executioner.OpStatus;
import org.nd4j.linalg.api.ops.impl.reduce.Mmul; import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
import org.nd4j.linalg.api.ops.impl.reduce.MmulBp; import org.nd4j.linalg.api.ops.impl.reduce.MmulBp;
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp;
import org.nd4j.linalg.api.ops.random.compat.RandomStandardNormal; import org.nd4j.linalg.api.ops.random.compat.RandomStandardNormal;
@ -40,6 +41,7 @@ import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.nativeblas.NativeOpsHolder; import org.nd4j.nativeblas.NativeOpsHolder;
import java.util.List; import java.util.List;
@ -604,4 +606,39 @@ public class CustomOpsTests extends BaseNd4jTest {
assertTrue(Shape.shapeEquals(e.shape(), z.shape())); assertTrue(Shape.shapeEquals(e.shape(), z.shape()));
assertEquals(e, z); assertEquals(e, z);
} }
@Test(expected = RuntimeException.class)
public void testInputValidationMergeMax(){
INDArray[] inputs = new INDArray[]{
Nd4j.createFromArray(0.0f, 1.0f, 2.0f).reshape('c', 1, 3),
Nd4j.createFromArray(1.0f).reshape('c', 1, 1)};
INDArray out = Nd4j.create(DataType.FLOAT, 1, 3).assign(Double.NaN);
CustomOp op = DynamicCustomOp.builder("mergemax")
.addInputs(inputs)
.addOutputs(out)
.callInplace(false)
.build();
Nd4j.exec(op);
System.out.println(out);
}
@Test
public void testIsMaxView(){
INDArray predictions = Nd4j.rand(DataType.FLOAT, 3, 4, 3, 2);
INDArray row = predictions.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(0));
row = row.reshape(1, row.length());
assertArrayEquals(new long[]{1, 4}, row.shape());
val result1 = row.ulike();
val result2 = row.ulike();
Nd4j.exec(new IsMax(row.dup(), result1, 1)); //OK
Nd4j.exec(new IsMax(row, result2, 1)); //C++ exception
assertEquals(result1, result2);
}
} }