From 5039fb22b7dced128093ed18e01cc185df067e13 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Wed, 29 Jan 2020 21:16:56 +1100 Subject: [PATCH] Fix datatype issue with GpuGraphRunnerTest (#198) Signed-off-by: AlexDBlack --- .../nd4j/tensorflow/conversion/GpuGraphRunnerTest.java | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/gpujava/org/nd4j/tensorflow/conversion/GpuGraphRunnerTest.java b/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/gpujava/org/nd4j/tensorflow/conversion/GpuGraphRunnerTest.java index 28cd5b7b2..a035592df 100644 --- a/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/gpujava/org/nd4j/tensorflow/conversion/GpuGraphRunnerTest.java +++ b/nd4j/nd4j-backends/nd4j-tests-tensorflow/src/test/gpujava/org/nd4j/tensorflow/conversion/GpuGraphRunnerTest.java @@ -18,6 +18,7 @@ package org.nd4j.tensorflow.conversion; import org.nd4j.BaseND4JTest; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.shade.protobuf.util.JsonFormat; import org.apache.commons.io.IOUtils; import org.junit.Test; @@ -40,6 +41,11 @@ import static org.junit.Assert.assertNotNull; public class GpuGraphRunnerTest extends BaseND4JTest { + @Override + public long getTimeoutMilliseconds() { + return 180000L; + } + @Test public void testGraphRunner() throws Exception { byte[] content = IOUtils.toByteArray(new ClassPathResource("/tf_graphs/nd4j_convert/simple_graph/frozen_model.pb").getInputStream()); @@ -68,8 +74,8 @@ public class GpuGraphRunnerTest extends BaseND4JTest { assertEquals(2,graphRunner.getInputOrder().size()); assertEquals(1,graphRunner.getOutputOrder().size()); - INDArray input1 = Nd4j.linspace(1,4,4).reshape(4); - INDArray input2 = Nd4j.linspace(1,4,4).reshape(4); + INDArray input1 = Nd4j.linspace(1,4,4).reshape(4).castTo(DataType.FLOAT); + INDArray input2 = Nd4j.linspace(1,4,4).reshape(4).castTo(DataType.FLOAT); Map inputs = new LinkedHashMap<>(); inputs.put("input_0",input1);