parent
4310e87860
commit
3f4379927a
|
@ -1329,13 +1329,19 @@ public abstract class BaseNDArrayFactory implements NDArrayFactory {
|
|||
*/
|
||||
@Override
|
||||
public INDArray scalar(Number value) {
|
||||
if (Nd4j.dataType() == DataType.DOUBLE)
|
||||
return scalar(value.doubleValue(), 0);
|
||||
if (Nd4j.dataType() == DataType.FLOAT || Nd4j.dataType() == DataType.HALF)
|
||||
return scalar(value.floatValue(), 0);
|
||||
if (Nd4j.dataType() == DataType.INT)
|
||||
return scalar(value.intValue(), 0);
|
||||
throw new IllegalStateException("Illegal data opType " + Nd4j.dataType());
|
||||
if (value instanceof Double)
|
||||
return create(new double[]{value.doubleValue()}, new long[0], new long[0], DataType.DOUBLE, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
else if (value instanceof Float)
|
||||
return create(new float[]{value.floatValue()}, new long[0], new long[0], DataType.FLOAT, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
else if (value instanceof Long)
|
||||
return create(new long[]{value.longValue()}, new long[0], new long[0], DataType.LONG, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
else if (value instanceof Integer)
|
||||
return create(new int[]{value.intValue()}, new long[0], new long[0], DataType.INT, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
else if (value instanceof Short)
|
||||
return create(new short[]{value.shortValue()}, new long[0], new long[0], DataType.SHORT, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
else if (value instanceof Byte)
|
||||
return create(new byte[]{value.byteValue()}, new long[0], new long[0], DataType.BYTE, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
throw new IllegalStateException("Unknown instance of Number: [" + value.getClass().getCanonicalName() + "]");
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -3599,6 +3599,7 @@ public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc
|
|||
// #include <ShapeDescriptor.h>
|
||||
// #include <helpers/ConstantShapeHelper.h>
|
||||
// #include <array/DataBuffer.h>
|
||||
// #include <execution/AffinityManager.h>
|
||||
|
||||
|
||||
@Namespace("nd4j") public static native @ByVal @Name("operator -") NDArray subtract(float arg0, @Const @ByRef NDArray arg1);
|
||||
|
|
|
@ -3599,6 +3599,7 @@ public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc
|
|||
// #include <ShapeDescriptor.h>
|
||||
// #include <helpers/ConstantShapeHelper.h>
|
||||
// #include <array/DataBuffer.h>
|
||||
// #include <execution/AffinityManager.h>
|
||||
|
||||
|
||||
@Namespace("nd4j") public static native @ByVal @Name("operator -") NDArray subtract(float arg0, @Const @ByRef NDArray arg1);
|
||||
|
@ -18208,24 +18209,6 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
}
|
||||
// #endif
|
||||
|
||||
// #if NOT_EXCLUDED(OP_space_to_batch_nd)
|
||||
@Namespace("nd4j::ops") public static class space_to_batch_nd extends DeclarableCustomOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public space_to_batch_nd(Pointer p) { super(p); }
|
||||
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||
public space_to_batch_nd(long size) { super((Pointer)null); allocateArray(size); }
|
||||
private native void allocateArray(long size);
|
||||
@Override public space_to_batch_nd position(long position) {
|
||||
return (space_to_batch_nd)super.position(position);
|
||||
}
|
||||
|
||||
public space_to_batch_nd() { super((Pointer)null); allocate(); }
|
||||
private native void allocate();
|
||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||
}
|
||||
// #endif
|
||||
|
||||
/**
|
||||
*
|
||||
*
|
||||
|
|
Loading…
Reference in New Issue