tf.keras model import (#258)
* tf op initial * .. * protobuf parsing working * model build working * test passing * headers * conffix * service loader + tests * revert cuda version * msg * override * refacc * pom * rem bad import * dtype fix + const cast caaching * rem unnecessary fields * rem println * rem dep * refacc * rem redundant arg * Ignore TFOpLayer in DTypeTests Signed-off-by: Alex Black <blacka101@gmail.com> Co-authored-by: Alex Black <blacka101@gmail.com>master
parent
ec6abacdb8
commit
b1bc7df160
|
@ -17,6 +17,7 @@
|
|||
package org.deeplearning4j.nn.dtypes;
|
||||
|
||||
import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
|
||||
import org.deeplearning4j.nn.modelimport.keras.layers.TFOpLayer;
|
||||
import org.nd4j.shade.guava.collect.ImmutableSet;
|
||||
import org.nd4j.shade.guava.reflect.ClassPath;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
@ -128,7 +129,7 @@ public class DTypeTests extends BaseDL4JTest {
|
|||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
if (Modifier.isAbstract(clazz.getModifiers()) || clazz.isInterface()) {
|
||||
if (Modifier.isAbstract(clazz.getModifiers()) || clazz.isInterface() || TFOpLayer.class == clazz) { //Skip TFOpLayer here - dtype depends on imported model dtype
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
|
@ -105,6 +105,14 @@
|
|||
<version>${project.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-tensorflow</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
<profiles>
|
||||
|
|
|
@ -103,4 +103,6 @@ public class Keras2LayerConfiguration extends KerasLayerConfiguration {
|
|||
|
||||
/* Keras weight initializers. */
|
||||
private final String LAYER_FIELD_INIT = "kernel_initializer";
|
||||
|
||||
private final String TENSORFLOW_OP_LAYER = "TensorFlowOpLayer";
|
||||
}
|
|
@ -0,0 +1,74 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.nn.modelimport.keras.layers;
|
||||
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
|
||||
public class KerasTFOpLayer extends KerasLayer {
|
||||
|
||||
public KerasTFOpLayer(Integer kerasVersion) throws UnsupportedKerasConfigurationException {
|
||||
super(kerasVersion);
|
||||
if (kerasVersion != 2){
|
||||
throw new UnsupportedKerasConfigurationException("KerasTFOpLayer expects Keras version 2");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructor from parsed Keras layer configuration dictionary.
|
||||
*
|
||||
* @param layerConfig dictionary containing Keras layer configuration
|
||||
* @throws InvalidKerasConfigurationException Invalid Keras config
|
||||
* @throws UnsupportedKerasConfigurationException Unsupported Keras config
|
||||
*/
|
||||
public KerasTFOpLayer(Map<String, Object> layerConfig)
|
||||
throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
|
||||
this(layerConfig, true);
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructor from parsed Keras layer configuration dictionary.
|
||||
*
|
||||
* @param layerConfig dictionary containing Keras layer configuration
|
||||
* @param enforceTrainingConfig whether to enforce training-related configuration options
|
||||
* @throws InvalidKerasConfigurationException Invalid Keras config
|
||||
* @throws UnsupportedKerasConfigurationException Unsupported Keras config
|
||||
*/
|
||||
public KerasTFOpLayer(Map<String, Object> layerConfig, boolean enforceTrainingConfig) throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException{
|
||||
super(layerConfig, enforceTrainingConfig);
|
||||
this.layer = new TFOpLayer((Map)((Map)layerConfig.get("config")).get("node_def"), (Map)((Map)layerConfig.get("config")).get("constants"));
|
||||
}
|
||||
|
||||
/**
|
||||
* Get layer output type.
|
||||
*
|
||||
* @param inputType Array of InputTypes
|
||||
* @return output type as InputType
|
||||
* @throws InvalidKerasConfigurationException Invalid Keras configuration
|
||||
*/
|
||||
public InputType getOutputType(InputType... inputType){
|
||||
return this.layer.getOutputType(0, inputType[0]);
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
|
@ -0,0 +1,106 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.nn.modelimport.keras.layers;
|
||||
|
||||
import org.deeplearning4j.nn.api.ParamInitializer;
|
||||
import org.deeplearning4j.nn.conf.GradientNormalization;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.Layer;
|
||||
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
|
||||
import org.deeplearning4j.nn.modelimport.keras.layers.TFOpLayerImpl;
|
||||
import org.deeplearning4j.nn.params.EmptyParamInitializer;
|
||||
import org.deeplearning4j.optimize.api.TrainingListener;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.learning.regularization.Regularization;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
|
||||
public class TFOpLayer extends Layer {
|
||||
|
||||
private Map nodeDef;
|
||||
private Map constants;
|
||||
public TFOpLayer(Map nodeDef, Map constants){
|
||||
super();
|
||||
this.nodeDef = nodeDef;
|
||||
this.constants = constants;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ParamInitializer initializer() {
|
||||
return EmptyParamInitializer.getInstance();
|
||||
}
|
||||
@Override
|
||||
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isPretrainParam(String param){
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public InputType getOutputType(int idx, InputType inputType){
|
||||
long[] shape = inputType.getShape(true);
|
||||
TFOpLayerImpl tempLayer = new TFOpLayerImpl(nodeDef, constants, null, null);
|
||||
long[] outputShape = tempLayer.getOutputShape(shape);
|
||||
return InputType.inferInputType(Nd4j.create(outputShape));
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setNIn(InputType inputType, boolean override){}
|
||||
|
||||
|
||||
@Override
|
||||
public GradientNormalization getGradientNormalization(){return null;}
|
||||
|
||||
|
||||
@Override
|
||||
public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf,
|
||||
Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView,
|
||||
boolean initializeParams, DataType networkDataType) {
|
||||
|
||||
TFOpLayerImpl tfOpLayerImpl = new TFOpLayerImpl(nodeDef, constants, conf, networkDataType);
|
||||
tfOpLayerImpl.setListeners(trainingListeners);
|
||||
tfOpLayerImpl.setIndex(layerIndex);
|
||||
return tfOpLayerImpl;
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getGradientNormalizationThreshold(){return 0.;}
|
||||
|
||||
@Override
|
||||
public List<Regularization> getRegularizationByParam(String paramName){return null;}
|
||||
|
||||
@Override
|
||||
public LayerMemoryReport getMemoryReport(InputType inputType) {
|
||||
return new LayerMemoryReport(); //TODO
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
}
|
|
@ -0,0 +1,169 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.nn.modelimport.keras.layers;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.nn.layers.AbstractLayer;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.nd4j.TFGraphRunnerService;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
import org.tensorflow.framework.NodeDef;
|
||||
import com.google.gson.Gson;
|
||||
import org.nd4j.shade.protobuf.Message;
|
||||
import org.nd4j.shade.protobuf.TextFormat;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.List;
|
||||
|
||||
|
||||
@Slf4j
|
||||
@Data
|
||||
public class TFOpLayerImpl extends AbstractLayer<TFOpLayer> {
|
||||
|
||||
|
||||
private Map nodeDef;
|
||||
private Map constants;
|
||||
private List<String> inputNames;
|
||||
TFGraphRunnerService graphRunnerService;
|
||||
|
||||
public TFOpLayerImpl(Map nodeDef, Map constants, NeuralNetConfiguration conf, DataType dtype){
|
||||
super(conf, dtype);
|
||||
this.nodeDef = nodeDef;
|
||||
this.constants = constants;
|
||||
setGraphRunner();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr){
|
||||
throw new RuntimeException("Backprop through TFOpLayerImpl is not supported yet." +
|
||||
" TFOpLayerImpl is created when importing TensorFlow 2.0 Keras models " +
|
||||
"(tf.keras) into DL4J, that contains TensorFlow operations not just Keras layers.");
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts a Map representation of Nodedef to a singleton TF Graph and instantiates a GraphRunner.
|
||||
*/
|
||||
private void setGraphRunner() {
|
||||
try{
|
||||
String json = new Gson().toJson(nodeDef);
|
||||
NodeDef.Builder builder = NodeDef.newBuilder();
|
||||
org.nd4j.shade.protobuf.util.JsonFormat.parser().merge(json, builder);
|
||||
NodeDef nodeDef = builder.build();
|
||||
List<String> allInputNames = new ArrayList<>(); // including constants
|
||||
Map<String, String> inputDataTypes = new HashMap<>();
|
||||
Map<String, INDArray> constArrays = new HashMap();
|
||||
this.inputNames = new ArrayList<>();
|
||||
List<String> outputNames = Arrays.asList(nodeDef.getName());
|
||||
Map<String, AttrValue> attrMap = nodeDef.getAttrMap();
|
||||
for (int i = 0; i < nodeDef.getInputCount(); i++){
|
||||
String inputName = nodeDef.getInput(i);
|
||||
String[] split = inputName.split("/");
|
||||
String attrKey;
|
||||
if (split.length == 1){
|
||||
attrKey = "T";
|
||||
}
|
||||
else{
|
||||
attrKey = "T" + split[split.length - 1];
|
||||
}
|
||||
allInputNames.add(nodeDef.getInput(i));
|
||||
inputDataTypes.put(nodeDef.getInput(i), attrMap.get(attrKey).getType().toString());
|
||||
if (constants.containsKey(String.valueOf(i))){
|
||||
constArrays.put(nodeDef.getInput(i), Nd4j.create((List<Number>)constants.get(String.valueOf(i))));
|
||||
}
|
||||
else{
|
||||
this.inputNames.add(nodeDef.getInput(i));
|
||||
}
|
||||
}
|
||||
String graph = "node{\n" + nodeDef.toString() + "\n}\nversions {\n producer: 22\n}";
|
||||
for (int i = 0; i < allInputNames.size(); i++){
|
||||
String inpName = allInputNames.get(i);
|
||||
String dtype = inputDataTypes.get(inpName);
|
||||
graph = "node{\nname: \"" + inpName + "\"\nop: \"Placeholder\"\nattr{\nkey: \"dtype\"\n value {\n type: " + dtype + "}\n}\n}\n" + graph;
|
||||
}
|
||||
log.info(graph);
|
||||
GraphDef.Builder graphDefBuilder = GraphDef.newBuilder();
|
||||
TextFormat.getParser().merge(graph, graphDefBuilder);
|
||||
GraphDef graphDef = graphDefBuilder.build();
|
||||
org.nd4j.shade.protobuf.ByteString serialized = graphDef.toByteString();
|
||||
byte[] graphBytes = serialized.toByteArray();
|
||||
|
||||
ServiceLoader<TFGraphRunnerService> sl = ServiceLoader.load(TFGraphRunnerService.class);
|
||||
Iterator<TFGraphRunnerService> iter = sl.iterator();
|
||||
if (!iter.hasNext()){
|
||||
throw new RuntimeException("The model contains a Tensorflow Op, which requires the nd4j-tensorflow dependency to execute.");
|
||||
}
|
||||
|
||||
this.graphRunnerService = iter.next().init(allInputNames, outputNames, graphBytes, constArrays, inputDataTypes);
|
||||
}
|
||||
catch (Exception e){
|
||||
throw new RuntimeException("Error parsing protobuf", e);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private INDArray runGraph(INDArray input){
|
||||
if (input.rank() == 3){
|
||||
// TODO make this a preprocessor
|
||||
input = input.permute(0, 2, 1);
|
||||
}
|
||||
Map<String, INDArray> inputMap = new HashMap<>();
|
||||
inputMap.put(inputNames.get(0), input);
|
||||
INDArray out = graphRunnerService.run(inputMap).values().toArray(new INDArray[0])[0];
|
||||
if (out.rank() == 3){
|
||||
out = out.permute(0, 2, 1); // TODO post-processing?
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
public long[] getOutputShape(long[] inputShape){
|
||||
long[] shape = ArrayUtils.clone(inputShape);
|
||||
for(int i = 0; i < shape.length; i++){
|
||||
if (shape[i] < 0){
|
||||
shape[i] = 1;
|
||||
}
|
||||
}
|
||||
INDArray dummyArr = Nd4j.zeros(shape);
|
||||
return runGraph(dummyArr).shape();
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr){
|
||||
return runGraph(input);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public boolean isPretrainLayer(){
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void clearNoiseWeightParams(){
|
||||
|
||||
}
|
||||
|
||||
}
|
|
@ -21,10 +21,12 @@ import org.deeplearning4j.nn.conf.graph.ElementWiseVertex;
|
|||
import org.deeplearning4j.nn.conf.layers.Layer;
|
||||
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration;
|
||||
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.layers.KerasInput;
|
||||
import org.deeplearning4j.nn.modelimport.keras.layers.KerasTFOpLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations.*;
|
||||
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.*;
|
||||
import org.deeplearning4j.nn.modelimport.keras.layers.core.*;
|
||||
|
@ -317,6 +319,11 @@ public class KerasLayerUtils {
|
|||
layer = new KerasELU(layerConfig, enforceTrainingConfig);
|
||||
} else if(layerClassName.equals(conf.getLAYER_CLASS_NAME_SOFTMAX())){
|
||||
layer = new KerasSoftmax(layerConfig, enforceTrainingConfig);
|
||||
} else if (conf instanceof Keras2LayerConfiguration){
|
||||
Keras2LayerConfiguration k2conf = (Keras2LayerConfiguration)conf;
|
||||
if (layerClassName.equals(k2conf.getTENSORFLOW_OP_LAYER())){
|
||||
layer = new KerasTFOpLayer(layerConfig, enforceTrainingConfig);
|
||||
}
|
||||
}
|
||||
if (layer == null){
|
||||
Class<? extends KerasLayer> customConfig = customLayers.get(layerClassName);
|
||||
|
@ -402,6 +409,16 @@ public class KerasLayerUtils {
|
|||
public static String getLayerNameFromConfig(Map<String, Object> layerConfig,
|
||||
KerasLayerConfiguration conf)
|
||||
throws InvalidKerasConfigurationException {
|
||||
if(conf instanceof Keras2LayerConfiguration){
|
||||
Keras2LayerConfiguration k2conf = (Keras2LayerConfiguration)conf;
|
||||
if (getClassNameFromConfig(layerConfig, conf).equals(((Keras2LayerConfiguration) conf).getTENSORFLOW_OP_LAYER())){
|
||||
if (!layerConfig.containsKey(conf.getLAYER_FIELD_NAME()))
|
||||
throw new InvalidKerasConfigurationException("Field " + conf.getLAYER_FIELD_NAME()
|
||||
+ " missing from layer config");
|
||||
return (String) layerConfig.get(conf.getLAYER_FIELD_NAME());
|
||||
}
|
||||
}
|
||||
|
||||
Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
|
||||
if (!innerConfig.containsKey(conf.getLAYER_FIELD_NAME()))
|
||||
throw new InvalidKerasConfigurationException("Field " + conf.getLAYER_FIELD_NAME()
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.nn.modelimport.keras;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.resources.Resources;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.Arrays;
|
||||
|
||||
public class TFKerasTests extends BaseDL4JTest{
|
||||
|
||||
@Test
|
||||
public void testModelWithTFOp1() throws Exception{
|
||||
File f = Resources.asFile("modelimport/keras/tfkeras/reshape.h5");
|
||||
ComputationGraph graph = KerasModelImport.importKerasModelAndWeights(f.getAbsolutePath());
|
||||
INDArray out = graph.outputSingle(Nd4j.zeros(12, 2, 3));
|
||||
Assert.assertArrayEquals(new long[]{12, 3}, out.shape());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testModelWithTFOp2() throws Exception{
|
||||
File f = Resources.asFile("modelimport/keras/tfkeras/permute.h5");
|
||||
ComputationGraph graph = KerasModelImport.importKerasModelAndWeights(f.getAbsolutePath());
|
||||
INDArray out = graph.outputSingle(Nd4j.zeros(12, 2, 3));
|
||||
// dl4j's feedforward doesn't support 3D output, so batch and time axes gets squashed
|
||||
long[] expectedShape = new long[]{12 * 2, 5};
|
||||
Assert.assertArrayEquals(expectedShape, out.shape());
|
||||
}
|
||||
|
||||
}
|
|
@ -77,7 +77,11 @@
|
|||
<artifactId>nd4j-common</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.google.code.gson</groupId>
|
||||
<artifactId>gson</artifactId>
|
||||
<version>${gson.version}</version>
|
||||
</dependency>
|
||||
<!-- ND4J Shaded Jackson Dependency -->
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
|
|
|
@ -62,6 +62,7 @@ public abstract class AbstractLayer<LayerConfT extends org.deeplearning4j.nn.con
|
|||
|
||||
public AbstractLayer(NeuralNetConfiguration conf, DataType dataType) {
|
||||
this.conf = conf;
|
||||
if (conf != null)
|
||||
cacheMode = conf.getCacheMode();
|
||||
this.dataType = dataType;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
|
||||
package org.nd4j;
|
||||
|
||||
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public interface TFGraphRunnerService{
|
||||
TFGraphRunnerService init(
|
||||
List<String> inputNames,
|
||||
List<String> outputNames,
|
||||
byte[] graphBytes,
|
||||
Map<String, INDArray> constants,
|
||||
Map<String, String> inputDataTypes
|
||||
);
|
||||
|
||||
Map<String,INDArray> run(Map<String,INDArray> inputs);
|
||||
}
|
|
@ -16,18 +16,16 @@
|
|||
|
||||
package org.nd4j.tensorflow.conversion.graphrunner;
|
||||
|
||||
import lombok.Builder;
|
||||
import lombok.Singular;
|
||||
import lombok.*;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.nd4j.shade.protobuf.ByteString;
|
||||
import org.nd4j.shade.protobuf.InvalidProtocolBufferException;
|
||||
import org.nd4j.shade.protobuf.util.JsonFormat;
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.nd4j.tensorflow.conversion.TensorDataType;
|
||||
import org.apache.commons.io.IOUtils;
|
||||
|
@ -56,6 +54,7 @@ import static org.bytedeco.tensorflow.global.tensorflow.*;
|
|||
* @author Adam Gibson
|
||||
*/
|
||||
@Slf4j
|
||||
@NoArgsConstructor
|
||||
public class GraphRunner implements Closeable {
|
||||
|
||||
private static boolean isTfWarmedUp = false;
|
||||
|
@ -103,6 +102,9 @@ public class GraphRunner implements Closeable {
|
|||
* @param inputDataTypes the expected input data types
|
||||
* @param outputDataTypes the expected output data types
|
||||
*/
|
||||
|
||||
|
||||
|
||||
@Builder
|
||||
public GraphRunner(List<String> inputNames,
|
||||
List<String> outputNames,
|
||||
|
@ -440,6 +442,7 @@ public class GraphRunner implements Closeable {
|
|||
* @return a map of the output names to the
|
||||
* ndarrays matching each output specified in the graph
|
||||
*/
|
||||
|
||||
public Map<String,INDArray> run(Map<String,INDArray> inputs) {
|
||||
if (!isTfWarmedUp && !isTfWarmingUp){
|
||||
isTfWarmingUp = true;
|
||||
|
@ -683,4 +686,7 @@ public class GraphRunner implements Closeable {
|
|||
|
||||
return builder1.build();
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
package org.nd4j.tensorflow.conversion.graphrunner;
|
||||
|
||||
import org.nd4j.TFGraphRunnerService;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.tensorflow.conversion.TensorDataType;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class GraphRunnerServiceProvider implements TFGraphRunnerService {
|
||||
|
||||
private GraphRunner graphRunner;
|
||||
Map<String, INDArray> inputs;
|
||||
|
||||
@Override
|
||||
public TFGraphRunnerService init(
|
||||
List<String> inputNames,
|
||||
List<String> outputNames,
|
||||
byte[] graphBytes,
|
||||
Map<String, INDArray> constants,
|
||||
Map<String, String> inputDataTypes){
|
||||
if (inputNames.size() != inputDataTypes.size()){
|
||||
throw new IllegalArgumentException("inputNames.size() != inputDataTypes.size()");
|
||||
}
|
||||
Map<String, TensorDataType> convertedDataTypes = new HashMap<>();
|
||||
for (int i = 0; i < inputNames.size(); i++){
|
||||
convertedDataTypes.put(inputNames.get(i), TensorDataType.fromProtoValue(inputDataTypes.get(inputNames.get(i))));
|
||||
}
|
||||
Map<String, INDArray> castConstants = new HashMap<>();
|
||||
for (Map.Entry<String, INDArray> e: constants.entrySet()) {
|
||||
DataType requiredDtype = TensorDataType.toNd4jType(TensorDataType.fromProtoValue(inputDataTypes.get(e.getKey())));
|
||||
castConstants.put(e.getKey(), e.getValue().castTo(requiredDtype));
|
||||
}
|
||||
this.inputs = castConstants;
|
||||
graphRunner = GraphRunner.builder().inputNames(inputNames)
|
||||
.outputNames(outputNames).graphBytes(graphBytes)
|
||||
.inputDataTypes(convertedDataTypes).build();
|
||||
return this;
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, INDArray> run(Map<String, INDArray> inputs){
|
||||
if (graphRunner == null){
|
||||
throw new RuntimeException("GraphRunner not initialized.");
|
||||
}
|
||||
this.inputs.putAll(inputs);
|
||||
return graphRunner.run(this.inputs);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,17 @@
|
|||
################################################################################
|
||||
# Copyright (c) 2020 Konduit K.K..
|
||||
#
|
||||
# This program and the accompanying materials are made available under the
|
||||
# terms of the Apache License, Version 2.0 which is available at
|
||||
# https://www.apache.org/licenses/LICENSE-2.0.
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
################################################################################
|
||||
|
||||
org.nd4j.tensorflow.conversion.graphrunner.GraphRunnerServiceProvider
|
Loading…
Reference in New Issue