parent
3f4379927a
commit
4211f3b4ce
|
@ -1352,12 +1352,7 @@ public abstract class BaseNDArrayFactory implements NDArrayFactory {
|
|||
*/
|
||||
@Override
|
||||
public INDArray scalar(float value) {
|
||||
if (Nd4j.dataType() == DataType.FLOAT || Nd4j.dataType() == DataType.HALF)
|
||||
return create(new float[] {value}, new int[0], new int[0], 0);
|
||||
else if (Nd4j.dataType() == DataType.DOUBLE)
|
||||
return scalar((double) value);
|
||||
else
|
||||
return scalar((int) value);
|
||||
return create(new float[] {value}, new long[0], new long[0], DataType.FLOAT, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -1368,10 +1363,7 @@ public abstract class BaseNDArrayFactory implements NDArrayFactory {
|
|||
*/
|
||||
@Override
|
||||
public INDArray scalar(double value) {
|
||||
if (Nd4j.dataType() == DataType.DOUBLE)
|
||||
return create(new double[] {value}, new int[0], new int[0], 0);
|
||||
else
|
||||
return scalar((float) value);
|
||||
return create(new double[] {value}, new long[0], new long[0], DataType.DOUBLE, Nd4j.getMemoryManager().getCurrentWorkspace());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
Loading…
Reference in New Issue