scalar constructor fix

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-21 08:50:59 +03:00
parent 4310e87860
commit 3f4379927a
3 changed files with 15 additions and 25 deletions

View File

@ -1329,13 +1329,19 @@ public abstract class BaseNDArrayFactory implements NDArrayFactory {
*/
@Override
public INDArray scalar(Number value) {
if (Nd4j.dataType() == DataType.DOUBLE)
return scalar(value.doubleValue(), 0);
if (Nd4j.dataType() == DataType.FLOAT || Nd4j.dataType() == DataType.HALF)
return scalar(value.floatValue(), 0);
if (Nd4j.dataType() == DataType.INT)
return scalar(value.intValue(), 0);
throw new IllegalStateException("Illegal data opType " + Nd4j.dataType());
if (value instanceof Double)
return create(new double[]{value.doubleValue()}, new long[0], new long[0], DataType.DOUBLE, Nd4j.getMemoryManager().getCurrentWorkspace());
else if (value instanceof Float)
return create(new float[]{value.floatValue()}, new long[0], new long[0], DataType.FLOAT, Nd4j.getMemoryManager().getCurrentWorkspace());
else if (value instanceof Long)
return create(new long[]{value.longValue()}, new long[0], new long[0], DataType.LONG, Nd4j.getMemoryManager().getCurrentWorkspace());
else if (value instanceof Integer)
return create(new int[]{value.intValue()}, new long[0], new long[0], DataType.INT, Nd4j.getMemoryManager().getCurrentWorkspace());
else if (value instanceof Short)
return create(new short[]{value.shortValue()}, new long[0], new long[0], DataType.SHORT, Nd4j.getMemoryManager().getCurrentWorkspace());
else if (value instanceof Byte)
return create(new byte[]{value.byteValue()}, new long[0], new long[0], DataType.BYTE, Nd4j.getMemoryManager().getCurrentWorkspace());
throw new IllegalStateException("Unknown instance of Number: [" + value.getClass().getCanonicalName() + "]");
}
/**

View File

@ -3599,6 +3599,7 @@ public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc
// #include <ShapeDescriptor.h>
// #include <helpers/ConstantShapeHelper.h>
// #include <array/DataBuffer.h>
// #include <execution/AffinityManager.h>
@Namespace("nd4j") public static native @ByVal @Name("operator -") NDArray subtract(float arg0, @Const @ByRef NDArray arg1);

View File

@ -3599,6 +3599,7 @@ public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc
// #include <ShapeDescriptor.h>
// #include <helpers/ConstantShapeHelper.h>
// #include <array/DataBuffer.h>
// #include <execution/AffinityManager.h>
@Namespace("nd4j") public static native @ByVal @Name("operator -") NDArray subtract(float arg0, @Const @ByRef NDArray arg1);
@ -18208,24 +18209,6 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
}
// #endif
// #if NOT_EXCLUDED(OP_space_to_batch_nd)
@Namespace("nd4j::ops") public static class space_to_batch_nd extends DeclarableCustomOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public space_to_batch_nd(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public space_to_batch_nd(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public space_to_batch_nd position(long position) {
return (space_to_batch_nd)super.position(position);
}
public space_to_batch_nd() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif
/**
*
*