one more scalar constructor fix

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-21 08:53:36 +03:00
parent 3f4379927a
commit 4211f3b4ce
1 changed files with 2 additions and 10 deletions

View File

@ -1352,12 +1352,7 @@ public abstract class BaseNDArrayFactory implements NDArrayFactory {
*/ */
@Override @Override
public INDArray scalar(float value) { public INDArray scalar(float value) {
if (Nd4j.dataType() == DataType.FLOAT || Nd4j.dataType() == DataType.HALF) return create(new float[] {value}, new long[0], new long[0], DataType.FLOAT, Nd4j.getMemoryManager().getCurrentWorkspace());
return create(new float[] {value}, new int[0], new int[0], 0);
else if (Nd4j.dataType() == DataType.DOUBLE)
return scalar((double) value);
else
return scalar((int) value);
} }
/** /**
@ -1368,10 +1363,7 @@ public abstract class BaseNDArrayFactory implements NDArrayFactory {
*/ */
@Override @Override
public INDArray scalar(double value) { public INDArray scalar(double value) {
if (Nd4j.dataType() == DataType.DOUBLE) return create(new double[] {value}, new long[0], new long[0], DataType.DOUBLE, Nd4j.getMemoryManager().getCurrentWorkspace());
return create(new double[] {value}, new int[0], new int[0], 0);
else
return scalar((float) value);
} }
@Override @Override