tf.keras model import (#258)

* tf op initial

* ..

* protobuf parsing working

* model build working

* test passing

* headers

* conffix

* service loader + tests

* revert cuda version

* msg

* override

* refacc

* pom

* rem bad import

* dtype fix + const cast caaching

* rem unnecessary fields

* rem println

* rem dep

* refacc

* rem redundant arg

* Ignore TFOpLayer in DTypeTests

Signed-off-by: Alex Black <blacka101@gmail.com>

Co-authored-by: Alex Black <blacka101@gmail.com>
master
Fariz Rahman 2020-03-24 13:37:27 +04:00 committed by GitHub
parent ec6abacdb8
commit b1bc7df160
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 550 additions and 6 deletions

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.nn.dtypes;
import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
import org.deeplearning4j.nn.modelimport.keras.layers.TFOpLayer;
import org.nd4j.shade.guava.collect.ImmutableSet;
import org.nd4j.shade.guava.reflect.ClassPath;
import lombok.extern.slf4j.Slf4j;
@ -128,7 +129,7 @@ public class DTypeTests extends BaseDL4JTest {
throw new RuntimeException(e);
}
if (Modifier.isAbstract(clazz.getModifiers()) || clazz.isInterface()) {
if (Modifier.isAbstract(clazz.getModifiers()) || clazz.isInterface() || TFOpLayer.class == clazz) { //Skip TFOpLayer here - dtype depends on imported model dtype
continue;
}

View File

@ -105,6 +105,14 @@
<version>${project.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-tensorflow</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<profiles>

View File

@ -103,4 +103,6 @@ public class Keras2LayerConfiguration extends KerasLayerConfiguration {
/* Keras weight initializers. */
private final String LAYER_FIELD_INIT = "kernel_initializer";
private final String TENSORFLOW_OP_LAYER = "TensorFlowOpLayer";
}

View File

@ -0,0 +1,74 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.nn.modelimport.keras.layers;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import java.util.Map;
public class KerasTFOpLayer extends KerasLayer {
public KerasTFOpLayer(Integer kerasVersion) throws UnsupportedKerasConfigurationException {
super(kerasVersion);
if (kerasVersion != 2){
throw new UnsupportedKerasConfigurationException("KerasTFOpLayer expects Keras version 2");
}
}
/**
* Constructor from parsed Keras layer configuration dictionary.
*
* @param layerConfig dictionary containing Keras layer configuration
* @throws InvalidKerasConfigurationException Invalid Keras config
* @throws UnsupportedKerasConfigurationException Unsupported Keras config
*/
public KerasTFOpLayer(Map<String, Object> layerConfig)
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
this(layerConfig, true);
}
/**
* Constructor from parsed Keras layer configuration dictionary.
*
* @param layerConfig dictionary containing Keras layer configuration
* @param enforceTrainingConfig whether to enforce training-related configuration options
* @throws InvalidKerasConfigurationException Invalid Keras config
* @throws UnsupportedKerasConfigurationException Unsupported Keras config
*/
public KerasTFOpLayer(Map<String, Object> layerConfig, boolean enforceTrainingConfig) throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException{
super(layerConfig, enforceTrainingConfig);
this.layer = new TFOpLayer((Map)((Map)layerConfig.get("config")).get("node_def"), (Map)((Map)layerConfig.get("config")).get("constants"));
}
/**
* Get layer output type.
*
* @param inputType Array of InputTypes
* @return output type as InputType
* @throws InvalidKerasConfigurationException Invalid Keras configuration
*/
public InputType getOutputType(InputType... inputType){
return this.layer.getOutputType(0, inputType[0]);
}
}

View File

@ -0,0 +1,106 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.nn.modelimport.keras.layers;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.modelimport.keras.layers.TFOpLayerImpl;
import org.deeplearning4j.nn.params.EmptyParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.regularization.Regularization;
import java.util.Collection;
import java.util.List;
import java.util.Map;
public class TFOpLayer extends Layer {
private Map nodeDef;
private Map constants;
public TFOpLayer(Map nodeDef, Map constants){
super();
this.nodeDef = nodeDef;
this.constants = constants;
}
@Override
public ParamInitializer initializer() {
return EmptyParamInitializer.getInstance();
}
@Override
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
return null;
}
@Override
public boolean isPretrainParam(String param){
return false;
}
@Override
public InputType getOutputType(int idx, InputType inputType){
long[] shape = inputType.getShape(true);
TFOpLayerImpl tempLayer = new TFOpLayerImpl(nodeDef, constants, null, null);
long[] outputShape = tempLayer.getOutputShape(shape);
return InputType.inferInputType(Nd4j.create(outputShape));
}
@Override
public void setNIn(InputType inputType, boolean override){}
@Override
public GradientNormalization getGradientNormalization(){return null;}
@Override
public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf,
Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView,
boolean initializeParams, DataType networkDataType) {
TFOpLayerImpl tfOpLayerImpl = new TFOpLayerImpl(nodeDef, constants, conf, networkDataType);
tfOpLayerImpl.setListeners(trainingListeners);
tfOpLayerImpl.setIndex(layerIndex);
return tfOpLayerImpl;
}
@Override
public double getGradientNormalizationThreshold(){return 0.;}
@Override
public List<Regularization> getRegularizationByParam(String paramName){return null;}
@Override
public LayerMemoryReport getMemoryReport(InputType inputType) {
return new LayerMemoryReport(); //TODO
}
}

View File

@ -0,0 +1,169 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.nn.modelimport.keras.layers;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.AbstractLayer;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.TFGraphRunnerService;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
import com.google.gson.Gson;
import org.nd4j.shade.protobuf.Message;
import org.nd4j.shade.protobuf.TextFormat;
import java.util.*;
import java.util.List;
@Slf4j
@Data
public class TFOpLayerImpl extends AbstractLayer<TFOpLayer> {
private Map nodeDef;
private Map constants;
private List<String> inputNames;
TFGraphRunnerService graphRunnerService;
public TFOpLayerImpl(Map nodeDef, Map constants, NeuralNetConfiguration conf, DataType dtype){
super(conf, dtype);
this.nodeDef = nodeDef;
this.constants = constants;
setGraphRunner();
}
@Override
public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr){
throw new RuntimeException("Backprop through TFOpLayerImpl is not supported yet." +
" TFOpLayerImpl is created when importing TensorFlow 2.0 Keras models " +
"(tf.keras) into DL4J, that contains TensorFlow operations not just Keras layers.");
}
/**
* Converts a Map representation of Nodedef to a singleton TF Graph and instantiates a GraphRunner.
*/
private void setGraphRunner() {
try{
String json = new Gson().toJson(nodeDef);
NodeDef.Builder builder = NodeDef.newBuilder();
org.nd4j.shade.protobuf.util.JsonFormat.parser().merge(json, builder);
NodeDef nodeDef = builder.build();
List<String> allInputNames = new ArrayList<>(); // including constants
Map<String, String> inputDataTypes = new HashMap<>();
Map<String, INDArray> constArrays = new HashMap();
this.inputNames = new ArrayList<>();
List<String> outputNames = Arrays.asList(nodeDef.getName());
Map<String, AttrValue> attrMap = nodeDef.getAttrMap();
for (int i = 0; i < nodeDef.getInputCount(); i++){
String inputName = nodeDef.getInput(i);
String[] split = inputName.split("/");
String attrKey;
if (split.length == 1){
attrKey = "T";
}
else{
attrKey = "T" + split[split.length - 1];
}
allInputNames.add(nodeDef.getInput(i));
inputDataTypes.put(nodeDef.getInput(i), attrMap.get(attrKey).getType().toString());
if (constants.containsKey(String.valueOf(i))){
constArrays.put(nodeDef.getInput(i), Nd4j.create((List<Number>)constants.get(String.valueOf(i))));
}
else{
this.inputNames.add(nodeDef.getInput(i));
}
}
String graph = "node{\n" + nodeDef.toString() + "\n}\nversions {\n producer: 22\n}";
for (int i = 0; i < allInputNames.size(); i++){
String inpName = allInputNames.get(i);
String dtype = inputDataTypes.get(inpName);
graph = "node{\nname: \"" + inpName + "\"\nop: \"Placeholder\"\nattr{\nkey: \"dtype\"\n value {\n type: " + dtype + "}\n}\n}\n" + graph;
}
log.info(graph);
GraphDef.Builder graphDefBuilder = GraphDef.newBuilder();
TextFormat.getParser().merge(graph, graphDefBuilder);
GraphDef graphDef = graphDefBuilder.build();
org.nd4j.shade.protobuf.ByteString serialized = graphDef.toByteString();
byte[] graphBytes = serialized.toByteArray();
ServiceLoader<TFGraphRunnerService> sl = ServiceLoader.load(TFGraphRunnerService.class);
Iterator<TFGraphRunnerService> iter = sl.iterator();
if (!iter.hasNext()){
throw new RuntimeException("The model contains a Tensorflow Op, which requires the nd4j-tensorflow dependency to execute.");
}
this.graphRunnerService = iter.next().init(allInputNames, outputNames, graphBytes, constArrays, inputDataTypes);
}
catch (Exception e){
throw new RuntimeException("Error parsing protobuf", e);
}
}
private INDArray runGraph(INDArray input){
if (input.rank() == 3){
// TODO make this a preprocessor
input = input.permute(0, 2, 1);
}
Map<String, INDArray> inputMap = new HashMap<>();
inputMap.put(inputNames.get(0), input);
INDArray out = graphRunnerService.run(inputMap).values().toArray(new INDArray[0])[0];
if (out.rank() == 3){
out = out.permute(0, 2, 1); // TODO post-processing?
}
return out;
}
public long[] getOutputShape(long[] inputShape){
long[] shape = ArrayUtils.clone(inputShape);
for(int i = 0; i < shape.length; i++){
if (shape[i] < 0){
shape[i] = 1;
}
}
INDArray dummyArr = Nd4j.zeros(shape);
return runGraph(dummyArr).shape();
}
@Override
public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr){
return runGraph(input);
}
@Override
public boolean isPretrainLayer(){
return false;
}
@Override
public void clearNoiseWeightParams(){
}
}

View File

@ -21,10 +21,12 @@ import org.deeplearning4j.nn.conf.graph.ElementWiseVertex;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaLayer;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.layers.KerasInput;
import org.deeplearning4j.nn.modelimport.keras.layers.KerasTFOpLayer;
import org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations.*;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.*;
import org.deeplearning4j.nn.modelimport.keras.layers.core.*;
@ -317,6 +319,11 @@ public class KerasLayerUtils {
layer = new KerasELU(layerConfig, enforceTrainingConfig);
} else if(layerClassName.equals(conf.getLAYER_CLASS_NAME_SOFTMAX())){
layer = new KerasSoftmax(layerConfig, enforceTrainingConfig);
} else if (conf instanceof Keras2LayerConfiguration){
Keras2LayerConfiguration k2conf = (Keras2LayerConfiguration)conf;
if (layerClassName.equals(k2conf.getTENSORFLOW_OP_LAYER())){
layer = new KerasTFOpLayer(layerConfig, enforceTrainingConfig);
}
}
if (layer == null){
Class<? extends KerasLayer> customConfig = customLayers.get(layerClassName);
@ -402,6 +409,16 @@ public class KerasLayerUtils {
public static String getLayerNameFromConfig(Map<String, Object> layerConfig,
KerasLayerConfiguration conf)
throws InvalidKerasConfigurationException {
if(conf instanceof Keras2LayerConfiguration){
Keras2LayerConfiguration k2conf = (Keras2LayerConfiguration)conf;
if (getClassNameFromConfig(layerConfig, conf).equals(((Keras2LayerConfiguration) conf).getTENSORFLOW_OP_LAYER())){
if (!layerConfig.containsKey(conf.getLAYER_FIELD_NAME()))
throw new InvalidKerasConfigurationException("Field " + conf.getLAYER_FIELD_NAME()
+ " missing from layer config");
return (String) layerConfig.get(conf.getLAYER_FIELD_NAME());
}
}
Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
if (!innerConfig.containsKey(conf.getLAYER_FIELD_NAME()))
throw new InvalidKerasConfigurationException("Field " + conf.getLAYER_FIELD_NAME()

View File

@ -0,0 +1,50 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.nn.modelimport.keras;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.junit.Assert;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.resources.Resources;
import java.io.File;
import java.util.Arrays;
public class TFKerasTests extends BaseDL4JTest{
@Test
public void testModelWithTFOp1() throws Exception{
File f = Resources.asFile("modelimport/keras/tfkeras/reshape.h5");
ComputationGraph graph = KerasModelImport.importKerasModelAndWeights(f.getAbsolutePath());
INDArray out = graph.outputSingle(Nd4j.zeros(12, 2, 3));
Assert.assertArrayEquals(new long[]{12, 3}, out.shape());
}
@Test
public void testModelWithTFOp2() throws Exception{
File f = Resources.asFile("modelimport/keras/tfkeras/permute.h5");
ComputationGraph graph = KerasModelImport.importKerasModelAndWeights(f.getAbsolutePath());
INDArray out = graph.outputSingle(Nd4j.zeros(12, 2, 3));
// dl4j's feedforward doesn't support 3D output, so batch and time axes gets squashed
long[] expectedShape = new long[]{12 * 2, 5};
Assert.assertArrayEquals(expectedShape, out.shape());
}
}

View File

@ -77,7 +77,11 @@
<artifactId>nd4j-common</artifactId>
<version>${nd4j.version}</version>
</dependency>
<dependency>
<groupId>com.google.code.gson</groupId>
<artifactId>gson</artifactId>
<version>${gson.version}</version>
</dependency>
<!-- ND4J Shaded Jackson Dependency -->
<dependency>
<groupId>org.nd4j</groupId>

View File

@ -62,6 +62,7 @@ public abstract class AbstractLayer<LayerConfT extends org.deeplearning4j.nn.con
public AbstractLayer(NeuralNetConfiguration conf, DataType dataType) {
this.conf = conf;
if (conf != null)
cacheMode = conf.getCacheMode();
this.dataType = dataType;
}

View File

@ -0,0 +1,37 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.List;
import java.util.Map;
public interface TFGraphRunnerService{
TFGraphRunnerService init(
List<String> inputNames,
List<String> outputNames,
byte[] graphBytes,
Map<String, INDArray> constants,
Map<String, String> inputDataTypes
);
Map<String,INDArray> run(Map<String,INDArray> inputs);
}

View File

@ -16,18 +16,16 @@
package org.nd4j.tensorflow.conversion.graphrunner;
import lombok.Builder;
import lombok.Singular;
import lombok.*;
import org.apache.commons.io.FileUtils;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.shade.protobuf.ByteString;
import org.nd4j.shade.protobuf.InvalidProtocolBufferException;
import org.nd4j.shade.protobuf.util.JsonFormat;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.tensorflow.conversion.TensorDataType;
import org.apache.commons.io.IOUtils;
@ -56,6 +54,7 @@ import static org.bytedeco.tensorflow.global.tensorflow.*;
* @author Adam Gibson
*/
@Slf4j
@NoArgsConstructor
public class GraphRunner implements Closeable {
private static boolean isTfWarmedUp = false;
@ -103,6 +102,9 @@ public class GraphRunner implements Closeable {
* @param inputDataTypes the expected input data types
* @param outputDataTypes the expected output data types
*/
@Builder
public GraphRunner(List<String> inputNames,
List<String> outputNames,
@ -440,6 +442,7 @@ public class GraphRunner implements Closeable {
* @return a map of the output names to the
* ndarrays matching each output specified in the graph
*/
public Map<String,INDArray> run(Map<String,INDArray> inputs) {
if (!isTfWarmedUp && !isTfWarmingUp){
isTfWarmingUp = true;
@ -683,4 +686,7 @@ public class GraphRunner implements Closeable {
return builder1.build();
}
}

View File

@ -0,0 +1,52 @@
package org.nd4j.tensorflow.conversion.graphrunner;
import org.nd4j.TFGraphRunnerService;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.tensorflow.conversion.TensorDataType;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class GraphRunnerServiceProvider implements TFGraphRunnerService {
private GraphRunner graphRunner;
Map<String, INDArray> inputs;
@Override
public TFGraphRunnerService init(
List<String> inputNames,
List<String> outputNames,
byte[] graphBytes,
Map<String, INDArray> constants,
Map<String, String> inputDataTypes){
if (inputNames.size() != inputDataTypes.size()){
throw new IllegalArgumentException("inputNames.size() != inputDataTypes.size()");
}
Map<String, TensorDataType> convertedDataTypes = new HashMap<>();
for (int i = 0; i < inputNames.size(); i++){
convertedDataTypes.put(inputNames.get(i), TensorDataType.fromProtoValue(inputDataTypes.get(inputNames.get(i))));
}
Map<String, INDArray> castConstants = new HashMap<>();
for (Map.Entry<String, INDArray> e: constants.entrySet()) {
DataType requiredDtype = TensorDataType.toNd4jType(TensorDataType.fromProtoValue(inputDataTypes.get(e.getKey())));
castConstants.put(e.getKey(), e.getValue().castTo(requiredDtype));
}
this.inputs = castConstants;
graphRunner = GraphRunner.builder().inputNames(inputNames)
.outputNames(outputNames).graphBytes(graphBytes)
.inputDataTypes(convertedDataTypes).build();
return this;
}
@Override
public Map<String, INDArray> run(Map<String, INDArray> inputs){
if (graphRunner == null){
throw new RuntimeException("GraphRunner not initialized.");
}
this.inputs.putAll(inputs);
return graphRunner.run(this.inputs);
}
}

View File

@ -0,0 +1,17 @@
################################################################################
# Copyright (c) 2020 Konduit K.K..
#
# This program and the accompanying materials are made available under the
# terms of the Apache License, Version 2.0 which is available at
# https://www.apache.org/licenses/LICENSE-2.0.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
# SPDX-License-Identifier: Apache-2.0
################################################################################
org.nd4j.tensorflow.conversion.graphrunner.GraphRunnerServiceProvider