Fix datatype issue with GpuGraphRunnerTest (#198)
Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
f25056363b
commit
5039fb22b7
|
@ -18,6 +18,7 @@
|
|||
package org.nd4j.tensorflow.conversion;
|
||||
|
||||
import org.nd4j.BaseND4JTest;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.shade.protobuf.util.JsonFormat;
|
||||
import org.apache.commons.io.IOUtils;
|
||||
import org.junit.Test;
|
||||
|
@ -40,6 +41,11 @@ import static org.junit.Assert.assertNotNull;
|
|||
|
||||
public class GpuGraphRunnerTest extends BaseND4JTest {
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 180000L;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGraphRunner() throws Exception {
|
||||
byte[] content = IOUtils.toByteArray(new ClassPathResource("/tf_graphs/nd4j_convert/simple_graph/frozen_model.pb").getInputStream());
|
||||
|
@ -68,8 +74,8 @@ public class GpuGraphRunnerTest extends BaseND4JTest {
|
|||
assertEquals(2,graphRunner.getInputOrder().size());
|
||||
assertEquals(1,graphRunner.getOutputOrder().size());
|
||||
|
||||
INDArray input1 = Nd4j.linspace(1,4,4).reshape(4);
|
||||
INDArray input2 = Nd4j.linspace(1,4,4).reshape(4);
|
||||
INDArray input1 = Nd4j.linspace(1,4,4).reshape(4).castTo(DataType.FLOAT);
|
||||
INDArray input2 = Nd4j.linspace(1,4,4).reshape(4).castTo(DataType.FLOAT);
|
||||
|
||||
Map<String,INDArray> inputs = new LinkedHashMap<>();
|
||||
inputs.put("input_0",input1);
|
||||
|
|
Loading…
Reference in New Issue