From 4cb87a94e859a066267fc5c365ed80b51e35486f Mon Sep 17 00:00:00 2001 From: Fariz Rahman Date: Tue, 28 Apr 2020 14:31:09 +0400 Subject: [PATCH] tf.keras import test and fixes (#347) * merge conf * merge conf * tfkeras tests * parameterized tests * rename * cuda versions * jccp versions * 'updates' * updates * rnn+mlp passing * repeat * updates * tests * Update pom.xml * Update pom.xml * rem print * cnn1d model conversion fixed * cnn1d activate fixed * cnn1d outptut shape fix * cnn1d bprop fix * cnn1d stack fix * KerasModelEndToEndTest - Remove permutes for NWC and NHWC format tests Signed-off-by: Alex Black * Fixes and update test - input shapes (NCHW -> NHWC input) Signed-off-by: Alex Black * Ignore for known bad tests Signed-off-by: Alex Black * Multiple fixes - MergeVertex, CNN1D layers, etc Signed-off-by: Alex Black * Fix issue with RNN/FF preprocessors, time distributed etc with NWC format Signed-off-by: Alex Black * LSTM NWC dropout fix Signed-off-by: Alex Black * Add sequence embedding layer NWC support (configurable output format) Signed-off-by: Alex Black * Fix expected shape in a couple of tests - NWC expected Signed-off-by: Alex Black * Fix EmbeddingSequenceLayer backprop for NWC output case + add gradient checks Signed-off-by: Alex Black * CnnToFeedForwardPreprocessor: align with Keras/TF; fix Keras reshape/flatten Signed-off-by: Alex Black * Update ConvDataFormatTests to match new reshape behaviour Signed-off-by: Alex Black * Switch hard-coded path to ResourceUtils.listClassPathfiles for TestTFKerasModelImport Signed-off-by: Alex Black * TestUtils fix Signed-off-by: Alex Black * Fixes Signed-off-by: Alex Black * Fix JSON serde issue with data formats Signed-off-by: Alex Black * Fix for input dtype inference; fix 2 tests Signed-off-by: Alex Black * Test fixes Signed-off-by: Alex Black * #8891 Ignore for TestVertxUIMultiSession until fixed Signed-off-by: Alex Black * Restore but deprecate TensorFlowCnnToFeedForwardPreProcessor for older zoo models Signed-off-by: Alex Black * Ignore for deprecated preprocessor in DTypeTests Signed-off-by: Alex Black * Remove debug printlns Signed-off-by: Alex Black Co-authored-by: Alex Black --- .../org/datavec/python/PythonExecutioner.java | 14 ++ .../org/datavec/python/PythonProcess.java | 132 +++++++++++++ .../java/org/datavec/python/keras/Model.java | 144 +++++++++++++++ .../java/org/deeplearning4j/TestUtils.java | 18 +- .../gradientcheck/GradientCheckTests.java | 131 ++++++------- .../GradientCheckTestsComputationGraph.java | 174 +++++++++--------- .../deeplearning4j/nn/dtypes/DTypeTests.java | 14 +- .../nn/graph/graphnodes/TestGraphNodes.java | 6 +- .../convolution/ConvDataFormatTests.java | 103 ++++++++++- .../nn/layers/recurrent/TestRnnLayers.java | 8 +- .../nn/layers/recurrent/TestSimpleRnn.java | 6 +- .../layers/recurrent/TestTimeDistributed.java | 75 +++++++- .../convolution/ConvDataFormatTests.java | 48 ++++- .../convolution/TestConvolution.java | 2 +- .../deeplearning4j-modelimport/pom.xml | 6 + .../modelimport/keras/layers/KerasInput.java | 21 ++- .../modelimport/keras/layers/TFOpLayer.java | 4 + .../keras/layers/TFOpLayerImpl.java | 8 - .../convolutional/KerasConvolution1D.java | 3 +- .../convolutional/KerasConvolution2D.java | 6 +- .../convolutional/KerasConvolutionUtils.java | 15 +- .../keras/layers/core/KerasFlatten.java | 9 +- .../keras/layers/core/KerasRepeatVector.java | 2 + .../keras/layers/core/KerasReshape.java | 15 +- .../layers/embeddings/KerasEmbedding.java | 2 + .../keras/layers/recurrent/KerasLSTM.java | 2 +- .../layers/recurrent/KerasSimpleRnn.java | 2 +- .../layers/wrappers/KerasBidirectional.java | 2 +- .../preprocessors/ReshapePreprocessor.java | 43 +++-- ...ensorFlowCnnToFeedForwardPreProcessor.java | 15 +- .../keras/utils/KerasLayerUtils.java | 3 + .../nn/modelimport/keras/TFKerasTests.java | 50 ----- .../keras/TestTFKerasModelImport.java | 147 +++++++++++++++ .../keras/configurations/JsonTest.java | 4 +- .../Keras2ModelConfigurationTest.java | 3 +- .../keras/e2e/KerasModelEndToEndTest.java | 81 +++----- .../weights/KerasWeightSettingTests.java | 19 +- .../deeplearning4j/nn/conf/CNN2DFormat.java | 2 +- .../deeplearning4j/nn/conf/DataFormat.java | 26 +++ .../nn/conf/MultiLayerConfiguration.java | 2 +- .../org/deeplearning4j/nn/conf/RNNFormat.java | 2 +- .../nn/conf/graph/MergeVertex.java | 17 +- .../nn/conf/inputs/InputType.java | 17 +- .../nn/conf/layers/BaseRecurrentLayer.java | 4 +- .../nn/conf/layers/Convolution1DLayer.java | 17 +- .../conf/layers/EmbeddingSequenceLayer.java | 12 +- .../nn/conf/layers/FeedForwardLayer.java | 9 +- .../nn/conf/layers/InputTypeUtil.java | 8 +- .../nn/conf/layers/RnnOutputLayer.java | 4 +- .../nn/conf/layers/Subsampling1DLayer.java | 2 +- .../nn/conf/layers/misc/RepeatVector.java | 17 +- .../CnnToFeedForwardPreProcessor.java | 40 ++-- .../FeedForwardToRnnPreProcessor.java | 3 +- .../RnnToFeedForwardPreProcessor.java | 5 +- .../serde/format/DataFormatDeserializer.java | 52 ++++++ .../serde/format/DataFormatSerializer.java | 37 ++++ .../nn/graph/vertex/impl/MergeVertex.java | 42 +++-- .../nn/layers/BaseOutputLayer.java | 3 +- .../nn/layers/RepeatVector.java | 25 ++- .../convolution/Convolution1DLayer.java | 29 ++- .../layers/convolution/ConvolutionLayer.java | 3 +- .../embedding/EmbeddingSequenceLayer.java | 29 ++- .../nn/layers/recurrent/LSTM.java | 16 +- .../layers/recurrent/LastTimeStepLayer.java | 7 +- .../nn/layers/recurrent/RnnOutputLayer.java | 3 +- .../params/ConvolutionParamInitializer.java | 1 + .../ui/TestVertxUIMultiSession.java | 3 +- 67 files changed, 1336 insertions(+), 438 deletions(-) create mode 100644 datavec/datavec-python/src/main/java/org/datavec/python/PythonProcess.java create mode 100644 datavec/datavec-python/src/main/java/org/datavec/python/keras/Model.java delete mode 100644 deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/TFKerasTests.java create mode 100644 deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/TestTFKerasModelImport.java create mode 100644 deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/DataFormat.java create mode 100644 deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/format/DataFormatDeserializer.java create mode 100644 deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/format/DataFormatSerializer.java diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java index 3d08d3141..d0ed42e36 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java @@ -20,6 +20,7 @@ package org.datavec.python; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.IOUtils; +import org.bytedeco.cpython.global.python; import org.bytedeco.numpy.global.numpy; import org.nd4j.linalg.api.concurrency.AffinityManager; import org.nd4j.linalg.api.ndarray.INDArray; @@ -343,6 +344,19 @@ public class PythonExecutioner { if (path == null) { log.info("Setting python default path"); File[] packages = numpy.cachePackages(); + + //// TODO: fix in javacpp + File sitePackagesWindows = new File(python.cachePackage(), "site-packages"); + File[] packages2 = new File[packages.length + 1]; + for (int i = 0;i < packages.length; i++){ + //System.out.println(packages[i].getAbsolutePath()); + packages2[i] = packages[i]; + } + packages2[packages.length] = sitePackagesWindows; + //System.out.println(sitePackagesWindows.getAbsolutePath()); + packages = packages2; + ////////// + Py_SetPath(packages); } else { log.info("Setting python path " + path); diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonProcess.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonProcess.java new file mode 100644 index 000000000..a8ee56510 --- /dev/null +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonProcess.java @@ -0,0 +1,132 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +package org.datavec.python; + +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.io.IOUtils; +import org.bytedeco.javacpp.Loader; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; + +@Slf4j +public class PythonProcess { + private static String pythonExecutable = Loader.load(org.bytedeco.cpython.python.class); + public static String runAndReturn(String... arguments)throws IOException, InterruptedException{ + String[] allArgs = new String[arguments.length + 1]; + for (int i = 0; i < arguments.length; i++){ + allArgs[i + 1] = arguments[i]; + } + allArgs[0] = pythonExecutable; + log.info("Executing command: " + Arrays.toString(allArgs)); + ProcessBuilder pb = new ProcessBuilder(allArgs); + Process process = pb.start(); + String out = IOUtils.toString(process.getInputStream(), StandardCharsets.UTF_8); + process.waitFor(); + return out; + + } + + public static void run(String... arguments)throws IOException, InterruptedException{ + String[] allArgs = new String[arguments.length + 1]; + for (int i = 0; i < arguments.length; i++){ + allArgs[i + 1] = arguments[i]; + } + allArgs[0] = pythonExecutable; + log.info("Executing command: " + Arrays.toString(allArgs)); + ProcessBuilder pb = new ProcessBuilder(allArgs); + pb.inheritIO().start().waitFor(); + } + public static void pipInstall(String packageName) throws PythonException{ + try{ + run("-m", "pip", "install", packageName); + }catch(Exception e){ + throw new PythonException("Error installing package " + packageName, e); + } + + } + + public static void pipInstall(String packageName, String version) throws PythonException{ + pipInstall(packageName + "==" + version); + } + + public static void pipUninstall(String packageName) throws PythonException{ + try{ + run("-m", "pip", "uninstall", packageName); + }catch(Exception e){ + throw new PythonException("Error uninstalling package " + packageName, e); + } + + } + public static void pipInstallFromGit(String gitRepoUrl) throws PythonException{ + if (!gitRepoUrl.contains("://")){ + gitRepoUrl = "git://" + gitRepoUrl; + } + try{ + run("-m", "pip", "install", "git+", gitRepoUrl); + }catch(Exception e){ + throw new PythonException("Error installing package from " + gitRepoUrl, e); + } + + } + + public static String getPackageVersion(String packageName) throws PythonException{ + String out; + try{ + out = runAndReturn("-m", "pip", "show", packageName); + } catch (Exception e){ + throw new PythonException("Error finding version for package " + packageName, e); + } + + if (!out.contains("Version: ")){ + throw new PythonException("Can't find package " + packageName); + } + String pkgVersion = out.split("Version: ")[1].split(System.lineSeparator())[0]; + return pkgVersion; + } + + public static boolean isPackageInstalled(String packageName)throws PythonException{ + try{ + String out = runAndReturn("-m", "pip", "show", packageName); + return !out.isEmpty(); + }catch (Exception e){ + throw new PythonException("Error checking if package is installed: " +packageName, e); + } + + } + + public static void pipInstallFromRequirementsTxt(String path) throws PythonException{ + try{ + run("-m", "pip", "install","-r", path); + }catch (Exception e){ + throw new PythonException("Error installing packages from " + path, e); + } + } + + public static void pipInstallFromSetupScript(String path, boolean inplace) throws PythonException{ + + try{ + run(path, inplace?"develop":"install"); + }catch (Exception e){ + throw new PythonException("Error installing package from " + path, e); + } + + } + +} \ No newline at end of file diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/keras/Model.java b/datavec/datavec-python/src/main/java/org/datavec/python/keras/Model.java new file mode 100644 index 000000000..d8a9b0651 --- /dev/null +++ b/datavec/datavec-python/src/main/java/org/datavec/python/keras/Model.java @@ -0,0 +1,144 @@ +package org.datavec.python.keras; + +import org.datavec.python.Python; +import org.datavec.python.PythonException; +import org.datavec.python.PythonObject; +import org.datavec.python.PythonProcess; +import org.nd4j.linalg.api.ndarray.INDArray; + +public class Model { + + private PythonObject pyModel; + + + private static PythonObject installAndImportTF() throws PythonException{ + if (!PythonProcess.isPackageInstalled("tensorflow")){ + PythonProcess.pipInstall("tensorflow"); + } + return Python.importModule("tensorflow"); + } + private static PythonObject getKerasModule() throws PythonException{ + PythonObject tf = installAndImportTF(); + PythonObject keras = tf.attr("keras"); + tf.del(); + return keras; + } + + private static PythonObject loadModel(String s) throws PythonException{ + PythonObject models = getKerasModule().attr("models"); + PythonObject loadModelF = models.attr("load_model"); + PythonObject model = loadModelF.call(s); + models.del(); + loadModelF.del(); + return model; + } + + public Model(String path) throws PythonException{ + pyModel = loadModel(path); + } + + public INDArray[] predict(INDArray... inputs) throws PythonException{ + PythonObject predictF = pyModel.attr("predict"); + PythonObject inputList = new PythonObject(inputs); + PythonObject pyOut = predictF.call(inputList); + INDArray[] out; + if (Python.isinstance(pyOut, Python.listType())){ + out = new INDArray[Python.len(pyOut).toInt()]; + for(int i = 0; i < out.length; i++){ + out[i] = pyOut.get(i).toNumpy().getNd4jArray(); + } + } + else{ + out = new INDArray[]{ + pyOut.toNumpy().getNd4jArray()}; + } + + predictF.del(); + inputList.del(); + pyOut.del(); + return out; + } + + public int numInputs(){ + PythonObject inputs = pyModel.attr("inputs"); + PythonObject pyNumInputs = Python.len(inputs); + int ret = pyNumInputs.toInt(); + inputs.del(); + pyNumInputs.del(); + return ret; + } + public int numOutputs(){ + PythonObject outputs = pyModel.attr("outputs"); + PythonObject pyNumOutputs = Python.len(outputs); + int ret = pyNumOutputs.toInt(); + outputs.del(); + pyNumOutputs.del(); + return ret; + } + + public long[][] inputShapes(){ + long[][] ret = new long[numInputs()][]; + for (int i = 0; i < ret.length; i++){ + ret[i] = inputShapeAt(i); + } + return ret; + } + + public long[][] outputShapes(){ + long[][] ret = new long[numOutputs()][]; + for (int i = 0; i < ret.length; i++){ + ret[i] = outputShapeAt(i); + } + return ret; + } + + public long[] inputShapeAt(int input){ + PythonObject inputs = pyModel.attr("inputs"); + PythonObject tensor = inputs.get(input); + PythonObject tensorShape = tensor.attr("shape"); + PythonObject shapeList = Python.list(tensorShape); + PythonObject pyNdim = Python.len(shapeList); + int ndim = pyNdim.toInt(); + long[] shape = new long[ndim]; + for(int i = 0; i < shape.length; i++){ + PythonObject pyDim = shapeList.get(i); + if (pyDim == null || !Python.isinstance(pyDim, Python.intType())){ + shape[i] = -1; + } + else{ + shape[i] = pyDim.toLong(); + } + } + pyNdim.del(); + shapeList.del(); + tensorShape.del(); + tensor.del(); + inputs.del(); + return shape; + } + + public long[] outputShapeAt(int output){ + PythonObject inputs = pyModel.attr("outputs"); + PythonObject tensor = inputs.get(output); + PythonObject tensorShape = tensor.attr("shape"); + PythonObject shapeList = Python.list(tensorShape); + PythonObject pyNdim = Python.len(shapeList); + int ndim = pyNdim.toInt(); + long[] shape = new long[ndim]; + for(int i = 0; i < shape.length; i++){ + PythonObject pyDim = shapeList.get(i); + if (pyDim == null || !Python.isinstance(pyDim, Python.intType())){ + shape[i] = -1; + } + else{ + shape[i] = pyDim.toLong(); + } + } + pyNdim.del(); + shapeList.del(); + tensorShape.del(); + tensor.del(); + inputs.del(); + return shape; + } +} diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/TestUtils.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/TestUtils.java index d54693f73..dac854d9d 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/TestUtils.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/TestUtils.java @@ -20,6 +20,7 @@ import org.apache.commons.compress.utils.IOUtils; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.layers.BaseLayer; import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer; import org.deeplearning4j.nn.graph.ComputationGraph; @@ -153,11 +154,22 @@ public class TestUtils { return randomOneHotTimeSeries(minibatch, outSize, tsLength, new Random(rngSeed)); } - public static INDArray randomOneHotTimeSeries(int minibatch, int outSize, int tsLength, Random rng){ - INDArray out = Nd4j.create(new int[]{minibatch, outSize, tsLength}, 'f'); + public static INDArray randomOneHotTimeSeries(int minibatch, int outSize, int tsLength, Random rng) { + return randomOneHotTimeSeries(RNNFormat.NCW, minibatch, outSize, tsLength, rng); + } + + public static INDArray randomOneHotTimeSeries(RNNFormat format, int minibatch, int outSize, int tsLength, Random rng){ + boolean ncw = format == RNNFormat.NCW; + long[] shape = ncw ? new long[]{minibatch, outSize, tsLength} : new long[]{minibatch, tsLength, outSize}; + char order = ncw ? 'f' : 'c'; + INDArray out = Nd4j.create(DataType.FLOAT, shape, order); for( int i=0; i feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) { + return null; + } + } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java index 5b42d95bc..83b312308 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java @@ -44,6 +44,7 @@ import org.nd4j.linalg.primitives.Pair; import java.util.Arrays; import java.util.List; +import java.util.Random; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; @@ -217,7 +218,7 @@ public class TestRnnLayers extends BaseDL4JTest { NeuralNetConfiguration.ListBuilder lb = new NeuralNetConfiguration.Builder() .list() - .layer(new SimpleRnn.Builder().nIn(5).nOut(5).build()); + .layer(new SimpleRnn.Builder().nIn(5).nOut(5).dataFormat(rnnDataFormat).build()); switch (i){ case 0: @@ -235,10 +236,7 @@ public class TestRnnLayers extends BaseDL4JTest { net.init(); INDArray in = Nd4j.rand(DataType.FLOAT, 3, 5, 5); - INDArray l = TestUtils.randomOneHotTimeSeries(3, 5, 10); - if (rnnDataFormat == RNNFormat.NWC){ - l = l.permute(0, 2, 1); - } + INDArray l = TestUtils.randomOneHotTimeSeries(rnnDataFormat, 3, 5, 10, new Random(12345)); try{ net.fit(in,l); } catch (Throwable t){ diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestSimpleRnn.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestSimpleRnn.java index 7d61316d5..639d3fafd 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestSimpleRnn.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestSimpleRnn.java @@ -61,15 +61,13 @@ public class TestSimpleRnn extends BaseDL4JTest { int tsLength = 7; INDArray in; if (rnnDataFormat == RNNFormat.NCW){ - in = Nd4j.rand(DataType.FLOAT, new int[]{m, nIn, tsLength}); + in = Nd4j.rand(DataType.FLOAT, m, nIn, tsLength); } else{ - in = Nd4j.rand(DataType.FLOAT, new int[]{m, tsLength, nIn}); + in = Nd4j.rand(DataType.FLOAT, m, tsLength, nIn); } -// in.get(all(), all(), interval(1,tsLength)).assign(0); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .updater(new NoOp()) .weightInit(WeightInit.XAVIER) diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestTimeDistributed.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestTimeDistributed.java index f56261cc5..0d38d699c 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestTimeDistributed.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestTimeDistributed.java @@ -7,10 +7,14 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.WorkspaceMode; import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.conf.layers.DenseLayer; +import org.deeplearning4j.nn.conf.layers.*; +import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer; import org.deeplearning4j.nn.conf.layers.LSTM; import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; +import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional; +import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed; +import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.junit.Test; import org.junit.runner.RunWith; @@ -106,4 +110,73 @@ public class TestTimeDistributed extends BaseDL4JTest { } } } + + + @Test + public void testTimeDistributedDense(){ + + for( int rnnType=0; rnnType<3; rnnType++ ) { + for( int ffType=0; ffType<3; ffType++ ) { + + Layer l0, l2; + switch (rnnType) { + case 0: + l0 = new LSTM.Builder().nOut(5).build(); + l2 = new LSTM.Builder().nOut(5).build(); + break; + case 1: + l0 = new SimpleRnn.Builder().nOut(5).build(); + l2 = new SimpleRnn.Builder().nOut(5).build(); + break; + case 2: + l0 = new Bidirectional(new LSTM.Builder().nOut(5).build()); + l2 = new Bidirectional(new LSTM.Builder().nOut(5).build()); + break; + default: + throw new RuntimeException("Not implemented: " + rnnType); + } + + Layer l1; + switch (ffType){ + case 0: + l1 = new DenseLayer.Builder().nOut(5).build(); + break; + case 1: + l1 = new VariationalAutoencoder.Builder().nOut(5).encoderLayerSizes(5).decoderLayerSizes(5).build(); + break; + case 2: + l1 = new AutoEncoder.Builder().nOut(5).build(); + break; + default: + throw new RuntimeException("Not implemented: " + ffType); + } + + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .activation(Activation.TANH) + .list() + .layer(l0) + .layer(l1) + .layer(l2) + .setInputType(InputType.recurrent(5, 9, rnnDataFormat)) + .build(); + + BaseRecurrentLayer l0a; + BaseRecurrentLayer l2a; + if (rnnType < 2) { + l0a = (BaseRecurrentLayer) l0; + l2a = (BaseRecurrentLayer) l2; + } else { + l0a = (BaseRecurrentLayer) ((Bidirectional) l0).getFwd(); + l2a = (BaseRecurrentLayer) ((Bidirectional) l2).getFwd(); + } + assertEquals(rnnDataFormat, l0a.getRnnDataFormat()); + assertEquals(rnnDataFormat, l2a.getRnnDataFormat()); + + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + INDArray in = Nd4j.rand(DataType.FLOAT, rnnDataFormat == RNNFormat.NCW ? new long[]{2, 5, 9} : new long[]{2, 9, 5} ); + net.output(in); + } + } + } } diff --git a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/convolution/ConvDataFormatTests.java b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/convolution/ConvDataFormatTests.java index c56994441..8210903ef 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/convolution/ConvDataFormatTests.java +++ b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/convolution/ConvDataFormatTests.java @@ -15,21 +15,24 @@ ******************************************************************************/ package org.deeplearning4j.convolution; -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Data; -import lombok.NoArgsConstructor; +import lombok.*; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.CuDNNTestUtils; import org.deeplearning4j.TestUtils; +import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.ConvolutionMode; +import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D; +import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor; +import org.deeplearning4j.nn.conf.preprocessor.ComposableInputPreProcessor; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.nn.workspace.ArrayType; +import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -816,6 +819,12 @@ public class ConvDataFormatTests extends BaseDL4JTest { .layer(new OutputLayer.Builder().activation(Activation.SOFTMAX).nOut(10).build()) .setInputType(inputType != null ? inputType : InputType.convolutional(12, 12, 3, format)); + if(format == CNN2DFormat.NHWC && !(layer instanceof GlobalPoolingLayer)){ + //Add a preprocessor due to the differences in how NHWC and NCHW activations are flattened + //DL4J's flattening behaviour matches Keras (hence TF) for import compatibility + builder.inputPreProcessor(2, new ComposableInputPreProcessor(new NHWCToNCHWPreprocessor(), new CnnToFeedForwardPreProcessor())); + } + MultiLayerNetwork net = new MultiLayerNetwork(builder.build()); net.init(); return net; @@ -964,4 +973,35 @@ public class ConvDataFormatTests extends BaseDL4JTest { } return differs; } + + //Converts NHWC to NCHW activations + @EqualsAndHashCode + private static class NHWCToNCHWPreprocessor implements InputPreProcessor { + + @Override + public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) { + return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.permute(0,3,1,2)); + } + + @Override + public INDArray backprop(INDArray output, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) { + return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, output.permute(0,2,3,1)); + } + + @Override + public InputPreProcessor clone() { + return this; + } + + @Override + public InputType getOutputType(InputType inputType) { + InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType; + return InputType.convolutional(c.getHeight(), c.getWidth(), c.getChannels(), CNN2DFormat.NCHW); + } + + @Override + public Pair feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) { + return null; + } + } } diff --git a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/convolution/TestConvolution.java b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/convolution/TestConvolution.java index 67a2958b7..72ab2c043 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/convolution/TestConvolution.java +++ b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/convolution/TestConvolution.java @@ -212,7 +212,7 @@ public class TestConvolution extends BaseDL4JTest { ComputationGraph model = KerasModelImport.importKerasModelAndWeights( fExtracted.getAbsolutePath(), new int[]{inSize, inSize, 3}, false); model = model.convertDataType(DataType.DOUBLE); - INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{1, 3, inSize, inSize}); + INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{1, inSize, inSize, 3}); //Keras import model -> NHWC CuDNNTestUtils.assertHelpersPresent(model.getLayers()); Map withCudnn = model.feedForward(in, false); diff --git a/deeplearning4j/deeplearning4j-modelimport/pom.xml b/deeplearning4j/deeplearning4j-modelimport/pom.xml index 6d71c394e..3dcdcc720 100644 --- a/deeplearning4j/deeplearning4j-modelimport/pom.xml +++ b/deeplearning4j/deeplearning4j-modelimport/pom.xml @@ -113,6 +113,12 @@ test + + org.datavec + datavec-python + ${datavec.version} + test + diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/KerasInput.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/KerasInput.java index 785e480d1..d83d27011 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/KerasInput.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/KerasInput.java @@ -19,6 +19,9 @@ package org.deeplearning4j.nn.modelimport.keras.layers; import lombok.Data; import lombok.EqualsAndHashCode; import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.ArrayUtils; +import org.deeplearning4j.nn.conf.CNN2DFormat; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.Convolution3D; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; @@ -121,27 +124,29 @@ public class KerasInput extends KerasLayer { InputType myInputType; switch (this.inputShape.length) { case 1: - myInputType = new InputType.InputTypeFeedForward(this.inputShape[0]); + myInputType = new InputType.InputTypeFeedForward(this.inputShape[0], null); break; case 2: if(this.dimOrder != null) { + System.out.println("Dim order: " + this.dimOrder); + System.out.println("Input shape: " + ArrayUtils.toString(this.inputShape)); switch (this.dimOrder) { case TENSORFLOW: //NWC == channels_last - myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0]); + myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0], RNNFormat.NWC); break; case THEANO: //NCW == channels_first - myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]); + myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1], RNNFormat.NCW); break; case NONE: //Assume RNN in [mb, seqLen, size] format - myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]); + myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0], RNNFormat.NWC); break; default: throw new IllegalStateException("Unknown/not supported dimension ordering: " + this.dimOrder); } } else { //Assume RNN in [mb, seqLen, size] format - myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]); + myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0], RNNFormat.NWC); } break; @@ -150,17 +155,17 @@ public class KerasInput extends KerasLayer { case TENSORFLOW: /* TensorFlow convolutional input: # rows, # cols, # channels */ myInputType = new InputType.InputTypeConvolutional(this.inputShape[0], this.inputShape[1], - this.inputShape[2]); + this.inputShape[2], CNN2DFormat.NHWC); break; case THEANO: /* Theano convolutional input: # channels, # rows, # cols */ myInputType = new InputType.InputTypeConvolutional(this.inputShape[1], this.inputShape[2], - this.inputShape[0]); + this.inputShape[0], CNN2DFormat.NCHW); break; default: this.dimOrder = DimOrder.THEANO; myInputType = new InputType.InputTypeConvolutional(this.inputShape[1], this.inputShape[2], - this.inputShape[0]); + this.inputShape[0], CNN2DFormat.NCHW); log.warn("Couldn't determine dim ordering / data format from model file. Older Keras " + "versions may come without specified backend, in which case we assume the model was " + "built with theano." ); diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayer.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayer.java index ecf64e8c0..5a1f0e8dd 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayer.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayer.java @@ -20,6 +20,7 @@ import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.Layer; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; @@ -65,6 +66,9 @@ public class TFOpLayer extends Layer { long[] shape = inputType.getShape(true); TFOpLayerImpl tempLayer = new TFOpLayerImpl(nodeDef, constants, null, null); long[] outputShape = tempLayer.getOutputShape(shape); + if (outputShape.length == 3){ + return InputType.recurrent(outputShape[2], outputShape[1], RNNFormat.NWC); + } return InputType.inferInputType(Nd4j.create(outputShape)); } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayerImpl.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayerImpl.java index d7b0b3b56..ebc1ca5f0 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayerImpl.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayerImpl.java @@ -125,17 +125,9 @@ public class TFOpLayerImpl extends AbstractLayer { } private INDArray runGraph(INDArray input){ - if (input.rank() == 3){ - // TODO make this a preprocessor - input = input.permute(0, 2, 1); - } Map inputMap = new HashMap<>(); inputMap.put(inputNames.get(0), input); INDArray out = graphRunnerService.run(inputMap).values().toArray(new INDArray[0])[0]; - if (out.rank() == 3){ - out = out.permute(0, 2, 1); // TODO post-processing? - } - return out; } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution1D.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution1D.java index a7d54c1e4..49e862de7 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution1D.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution1D.java @@ -95,7 +95,6 @@ public class KerasConvolution1D extends KerasConvolution { IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), enforceTrainingConfig, conf, kerasMajorVersion); - Convolution1DLayer.Builder builder = new Convolution1DLayer.Builder().name(this.layerName) .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout) .activation(getIActivationFromConfig(layerConfig, conf)) @@ -104,7 +103,7 @@ public class KerasConvolution1D extends KerasConvolution { .convolutionMode(getConvolutionModeFromConfig(layerConfig, conf)) .kernelSize(getKernelSizeFromConfig(layerConfig, 1, conf, kerasMajorVersion)[0]) .hasBias(hasBias) - .stride(getStrideFromConfig(layerConfig, 1, conf)[0]); + .stride(getStrideFromConfig(layerConfig, 1, conf)[0]).rnnDataFormat(dimOrder == DimOrder.TENSORFLOW? RNNFormat.NWC: RNNFormat.NCW); int[] padding = getPaddingFromBorderModeConfig(layerConfig, 1, conf, kerasMajorVersion); if (hasBias) builder.biasInit(0.0); diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution2D.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution2D.java index e9c74e78c..51d481e5d 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution2D.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution2D.java @@ -20,6 +20,7 @@ import lombok.Data; import lombok.EqualsAndHashCode; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.api.layers.LayerConstraint; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; @@ -27,6 +28,7 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurat import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; import org.deeplearning4j.nn.weights.IWeightInit; +import oshi.jna.platform.windows.PowrProf; import java.util.Map; @@ -93,6 +95,7 @@ public class KerasConvolution2D extends KerasConvolution { LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig( layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion); + System.out.println("----" + dimOrder); ConvolutionLayer.Builder builder = new ConvolutionLayer.Builder().name(this.layerName) .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout) .activation(getIActivationFromConfig(layerConfig, conf)) @@ -101,7 +104,8 @@ public class KerasConvolution2D extends KerasConvolution { .convolutionMode(getConvolutionModeFromConfig(layerConfig, conf)) .kernelSize(getKernelSizeFromConfig(layerConfig, 2, conf, kerasMajorVersion)) .hasBias(hasBias) - .stride(getStrideFromConfig(layerConfig, 2, conf)); + .stride(getStrideFromConfig(layerConfig, 2, conf)) + .dataFormat((dimOrder==DimOrder.TENSORFLOW)? CNN2DFormat.NHWC:CNN2DFormat.NCHW); int[] padding = getPaddingFromBorderModeConfig(layerConfig, 2, conf, kerasMajorVersion); if (hasBias) builder.biasInit(0.0); diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolutionUtils.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolutionUtils.java index 0968260b7..d60e5f6e8 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolutionUtils.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolutionUtils.java @@ -360,8 +360,19 @@ public class KerasConvolutionUtils { } } else if (dimension == 1) { - int paddingInt = (int) innerConfig.get(layerField); - padding = new int[]{paddingInt, paddingInt}; + Object paddingObj = innerConfig.get(layerField); + if (paddingObj instanceof List){ + List paddingList = (List)paddingObj; + padding = new int[]{ + paddingList.get(0), + paddingList.get(1) + }; + } + else{ + int paddingInt = (int) innerConfig.get(layerField); + padding = new int[]{paddingInt, paddingInt}; + } + } else { throw new UnsupportedKerasConfigurationException( "Keras padding layer not supported"); diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasFlatten.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasFlatten.java index 196f9d3d9..a807d3357 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasFlatten.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasFlatten.java @@ -18,6 +18,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.core; import lombok.extern.slf4j.Slf4j; import lombok.val; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType.InputTypeConvolutional; @@ -27,7 +28,6 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurat import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor; import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor; -import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor; import java.util.Map; @@ -93,11 +93,10 @@ public class KerasFlatten extends KerasLayer { switch (this.getDimOrder()) { case NONE: case THEANO: - preprocessor = new CnnToFeedForwardPreProcessor(it.getHeight(), it.getWidth(), it.getChannels()); + preprocessor = new CnnToFeedForwardPreProcessor(it.getHeight(), it.getWidth(), it.getChannels(), CNN2DFormat.NCHW); break; case TENSORFLOW: - preprocessor = new TensorFlowCnnToFeedForwardPreProcessor(it.getHeight(), it.getWidth(), - it.getChannels()); + preprocessor = new CnnToFeedForwardPreProcessor(it.getHeight(), it.getWidth(), it.getChannels(), CNN2DFormat.NHWC); break; default: throw new InvalidKerasConfigurationException("Unknown Keras backend " + this.getDimOrder()); @@ -111,7 +110,7 @@ public class KerasFlatten extends KerasLayer { // to RNN type. Otherwise we add this trivial preprocessor (since there's nothing to flatten). InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0]; val inputShape = new long[]{it.getSize()}; - preprocessor = new ReshapePreprocessor(inputShape, inputShape, false); + preprocessor = new ReshapePreprocessor(inputShape, inputShape, false, null); } return preprocessor; } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasRepeatVector.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasRepeatVector.java index 41254e221..15f3011c4 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasRepeatVector.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasRepeatVector.java @@ -17,6 +17,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.core; import lombok.extern.slf4j.Slf4j; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.misc.RepeatVector; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; @@ -60,6 +61,7 @@ public class KerasRepeatVector extends KerasLayer { super(layerConfig, enforceTrainingConfig); this.layer = new RepeatVector.Builder().repetitionFactor(getRepeatMultiplier(layerConfig, conf)) + .dataFormat(RNNFormat.NWC) .name(this.layerName).build(); } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasReshape.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasReshape.java index e5f1375d1..bb4dc5ecf 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasReshape.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasReshape.java @@ -18,6 +18,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.core; import lombok.val; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; @@ -111,11 +112,9 @@ public class KerasReshape extends KerasLayer { } else { targetShape = new long[]{targetShape[1], targetShape[0], targetShape[2]}; } - preprocessor = new ReshapePreprocessor(inputShape, targetShape, false); + preprocessor = new ReshapePreprocessor(inputShape, targetShape, false, CNN2DFormat.NCHW); } else { // (dimOrder == DimOrder.TENSORFLOW || dimOrder == DimOrder.NONE && kerasMajorVersion == 2) - if (inputShape[0] != targetShape[0]) - targetShape = new long[]{targetShape[2], targetShape[0], targetShape[1]}; - preprocessor = new ReshapePreprocessor(inputShape, targetShape, false); + preprocessor = new ReshapePreprocessor(inputShape, targetShape, false, CNN2DFormat.NHWC); } } else if (inputType[0] instanceof InputType.InputTypeConvolutional3D) { @@ -128,23 +127,23 @@ public class KerasReshape extends KerasLayer { } else { targetShape = new long[] { targetShape[2], targetShape[1], targetShape[0], targetShape[3] }; } - preprocessor = new ReshapePreprocessor(inputShape, targetShape, false); + preprocessor = new ReshapePreprocessor(inputShape, targetShape, false, null); } else { if (inputShape[0] != targetShape[0]) targetShape = new long[] { targetShape[3], targetShape[0], targetShape[1], targetShape[2] }; - preprocessor = new ReshapePreprocessor(inputShape, targetShape, false); + preprocessor = new ReshapePreprocessor(inputShape, targetShape, false, null); } } else if (inputType[0] instanceof InputType.InputTypeRecurrent) { InputType.InputTypeRecurrent it = (InputType.InputTypeRecurrent) inputType[0]; val inputShape = new long[]{it.getSize(), it.getTimeSeriesLength()}; - preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false); + preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false, null); } else if (inputType[0] instanceof InputType.InputTypeFeedForward) { InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0]; val inputShape = new long[]{it.getSize()}; if (targetShape.length == 3) { targetShape = targetShapeForDimOrder(inputShape, targetShape); } - preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false); + preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false, null); } return preprocessor; } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/embeddings/KerasEmbedding.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/embeddings/KerasEmbedding.java index 1ee13c0b0..f5a29d4f5 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/embeddings/KerasEmbedding.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/embeddings/KerasEmbedding.java @@ -21,6 +21,7 @@ import lombok.EqualsAndHashCode; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.conf.InputPreProcessor; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; @@ -121,6 +122,7 @@ public class KerasEmbedding extends KerasLayer { .biasInit(0.0) .l1(this.weightL1Regularization) .l2(this.weightL2Regularization) + .outputDataFormat(RNNFormat.NWC) .hasBias(false); if (embeddingConstraint != null) builder.constrainWeights(embeddingConstraint); diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTM.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTM.java index 0888b3376..3ab596cd5 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTM.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTM.java @@ -186,7 +186,7 @@ public class KerasLSTM extends KerasLayer { .weightInitRecurrent(recurrentInit) .biasInit(0.0) // TODO: this is incorrect .l1(this.weightL1Regularization) - .l2(this.weightL2Regularization); + .l2(this.weightL2Regularization).dataFormat(RNNFormat.NWC); Integer nIn = KerasLayerUtils.getNInFromInputDim(layerConfig, conf); if(nIn != null) builder.setNIn(nIn); diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnn.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnn.java index 67fe611e1..3e44bd1f9 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnn.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnn.java @@ -158,7 +158,7 @@ public class KerasSimpleRnn extends KerasLayer { .weightInitRecurrent(recurrentInit) .biasInit(0.0) .l1(this.weightL1Regularization) - .l2(this.weightL2Regularization); + .l2(this.weightL2Regularization).dataFormat(RNNFormat.NWC); Integer nIn = KerasLayerUtils.getNInFromInputDim(layerConfig, conf); if(nIn != null) builder.setNIn(nIn); diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/wrappers/KerasBidirectional.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/wrappers/KerasBidirectional.java index c8cc4fc20..faa271987 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/wrappers/KerasBidirectional.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/wrappers/KerasBidirectional.java @@ -147,7 +147,7 @@ public class KerasBidirectional extends KerasLayer { break; case "SimpleRNN": kerasRnnlayer = new KerasSimpleRnn(innerRnnConfig, enforceTrainingConfig, previousLayers); - SimpleRnn rnnLayer = (SimpleRnn) ((KerasSimpleRnn) kerasRnnlayer).getSimpleRnnLayer(); + Layer rnnLayer = ((KerasSimpleRnn) kerasRnnlayer).getSimpleRnnLayer(); this.layer = new Bidirectional(mode, rnnLayer); layer.setLayerName(layerName); break; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/ReshapePreprocessor.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/ReshapePreprocessor.java index afc9392a5..e6819a3bf 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/ReshapePreprocessor.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/ReshapePreprocessor.java @@ -21,6 +21,9 @@ import lombok.Data; import lombok.EqualsAndHashCode; import lombok.extern.slf4j.Slf4j; import lombok.val; +import org.deeplearning4j.nn.conf.CNN2DFormat; +import org.deeplearning4j.nn.conf.DataFormat; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException; import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor; @@ -54,25 +57,30 @@ public class ReshapePreprocessor extends BaseInputPreProcessor { private final long[] inputShape; private final long[] targetShape; private boolean hasMiniBatchDimension; - - /** - * @deprecated Use constructor {@link #ReshapePreprocessor(long[], long[], boolean)} - */ - @Deprecated - public ReshapePreprocessor(long[] inputShape, long[] targetShape) { - this(inputShape, targetShape, false); - } + private DataFormat format; /** * @param inputShape Input shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension * @param targetShape Target shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension * @param hasMiniBatchDimension If true: shapes should be of the form [minibatch, x, y, ...]; if false: shapes should be of form [x, y, ...] */ + public ReshapePreprocessor(long[] inputShape, long[] targetShape, boolean hasMiniBatchDimension) { + this(inputShape, targetShape, hasMiniBatchDimension, null); + } + + /** + * @param inputShape Input shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension + * @param targetShape Target shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension + * @param hasMiniBatchDimension If true: shapes should be of the form [minibatch, x, y, ...]; if false: shapes should be of form [x, y, ...] + * @param dataFormat May be null. If non-null: + */ public ReshapePreprocessor(@JsonProperty("inputShape") long[] inputShape, @JsonProperty("targetShape") long[] targetShape, - @JsonProperty("hasMiniBatchDimension") boolean hasMiniBatchDimension) { + @JsonProperty("hasMiniBatchDimension") boolean hasMiniBatchDimension, + @JsonProperty("dataFormat") DataFormat dataFormat) { this.inputShape = inputShape; this.targetShape = targetShape; this.hasMiniBatchDimension = hasMiniBatchDimension; + this.format = dataFormat; } private long[] getShape(long[] originalShape, long minibatch) { @@ -140,13 +148,26 @@ public class ReshapePreprocessor extends BaseInputPreProcessor { ret = InputType.feedForward(shape[1]); break; case 3: - ret = InputType.recurrent(shape[2], shape[1]); + RNNFormat format = RNNFormat.NCW; + if(this.format != null && this.format instanceof RNNFormat) + format = (RNNFormat)this.format; + + ret = InputType.recurrent(shape[2], shape[1], format); break; case 4: if (inputShape.length == 1 || inputType.getType() == InputType.Type.RNN) { ret = InputType.convolutional(shape[1], shape[2], shape[3]); } else { - ret = InputType.convolutional(shape[2], shape[3], shape[1]); + + CNN2DFormat cnnFormat = CNN2DFormat.NCHW; + if (this.format != null && this.format instanceof CNN2DFormat) + cnnFormat = (CNN2DFormat) this.format; + + if (cnnFormat == CNN2DFormat.NCHW) { + ret = InputType.convolutional(shape[2], shape[3], shape[1], cnnFormat); + } else { + ret = InputType.convolutional(shape[1], shape[2], shape[3], cnnFormat); + } } break; default: diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/TensorFlowCnnToFeedForwardPreProcessor.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/TensorFlowCnnToFeedForwardPreProcessor.java index db7d2e990..2e1ac2e51 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/TensorFlowCnnToFeedForwardPreProcessor.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/TensorFlowCnnToFeedForwardPreProcessor.java @@ -27,26 +27,25 @@ import org.nd4j.shade.jackson.annotation.JsonCreator; import org.nd4j.shade.jackson.annotation.JsonProperty; /** - * Specialized CnnToFeedForwardInputPreProcessor for use with - * Convolutional layers imported from Keras using the TensorFlow - * backend. - * - * @author dave@skymind.io + * @deprecated Exists only for backward compatibility of older pretrained models. Should not be used. + * Use {@link CnnToFeedForwardPreProcessor} for all new models instead. */ -@Slf4j +@Slf4j @Deprecated public class TensorFlowCnnToFeedForwardPreProcessor extends CnnToFeedForwardPreProcessor { - @JsonCreator + @JsonCreator @Deprecated public TensorFlowCnnToFeedForwardPreProcessor(@JsonProperty("inputHeight") long inputHeight, @JsonProperty("inputWidth") long inputWidth, @JsonProperty("numChannels") long numChannels) { super(inputHeight, inputWidth, numChannels); } + @Deprecated public TensorFlowCnnToFeedForwardPreProcessor(long inputHeight, long inputWidth) { super(inputHeight, inputWidth); } + @Deprecated public TensorFlowCnnToFeedForwardPreProcessor() { super(); } @@ -81,4 +80,4 @@ public class TensorFlowCnnToFeedForwardPreProcessor extends CnnToFeedForwardPreP public TensorFlowCnnToFeedForwardPreProcessor clone() { return (TensorFlowCnnToFeedForwardPreProcessor) super.clone(); } -} +} \ No newline at end of file diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLayerUtils.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLayerUtils.java index 3f69cb7d4..a0578a4df 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLayerUtils.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLayerUtils.java @@ -31,6 +31,7 @@ import org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations.*; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.*; import org.deeplearning4j.nn.modelimport.keras.layers.core.*; import org.deeplearning4j.nn.modelimport.keras.layers.embeddings.KerasEmbedding; +import org.deeplearning4j.nn.modelimport.keras.layers.local.KerasLocallyConnected1D; import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasAlphaDropout; import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasGaussianDropout; import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasGaussianNoise; @@ -319,6 +320,8 @@ public class KerasLayerUtils { layer = new KerasELU(layerConfig, enforceTrainingConfig); } else if(layerClassName.equals(conf.getLAYER_CLASS_NAME_SOFTMAX())){ layer = new KerasSoftmax(layerConfig, enforceTrainingConfig); + } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_LOCALLY_CONNECTED_1D())){ + layer = new KerasLocallyConnected1D(layerConfig, enforceTrainingConfig); } else if (conf instanceof Keras2LayerConfiguration){ Keras2LayerConfiguration k2conf = (Keras2LayerConfiguration)conf; if (layerClassName.equals(k2conf.getTENSORFLOW_OP_LAYER())){ diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/TFKerasTests.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/TFKerasTests.java deleted file mode 100644 index cb74b1ed1..000000000 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/TFKerasTests.java +++ /dev/null @@ -1,50 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2020 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.nn.modelimport.keras; - -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.nn.graph.ComputationGraph; -import org.junit.Assert; -import org.junit.Test; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.resources.Resources; - -import java.io.File; -import java.util.Arrays; - -public class TFKerasTests extends BaseDL4JTest{ - - @Test - public void testModelWithTFOp1() throws Exception{ - File f = Resources.asFile("modelimport/keras/tfkeras/reshape.h5"); - ComputationGraph graph = KerasModelImport.importKerasModelAndWeights(f.getAbsolutePath()); - INDArray out = graph.outputSingle(Nd4j.zeros(12, 2, 3)); - Assert.assertArrayEquals(new long[]{12, 3}, out.shape()); - } - - @Test - public void testModelWithTFOp2() throws Exception{ - File f = Resources.asFile("modelimport/keras/tfkeras/permute.h5"); - ComputationGraph graph = KerasModelImport.importKerasModelAndWeights(f.getAbsolutePath()); - INDArray out = graph.outputSingle(Nd4j.zeros(12, 2, 3)); - // dl4j's feedforward doesn't support 3D output, so batch and time axes gets squashed - long[] expectedShape = new long[]{12 * 2, 5}; - Assert.assertArrayEquals(expectedShape, out.shape()); - } - -} diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/TestTFKerasModelImport.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/TestTFKerasModelImport.java new file mode 100644 index 000000000..455d7ca56 --- /dev/null +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/TestTFKerasModelImport.java @@ -0,0 +1,147 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.nn.modelimport.keras; + +import org.apache.commons.io.FileUtils; +import org.datavec.python.keras.Model; +import org.deeplearning4j.BaseDL4JTest; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.nd4j.common.tests.ResourceUtils; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.concurrency.AffinityManager; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.resources.Resources; + +import java.io.File; +import java.util.List; + + +@RunWith(Parameterized.class) +public class TestTFKerasModelImport extends BaseDL4JTest{ + + @Rule + public TemporaryFolder testDir = new TemporaryFolder(); + + private String modelFile; + + @Override + public long getTimeoutMilliseconds(){ + return 300000; + } // installing TF will take a while + + + @Parameterized.Parameters(name = "file={0}") + public static Object[] params() throws Exception { + List paths = ResourceUtils.listClassPathFiles("modelimport/keras/tfkeras", true, false); + return paths.toArray(new String[0]); + } + + public TestTFKerasModelImport(String modelFile){ + this.modelFile = modelFile; + } + + @Test + public void testModelImport() throws Exception{ + testModelImportWithData(modelFile); + } + + private void testModelImportWithData(String path) throws Exception{ + System.out.println(path); + // TODO multi input/output + INDArray inputArray; + INDArray expectedOutputArray; + File f = Resources.asFile(path); //May in in JAR that HDF5 can't read from + File modelFile = new File(testDir.getRoot(), f.getName()); + FileUtils.copyFile(f, modelFile); + + synchronized (Hdf5Archive.LOCK_OBJECT){ + Hdf5Archive hdf5Archive = new Hdf5Archive(modelFile.getAbsolutePath()); + List rootGroups = hdf5Archive.getGroups(); + if (rootGroups.contains("data")){ + String inputName = hdf5Archive.readAttributeAsString("input_names", "data"); + String outputName = hdf5Archive.readAttributeAsString("output_names", "data"); + inputArray = hdf5Archive.readDataSet(inputName, "data"); + expectedOutputArray = hdf5Archive.readDataSet(outputName, "data"); + } + else{ + hdf5Archive.close(); + return; + } + hdf5Archive.close(); + } + INDArray outputArray; + + ComputationGraph dl4jModel = KerasModelImport.importKerasModelAndWeights(path); + outputArray = dl4jModel.outputSingle(inputArray); + + expectedOutputArray = expectedOutputArray.castTo(DataType.FLOAT); + outputArray = outputArray.castTo(DataType.FLOAT); + if (path.contains("misc_")){ + //shape relaxation + expectedOutputArray = expectedOutputArray.reshape( -1); + outputArray = outputArray.reshape(-1); + } + + System.out.println(outputArray.toString()); + System.out.println(expectedOutputArray.toString()); + Assert.assertArrayEquals(expectedOutputArray.shape(), outputArray.shape()); + Assert.assertTrue(expectedOutputArray.equalsWithEps(outputArray, 1e-3)); + } + + private void testModelImportWithKeras(String path) throws Exception{ + Model kerasModel = new Model(path); + ComputationGraph dl4jModel = KerasModelImport.importKerasModelAndWeights(path); + Assert.assertEquals(kerasModel.numInputs(), dl4jModel.getNumInputArrays()); + Assert.assertEquals(kerasModel.numOutputs(), dl4jModel.getNumOutputArrays()); + INDArray[] kerasInputArrays = new INDArray[kerasModel.numInputs()]; + INDArray[] dl4jInputArrays = new INDArray[kerasModel.numInputs()]; + + for (int i = 0; i < kerasInputArrays.length; i ++) { + long[] shape = kerasModel.inputShapeAt(i); + for (int j = 0; j < shape.length; j++) { + if (shape[j] < 0) { + shape[j] = 1; + } + } + + kerasInputArrays[i] = Nd4j.rand(shape); + } + + INDArray[] kerasOut = kerasModel.predict(kerasInputArrays); + INDArray[] dl4jOut = dl4jModel.output(dl4jInputArrays); + + Assert.assertEquals(kerasOut.length, dl4jOut.length); + + for (int i = 0; i < kerasOut.length; i++){ + INDArray kerasOutArr = kerasOut[i]; + kerasOutArr = kerasOutArr.reshape(1, -1);// bit of relaxation on shape + kerasOutArr= kerasOutArr.castTo(DataType.DOUBLE); + Nd4j.getAffinityManager().ensureLocation(dl4jOut[i], AffinityManager.Location.HOST); + INDArray dl4jOutArr = dl4jOut[i].reshape(1, -1); + System.out.println(kerasOutArr.shapeInfoToString()); + System.out.println(dl4jOutArr.shapeInfoToString()); + Assert.assertEquals(kerasOutArr, dl4jOutArr); + } + } +} diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/JsonTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/JsonTest.java index 4aae27af3..c95a6a8bb 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/JsonTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/JsonTest.java @@ -22,7 +22,6 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor; import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor; import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor; -import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor; import org.junit.Test; import static org.junit.Assert.assertEquals; @@ -34,8 +33,7 @@ public class JsonTest extends BaseDL4JTest { InputPreProcessor[] pp = new InputPreProcessor[] { new KerasFlattenRnnPreprocessor(10, 5), new PermutePreprocessor(new int[]{0,1,2}), - new ReshapePreprocessor(new long[]{10,10}, new long[]{100,1}), - new TensorFlowCnnToFeedForwardPreProcessor() + new ReshapePreprocessor(new long[]{10,10}, new long[]{100,1}, true, null) }; for(InputPreProcessor p : pp ){ diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java index 4d4bf067e..7f566a6cb 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java @@ -29,6 +29,7 @@ import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasSpaceTo import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.junit.Ignore; import org.junit.Test; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.resources.Resources; @@ -250,7 +251,7 @@ public class Keras2ModelConfigurationTest extends BaseDL4JTest { .enforceTrainingConfig(false).buildSequential().getMultiLayerConfiguration(); MultiLayerNetwork model = new MultiLayerNetwork(config); model.init(); - INDArray input = Nd4j.create(50, 500, 1500); + INDArray input = Nd4j.create(DataType.FLOAT, 50, 1500, 500); //NWC format - [Minibatch, seqLength, channels] INDArray out = model.output(input); assertTrue(Arrays.equals(out.shape(), new long[]{50, 64})); } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java index a23001444..f705f9f86 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java @@ -87,15 +87,6 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { @Rule public final TemporaryFolder testDir = new TemporaryFolder(); - public static final BiFunction nwc2ncwExpected = new BiFunction() { - @Override - public INDArray apply(String s, INDArray array) { - if(array.rank() == 3) - return array.permute(0, 2, 1); //NWC to NCW - return array; - } - }; - @Override public long getTimeoutMilliseconds() { return 180000L; //Most benchmarks should run very quickly; large timeout is to avoid issues with unusually slow download of test resources @@ -169,28 +160,28 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { public void importImdbLstmTfKeras1() throws Exception { String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_model.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected); + importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, null); } @Test public void importImdbLstmThKeras1() throws Exception { String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_model.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected); + importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, null); } @Test public void importImdbLstmTfKeras2() throws Exception { String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_model.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected); + importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, null); } @Test public void importImdbLstmThKeras2() throws Exception { String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_model.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, false, true, false, false, true, null, nwc2ncwExpected); + importEndModelTest(modelPath, inputsOutputPath, false, true, false, false, true, null, null); } /** @@ -262,7 +253,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { String modelPath = "modelimport/keras/examples/simple_flatten_rnn/simple_flatten_rnn_tf_keras_2_model.h5"; String inputsOutputPath = "modelimport/keras/examples/simple_flatten_rnn/" + "simple_flatten_rnn_tf_keras_2_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected); + importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, null); } /** @@ -316,7 +307,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { @Test public void importAcganDiscriminator() throws Exception { ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/acgan/acgan_discriminator_1_epochs.h5"); - INDArray input = Nd4j.create(10, 1, 28, 28); + INDArray input = Nd4j.create(10, 28, 28, 1); //NHWC INDArray[] output = model.output(input); } @@ -403,7 +394,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { // Make predictions int miniBatch = 32; - INDArray input = Nd4j.ones(miniBatch, 4, 10); + INDArray input = Nd4j.ones(miniBatch, 10, 4); //NWC format - with nIn=4, seqLength = 10 INDArray[] out = graph.output(input); // Fit model @@ -450,7 +441,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { @Test public void importMobileNet() throws Exception { ComputationGraph graph = importFunctionalModelH5Test("modelimport/keras/examples/mobilenet/alternative.hdf5"); - INDArray input = Nd4j.ones(10, 3, 299, 299); + INDArray input = Nd4j.ones(10, 299, 299, 3); graph.output(input); } @@ -462,7 +453,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { int[] inputShape = new int[]{299, 299, 3}; ComputationGraph graph = importFunctionalModelH5Test( "modelimport/keras/examples/inception/inception_tf_keras_2.h5", inputShape, false); - INDArray input = Nd4j.ones(10, 3, 299, 299); + INDArray input = Nd4j.ones(10, 299, 299, 3); //TF = channels last = NHWC graph.output(input); System.out.println(graph.summary()); } @@ -476,7 +467,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { public void importInception() throws Exception { ComputationGraph graph = importFunctionalModelH5Test( "modelimport/keras/examples/inception/inception_v3_complete.h5"); - INDArray input = Nd4j.ones(10, 3, 299, 299); + INDArray input = Nd4j.ones(10, 299, 299, 3); //TF = channels last = NHWC graph.output(input); System.out.println(graph.summary()); } @@ -533,14 +524,14 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { * - Separate (policy and value) residual architecture * - Separate (policy and value) convolutional architecture */ - @Test + @Test @Ignore //AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last public void importSepConvPolicy() throws Exception { ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/sep_conv_policy.h5"); INDArray input = Nd4j.create(32, 19, 19, 10); model.output(input); } - @Test + @Test @Ignore //AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last public void importSepResPolicy() throws Exception { ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/sep_res_policy.h5"); INDArray input = Nd4j.create(32, 19, 19, 10); @@ -548,28 +539,28 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { } - @Test + @Test @Ignore //AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last public void importSepConvValue() throws Exception { ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/sep_conv_value.h5"); INDArray input = Nd4j.create(32, 19, 19, 10); model.output(input); } - @Test + @Test @Ignore //AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last public void importSepResValue() throws Exception { ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/sep_res_value.h5"); INDArray input = Nd4j.create(32, 19, 19, 10); model.output(input); } - @Test + @Test @Ignore //AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last public void importDualRes() throws Exception { ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/dual_res.h5"); INDArray input = Nd4j.create(32, 19, 19, 10); model.output(input); } - @Test + @Test @Ignore //AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last public void importDualConv() throws Exception { ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/dual_conv.h5"); INDArray input = Nd4j.create(32, 19, 19, 10); @@ -634,16 +625,9 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { System.out.println("Starting test: " + name); String modelPath = "modelimport/keras/examples/causal_conv1d/" + name; String inputsOutputPath = "modelimport/keras/examples/causal_conv1d/" + (name.substring(0,name.length()-"model.h5".length()) + "inputs_and_outputs.h5"); - Function f = new Function() { - @Override - public INDArray apply(INDArray i) { - //NWC to NCW - return i.permute(0, 2, 1); - } - }; MultiLayerNetwork net = importEndModelTest(modelPath, inputsOutputPath, true, true, - true, true, false, f, nwc2ncwExpected); + true, true, false, null, null); Layer l = net.getLayer(0); Convolution1DLayer c1d = (Convolution1DLayer) l.getConfig(); assertEquals(ConvolutionMode.Causal, c1d.getConvolutionMode()); @@ -707,25 +691,9 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { System.out.println("Starting test: " + name); String modelPath = "modelimport/keras/examples/conv1d/" + name; String inputsOutputPath = "modelimport/keras/examples/conv1d/" + (name.substring(0,name.length()-"model.h5".length()) + "inputs_and_outputs.h5"); - Function f = name.contains("_cf_") ? null : new Function() { - @Override - public INDArray apply(INDArray i) { - //NWC to NCW - return i.permute(0, 2, 1); - } - }; - - BiFunction f2 = name.contains("_cf_") ? null : new BiFunction() { - @Override - public INDArray apply(String s, INDArray array) { -// if("conv".equals(s)){ - return array.permute(0, 2, 1); -// } - } - }; importEndModelTest(modelPath, inputsOutputPath, true, true, - true, true, false, f, f2); + true, true, false, null, null); //f, f2); } } @@ -882,8 +850,6 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { INDArray[] inputs = new INDArray[inputNames.size()]; for (int i = 0; i < inputNames.size(); i++) { inputs[i] = archive.readDataSet(inputNames.get(i), GROUP_ATTR_INPUTS); - if (inputs[i].shape().length == 4 && tensorFlowImageDimOrdering) - inputs[i] = inputs[i].permute(0, 3, 1, 2); } return inputs; } @@ -893,8 +859,6 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { Map activations = new HashMap(); for (String layerName : archive.getDataSets(GROUP_ACTIVATIONS)) { INDArray activation = archive.readDataSet(layerName, GROUP_ACTIVATIONS); - if (activation.shape().length == 4 && tensorFlowImageDimOrdering) - activation = activation.permute(0, 3, 1, 2); activations.put(layerName, activation); } return activations; @@ -907,8 +871,6 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { INDArray[] outputs = new INDArray[outputNames.size()]; for (int i = 0; i < outputNames.size(); i++) { outputs[i] = archive.readDataSet(outputNames.get(i), GROUP_ATTR_OUTPUTS); - if (outputs[i].shape().length == 4 && tensorFlowImageDimOrdering) - outputs[i] = outputs[i].permute(0, 3, 1, 2); } return outputs; } @@ -920,8 +882,6 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { INDArray[] predictions = new INDArray[outputNames.size()]; for (int i = 0; i < outputNames.size(); i++) { predictions[i] = archive.readDataSet(outputNames.get(i), GROUP_PREDICTIONS); - if (predictions[i].shape().length == 4 && tensorFlowImageDimOrdering) - predictions[i] = predictions[i].permute(0, 3, 1, 2); } return predictions; } @@ -941,6 +901,11 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { // skip too small absolute inputs if (Math.abs(aAbsMax) > threshold && Math.abs(bAbsMax) > threshold) { boolean eq = expected.equalsWithEps(actual.castTo(expected.dataType()), eps); + if(!eq){ + System.out.println("Expected: " + Arrays.toString(expected.shape()) + ", actual: " + Arrays.toString(actual.shape())); + System.out.println("Expected:\n" + expected); + System.out.println("Actual: \n" + actual); + } assertTrue("Output differs: " + label, eq); } } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/weights/KerasWeightSettingTests.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/weights/KerasWeightSettingTests.java index 75334bcd0..7204ebb82 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/weights/KerasWeightSettingTests.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/weights/KerasWeightSettingTests.java @@ -176,10 +176,10 @@ public class KerasWeightSettingTests extends BaseDL4JTest { INDArray bias = model.getLayer(0).getParam("b"); assertEquals(6, bias.length()); - INDArray input = Nd4j.ones(1, 5, 3, 4); + INDArray input = Nd4j.ones(1, 3, 4, 5); //NHWC INDArray output = model.output(input); - assertArrayEquals(new long[] {1, 6, 1, 2}, output.shape()); + assertArrayEquals(new long[] {1, 1, 2, 6}, output.shape()); //NHWC logSuccess(modelPath); } @@ -224,7 +224,7 @@ public class KerasWeightSettingTests extends BaseDL4JTest { INDArray input = Nd4j.zeros(mb, inputLength); INDArray output = model.output(input); - assertArrayEquals(new long[]{mb, nOut, inputLength - kernel + 1}, output.shape()); + assertArrayEquals(new long[]{mb, inputLength - kernel + 1, nOut}, output.shape()); //NWC logSuccess(modelPath); } @@ -238,9 +238,9 @@ public class KerasWeightSettingTests extends BaseDL4JTest { KerasLayer.registerCustomLayer("Lambda", KerasSpaceToDepth.class); MultiLayerNetwork model = loadMultiLayerNetwork(modelPath, false); - INDArray input = Nd4j.zeros(10, 4, 6, 6); + INDArray input = Nd4j.zeros(10, 6, 6, 4); INDArray output = model.output(input); - assertArrayEquals(new long[]{10, 16, 3, 3}, output.shape()); + assertArrayEquals(new long[]{10, 3, 3, 16}, output.shape()); logSuccess(modelPath); } @@ -248,10 +248,11 @@ public class KerasWeightSettingTests extends BaseDL4JTest { KerasLayer.registerCustomLayer("Lambda", KerasSpaceToDepth.class); ComputationGraph model = loadComputationalGraph(modelPath, false); - INDArray input[] = new INDArray[]{Nd4j.zeros(10, 4, 6, 6), Nd4j.zeros(10, 16, 3, 3)}; +// INDArray input[] = new INDArray[]{Nd4j.zeros(10, 4, 6, 6), Nd4j.zeros(10, 16, 3, 3)}; + INDArray input[] = new INDArray[]{Nd4j.zeros(10, 6, 6, 4), Nd4j.zeros(10, 3, 3, 16)}; INDArray[] output = model.output(input); log.info(Arrays.toString(output[0].shape())); - assertArrayEquals(new long[]{10, 32, 3, 3}, output[0].shape()); + assertArrayEquals(new long[]{10, 3, 3, 32}, output[0].shape()); logSuccess(modelPath); } @@ -278,7 +279,7 @@ public class KerasWeightSettingTests extends BaseDL4JTest { INDArray inEmbedding = Nd4j.zeros(mb, inputLength); INDArray output = model.output(inEmbedding); - assertArrayEquals(new long[]{mb, nOut, inputLength}, output.shape()); + assertArrayEquals(new long[]{mb, inputLength, nOut}, output.shape()); //NWC format logSuccess(modelPath); } @@ -304,7 +305,7 @@ public class KerasWeightSettingTests extends BaseDL4JTest { INDArray inEmbedding = Nd4j.zeros(mb, inputLength); INDArray output = model.output(inEmbedding); - assertArrayEquals(new long[]{mb, nOut, inputLength - kernel + 1}, output.shape()); + assertArrayEquals(new long[]{mb, inputLength - kernel + 1, nOut}, output.shape()); //NWC logSuccess(modelPath); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/CNN2DFormat.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/CNN2DFormat.java index 62b8ac5e6..b40d76bff 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/CNN2DFormat.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/CNN2DFormat.java @@ -9,7 +9,7 @@ package org.deeplearning4j.nn.conf; * * @author Alex Black */ -public enum CNN2DFormat { +public enum CNN2DFormat implements DataFormat { NCHW, NHWC; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/DataFormat.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/DataFormat.java new file mode 100644 index 000000000..bde392a58 --- /dev/null +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/DataFormat.java @@ -0,0 +1,26 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.nn.conf; + +import org.deeplearning4j.nn.conf.serde.format.DataFormatDeserializer; +import org.deeplearning4j.nn.conf.serde.format.DataFormatSerializer; +import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize; +import org.nd4j.shade.jackson.databind.annotation.JsonSerialize; + +@JsonSerialize(using = DataFormatSerializer.class) +@JsonDeserialize(using = DataFormatDeserializer.class) +public interface DataFormat { +} diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/MultiLayerConfiguration.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/MultiLayerConfiguration.java index b1ac6a968..1a61acc4d 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/MultiLayerConfiguration.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/MultiLayerConfiguration.java @@ -663,7 +663,7 @@ public class MultiLayerConfiguration implements Serializable, Cloneable { BaseRecurrentLayer brl = (BaseRecurrentLayer) firstLayer; val nIn = brl.getNIn(); if (nIn > 0) { - inputType = InputType.recurrent(nIn); + inputType = InputType.recurrent(nIn, brl.getRnnDataFormat()); } } else if (firstLayer instanceof DenseLayer || firstLayer instanceof EmbeddingLayer || firstLayer instanceof OutputLayer) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/RNNFormat.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/RNNFormat.java index 186b405e7..c12857178 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/RNNFormat.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/RNNFormat.java @@ -23,7 +23,7 @@ package org.deeplearning4j.nn.conf; * "width" corresponds to sequence length and "channels" corresponds to sequence item size. */ -public enum RNNFormat { +public enum RNNFormat implements DataFormat { NCW, NWC } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/MergeVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/MergeVertex.java index 77dd41c3a..726a68403 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/MergeVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/MergeVertex.java @@ -18,6 +18,8 @@ package org.deeplearning4j.nn.conf.graph; import lombok.val; +import org.deeplearning4j.nn.conf.CNN2DFormat; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException; import org.deeplearning4j.nn.conf.layers.Convolution3D; @@ -38,6 +40,8 @@ import org.nd4j.linalg.api.ndarray.INDArray; */ public class MergeVertex extends GraphVertex { + protected int mergeAxis = 1; //default value for backward compatibility (deserialization of old version JSON) - NCHW and NCW format + @Override public MergeVertex clone() { return new MergeVertex(); @@ -76,7 +80,7 @@ public class MergeVertex extends GraphVertex { @Override public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx, INDArray paramsView, boolean initializeParams, DataType networkDatatype) { - return new org.deeplearning4j.nn.graph.vertex.impl.MergeVertex(graph, name, idx, networkDatatype); + return new org.deeplearning4j.nn.graph.vertex.impl.MergeVertex(graph, name, idx, networkDatatype, mergeAxis); } @Override @@ -126,6 +130,7 @@ public class MergeVertex extends GraphVertex { //FF or RNN data inputs int size = 0; InputType.Type type = null; + RNNFormat format = null; for (int i = 0; i < vertexInputs.length; i++) { if (vertexInputs[i].getType() != first.getType()) { throw new InvalidInputTypeException( @@ -142,6 +147,8 @@ public class MergeVertex extends GraphVertex { break; case RNN: thisSize = ((InputType.InputTypeRecurrent) vertexInputs[i]).getSize(); + format = ((InputType.InputTypeRecurrent) vertexInputs[i]).getFormat(); + this.mergeAxis = format == RNNFormat.NCW ? 1 : 2; type = InputType.Type.RNN; break; default: @@ -160,7 +167,7 @@ public class MergeVertex extends GraphVertex { return InputType.feedForward(size); } else { val tsLength = ((InputType.InputTypeRecurrent) vertexInputs[0]).getTimeSeriesLength(); - return InputType.recurrent(size, tsLength); + return InputType.recurrent(size, tsLength, format); } } else { //size is unknown @@ -168,13 +175,14 @@ public class MergeVertex extends GraphVertex { return InputType.feedForward(-1); } else { val tsLength = ((InputType.InputTypeRecurrent) vertexInputs[0]).getTimeSeriesLength(); - return InputType.recurrent(-1, tsLength); + return InputType.recurrent(-1, tsLength, format); } } } else { //CNN inputs... also check that the channels, width and heights match: InputType.InputTypeConvolutional firstConv = (InputType.InputTypeConvolutional) first; + CNN2DFormat format = firstConv.getFormat(); val fd = firstConv.getChannels(); val fw = firstConv.getWidth(); @@ -206,7 +214,8 @@ public class MergeVertex extends GraphVertex { depthSum += od; } - return InputType.convolutional(fh, fw, depthSum); + this.mergeAxis = format == CNN2DFormat.NCHW ? 1 : 3; + return InputType.convolutional(fh, fw, depthSum, format); } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/inputs/InputType.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/inputs/InputType.java index 2c7a4e5f8..ce5e8c78f 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/inputs/InputType.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/inputs/InputType.java @@ -20,6 +20,7 @@ import lombok.Data; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.NoArgsConstructor; +import org.deeplearning4j.nn.conf.DataFormat; import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.layers.Convolution3D; @@ -91,7 +92,11 @@ public abstract class InputType implements Serializable { * @return InputTypeFeedForward */ public static InputType feedForward(long size) { - return new InputTypeFeedForward(size); + return new InputTypeFeedForward(size, null); + } + + public static InputType feedForward(long size, DataFormat timeDistributedFormat) { + return new InputTypeFeedForward(size,timeDistributedFormat); } /** @@ -132,7 +137,6 @@ public abstract class InputType implements Serializable { * @return InputTypeConvolutional */ public static InputType convolutional(long height, long width, long depth) { -// return new InputTypeConvolutional(height, width, depth); return convolutional(height, width, depth, CNN2DFormat.NCHW); } @@ -191,9 +195,11 @@ public abstract class InputType implements Serializable { @EqualsAndHashCode(callSuper = false) public static class InputTypeFeedForward extends InputType { private long size; + private DataFormat timeDistributedFormat; - public InputTypeFeedForward(@JsonProperty("size") long size) { + public InputTypeFeedForward(@JsonProperty("size") long size, @JsonProperty("timeDistributedFormat") DataFormat timeDistributedFormat) { this.size = size; + this.timeDistributedFormat = timeDistributedFormat; } @Override @@ -203,7 +209,7 @@ public abstract class InputType implements Serializable { @Override public String toString() { - return "InputTypeFeedForward(" + size + ")"; + return "InputTypeFeedForward(" + size + (timeDistributedFormat != null ? "," + timeDistributedFormat : "") + ")"; } @Override @@ -302,7 +308,8 @@ public abstract class InputType implements Serializable { this.height = height; this.width = width; this.channels = channels; - this.format = format; + if(format != null) + this.format = format; } public InputTypeConvolutional(long height, long width, long channels) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseRecurrentLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseRecurrentLayer.java index fda5aba83..0b98dfad9 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseRecurrentLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseRecurrentLayer.java @@ -64,11 +64,11 @@ public abstract class BaseRecurrentLayer extends FeedForwardLayer { + "\"): expect RNN input type with size > 0. Got: " + inputType); } + InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType; if (nIn <= 0 || override) { - InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType; this.nIn = r.getSize(); - this.rnnDataFormat = r.getFormat(); } + this.rnnDataFormat = r.getFormat(); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution1DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution1DLayer.java index 32b57f34b..f4d247670 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution1DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution1DLayer.java @@ -44,6 +44,7 @@ import java.util.Map; @ToString(callSuper = true) @EqualsAndHashCode(callSuper = true) public class Convolution1DLayer extends ConvolutionLayer { + private RNNFormat rnnDataFormat = RNNFormat.NCW; /* //TODO: We will eventually want to NOT subclass off of ConvolutionLayer. //Currently, we just subclass off the ConvolutionLayer and hard code the "width" dimension to 1 @@ -56,6 +57,7 @@ public class Convolution1DLayer extends ConvolutionLayer { private Convolution1DLayer(Builder builder) { super(builder); initializeConstraints(builder); + this.rnnDataFormat = builder.rnnDataFormat; } @Override @@ -92,7 +94,8 @@ public class Convolution1DLayer extends ConvolutionLayer { outLength = Convolution1DUtils.getOutputSize(inputTsLength, kernelSize[0], stride[0], padding[0], convolutionMode, dilation[0]); } - return InputType.recurrent(nOut, outLength); + + return InputType.recurrent(nOut, outLength, rnnDataFormat); } @Override @@ -102,10 +105,11 @@ public class Convolution1DLayer extends ConvolutionLayer { + "\"): expect RNN input type with size > 0. Got: " + inputType); } + InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType; if (nIn <= 0 || override) { - InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType; this.nIn = r.getSize(); } + this.rnnDataFormat = r.getFormat(); } @Override @@ -115,11 +119,13 @@ public class Convolution1DLayer extends ConvolutionLayer { + "\"): input is null"); } - return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, RNNFormat.NCW,getLayerName()); + return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, rnnDataFormat,getLayerName()); } public static class Builder extends ConvolutionLayer.BaseConvBuilder { + private RNNFormat rnnDataFormat = RNNFormat.NCW; + public Builder() { this(0, 1, 0); this.setKernelSize((int[]) null); @@ -130,6 +136,11 @@ public class Convolution1DLayer extends ConvolutionLayer { return true; } + + public Builder rnnDataFormat(RNNFormat rnnDataFormat){ + this.rnnDataFormat = rnnDataFormat; + return this; + } /** * @param kernelSize Kernel size * @param stride Stride diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingSequenceLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingSequenceLayer.java index 9b7725801..1b76b6c7b 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingSequenceLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingSequenceLayer.java @@ -21,6 +21,7 @@ import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; @@ -58,12 +59,14 @@ public class EmbeddingSequenceLayer extends FeedForwardLayer { private int inputLength = 1; // By default only use one index to embed private boolean hasBias = false; private boolean inferInputLength = false; // use input length as provided by input data + private RNNFormat outputFormat = RNNFormat.NCW; //Default value for older deserialized models private EmbeddingSequenceLayer(Builder builder) { super(builder); this.hasBias = builder.hasBias; this.inputLength = builder.inputLength; this.inferInputLength = builder.inferInputLength; + this.outputFormat = builder.outputFormat; initializeConstraints(builder); } @@ -87,7 +90,7 @@ public class EmbeddingSequenceLayer extends FeedForwardLayer { throw new IllegalStateException("Invalid input for Embedding layer (layer index = " + layerIndex + ", layer name = \"" + getLayerName() + "\"): expect FF/RNN input type. Got: " + inputType); } - return InputType.recurrent(nOut, inputLength); + return InputType.recurrent(nOut, inputLength, outputFormat); } @Override @@ -167,6 +170,13 @@ public class EmbeddingSequenceLayer extends FeedForwardLayer { */ private boolean inferInputLength = true; + private RNNFormat outputFormat = RNNFormat.NCW; //Default value for older deserialized models + + public Builder outputDataFormat(RNNFormat format){ + this.outputFormat = format; + return this; + } + /** * If true: include bias parameters in the layer. False (default): no bias. * diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/FeedForwardLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/FeedForwardLayer.java index 9b49f6415..20d3c926a 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/FeedForwardLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/FeedForwardLayer.java @@ -17,6 +17,7 @@ package org.deeplearning4j.nn.conf.layers; import lombok.*; +import org.deeplearning4j.nn.conf.DataFormat; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.preprocessor.Cnn3DToFeedForwardPreProcessor; @@ -35,6 +36,7 @@ public abstract class FeedForwardLayer extends BaseLayer { protected long nIn; protected long nOut; + protected DataFormat timeDistributedFormat; public FeedForwardLayer(Builder builder) { super(builder); @@ -51,7 +53,7 @@ public abstract class FeedForwardLayer extends BaseLayer { + getLayerName() + "\"): expected FeedForward input type. Got: " + inputType); } - return InputType.feedForward(nOut); + return InputType.feedForward(nOut, timeDistributedFormat); } @Override @@ -71,6 +73,11 @@ public abstract class FeedForwardLayer extends BaseLayer { this.nIn = f.getFlattenedSize(); } } + + if(inputType instanceof InputType.InputTypeFeedForward){ + InputType.InputTypeFeedForward f = (InputType.InputTypeFeedForward) inputType; + this.timeDistributedFormat = f.getTimeDistributedFormat(); + } } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/InputTypeUtil.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/InputTypeUtil.java index 06a71f01d..a60c3a6bc 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/InputTypeUtil.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/InputTypeUtil.java @@ -536,11 +536,17 @@ public class InputTypeUtil { } switch (inputType.getType()) { - case FF: case CNNFlat: //FF -> RNN or CNNFlat -> RNN //In either case, input data format is a row vector per example return new FeedForwardToRnnPreProcessor(rnnDataFormat); + case FF: + //If time distributed format is defined, use that. Otherwise use the layer-defined rnnDataFormat, which may be default + InputType.InputTypeFeedForward ff = (InputType.InputTypeFeedForward)inputType; + if(ff.getTimeDistributedFormat() != null && ff.getTimeDistributedFormat() instanceof RNNFormat){ + return new FeedForwardToRnnPreProcessor((RNNFormat) ff.getTimeDistributedFormat()); + } + return new FeedForwardToRnnPreProcessor(rnnDataFormat); case RNN: //RNN -> RNN: No preprocessor necessary return null; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnOutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnOutputLayer.java index ec0ecf59c..cfd337514 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnOutputLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnOutputLayer.java @@ -98,9 +98,9 @@ public class RnnOutputLayer extends BaseOutputLayer { + "\"): Expected RNN input, got " + inputType); } + InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType; + this.rnnDataFormat = r.getFormat(); if (nIn <= 0 || override) { - InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType; - this.rnnDataFormat = r.getFormat(); this.nIn = r.getSize(); } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling1DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling1DLayer.java index 4b2b959c2..5d2a55994 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling1DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling1DLayer.java @@ -91,7 +91,7 @@ public class Subsampling1DLayer extends SubsamplingLayer { outLength = Convolution1DUtils.getOutputSize(inputTsLength, kernelSize[0], stride[0], padding[0], convolutionMode, dilation[0]); } - return InputType.recurrent(r.getSize(), outLength); + return InputType.recurrent(r.getSize(), outLength, r.getFormat()); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/RepeatVector.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/RepeatVector.java index 424130f17..e7f252c4b 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/RepeatVector.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/RepeatVector.java @@ -19,6 +19,7 @@ package org.deeplearning4j.nn.conf.layers.misc; import lombok.*; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; @@ -46,10 +47,12 @@ import java.util.Map; public class RepeatVector extends FeedForwardLayer { private int n = 1; + private RNNFormat dataFormat = RNNFormat.NCW; protected RepeatVector(Builder builder) { super(builder); this.n = builder.n; + this.dataFormat = builder.dataFormat; } @Override @@ -83,7 +86,7 @@ public class RepeatVector extends FeedForwardLayer { + "\"): Expected FF input, got " + inputType); } InputType.InputTypeFeedForward ffInput = (InputType.InputTypeFeedForward) inputType; - return InputType.recurrent(ffInput.getSize(), n); + return InputType.recurrent(ffInput.getSize(), n, this.dataFormat); } @Override @@ -101,13 +104,14 @@ public class RepeatVector extends FeedForwardLayer { } + @NoArgsConstructor @Getter @Setter public static class Builder> extends FeedForwardLayer.Builder { private int n = 1; // no repetition by default - + private RNNFormat dataFormat = RNNFormat.NCW; /** * Set repetition factor for RepeatVector layer */ @@ -115,6 +119,15 @@ public class RepeatVector extends FeedForwardLayer { return n; } + public RNNFormat getDataFormat(){ + return dataFormat; + } + + public Builder dataFormat(RNNFormat dataFormat){ + this.dataFormat = dataFormat; + return this; + } + /** * Set repetition factor for RepeatVector layer * diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/CnnToFeedForwardPreProcessor.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/CnnToFeedForwardPreProcessor.java index 681d1f3f9..a9f82ba32 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/CnnToFeedForwardPreProcessor.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/CnnToFeedForwardPreProcessor.java @@ -39,11 +39,13 @@ import java.util.Arrays; * For example, CNN -> Denselayer
* This does two things:
* (b) Reshapes 4d activations out of CNN layer, with shape - * [numExamples, numChannels, inputHeight, inputWidth]) into 2d activations (with shape - * [numExamples, inputHeight*inputWidth*numChannels]) for use in feed forward layer + * [numExamples, numChannels, inputHeight, inputWidth]) (for {@link CNN2DFormat#NCHW} format activations) or shape + * [numExamples, inputHeight, inputWidth, numChannels] (for {@link CNN2DFormat#NHWC}) format activations) into 2d activations + * (with shape [numExamples, inputHeight*inputWidth*numChannels]) for use in feed forward layer. * (a) Reshapes epsilons (weights*deltas) out of FeedFoward layer (which is 2D or 3D with shape * [numExamples, inputHeight*inputWidth*numChannels]) into 4d epsilons (with shape - * [numExamples, numChannels, inputHeight, inputWidth]) suitable to feed into CNN layers.
+ * [numExamples, numChannels, inputHeight, inputWidth] or [numExamples, inputHeight, inputWidth, numChannels]) suitable to + * feed into CNN layers.
* Note: numChannels is equivalent to channels or featureMaps referenced in different literature * @author Adam Gibson * @see FeedForwardToCnnPreProcessor for opposite case (i.e., DenseLayer -> CNNetc) @@ -68,7 +70,8 @@ public class CnnToFeedForwardPreProcessor implements InputPreProcessor { this.inputHeight = inputHeight; this.inputWidth = inputWidth; this.numChannels = numChannels; - this.format = format; + if(format != null) + this.format = format; } public CnnToFeedForwardPreProcessor(long inputHeight, long inputWidth) { @@ -96,10 +99,17 @@ public class CnnToFeedForwardPreProcessor implements InputPreProcessor { wDim = 2; } + if(inputHeight == 0 && inputWidth == 0 && numChannels == 0){ + this.inputHeight = input.size(hDim); + this.inputWidth = input.size(wDim); + this.numChannels = input.size(chDim); + } + if(input.size(chDim) != numChannels || input.size(hDim) != inputHeight || input.size(wDim) != inputWidth){ - throw new IllegalStateException("Invalid input, does not match configuration: expected [minibatch, numChannels=" - + numChannels + ", inputHeight=" + inputHeight + ", inputWidth=" + inputWidth + "] but got input array of" + - "shape " + Arrays.toString(input.shape())); + throw new IllegalStateException("Invalid input, does not match configuration: expected " + + (format == CNN2DFormat.NCHW ? "[minibatch, numChannels=" + numChannels + ", inputHeight=" + inputHeight + ", inputWidth=" + inputWidth + "] " : + "[minibatch, inputHeight=" + inputHeight + ", inputWidth=" + inputWidth + ", numChannels=" + numChannels + "]") + + " but got input array of shape " + Arrays.toString(input.shape())); } //Check input: nchw format @@ -110,15 +120,13 @@ public class CnnToFeedForwardPreProcessor implements InputPreProcessor { + Arrays.toString(input.shape())); } - if(format == CNN2DFormat.NHWC) { - input = input.permute(0, 3, 1, 2); //NHWC to NCHW - } - //Assume input is standard rank 4 activations out of CNN layer //First: we require input to be in c order. But c order (as declared in array order) isn't enough; also need strides to be correct if (input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input)) input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c'); + //Note that to match Tensorflow/Keras, we do a simple "c order reshape" for both NCHW and NHWC + val inShape = input.shape(); //[miniBatch,depthOut,outH,outW] val outShape = new long[]{inShape[0], inShape[1] * inShape[2] * inShape[3]}; @@ -139,11 +147,13 @@ public class CnnToFeedForwardPreProcessor implements InputPreProcessor { + inputHeight + " x columns " + inputWidth + " x channels " + numChannels + " but was instead " + Arrays.toString(epsilons.shape())); - INDArray ret = epsilons.reshape('c', epsilons.size(0), numChannels, inputHeight, inputWidth); - - if(format == CNN2DFormat.NHWC){ - ret = ret.permute(0,2,3,1); //NCHW to NHWC + INDArray ret; + if(format == CNN2DFormat.NCHW){ + ret = epsilons.reshape('c', epsilons.size(0), numChannels, inputHeight, inputWidth); + } else { + ret = epsilons.reshape('c', epsilons.size(0), inputHeight, inputWidth, numChannels); } + return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, ret); //Move if required to specified workspace } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/FeedForwardToRnnPreProcessor.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/FeedForwardToRnnPreProcessor.java index 7da79b935..d0b01698e 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/FeedForwardToRnnPreProcessor.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/FeedForwardToRnnPreProcessor.java @@ -52,7 +52,8 @@ public class FeedForwardToRnnPreProcessor implements InputPreProcessor { private RNNFormat rnnDataFormat = RNNFormat.NCW; public FeedForwardToRnnPreProcessor(@JsonProperty("rnnDataFormat") RNNFormat rnnDataFormat){ - this.rnnDataFormat = rnnDataFormat; + if(rnnDataFormat != null) + this.rnnDataFormat = rnnDataFormat; } @Override public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/RnnToFeedForwardPreProcessor.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/RnnToFeedForwardPreProcessor.java index 125aaf78b..f3e9323a5 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/RnnToFeedForwardPreProcessor.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/RnnToFeedForwardPreProcessor.java @@ -57,7 +57,8 @@ public class RnnToFeedForwardPreProcessor implements InputPreProcessor { private RNNFormat rnnDataFormat = RNNFormat.NCW; public RnnToFeedForwardPreProcessor(@JsonProperty("rnnDataFormat") RNNFormat rnnDataFormat){ - this.rnnDataFormat = rnnDataFormat; + if(rnnDataFormat != null) + this.rnnDataFormat = rnnDataFormat; } @Override public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) { @@ -116,7 +117,7 @@ public class RnnToFeedForwardPreProcessor implements InputPreProcessor { } InputType.InputTypeRecurrent rnn = (InputType.InputTypeRecurrent) inputType; - return InputType.feedForward(rnn.getSize()); + return InputType.feedForward(rnn.getSize(), rnn.getFormat()); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/format/DataFormatDeserializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/format/DataFormatDeserializer.java new file mode 100644 index 000000000..ec50a5edb --- /dev/null +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/format/DataFormatDeserializer.java @@ -0,0 +1,52 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.nn.conf.serde.format; + +import org.deeplearning4j.nn.conf.CNN2DFormat; +import org.deeplearning4j.nn.conf.DataFormat; +import org.deeplearning4j.nn.conf.RNNFormat; +import org.nd4j.shade.jackson.core.JsonParser; +import org.nd4j.shade.jackson.core.JsonProcessingException; +import org.nd4j.shade.jackson.databind.DeserializationContext; +import org.nd4j.shade.jackson.databind.JsonDeserializer; +import org.nd4j.shade.jackson.databind.JsonNode; + +import java.io.IOException; + +/** + * Simple JSON deserializer for {@link DataFormat} instances - {@link CNN2DFormat} and {@link RNNFormat} + * + * @author Alex Black + */ +public class DataFormatDeserializer extends JsonDeserializer { + @Override + public DataFormat deserialize(JsonParser jp, DeserializationContext deserializationContext) throws IOException, JsonProcessingException { + JsonNode node = jp.getCodec().readTree(jp); + String text = node.textValue(); + switch (text){ + case "NCHW": + return CNN2DFormat.NCHW; + case "NHWC": + return CNN2DFormat.NHWC; + case "NCW": + return RNNFormat.NCW; + case "NWC": + return RNNFormat.NWC; + default: + return null; + } + } +} diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/format/DataFormatSerializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/format/DataFormatSerializer.java new file mode 100644 index 000000000..9abe90d38 --- /dev/null +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/format/DataFormatSerializer.java @@ -0,0 +1,37 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.nn.conf.serde.format; + +import org.deeplearning4j.nn.conf.CNN2DFormat; +import org.deeplearning4j.nn.conf.DataFormat; +import org.deeplearning4j.nn.conf.RNNFormat; +import org.nd4j.shade.jackson.core.JsonGenerator; +import org.nd4j.shade.jackson.databind.JsonSerializer; +import org.nd4j.shade.jackson.databind.SerializerProvider; + +import java.io.IOException; + +/** + * Simple JSON deserializer for {@link DataFormat} instances - {@link CNN2DFormat} and {@link RNNFormat} + * + * @author Alex Black + */ +public class DataFormatSerializer extends JsonSerializer { + @Override + public void serialize(DataFormat dataFormat, JsonGenerator jsonGenerator, SerializerProvider serializerProvider) throws IOException { + jsonGenerator.writeString(dataFormat.toString()); + } +} diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/MergeVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/MergeVertex.java index cb58a9813..4f864d0cc 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/MergeVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/MergeVertex.java @@ -28,6 +28,7 @@ import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or; import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.primitives.Pair; import org.deeplearning4j.nn.workspace.ArrayType; @@ -48,14 +49,16 @@ public class MergeVertex extends BaseGraphVertex { private long[][] forwardPassShapes; private int fwdPassRank; + private int mergeAxis; - public MergeVertex(ComputationGraph graph, String name, int vertexIndex, DataType dataType) { - this(graph, name, vertexIndex, null, null, dataType); + public MergeVertex(ComputationGraph graph, String name, int vertexIndex, DataType dataType, int mergeAxis) { + this(graph, name, vertexIndex, null, null, dataType, mergeAxis); } public MergeVertex(ComputationGraph graph, String name, int vertexIndex, VertexIndices[] inputVertices, - VertexIndices[] outputVertices, DataType dataType) { + VertexIndices[] outputVertices, DataType dataType, int mergeAxis) { super(graph, name, vertexIndex, inputVertices, outputVertices, dataType); + this.mergeAxis = mergeAxis; } @Override @@ -92,7 +95,6 @@ public class MergeVertex extends BaseGraphVertex { forwardPassShapes = new long[in.length][0]; val nExamples = in[0].size(0); - int nOut = 0; fwdPassRank = in[0].rank(); for (int i = 0; i < in.length; i++) { val currShape = in[i].shape(); @@ -109,12 +111,11 @@ public class MergeVertex extends BaseGraphVertex { + Arrays.toString(in[0].shape()) + ", activations[" + i + "] shape: " + Arrays.toString(in[i].shape())); } - - nOut += currShape[1]; //Same dimension for all of CNNs, FF, RNNs } try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATIONS)){ - return Nd4j.concat(1, in); + INDArray out = Nd4j.concat(mergeAxis, in); + return out; } } @@ -145,20 +146,16 @@ public class MergeVertex extends BaseGraphVertex { break; case 3: for (int i = 0; i < forwardPassShapes.length; i++) { - out[i].assign(epsilon.get(NDArrayIndex.all(), //All rows - NDArrayIndex.interval(cumulative, cumulative + forwardPassShapes[i][1]), //subset of columns - NDArrayIndex.all())); //All time steps + out[i].assign(epsilon.get(indices(3, mergeAxis, cumulative, cumulative + forwardPassShapes[i][mergeAxis]))); //All time steps - cumulative += forwardPassShapes[i][1]; + cumulative += forwardPassShapes[i][mergeAxis]; } break; case 4: for (int i = 0; i < forwardPassShapes.length; i++) { - out[i].assign(epsilon.get(NDArrayIndex.all(), - NDArrayIndex.interval(cumulative, cumulative + forwardPassShapes[i][1]), //Subset of depth - NDArrayIndex.all(), //Width - NDArrayIndex.all())); //height - cumulative += forwardPassShapes[i][1]; + out[i].assign(epsilon.get(indices(4, mergeAxis, cumulative, cumulative + forwardPassShapes[i][mergeAxis]))); //height + + cumulative += forwardPassShapes[i][mergeAxis]; } break; default: @@ -168,6 +165,19 @@ public class MergeVertex extends BaseGraphVertex { return new Pair<>(null, out); } + private INDArrayIndex[] indices(int num, int axis, long from, long to){ + INDArrayIndex[] out = new INDArrayIndex[num]; + for( int i=0; i(retGradient, epsOut); } @@ -140,7 +150,10 @@ public class Convolution1DLayer extends ConvolutionLayer { // remove singleton fourth dimension from input and current epsilon epsNext = epsNext.reshape(epsNext.size(0), epsNext.size(1), epsNext.size(2)); input = origInput; - + if (getRnnDataFormat() == RNNFormat.NWC){ + epsNext = epsNext.permute(0, 2, 1); + this.input = input.permute(0, 2, 1); + } return new Pair<>(gradientEpsNext.getFirst(), epsNext); } @@ -185,7 +198,8 @@ public class Convolution1DLayer extends ConvolutionLayer { .s(c.getStride()[0]) .d(c.getDilation()[0]) .p(c.getPadding()[0]) - .dataFormat(Conv1DConfig.NCW) + .dataFormat((((org.deeplearning4j.nn.conf.layers.Convolution1DLayer) + layerConf()).getRnnDataFormat()== RNNFormat.NCW)?Conv1DConfig.NCW: Conv1DConfig.NCW) .paddingMode(PaddingMode.CAUSAL) .build(); INDArray w = getParam(ConvolutionParamInitializer.WEIGHT_KEY); @@ -209,6 +223,9 @@ public class Convolution1DLayer extends ConvolutionLayer { @Override public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr){ + if (getRnnDataFormat() == RNNFormat.NWC){ + this.input = input.permute(0, 2, 1); + } INDArray act4d = super.activate(training, workspaceMgr); INDArray act3d = act4d.reshape(act4d.size(0), act4d.size(1), act4d.size(2)); @@ -219,6 +236,10 @@ public class Convolution1DLayer extends ConvolutionLayer { act3d.shape(), maskOut.shape()); Broadcast.mul(act3d, maskOut, act3d, 0, 2); } + if (getRnnDataFormat() == RNNFormat.NWC){ + this.input = input.permute(0, 2, 1); + act3d = act3d.permute(0, 2, 1); + } return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, act3d); //Should be zero copy most of the time } @@ -231,4 +252,8 @@ public class Convolution1DLayer extends ConvolutionLayer { layerConf().getConvolutionMode()); return new Pair<>(reduced, currentMaskState); } + + private RNNFormat getRnnDataFormat(){ + return ((org.deeplearning4j.nn.conf.layers.Convolution1DLayer) layerConf()).getRnnDataFormat(); + } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayer.java index ddaefeaa4..2b8a5148b 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayer.java @@ -160,7 +160,8 @@ public class ConvolutionLayer extends BaseLayer p = preOutput4d(true, true, workspaceMgr); INDArray z = p.getFirst(); - if(layerConf().getCnn2dDataFormat() != CNN2DFormat.NCHW){ + CNN2DFormat f = layerConf().getCnn2dDataFormat(); + if(f != CNN2DFormat.NCHW){ z = z.permute(0,3,1,2); //NHWC to NCHW } delta = afn.backprop(z, epsilon).getFirst(); //TODO handle activation function params diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingSequenceLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingSequenceLayer.java index 2987d4be2..96c67caf7 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingSequenceLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingSequenceLayer.java @@ -20,6 +20,7 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.BaseLayer; @@ -64,8 +65,14 @@ public class EmbeddingSequenceLayer extends BaseLayer [minibatch, nOut, seqLen] i.e., NWC -> NCW + } return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, ret); } @@ -177,8 +190,14 @@ public class EmbeddingSequenceLayer extends BaseLayer backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { long[] newEpsShape = origOutputShape; - boolean nwc = (underlying instanceof BaseRecurrentLayer && - ((BaseRecurrentLayer) underlying).getDataFormat() == RNNFormat.NWC)|| - (underlying instanceof MaskZeroLayer && ((MaskZeroLayer)underlying).getUnderlying() instanceof - BaseRecurrentLayer && ((BaseRecurrentLayer)((MaskZeroLayer)underlying).getUnderlying()).getDataFormat() - == RNNFormat.NWC); + + boolean nwc = TimeSeriesUtils.getFormatFromRnnLayer(underlying.conf().getLayer()) == RNNFormat.NWC; INDArray newEps = Nd4j.create(epsilon.dataType(), newEpsShape, 'f'); if(lastTimeStepIdxs == null){ //no mask case diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnOutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnOutputLayer.java index 85d5a9839..04f8ccce3 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnOutputLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnOutputLayer.java @@ -58,7 +58,8 @@ public class RnnOutputLayer extends BaseOutputLayer