1184 lines
41 KiB
Java
1184 lines
41 KiB
Java
/*******************************************************************************
|
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
*
|
|
* This program and the accompanying materials are made available under the
|
|
* terms of the Apache License, Version 2.0 which is available at
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
* License for the specific language governing permissions and limitations
|
|
* under the License.
|
|
*
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
******************************************************************************/
|
|
|
|
package org.nd4j.imports;
|
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
import lombok.val;
|
|
import org.junit.After;
|
|
import org.junit.Before;
|
|
import org.junit.Ignore;
|
|
import org.junit.Test;
|
|
import org.junit.runner.RunWith;
|
|
import org.junit.runners.Parameterized;
|
|
import org.nd4j.autodiff.execution.conf.ExecutionMode;
|
|
import org.nd4j.autodiff.execution.conf.ExecutorConfiguration;
|
|
import org.nd4j.autodiff.execution.conf.OutputMode;
|
|
import org.nd4j.autodiff.functions.DifferentialFunction;
|
|
import org.nd4j.autodiff.samediff.SDVariable;
|
|
import org.nd4j.autodiff.samediff.SameDiff;
|
|
import org.nd4j.autodiff.samediff.internal.Variable;
|
|
import org.nd4j.graph.FlatGraph;
|
|
import org.nd4j.graph.FlatNode;
|
|
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
|
|
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
|
import org.nd4j.linalg.BaseNd4jTest;
|
|
import org.nd4j.linalg.api.buffer.DataType;
|
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
|
import org.nd4j.linalg.api.ops.impl.controlflow.If;
|
|
import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear;
|
|
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
|
import org.nd4j.linalg.factory.Nd4j;
|
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
|
import org.nd4j.linalg.io.ClassPathResource;
|
|
import org.nd4j.linalg.util.HashUtil;
|
|
import org.nd4j.nativeblas.NativeOpsHolder;
|
|
import org.tensorflow.framework.GraphDef;
|
|
|
|
import java.io.DataInputStream;
|
|
import java.io.File;
|
|
import java.io.IOException;
|
|
import java.nio.ByteBuffer;
|
|
import java.util.Arrays;
|
|
import java.util.Collections;
|
|
import java.util.HashMap;
|
|
import java.util.Map;
|
|
import java.util.stream.Collectors;
|
|
|
|
import static org.junit.Assert.*;
|
|
|
|
|
|
@Slf4j
|
|
@Ignore
|
|
@RunWith(Parameterized.class)
|
|
public class TensorFlowImportTest extends BaseNd4jTest {
|
|
private static ExecutorConfiguration configuration = ExecutorConfiguration.builder()
|
|
.executionMode(ExecutionMode.SEQUENTIAL)
|
|
.profilingMode(OpExecutioner.ProfilingMode.DISABLED)
|
|
.gatherTimings(true)
|
|
.outputMode(OutputMode.IMPLICIT)
|
|
.build();
|
|
|
|
public TensorFlowImportTest(Nd4jBackend backend) {
|
|
super(backend);
|
|
}
|
|
|
|
|
|
@Override
|
|
public char ordering() {
|
|
return 'c';
|
|
}
|
|
|
|
@Before
|
|
public void setUp() {
|
|
}
|
|
|
|
@After
|
|
public void tearDown() {
|
|
NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false);
|
|
NativeOpsHolder.getInstance().getDeviceNativeOps().enableVerboseMode(false);
|
|
}
|
|
|
|
@Test
|
|
public void testClassHolder() {
|
|
DifferentialFunctionClassHolder.getInstance();
|
|
}
|
|
|
|
@Test
|
|
public void testSingleExample_1() {
|
|
val g =TFGraphMapper.getInstance().importGraph(new File("C:\\Users\\raver\\Downloads\\mnist.pb"));
|
|
|
|
val array = Nd4j.ones(1, 28, 28);
|
|
g.associateArrayWithVariable(array, "flatten_1_input");
|
|
|
|
//g.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/mnist.fb"), ExecutorConfiguration.builder().outputMode(OutputMode.VARIABLE_SPACE).build());
|
|
|
|
g.execAndEndResult();
|
|
}
|
|
|
|
|
|
@Test
|
|
public void testAssertImport_1() {
|
|
val graph = TFGraphMapper.getInstance().importGraph(new File("C:\\Users\\raver\\Downloads\\test.pb"));
|
|
}
|
|
|
|
@Test
|
|
public void testArgMaxImport_2() throws Exception {
|
|
val graph = TFGraphMapper.getInstance().importGraph(new ClassPathResource("/tf_graphs/examples/reductions/argmax3,4,5_-1/frozen_graph.pbtxt").getInputStream());
|
|
|
|
graph.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/argmax_macos.fb"), ExecutorConfiguration.builder().outputMode(OutputMode.IMPLICIT).build(), true);
|
|
|
|
log.info(graph.asFlatPrint());
|
|
}
|
|
|
|
@Test
|
|
public void testArgMaxImport_1() throws Exception {
|
|
val graph = TFGraphMapper.getInstance().importGraph(new ClassPathResource("/tf_graphs/argmax.pb.txt").getInputStream());
|
|
|
|
log.info(graph.asFlatPrint());
|
|
val result = graph.execAndEndResult();
|
|
|
|
val exp = Nd4j.trueVector(new long[]{2, 2, 2});
|
|
|
|
assertEquals(exp, result);
|
|
}
|
|
|
|
|
|
@Test
|
|
public void testIfStatementNodes() throws Exception {
|
|
// /home/agibsonccc/code/dl4j-test-resources/src/main/resources/tf_graphs/examples/simple_cond/frozen_graph.pbtxt
|
|
val resourceInputStream = new ClassPathResource("/tf_graphs/examples/simple_cond/frozen_model.pb").getInputStream();
|
|
val mapper = TFGraphMapper.getInstance();
|
|
val readGraph = TFGraphMapper.getInstance().parseGraphFrom(resourceInputStream);
|
|
val nodes = mapper.nodesByName(readGraph);
|
|
/**
|
|
* Work backwards starting fom the condition id (usually a name containing condid/pred_id:
|
|
|
|
*/
|
|
|
|
val firstInput = nodes.get("cond5/Merge");
|
|
val ifNodes = mapper.nodesForIf(firstInput,readGraph);
|
|
assertEquals(5,ifNodes.getFalseNodes().size());
|
|
assertEquals(5,ifNodes.getTrueNodes().size());
|
|
assertEquals(10,ifNodes.getCondNodes().size());
|
|
|
|
|
|
val secondInput = nodes.get("cond6/Merge");
|
|
val ifNodesTwo = mapper.nodesForIf(secondInput,readGraph);
|
|
assertEquals(5,ifNodesTwo.getFalseNodes().size());
|
|
assertEquals(5,ifNodesTwo.getTrueNodes().size());
|
|
assertEquals(6,ifNodesTwo.getCondNodes().size());
|
|
|
|
|
|
val parentContext = SameDiff.create();
|
|
val ifStatement = new If();
|
|
ifStatement.initFromTensorFlow(firstInput,parentContext,Collections.emptyMap(),readGraph);
|
|
assertNotNull(ifStatement.getLoopBodyExecution());
|
|
assertNotNull(ifStatement.getFalseBodyExecution());
|
|
assertNotNull(ifStatement.getPredicateExecution());
|
|
|
|
}
|
|
|
|
@Test
|
|
@Ignore
|
|
public void testIfIgnoreWhileMerge() throws Exception {
|
|
val resourceInputStream = new ClassPathResource("/tf_graphs/examples/simple_while/frozen_model.pb").getInputStream();
|
|
val mapper = TFGraphMapper.getInstance();
|
|
val readGraph = TFGraphMapper.getInstance().parseGraphFrom(resourceInputStream);
|
|
val nodes = mapper.nodesByName(readGraph);
|
|
val firstInput = nodes.get("output/Merge");
|
|
assertNotNull(firstInput);
|
|
assertFalse(mapper.isOpIgnoreException(firstInput));
|
|
|
|
val resourceInputStreamIf = new ClassPathResource("/tf_graphs/examples/simple_cond/frozen_model.pb").getInputStream();
|
|
val readGraphIf = TFGraphMapper.getInstance().parseGraphFrom(resourceInputStreamIf);
|
|
val nodesif = mapper.nodesByName(readGraphIf);
|
|
/**
|
|
* Work backwards starting fom the condition id (usually a name containing condid/pred_id:
|
|
|
|
*/
|
|
|
|
val secondInput = nodesif.get("cond5/Merge");
|
|
assertNotNull(secondInput);
|
|
assertTrue(mapper.isOpIgnoreException(secondInput));
|
|
|
|
}
|
|
|
|
|
|
@Test
|
|
public void testHashEquality1() {
|
|
long hash = HashUtil.getLongHash("Conv2D");
|
|
assertEquals(-1637140380760460323L, hash);
|
|
}
|
|
|
|
|
|
@Test
|
|
public void testHashEquality2() {
|
|
long hash = HashUtil.getLongHash("switch");
|
|
assertEquals(-1988317239813741487L, hash);
|
|
}
|
|
|
|
@Test
|
|
public void testCustomOps1() {
|
|
val map = Nd4j.getExecutioner().getCustomOperations();
|
|
|
|
assertTrue(map.size() > 0);
|
|
}
|
|
|
|
@Test
|
|
@Ignore
|
|
public void importGraph1() throws Exception {
|
|
SameDiff graph = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/max_add_2.pb.txt").getInputStream());
|
|
|
|
assertNotNull(graph);
|
|
|
|
assertEquals(2, graph.variableMap().size());
|
|
|
|
SDVariable var0 = graph.variableMap().get("zeros");
|
|
SDVariable var1 = graph.variableMap().get("ones");
|
|
|
|
assertNotNull(var0);
|
|
assertNotNull(var1);
|
|
|
|
assertNotNull(var0.getArr());
|
|
assertNotNull(var1.getArr());
|
|
|
|
assertEquals(0.0, var0.getArr().sumNumber().doubleValue(), 1e-5);
|
|
assertEquals(12.0, var1.getArr().sumNumber().doubleValue(), 1e-5);
|
|
}
|
|
|
|
|
|
@Test
|
|
@Ignore
|
|
public void importGraph2() throws Exception {
|
|
SameDiff graph = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/tensorflow_inception_graph.pb").getInputStream());
|
|
|
|
assertNotNull(graph);
|
|
}
|
|
|
|
|
|
@Test
|
|
@Ignore
|
|
public void importGraph3() throws Exception {
|
|
SameDiff graph = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/max_log_reg.pb.txt").getInputStream());
|
|
|
|
assertNotNull(graph);
|
|
}
|
|
|
|
@Test
|
|
@Ignore
|
|
public void testImportIris() throws Exception {
|
|
SameDiff graph = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/train_iris.pb").getInputStream());
|
|
assertNotNull(graph);
|
|
|
|
}
|
|
|
|
|
|
@Test
|
|
@Ignore
|
|
public void importGraph4() throws Exception {
|
|
SameDiff graph = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/max_multiply.pb.txt").getInputStream());
|
|
|
|
assertNotNull(graph);
|
|
|
|
val p0 = Nd4j.create(10, 10).assign(2.0);
|
|
val p1 = Nd4j.create(10, 10).assign(3.0);
|
|
|
|
graph.associateArrayWithVariable(p0,graph.variableMap().get("Placeholder"));
|
|
graph.associateArrayWithVariable(p1, graph.variableMap().get("Placeholder_1"));
|
|
|
|
|
|
graph.var("Placeholder", p0);
|
|
graph.var("Placeholder_1", p1);
|
|
|
|
val res = graph.execAndEndResult();
|
|
|
|
|
|
|
|
assertEquals(6.0, res.meanNumber().doubleValue(), 1e-5);
|
|
}
|
|
|
|
|
|
|
|
@Test
|
|
public void testLenet() throws Exception {
|
|
/**
|
|
* Produced with:
|
|
* python ~/anaconda2/lib/python2.7/site-packages/tensorflow/python/tools/freeze_graph.py --input_graph=graph2.pb.txt --input_checkpoint=test3.ckpt --output_graph=graph_frozen2.pb --output_node_name=output/BiasAdd --input_binary=False
|
|
|
|
*/
|
|
|
|
Nd4j.create(1);
|
|
val rawGraph = GraphDef.parseFrom(new ClassPathResource("tf_graphs/lenet_cnn.pb").getInputStream());
|
|
val nodeNames = rawGraph.getNodeList().stream().map(node -> node.getName()).collect(Collectors.toList());
|
|
System.out.println(nodeNames);
|
|
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/lenet_cnn.pb").getInputStream());
|
|
|
|
|
|
val convNode = tg.getVariable("conv2d/kernel");
|
|
assertNotNull(convNode.getArr());
|
|
val shape = convNode.getShape();
|
|
System.out.println(Arrays.toString(shape));
|
|
|
|
// this is NHWC weights. will be changed soon.
|
|
assertArrayEquals(new int[]{5,5,1,32}, shape);
|
|
System.out.println(convNode);
|
|
}
|
|
|
|
@Test
|
|
public void testIntermediate2() throws Exception {
|
|
Nd4j.create(1);
|
|
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/max_lstm.pb").getInputStream());
|
|
}
|
|
|
|
@Test
|
|
public void testIntermediate1() throws Exception {
|
|
Nd4j.create(1);
|
|
|
|
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/tensorflow_inception_graph.pb").getInputStream());
|
|
|
|
assertTrue(tg.getVariable("input") != null);
|
|
// assertTrue(tg.getVariableSpace().getVariable("input").isPlaceholder());
|
|
|
|
val ipod = Nd4j.read(new DataInputStream(new ClassPathResource("tf_graphs/ipod.nd4").getInputStream()));
|
|
|
|
tg.setArrayForVariable("input",ipod);
|
|
|
|
val buffer = tg.asFlatBuffers(true);
|
|
assertNotNull(buffer);
|
|
|
|
}
|
|
|
|
|
|
|
|
@Test
|
|
public void testIntermediateLoop1() throws Exception {
|
|
Nd4j.create(1);
|
|
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/simple_while.pb.txt").getInputStream());
|
|
|
|
assertNotNull(tg);
|
|
|
|
|
|
val graph = FlatGraph.getRootAsFlatGraph(tg.asFlatBuffers(true));
|
|
|
|
assertEquals(6, graph.variablesLength());
|
|
// assertEquals("alpha/Assign", graph.nodes(0).name());
|
|
}
|
|
|
|
|
|
@Test
|
|
@Ignore
|
|
public void testWeirdConvImport() {
|
|
val tg = TFGraphMapper.getInstance().importGraph(new File("/home/agibsonccc/code/raver_tfimport_test1/profiling_conv.pb.txt"));
|
|
assertNotNull(tg);
|
|
}
|
|
|
|
|
|
@Test
|
|
public void testIntermediateLoop3() throws Exception {
|
|
Nd4j.create(1);
|
|
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/nested_while.pb.txt").getInputStream());
|
|
|
|
assertNotNull(tg);
|
|
|
|
// now converting to FlatBuffer
|
|
val fb = tg.asFlatBuffers(true);
|
|
assertNotNull(fb);
|
|
|
|
val graph = FlatGraph.getRootAsFlatGraph(fb);
|
|
assertEquals(15, graph.variablesLength());
|
|
|
|
//assertEquals("phi/Assign", graph.nodes(0).name());
|
|
//assertEquals("alpha/Assign", graph.nodes(1).name());
|
|
|
|
assertEquals(2, graph.nodes(0).inputPairedLength());
|
|
assertEquals(2, graph.nodes(1).inputPairedLength());
|
|
|
|
// tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/nested_while.fb"));
|
|
}
|
|
|
|
|
|
|
|
@Test
|
|
@Ignore
|
|
public void testIntermediateStridedSlice1() throws Exception {
|
|
Nd4j.create(1);
|
|
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/tensor_slice.pb.txt").getInputStream());
|
|
|
|
assertNotNull(tg);
|
|
|
|
val constIn = tg.getVariable("StridedSlice/input");
|
|
assertNotNull(constIn);
|
|
|
|
val arr = tg.getArrForVarName(constIn.getVarName());
|
|
assertEquals(139.5, arr.sumNumber().doubleValue(), 1e-5);
|
|
|
|
|
|
// now converting to FlatBuffer
|
|
val fb = tg.asFlatBuffers(true);
|
|
assertNotNull(fb);
|
|
|
|
val graph = FlatGraph.getRootAsFlatGraph(fb);
|
|
assertEquals(5, graph.variablesLength());
|
|
|
|
val nodeSlice = graph.nodes(0);
|
|
|
|
assertEquals(14, nodeSlice.extraIntegerLength());
|
|
|
|
val begin_mask = nodeSlice.extraInteger(0);
|
|
val ellipsis_mask = nodeSlice.extraInteger(1);
|
|
val end_mask = nodeSlice.extraInteger(2);
|
|
val new_axis_mask = nodeSlice.extraInteger(3);
|
|
val shrink_axis_mask = nodeSlice.extraInteger(4);
|
|
|
|
assertEquals(0, begin_mask);
|
|
assertEquals(0, ellipsis_mask);
|
|
assertEquals(0, end_mask);
|
|
assertEquals(0, new_axis_mask);
|
|
assertEquals(0, shrink_axis_mask);
|
|
|
|
val nodeSum = graph.nodes(1);
|
|
|
|
/* assertEquals("StridedSlice", nodeSlice.name());
|
|
assertEquals("Sum", nodeSum.name());
|
|
*/
|
|
assertEquals(4, nodeSlice.inputPairedLength());
|
|
assertEquals(2, nodeSum.inputPairedLength());
|
|
|
|
// we expect these inputs to be 5:0 and 6:0 respectively
|
|
// where 5 (or 6) is a graph node id
|
|
// and :0 is graph node output index, which is 0 because that's predefined variables
|
|
// P.s. nodeSlice.id() should be equal to 5 :)
|
|
val in0 = nodeSum.inputPaired(0);
|
|
val in1 = nodeSum.inputPaired(1);
|
|
/*
|
|
assertEquals(5, nodeSlice.id());
|
|
assertEquals(7, nodeSum.id());
|
|
|
|
assertEquals(nodeSlice.id(), in0.first());
|
|
assertEquals(5, in0.first());
|
|
|
|
assertEquals(6, in1.first());
|
|
assertEquals(0, in1.second());
|
|
*/
|
|
|
|
// tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/tensor_slice.fb"), ExecutorConfiguration.builder().outputMode(OutputMode.IMPLICIT).build());
|
|
/*
|
|
val executioner = new NativeGraphExecutioner();
|
|
|
|
val exp = Nd4j.create(3, 1).assign(3);
|
|
|
|
val results = executioner.executeGraph(tg, configuration);
|
|
|
|
assertNotNull(results);
|
|
assertEquals(1, results.length);
|
|
assertEquals(73.5f, results[0].getFloat(0), 1e-5f);*/
|
|
}
|
|
|
|
@Test
|
|
@Ignore
|
|
public void testIntermediateTensorArraySimple1() throws Exception {
|
|
Nd4j.create(1);
|
|
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/tensor_array.pb.txt").getInputStream());
|
|
tg.setArrayForVariable("input_matrix",Nd4j.ones(3,2));
|
|
|
|
assertNotNull(tg);
|
|
|
|
val firstSlice = tg.getVariable("strided_slice");
|
|
|
|
|
|
val fb = tg.asFlatBuffers(true);
|
|
assertNotNull(fb);
|
|
|
|
val graph = FlatGraph.getRootAsFlatGraph(fb);
|
|
assertEquals(36, graph.variablesLength());
|
|
|
|
assertTrue(graph.nodesLength() > 1);
|
|
/* assertEquals("strided_slice", graph.nodes(0).name());
|
|
assertEquals("TensorArray", graph.nodes(1).name());
|
|
*/
|
|
// assertEquals(4, graph.nodes(0).inputPairedLength());
|
|
|
|
//tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/tensor_array.fb"));
|
|
}
|
|
|
|
@Test
|
|
@Ignore
|
|
public void testIntermediateTensorArrayLoop1() throws Exception {
|
|
val input = Nd4j.linspace(1, 10, 10, DataType.FLOAT).reshape(5, 2);
|
|
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/tensor_array_loop.pb.txt").getInputStream());
|
|
tg.setArrayForVariable("input_matrix",input);
|
|
assertNotNull(tg);
|
|
|
|
val fb = tg.asFlatBuffers(true);
|
|
assertNotNull(fb);
|
|
|
|
val graph = FlatGraph.getRootAsFlatGraph(fb);
|
|
assertEquals(12, graph.variablesLength());
|
|
|
|
val strided_slice = graph.nodes(0);
|
|
|
|
/* assertEquals("strided_slice", strided_slice.name());
|
|
assertEquals("TensorArray", graph.nodes(1).name());
|
|
*/
|
|
assertEquals(4, strided_slice.inputPairedLength());
|
|
|
|
|
|
// we expect these inputs to be 1:0, 2:0, 3:0 and 4:0 respectively
|
|
// where 1 (or 2/3/4) is a graph node id
|
|
// and :0 is graph node output index, which is 0 because that's predefined variables
|
|
val in0 = strided_slice.inputPaired(0);
|
|
val in1 = strided_slice.inputPaired(1);
|
|
val in2 = strided_slice.inputPaired(2);
|
|
val in3 = strided_slice.inputPaired(3);
|
|
|
|
assertEquals(2, in0.first());
|
|
assertEquals(0, in0.second());
|
|
|
|
assertEquals(3, in1.first());
|
|
assertEquals(0, in1.second());
|
|
|
|
assertEquals(4, in2.first());
|
|
assertEquals(0, in2.second());
|
|
|
|
assertEquals(5, in3.first());
|
|
assertEquals(0, in3.second());
|
|
}
|
|
|
|
|
|
|
|
|
|
@Test
|
|
public void testIntermediateReduction() throws Exception {
|
|
Nd4j.create(1);
|
|
SameDiff tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/reduce_dim.pb.txt").getInputStream());
|
|
SDVariable sumResultVar = tg.getVariable("Sum");
|
|
|
|
/* val func = tg.getFunctionForVertexId(sumResultVar.getVertexId());
|
|
assertEquals(0,func.getDimensions()[0]);
|
|
assertEquals(3,tg.variables().size());
|
|
assertNotNull(sumResultVar);
|
|
assertNotNull(tg.getFunctionForVertexId(sumResultVar.getVertexId()));
|
|
System.out.println(tg.variables());
|
|
|
|
assertNotNull(func.getDimensions());
|
|
assertEquals(0,func.getDimensions()[0]);*/
|
|
|
|
ByteBuffer fb = tg.asFlatBuffers(true);
|
|
assertNotNull(fb);
|
|
|
|
FlatGraph graph = FlatGraph.getRootAsFlatGraph(fb);
|
|
assertEquals(1, graph.nodesLength());
|
|
assertEquals(2, graph.variablesLength());
|
|
|
|
assertEquals("Sum", graph.nodes(0).name());
|
|
|
|
FlatNode nodeSum = graph.nodes(0);
|
|
assertEquals(2, nodeSum.inputPairedLength());
|
|
|
|
|
|
// we expect these inputs to be 1:0 and 2:0 respectively
|
|
// where 1 (or 2) is a graph node id
|
|
// and :0 is graph node output index, which is 0 because that's predefined variables
|
|
val in0 = nodeSum.inputPaired(0);
|
|
val in1 = nodeSum.inputPaired(1);
|
|
|
|
assertEquals(1, in0.first());
|
|
assertEquals(0, in0.second());
|
|
|
|
assertEquals(2, in1.first());
|
|
assertEquals(0, in1.second());
|
|
|
|
System.out.println(tg.summary());
|
|
|
|
int dimensionsLength = nodeSum.dimensionsLength();
|
|
assertEquals(1, dimensionsLength);
|
|
int d = nodeSum.dimensions(0);
|
|
assertEquals(1, d);
|
|
|
|
|
|
//log.info("nodeSum inputs length: {}; inputPaired length: {}",nodeSum.inputLength(), nodeSum.inputPairedLength());
|
|
|
|
tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/reduce_dim.fb"));
|
|
|
|
/*val executioner = new NativeGraphExecutioner();
|
|
|
|
val exp = Nd4j.create(3, 1).assign(3);
|
|
|
|
val results = executioner.executeGraph(tg, configuration);
|
|
|
|
assertNotNull(results);
|
|
assertEquals(1, results.length);
|
|
assertEquals(exp, results[0]);
|
|
*/
|
|
}
|
|
|
|
@Test
|
|
public void testDefaultArgs() {
|
|
val op = new RectifiedLinear();
|
|
|
|
val extras = op.extraArgs();
|
|
assertTrue(extras.length == 1);
|
|
val value = (Double) extras[0];
|
|
|
|
assertEquals(0.0f, value.floatValue(), 1e-5f);
|
|
}
|
|
|
|
@Test
|
|
public void testInferShape() throws IOException {
|
|
/**
|
|
* node {
|
|
name: "input"
|
|
op: "Placeholder"
|
|
attr {
|
|
key: "dtype"
|
|
value {
|
|
type: DT_FLOAT
|
|
}
|
|
}
|
|
attr {
|
|
key: "shape"
|
|
value {
|
|
shape {
|
|
dim {
|
|
size: -1
|
|
}
|
|
dim {
|
|
size: 4
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
node {
|
|
name: "bias"
|
|
op: "Const"
|
|
attr {
|
|
key: "dtype"
|
|
value {
|
|
type: DT_FLOAT
|
|
}
|
|
}
|
|
attr {
|
|
key: "value"
|
|
value {
|
|
tensor {
|
|
dtype: DT_FLOAT
|
|
tensor_shape {
|
|
dim {
|
|
size: 4
|
|
}
|
|
}
|
|
tensor_content: "\000\000\200?\000\000\000@\000\000@@\000\000\200@"
|
|
}
|
|
}
|
|
}
|
|
}
|
|
node {
|
|
name: "bias/read"
|
|
op: "Identity"
|
|
input: "bias"
|
|
attr {
|
|
key: "_class"
|
|
value {
|
|
list {
|
|
s: "loc:@bias"
|
|
}
|
|
}
|
|
}
|
|
attr {
|
|
key: "T"
|
|
value {
|
|
type: DT_FLOAT
|
|
}
|
|
}
|
|
}
|
|
node {
|
|
name: "output"
|
|
op: "BiasAdd"
|
|
input: "input"
|
|
input: "bias/read"
|
|
attr {
|
|
key: "data_format"
|
|
value {
|
|
s: "NHWC"
|
|
}
|
|
}
|
|
attr {
|
|
key: "T"
|
|
value {
|
|
type: DT_FLOAT
|
|
}
|
|
}
|
|
}
|
|
library {
|
|
}
|
|
|
|
*/
|
|
SameDiff graph = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/bias_add/frozen_model.pb").getInputStream());
|
|
assertNotNull(graph);
|
|
|
|
INDArray input = Nd4j.linspace(1,40,40, DataType.FLOAT).reshape(10,4);
|
|
INDArray expectedOutput = Nd4j.linspace(1,40,40, DataType.FLOAT).reshape(10,4).addRowVector(Nd4j.linspace(1,4,4, DataType.FLOAT));
|
|
INDArray actual = graph.execSingle(Collections.singletonMap("input",input), graph.outputs().get(0));
|
|
assertEquals(input,graph.getVariable("input").getArr());
|
|
assertArrayEquals(input.shape(),graph.getShapeForVarName(graph.getVariable("input").getVarName()));
|
|
assertEquals(expectedOutput,actual);
|
|
}
|
|
|
|
|
|
@Test
|
|
public void testImportMapping1() throws Exception {
|
|
Nd4j.create(1);
|
|
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/ae_00/frozen_model.pb").getInputStream());
|
|
|
|
val variables = new HashMap<String, SDVariable>();
|
|
for (val var : tg.variables()) {
|
|
variables.put(var.getVarName(), var);
|
|
}
|
|
|
|
val functions = new HashMap<String, DifferentialFunction>();
|
|
for (val func: tg.functions()) {
|
|
val ownName = func.getOwnName();
|
|
val outName = func.outputVariables()[0].getVarName();
|
|
|
|
assertTrue("Missing ownName: [" + ownName +"]",variables.containsKey(ownName));
|
|
assertEquals(ownName, outName);
|
|
}
|
|
}
|
|
|
|
@Test
|
|
public void testCondMapping1() throws Exception {
|
|
Nd4j.create(1);
|
|
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simpleif_0/frozen_model.pb").getInputStream());
|
|
assertNotNull(tg);
|
|
|
|
tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simpleif_0_1.fb"));
|
|
/*
|
|
//log.info("{}", tg.asFlatPrint());
|
|
val array = tg.execAndEndResult();
|
|
val exp = Nd4j.create(2, 2).assign(-2);
|
|
assertNotNull(array);
|
|
assertEquals(exp, array);*/
|
|
}
|
|
|
|
@Test
|
|
public void testCondMapping2() throws Exception {
|
|
Nd4j.create(1);
|
|
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simpleif_0/frozen_model.pb").getInputStream());
|
|
assertNotNull(tg);
|
|
|
|
val input = Nd4j.create(2, 2).assign(-1);
|
|
tg.associateArrayWithVariable(input, tg.getVariable("input_0"));
|
|
tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simpleif_0.fb"));
|
|
|
|
//log.info("{}", tg.asFlatPrint());
|
|
val array = tg.execAndEndResult();
|
|
val exp = Nd4j.create(2, 2).assign(1);
|
|
assertNotNull(array);
|
|
assertEquals(exp, array);
|
|
}
|
|
|
|
@Test
|
|
public void testWhileMapping1() throws Exception {
|
|
Nd4j.create(1);
|
|
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_0/frozen_model.pb").getInputStream());
|
|
assertNotNull(tg);
|
|
val input = Nd4j.create(2, 2).assign(1);
|
|
tg.associateArrayWithVariable(input, tg.getVariable("input_0"));
|
|
|
|
//tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_0_3.fb"));
|
|
|
|
//log.info("{}", tg.asFlatPrint());
|
|
|
|
|
|
val array = tg.execAndEndResult();
|
|
val exp = Nd4j.create(2, 2).assign(1);
|
|
assertNotNull(array);
|
|
assertEquals(exp, array);
|
|
}
|
|
|
|
@Test
|
|
public void testWhileMapping2() throws Exception {
|
|
Nd4j.create(1);
|
|
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_0/frozen_model.pb").getInputStream());
|
|
assertNotNull(tg);
|
|
val input = Nd4j.trueScalar(4.0);
|
|
tg.associateArrayWithVariable(input, tg.getVariable("input_1"));
|
|
|
|
tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_0_4.fb"));
|
|
|
|
//log.info("{}", tg.asFlatPrint());
|
|
/*
|
|
val array = tg.execAndEndResult();
|
|
val exp = Nd4j.create(2, 2).assign(2);
|
|
assertNotNull(array);
|
|
assertEquals(exp, array);*/
|
|
}
|
|
|
|
@Test
|
|
public void testWhileMapping3() throws Exception {
|
|
Nd4j.create(1);
|
|
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_0/frozen_model.pb").getInputStream());
|
|
assertNotNull(tg);
|
|
val input = Nd4j.trueScalar(9.0);
|
|
tg.associateArrayWithVariable(input, tg.getVariable("input_1"));
|
|
|
|
//tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_0.fb"));
|
|
|
|
//log.info("{}", tg.asFlatPrint());
|
|
|
|
val array = tg.execAndEndResult();
|
|
val exp = Nd4j.create(2, 2).assign(4);
|
|
assertNotNull(array);
|
|
assertEquals(exp, array);
|
|
}
|
|
|
|
|
|
@Test
|
|
public void testWhileDualMapping1() throws Exception {
|
|
Nd4j.create(1);
|
|
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_1/frozen_model.pb").getInputStream());
|
|
assertNotNull(tg);
|
|
val input0 = Nd4j.create(2, 2).assign(-4.0);
|
|
val input1 = Nd4j.trueScalar(1.0);
|
|
tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));
|
|
tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));
|
|
|
|
//tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_1.fb"));
|
|
|
|
//log.info("{}", tg.asFlatPrint());
|
|
|
|
val array = tg.execAndEndResult();
|
|
val exp = Nd4j.create(2, 2).assign(-1);
|
|
assertNotNull(array);
|
|
assertEquals(exp, array);
|
|
}
|
|
|
|
@Test
|
|
public void testWhileDualMapping2() throws Exception {
|
|
Nd4j.create(1);
|
|
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_1/frozen_model.pb").getInputStream());
|
|
assertNotNull(tg);
|
|
val input0 = Nd4j.create(2, 2).assign(-9.0);
|
|
val input1 = Nd4j.trueScalar(1.0);
|
|
tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));
|
|
tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));
|
|
|
|
//tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_1.fb"));
|
|
|
|
//log.info("{}", tg.asFlatPrint());
|
|
|
|
val array = tg.execAndEndResult();
|
|
val exp = Nd4j.create(2, 2).assign(-3);
|
|
assertNotNull(array);
|
|
assertEquals(exp, array);
|
|
}
|
|
|
|
|
|
@Test
|
|
public void testMixedWhileCond1() throws Exception {
|
|
Nd4j.create(1);
|
|
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_nested/frozen_model.pb").getInputStream());
|
|
assertNotNull(tg);
|
|
val input0 = Nd4j.create(2, 2).assign(1.0);
|
|
val input1 = Nd4j.create(3, 3).assign(2.0);
|
|
tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));
|
|
tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));
|
|
|
|
//tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_nested.fb"));
|
|
|
|
|
|
//log.info("{}", tg.asFlatPrint());
|
|
|
|
val array = tg.execAndEndResult();
|
|
//val array = tg.getVariable("output").getArr();
|
|
val exp = Nd4j.create(2, 2).assign(15.0);
|
|
assertNotNull(array);
|
|
assertEquals(exp, array);
|
|
}
|
|
|
|
@Test
|
|
@Ignore
|
|
public void testProfConv() throws Exception {
|
|
Nd4j.create(1);
|
|
val tg = TFGraphMapper.getInstance().importGraph(new File("/home/raver119/develop/workspace/models/profiling_conv.pb.txt"));
|
|
assertNotNull(tg);
|
|
|
|
tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/profiling_conv.fb"));
|
|
}
|
|
|
|
@Test
|
|
@Ignore
|
|
public void testCrash_119_matrix_diag() throws Exception {
|
|
Nd4j.create(1);
|
|
|
|
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/partition_stitch_misc/frozen_model.pb").getInputStream());
|
|
assertNotNull(tg);
|
|
|
|
val input0 = Nd4j.create(2, 5, 4).assign(1.0);
|
|
val input1 = Nd4j.create(2, 3, 5, 4).assign(2.0);
|
|
val input2 = Nd4j.create(3, 1, 5, 4).assign(3.0);
|
|
tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));
|
|
tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));
|
|
tg.associateArrayWithVariable(input2, tg.getVariable("input_2"));
|
|
|
|
|
|
tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/partition_stitch_misc.fb"));
|
|
}
|
|
|
|
@Test
|
|
@Ignore
|
|
public void testCrash_119_tensor_dot_misc() throws Exception {
|
|
Nd4j.create(1);
|
|
|
|
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/tensor_dot_misc/frozen_model.pb").getInputStream());
|
|
assertNotNull(tg);
|
|
|
|
val input0 = Nd4j.create(36, 3, 4, 5).assign(1.0);
|
|
val input1 = Nd4j.create(5, 5, 3, 4).assign(2.0);
|
|
|
|
tg.associateArrayWithVariable(input0, tg.getVariable("input_a"));
|
|
tg.associateArrayWithVariable(input1, tg.getVariable("input_b"));
|
|
|
|
tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/tensor_dot_misc.fb"));
|
|
}
|
|
|
|
@Test
|
|
@Ignore
|
|
public void testCrash_119_transpose() throws Exception {
|
|
Nd4j.create(1);
|
|
|
|
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/transpose/frozen_model.pb").getInputStream());
|
|
assertNotNull(tg);
|
|
|
|
val input0 = Nd4j.create(new double[]{0.98114507, 0.96400015, 0.58669623, 0.60073098, 0.75425418, 0.44258752, 0.76373084, 0.96593234, 0.34067846}, new int[] {3, 3});
|
|
val input1 = Nd4j.create(new double[]{0.98114507, 0.60073098, 0.76373084, 0.96400015, 0.75425418, 0.96593234, 0.58669623, 0.44258752, 0.34067846}, new int[] {3, 3});
|
|
|
|
tg.associateArrayWithVariable(input0, tg.getVariable("input"));
|
|
tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));
|
|
|
|
tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/transpose.fb"));
|
|
}
|
|
|
|
@Test
|
|
//@Ignore
|
|
public void testCrash_119_simpleif_0() throws Exception {
|
|
Nd4j.create(1);
|
|
|
|
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simpleif_0/frozen_model.pb").getInputStream());
|
|
assertNotNull(tg);
|
|
|
|
val input0 = Nd4j.create(new float[] {1, 2, 3, 4}, new int[] {2, 2});
|
|
val input1 = Nd4j.trueScalar(11f);
|
|
|
|
tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));
|
|
tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));
|
|
|
|
tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simpleif_0.fb"));
|
|
}
|
|
|
|
@Test
|
|
@Ignore
|
|
public void testCrash_119_ae_00() throws Exception {
|
|
Nd4j.create(1);
|
|
|
|
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/ae_00/frozen_model.pb").getInputStream());
|
|
assertNotNull(tg);
|
|
|
|
val input0 = Nd4j.create(new double[] {0.98174960, 0.44406342, 0.50100771, 1.00000000, -0.94038386, 0.46501783, -0.49040590, 0.98153842, -0.00198260, 0.49108310, -0.06085236, 0.93523693, -0.05857396, -0.46633510, -0.02806635, -0.96879626, -0.03938015, -0.51578135, -0.06333921, -1.00000000}, new int[] {5, 4});
|
|
|
|
tg.associateArrayWithVariable(input0, tg.getVariable("input"));
|
|
|
|
tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/ae_00.fb"));
|
|
}
|
|
|
|
@Test
|
|
@Ignore
|
|
public void testCrash_119_expand_dim() throws Exception {
|
|
Nd4j.create(1);
|
|
|
|
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/expand_dim/frozen_model.pb").getInputStream());
|
|
assertNotNull(tg);
|
|
|
|
val input0 = Nd4j.create(new double[] {0.09753360, 0.76124972, 0.24693797, 0.13813169, 0.33144656, 0.08299957, 0.67197708, 0.80659380, 0.98274191, 0.63566073, 0.21592326, 0.54902743}, new int[] {3, 4});
|
|
|
|
tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));
|
|
|
|
tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/expand_dim.fb"));
|
|
}
|
|
|
|
@Test
|
|
//@Ignore
|
|
public void testCrash_119_reduce_dim_false() throws Exception {
|
|
Nd4j.create(1);
|
|
|
|
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/reduce_dim.pb.txt").getInputStream());
|
|
assertNotNull(tg);
|
|
|
|
|
|
tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/reduce_dim_false.fb"), ExecutorConfiguration.builder().outputMode(OutputMode.IMPLICIT).build(), true);
|
|
}
|
|
|
|
@Test
|
|
//@Ignore
|
|
public void testCrash_119_reduce_dim_true() throws Exception {
|
|
Nd4j.create(1);
|
|
|
|
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/reduce_dim_true.pb.txt").getInputStream());
|
|
assertNotNull(tg);
|
|
|
|
tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/reduce_dim_true.fb"), ExecutorConfiguration.builder().outputMode(OutputMode.IMPLICIT).build(), true);
|
|
}
|
|
|
|
@Test
|
|
public void testTensorArray_119_1() throws Exception {
|
|
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/tensor_array.pb.txt").getInputStream());
|
|
assertNotNull(tg);
|
|
|
|
val input_matrix = Nd4j.ones(3, 2);
|
|
val array = tg.execSingle(Collections.singletonMap("input_matrix", input_matrix), tg.outputs().get(0));
|
|
|
|
val exp = Nd4j.create(new float[] {1, 1, 2, 2, 3, 3}, new int[]{3, 2});
|
|
|
|
assertEquals(exp, array);
|
|
}
|
|
|
|
@Test
|
|
public void testTensorArray_119_2() throws Exception {
|
|
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/tensor_array_read.pb.txt").getInputStream());
|
|
assertNotNull(tg);
|
|
|
|
val input_matrix = Nd4j.ones(3, 2);
|
|
|
|
val array = tg.exec(Collections.singletonMap("input_matrix", input_matrix), tg.outputs().get(0));
|
|
|
|
val exp = Nd4j.create(new float[] {2, 2}, new int[]{2});
|
|
|
|
assertEquals(exp, array);
|
|
}
|
|
|
|
|
|
@Test
|
|
public void testTensorArray_119_3() throws Exception {
|
|
Nd4j.create(1);
|
|
|
|
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/tensor_array_unstack.pb.txt").getInputStream());
|
|
assertNotNull(tg);
|
|
|
|
val array = tg.execSingle(Collections.emptyMap(), tg.outputs().get(0));
|
|
|
|
val exp = Nd4j.create(new float[] {5, 6, 7, 8}, new int[]{4});
|
|
|
|
assertEquals(exp, array);
|
|
}
|
|
|
|
@Test
|
|
public void testTensorArray_119_4() throws Exception {
|
|
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/tensor_array_loop.pb.txt").getInputStream());
|
|
assertNotNull(tg);
|
|
|
|
val input_matrix = Nd4j.linspace(1, 10, 10, DataType.FLOAT).reshape(5, 2);
|
|
log.info("Graph: {}", tg.asFlatPrint());
|
|
val array = tg.execSingle(Collections.singletonMap("input_matrix", input_matrix), tg.outputs().get(0));
|
|
|
|
val exp = Nd4j.create(new float[] {3,6, 9,12, 15,18, 21,24, 27,30}, new int[]{5, 2});
|
|
|
|
assertEquals(exp, array);
|
|
}
|
|
|
|
@Test
|
|
public void testLossImport_1() throws Exception {
|
|
Nd4j.create(1);
|
|
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/losses/log_loss_rank2_axis1_SUM_OVER_BATCH_SIZE/frozen_model.pb").getInputStream());
|
|
|
|
tg.execAndEndResult();
|
|
}
|
|
|
|
@Test
|
|
public void testG_1() throws Exception {
|
|
Nd4j.create(1);
|
|
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/g_08/frozen_model.pb").getInputStream());
|
|
|
|
val g = tg.asFlatBuffers(true);
|
|
}
|
|
|
|
@Test
|
|
public void testBoolImport_1() throws Exception {
|
|
Nd4j.create(1);
|
|
for (int e = 0; e < 1000; e++){
|
|
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/reduce_any/rank0/frozen_model.pb").getInputStream());
|
|
|
|
Map<String,INDArray> result = tg.exec(Collections.emptyMap(), tg.outputs());
|
|
|
|
assertNotNull(result);
|
|
assertTrue(result.size() > 0);
|
|
}
|
|
}
|
|
|
|
@Test
|
|
public void testLogical_1() throws Exception {
|
|
Nd4j.create(1);
|
|
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/transforms/logicalxor_3,4_3,4/frozen_model.pb").getInputStream());
|
|
|
|
tg.execAndEndResult();
|
|
}
|
|
|
|
@Test
|
|
public void testSSD_1() throws Exception {
|
|
// tf_graphs/examples/ssd_inception_v2_coco_2018_01_28/frozen_inference_graph.pb
|
|
Nd4j.create(1);
|
|
|
|
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/ssd_inception_v2_coco_2018_01_28/frozen_inference_graph.pb").getInputStream());
|
|
assertNotNull(tg);
|
|
}
|
|
|
|
@Test(expected = ND4JIllegalStateException.class)
|
|
public void testNonFrozenGraph1() throws Exception {
|
|
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/unfrozen_simple_ae.pb").getInputStream());
|
|
}
|
|
|
|
@Test
|
|
public void testRandomGraph() throws Exception {
|
|
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/assert_equal/scalar_float32/frozen_model.pb").getInputStream());
|
|
assertNotNull(tg);
|
|
|
|
tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/scalar_float32.fb"));
|
|
}
|
|
|
|
@Test
|
|
public void testRandomGraph2() throws Exception {
|
|
val tg = TFGraphMapper.getInstance().importGraph(new File("c:\\develop\\mobilenet_v2_1.0_224_frozen.pb"));
|
|
assertNotNull(tg);
|
|
|
|
tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/mobilenet_v2.fb"));
|
|
}
|
|
|
|
@Test
|
|
@Ignore
|
|
public void testRandomGraph3() throws Exception {
|
|
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/assert_equal/3,4_3,4_float32/frozen_model.pb").getInputStream());
|
|
assertNotNull(tg);
|
|
|
|
log.info("{}", tg.asFlatPrint());
|
|
tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/assertsomething.fb"));
|
|
}
|
|
|
|
|
|
@Test
|
|
public void testControlDependencies1() throws Exception {
|
|
SameDiff sd = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/cond/cond_true/frozen_model.pb").getInputStream());
|
|
|
|
|
|
|
|
/*
|
|
Control dependencies:
|
|
variables:
|
|
- cond/LinSpace/start - depends on cond/switch_t
|
|
- cond/LinSpace/stop - depends on cond/switch_t
|
|
- cond/LinSpace/num - depends on cond/switch_t
|
|
- cond/ones - depends on cond/switch_f
|
|
*/
|
|
|
|
Map<String,Variable> variables = sd.getVariables();
|
|
|
|
assertEquals(variables.get("cond/LinSpace/start").getControlDeps(), Collections.singletonList("cond/switch_t"));
|
|
assertEquals(variables.get("cond/LinSpace/stop"), Collections.singletonList("cond/switch_t"));
|
|
assertEquals(variables.get("cond/LinSpace/num"), Collections.singletonList("cond/switch_t"));
|
|
assertEquals(variables.get("cond/ones"), Collections.singletonList("cond/switch_f"));
|
|
}
|
|
} |