diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index 56a713d6c..4d2ffed77 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -8591,7 +8591,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends { } private static INDArray fwd(INDArray input, INDArray W, INDArray b){ - INDArray ret = Nd4j.createUninitialized(input.size(0), W.size(1)); + INDArray ret = Nd4j.createUninitialized(input.size(0), W.size(1)).castTo(DataType.DOUBLE); input.mmuli(W, ret); ret.addiRowVector(b); return ret;