diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index 059412c19..56a713d6c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -1595,7 +1595,7 @@ public class Nd4jTestsC extends BaseNd4jTestWithBackends { @ParameterizedTest @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") public void testSigmoid(Nd4jBackend backend) { - INDArray n = Nd4j.create(new float[] {1, 2, 3, 4}); + INDArray n = Nd4j.create(new float[] {1, 2, 3, 4}).castTo(DataType.DOUBLE); INDArray assertion = Nd4j.create(new float[] {0.73105858f, 0.88079708f, 0.95257413f, 0.98201379f}).castTo(DataType.DOUBLE); INDArray sigmoid = Transforms.sigmoid(n, false); assertEquals(assertion, sigmoid);