Fix datatype issue with GpuGraphRunnerTest (#198)

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2020-01-29 21:16:56 +11:00 committed by GitHub
parent f25056363b
commit 5039fb22b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 8 additions and 2 deletions

View File

@ -18,6 +18,7 @@
package org.nd4j.tensorflow.conversion; package org.nd4j.tensorflow.conversion;
import org.nd4j.BaseND4JTest; import org.nd4j.BaseND4JTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.shade.protobuf.util.JsonFormat; import org.nd4j.shade.protobuf.util.JsonFormat;
import org.apache.commons.io.IOUtils; import org.apache.commons.io.IOUtils;
import org.junit.Test; import org.junit.Test;
@ -40,6 +41,11 @@ import static org.junit.Assert.assertNotNull;
public class GpuGraphRunnerTest extends BaseND4JTest { public class GpuGraphRunnerTest extends BaseND4JTest {
@Override
public long getTimeoutMilliseconds() {
return 180000L;
}
@Test @Test
public void testGraphRunner() throws Exception { public void testGraphRunner() throws Exception {
byte[] content = IOUtils.toByteArray(new ClassPathResource("/tf_graphs/nd4j_convert/simple_graph/frozen_model.pb").getInputStream()); byte[] content = IOUtils.toByteArray(new ClassPathResource("/tf_graphs/nd4j_convert/simple_graph/frozen_model.pb").getInputStream());
@ -68,8 +74,8 @@ public class GpuGraphRunnerTest extends BaseND4JTest {
assertEquals(2,graphRunner.getInputOrder().size()); assertEquals(2,graphRunner.getInputOrder().size());
assertEquals(1,graphRunner.getOutputOrder().size()); assertEquals(1,graphRunner.getOutputOrder().size());
INDArray input1 = Nd4j.linspace(1,4,4).reshape(4); INDArray input1 = Nd4j.linspace(1,4,4).reshape(4).castTo(DataType.FLOAT);
INDArray input2 = Nd4j.linspace(1,4,4).reshape(4); INDArray input2 = Nd4j.linspace(1,4,4).reshape(4).castTo(DataType.FLOAT);
Map<String,INDArray> inputs = new LinkedHashMap<>(); Map<String,INDArray> inputs = new LinkedHashMap<>();
inputs.put("input_0",input1); inputs.put("input_0",input1);