Fix datatype issue with GpuGraphRunnerTest (#198)
Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
f25056363b
commit
5039fb22b7
|
@ -18,6 +18,7 @@
|
||||||
package org.nd4j.tensorflow.conversion;
|
package org.nd4j.tensorflow.conversion;
|
||||||
|
|
||||||
import org.nd4j.BaseND4JTest;
|
import org.nd4j.BaseND4JTest;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.shade.protobuf.util.JsonFormat;
|
import org.nd4j.shade.protobuf.util.JsonFormat;
|
||||||
import org.apache.commons.io.IOUtils;
|
import org.apache.commons.io.IOUtils;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
@ -40,6 +41,11 @@ import static org.junit.Assert.assertNotNull;
|
||||||
|
|
||||||
public class GpuGraphRunnerTest extends BaseND4JTest {
|
public class GpuGraphRunnerTest extends BaseND4JTest {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 180000L;
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testGraphRunner() throws Exception {
|
public void testGraphRunner() throws Exception {
|
||||||
byte[] content = IOUtils.toByteArray(new ClassPathResource("/tf_graphs/nd4j_convert/simple_graph/frozen_model.pb").getInputStream());
|
byte[] content = IOUtils.toByteArray(new ClassPathResource("/tf_graphs/nd4j_convert/simple_graph/frozen_model.pb").getInputStream());
|
||||||
|
@ -68,8 +74,8 @@ public class GpuGraphRunnerTest extends BaseND4JTest {
|
||||||
assertEquals(2,graphRunner.getInputOrder().size());
|
assertEquals(2,graphRunner.getInputOrder().size());
|
||||||
assertEquals(1,graphRunner.getOutputOrder().size());
|
assertEquals(1,graphRunner.getOutputOrder().size());
|
||||||
|
|
||||||
INDArray input1 = Nd4j.linspace(1,4,4).reshape(4);
|
INDArray input1 = Nd4j.linspace(1,4,4).reshape(4).castTo(DataType.FLOAT);
|
||||||
INDArray input2 = Nd4j.linspace(1,4,4).reshape(4);
|
INDArray input2 = Nd4j.linspace(1,4,4).reshape(4).castTo(DataType.FLOAT);
|
||||||
|
|
||||||
Map<String,INDArray> inputs = new LinkedHashMap<>();
|
Map<String,INDArray> inputs = new LinkedHashMap<>();
|
||||||
inputs.put("input_0",input1);
|
inputs.put("input_0",input1);
|
||||||
|
|
Loading…
Reference in New Issue