parent
3f4379927a
commit
4211f3b4ce
|
@ -1352,12 +1352,7 @@ public abstract class BaseNDArrayFactory implements NDArrayFactory {
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public INDArray scalar(float value) {
|
public INDArray scalar(float value) {
|
||||||
if (Nd4j.dataType() == DataType.FLOAT || Nd4j.dataType() == DataType.HALF)
|
return create(new float[] {value}, new long[0], new long[0], DataType.FLOAT, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||||
return create(new float[] {value}, new int[0], new int[0], 0);
|
|
||||||
else if (Nd4j.dataType() == DataType.DOUBLE)
|
|
||||||
return scalar((double) value);
|
|
||||||
else
|
|
||||||
return scalar((int) value);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -1368,10 +1363,7 @@ public abstract class BaseNDArrayFactory implements NDArrayFactory {
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public INDArray scalar(double value) {
|
public INDArray scalar(double value) {
|
||||||
if (Nd4j.dataType() == DataType.DOUBLE)
|
return create(new double[] {value}, new long[0], new long[0], DataType.DOUBLE, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||||
return create(new double[] {value}, new int[0], new int[0], 0);
|
|
||||||
else
|
|
||||||
return scalar((float) value);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
Loading…
Reference in New Issue