From 3f4379927a41df7bf23cf7ea425245af72bfa7f2 Mon Sep 17 00:00:00 2001 From: raver119 Date: Wed, 21 Aug 2019 08:50:59 +0300 Subject: [PATCH] scalar constructor fix Signed-off-by: raver119 --- .../linalg/factory/BaseNDArrayFactory.java | 20 ++++++++++++------- .../java/org/nd4j/nativeblas/Nd4jCuda.java | 1 + .../java/org/nd4j/nativeblas/Nd4jCpu.java | 19 +----------------- 3 files changed, 15 insertions(+), 25 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/BaseNDArrayFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/BaseNDArrayFactory.java index ab9c31686..53be087f2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/BaseNDArrayFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/BaseNDArrayFactory.java @@ -1329,13 +1329,19 @@ public abstract class BaseNDArrayFactory implements NDArrayFactory { */ @Override public INDArray scalar(Number value) { - if (Nd4j.dataType() == DataType.DOUBLE) - return scalar(value.doubleValue(), 0); - if (Nd4j.dataType() == DataType.FLOAT || Nd4j.dataType() == DataType.HALF) - return scalar(value.floatValue(), 0); - if (Nd4j.dataType() == DataType.INT) - return scalar(value.intValue(), 0); - throw new IllegalStateException("Illegal data opType " + Nd4j.dataType()); + if (value instanceof Double) + return create(new double[]{value.doubleValue()}, new long[0], new long[0], DataType.DOUBLE, Nd4j.getMemoryManager().getCurrentWorkspace()); + else if (value instanceof Float) + return create(new float[]{value.floatValue()}, new long[0], new long[0], DataType.FLOAT, Nd4j.getMemoryManager().getCurrentWorkspace()); + else if (value instanceof Long) + return create(new long[]{value.longValue()}, new long[0], new long[0], DataType.LONG, Nd4j.getMemoryManager().getCurrentWorkspace()); + else if (value instanceof Integer) + return create(new int[]{value.intValue()}, new long[0], new long[0], DataType.INT, Nd4j.getMemoryManager().getCurrentWorkspace()); + else if (value instanceof Short) + return create(new short[]{value.shortValue()}, new long[0], new long[0], DataType.SHORT, Nd4j.getMemoryManager().getCurrentWorkspace()); + else if (value instanceof Byte) + return create(new byte[]{value.byteValue()}, new long[0], new long[0], DataType.BYTE, Nd4j.getMemoryManager().getCurrentWorkspace()); + throw new IllegalStateException("Unknown instance of Number: [" + value.getClass().getCanonicalName() + "]"); } /** diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index 886b258ae..33ba27069 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -3599,6 +3599,7 @@ public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc // #include // #include // #include +// #include @Namespace("nd4j") public static native @ByVal @Name("operator -") NDArray subtract(float arg0, @Const @ByRef NDArray arg1); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index 2f1189aea..4fa484a23 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -3599,6 +3599,7 @@ public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc // #include // #include // #include +// #include @Namespace("nd4j") public static native @ByVal @Name("operator -") NDArray subtract(float arg0, @Const @ByRef NDArray arg1); @@ -18208,24 +18209,6 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); } // #endif -// #if NOT_EXCLUDED(OP_space_to_batch_nd) - @Namespace("nd4j::ops") public static class space_to_batch_nd extends DeclarableCustomOp { - static { Loader.load(); } - /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ - public space_to_batch_nd(Pointer p) { super(p); } - /** Native array allocator. Access with {@link Pointer#position(long)}. */ - public space_to_batch_nd(long size) { super((Pointer)null); allocateArray(size); } - private native void allocateArray(long size); - @Override public space_to_batch_nd position(long position) { - return (space_to_batch_nd)super.position(position); - } - - public space_to_batch_nd() { super((Pointer)null); allocate(); } - private native void allocate(); - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } -// #endif - /** * *