[WIP] Few fixes (#153)
* throw exception if op execution failed Signed-off-by: raver119 <raver119@gmail.com> * expected for test Signed-off-by: raver119 <raver119@gmail.com> * one more ismax test Signed-off-by: raver119 <raver119@gmail.com> * ismax view fix Signed-off-by: raver119 <raver119@gmail.com>master
parent
80d35377d4
commit
243bf866c4
|
@ -84,8 +84,8 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>
|
|||
if (length < ELEMENT_THRESHOLD) {
|
||||
|
||||
for (int i = 0; i < length; i++) {
|
||||
if (currMax < input->e<X>(i*eleStride)) {
|
||||
currMax = input->e<X>(i*eleStride);
|
||||
if (currMax < input->e<X>(i)) {
|
||||
currMax = input->e<X>(i);
|
||||
maxIdx = i;
|
||||
}
|
||||
output->p<Z>(i, 0.f);
|
||||
|
@ -97,8 +97,8 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>
|
|||
int maxIdxLocal = maxIdx;
|
||||
auto currMaxLocal = currMax;
|
||||
for (int i = 0; i < length; i++) {
|
||||
if (currMaxLocal < input->e<X>(i*eleStride)) {
|
||||
currMaxLocal = input->e<X>(i*eleStride);
|
||||
if (currMaxLocal < input->e<X>(i)) {
|
||||
currMaxLocal = input->e<X>(i);
|
||||
maxIdxLocal = i;
|
||||
}
|
||||
output->p<Z>(i, 0.f);
|
||||
|
|
|
@ -2489,7 +2489,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
val ctx = AtomicAllocator.getInstance().getDeviceContext();
|
||||
((CudaOpContext) context).setCudaStream(ctx.getOldStream(), ctx.getBufferReduction(), ctx.getBufferAllocation());
|
||||
|
||||
nativeOps.execCustomOp2(null, op.opHash(), context.contextPointer());
|
||||
val status = nativeOps.execCustomOp2(null, op.opHash(), context.contextPointer());
|
||||
if (status != 0)
|
||||
throw new RuntimeException("Op [" + op.opName() + "] execution failed");
|
||||
|
||||
for (val arr:op.outputArguments())
|
||||
AtomicAllocator.getInstance().registerAction(ctx, arr);
|
||||
|
|
|
@ -2077,7 +2077,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
}
|
||||
|
||||
|
||||
loop.execCustomOp2(null, op.opHash(), context.contextPointer());
|
||||
val status = loop.execCustomOp2(null, op.opHash(), context.contextPointer());
|
||||
if (status != 0)
|
||||
throw new RuntimeException("Op [" + op.opName() + "] execution failed");
|
||||
|
||||
if (context.getOutputArrays().isEmpty())
|
||||
return new INDArray[0];
|
||||
|
|
|
@ -32,6 +32,7 @@ import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
|||
import org.nd4j.linalg.api.ops.executioner.OpStatus;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.MmulBp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp;
|
||||
import org.nd4j.linalg.api.ops.random.compat.RandomStandardNormal;
|
||||
|
@ -40,6 +41,7 @@ import org.nd4j.linalg.api.shape.Shape;
|
|||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||
|
||||
import java.util.List;
|
||||
|
@ -604,4 +606,39 @@ public class CustomOpsTests extends BaseNd4jTest {
|
|||
assertTrue(Shape.shapeEquals(e.shape(), z.shape()));
|
||||
assertEquals(e, z);
|
||||
}
|
||||
|
||||
@Test(expected = RuntimeException.class)
|
||||
public void testInputValidationMergeMax(){
|
||||
INDArray[] inputs = new INDArray[]{
|
||||
Nd4j.createFromArray(0.0f, 1.0f, 2.0f).reshape('c', 1, 3),
|
||||
Nd4j.createFromArray(1.0f).reshape('c', 1, 1)};
|
||||
|
||||
INDArray out = Nd4j.create(DataType.FLOAT, 1, 3).assign(Double.NaN);
|
||||
CustomOp op = DynamicCustomOp.builder("mergemax")
|
||||
.addInputs(inputs)
|
||||
.addOutputs(out)
|
||||
.callInplace(false)
|
||||
.build();
|
||||
|
||||
Nd4j.exec(op);
|
||||
System.out.println(out);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testIsMaxView(){
|
||||
INDArray predictions = Nd4j.rand(DataType.FLOAT, 3, 4, 3, 2);
|
||||
|
||||
INDArray row = predictions.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(0), NDArrayIndex.point(0));
|
||||
row = row.reshape(1, row.length());
|
||||
assertArrayEquals(new long[]{1, 4}, row.shape());
|
||||
|
||||
val result1 = row.ulike();
|
||||
val result2 = row.ulike();
|
||||
|
||||
Nd4j.exec(new IsMax(row.dup(), result1, 1)); //OK
|
||||
Nd4j.exec(new IsMax(row, result2, 1)); //C++ exception
|
||||
|
||||
assertEquals(result1, result2);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue