tf.keras import test and fixes (#347)

* merge conf

* merge conf

* tfkeras tests

* parameterized tests

* rename

* cuda versions

* jccp versions

* 'updates'

* updates

* rnn+mlp passing

* repeat

* updates

* tests

* Update pom.xml

* Update pom.xml

* rem print

* cnn1d model conversion fixed

* cnn1d activate fixed

* cnn1d outptut shape fix

* cnn1d bprop fix

* cnn1d stack fix

* KerasModelEndToEndTest - Remove permutes for NWC and NHWC format tests

Signed-off-by: Alex Black <blacka101@gmail.com>

* Fixes and update test - input shapes (NCHW -> NHWC input)

Signed-off-by: Alex Black <blacka101@gmail.com>

* Ignore for known bad tests

Signed-off-by: Alex Black <blacka101@gmail.com>

* Multiple fixes - MergeVertex, CNN1D layers, etc

Signed-off-by: Alex Black <blacka101@gmail.com>

* Fix issue with RNN/FF preprocessors, time distributed etc with NWC format

Signed-off-by: Alex Black <blacka101@gmail.com>

* LSTM NWC dropout fix

Signed-off-by: Alex Black <blacka101@gmail.com>

* Add sequence embedding layer NWC support (configurable output format)

Signed-off-by: Alex Black <blacka101@gmail.com>

* Fix expected shape in a couple of tests - NWC expected

Signed-off-by: Alex Black <blacka101@gmail.com>

* Fix EmbeddingSequenceLayer backprop for NWC output case + add gradient checks

Signed-off-by: Alex Black <blacka101@gmail.com>

* CnnToFeedForwardPreprocessor: align with Keras/TF; fix Keras reshape/flatten

Signed-off-by: Alex Black <blacka101@gmail.com>

* Update ConvDataFormatTests to match new reshape behaviour

Signed-off-by: Alex Black <blacka101@gmail.com>

* Switch hard-coded path to ResourceUtils.listClassPathfiles for TestTFKerasModelImport

Signed-off-by: Alex Black <blacka101@gmail.com>

* TestUtils fix

Signed-off-by: Alex Black <blacka101@gmail.com>

* Fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* Fix JSON serde issue with data formats

Signed-off-by: Alex Black <blacka101@gmail.com>

* Fix for input dtype inference; fix 2 tests

Signed-off-by: Alex Black <blacka101@gmail.com>

* Test fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* #8891 Ignore for TestVertxUIMultiSession until fixed

Signed-off-by: Alex Black <blacka101@gmail.com>

* Restore but deprecate TensorFlowCnnToFeedForwardPreProcessor for older zoo models

Signed-off-by: Alex Black <blacka101@gmail.com>

* Ignore for deprecated preprocessor in DTypeTests

Signed-off-by: Alex Black <blacka101@gmail.com>

* Remove debug printlns

Signed-off-by: Alex Black <blacka101@gmail.com>

Co-authored-by: Alex Black <blacka101@gmail.com>
master
Fariz Rahman 2020-04-28 14:31:09 +04:00 committed by GitHub
parent b9d5f1645b
commit 4cb87a94e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
67 changed files with 1336 additions and 438 deletions

View File

@ -20,6 +20,7 @@ package org.datavec.python;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.bytedeco.cpython.global.python;
import org.bytedeco.numpy.global.numpy;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -343,6 +344,19 @@ public class PythonExecutioner {
if (path == null) {
log.info("Setting python default path");
File[] packages = numpy.cachePackages();
//// TODO: fix in javacpp
File sitePackagesWindows = new File(python.cachePackage(), "site-packages");
File[] packages2 = new File[packages.length + 1];
for (int i = 0;i < packages.length; i++){
//System.out.println(packages[i].getAbsolutePath());
packages2[i] = packages[i];
}
packages2[packages.length] = sitePackagesWindows;
//System.out.println(sitePackagesWindows.getAbsolutePath());
packages = packages2;
//////////
Py_SetPath(packages);
} else {
log.info("Setting python path " + path);

View File

@ -0,0 +1,132 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.python;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.bytedeco.javacpp.Loader;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
@Slf4j
public class PythonProcess {
private static String pythonExecutable = Loader.load(org.bytedeco.cpython.python.class);
public static String runAndReturn(String... arguments)throws IOException, InterruptedException{
String[] allArgs = new String[arguments.length + 1];
for (int i = 0; i < arguments.length; i++){
allArgs[i + 1] = arguments[i];
}
allArgs[0] = pythonExecutable;
log.info("Executing command: " + Arrays.toString(allArgs));
ProcessBuilder pb = new ProcessBuilder(allArgs);
Process process = pb.start();
String out = IOUtils.toString(process.getInputStream(), StandardCharsets.UTF_8);
process.waitFor();
return out;
}
public static void run(String... arguments)throws IOException, InterruptedException{
String[] allArgs = new String[arguments.length + 1];
for (int i = 0; i < arguments.length; i++){
allArgs[i + 1] = arguments[i];
}
allArgs[0] = pythonExecutable;
log.info("Executing command: " + Arrays.toString(allArgs));
ProcessBuilder pb = new ProcessBuilder(allArgs);
pb.inheritIO().start().waitFor();
}
public static void pipInstall(String packageName) throws PythonException{
try{
run("-m", "pip", "install", packageName);
}catch(Exception e){
throw new PythonException("Error installing package " + packageName, e);
}
}
public static void pipInstall(String packageName, String version) throws PythonException{
pipInstall(packageName + "==" + version);
}
public static void pipUninstall(String packageName) throws PythonException{
try{
run("-m", "pip", "uninstall", packageName);
}catch(Exception e){
throw new PythonException("Error uninstalling package " + packageName, e);
}
}
public static void pipInstallFromGit(String gitRepoUrl) throws PythonException{
if (!gitRepoUrl.contains("://")){
gitRepoUrl = "git://" + gitRepoUrl;
}
try{
run("-m", "pip", "install", "git+", gitRepoUrl);
}catch(Exception e){
throw new PythonException("Error installing package from " + gitRepoUrl, e);
}
}
public static String getPackageVersion(String packageName) throws PythonException{
String out;
try{
out = runAndReturn("-m", "pip", "show", packageName);
} catch (Exception e){
throw new PythonException("Error finding version for package " + packageName, e);
}
if (!out.contains("Version: ")){
throw new PythonException("Can't find package " + packageName);
}
String pkgVersion = out.split("Version: ")[1].split(System.lineSeparator())[0];
return pkgVersion;
}
public static boolean isPackageInstalled(String packageName)throws PythonException{
try{
String out = runAndReturn("-m", "pip", "show", packageName);
return !out.isEmpty();
}catch (Exception e){
throw new PythonException("Error checking if package is installed: " +packageName, e);
}
}
public static void pipInstallFromRequirementsTxt(String path) throws PythonException{
try{
run("-m", "pip", "install","-r", path);
}catch (Exception e){
throw new PythonException("Error installing packages from " + path, e);
}
}
public static void pipInstallFromSetupScript(String path, boolean inplace) throws PythonException{
try{
run(path, inplace?"develop":"install");
}catch (Exception e){
throw new PythonException("Error installing package from " + path, e);
}
}
}

View File

@ -0,0 +1,144 @@
package org.datavec.python.keras;
import org.datavec.python.Python;
import org.datavec.python.PythonException;
import org.datavec.python.PythonObject;
import org.datavec.python.PythonProcess;
import org.nd4j.linalg.api.ndarray.INDArray;
public class Model {
private PythonObject pyModel;
private static PythonObject installAndImportTF() throws PythonException{
if (!PythonProcess.isPackageInstalled("tensorflow")){
PythonProcess.pipInstall("tensorflow");
}
return Python.importModule("tensorflow");
}
private static PythonObject getKerasModule() throws PythonException{
PythonObject tf = installAndImportTF();
PythonObject keras = tf.attr("keras");
tf.del();
return keras;
}
private static PythonObject loadModel(String s) throws PythonException{
PythonObject models = getKerasModule().attr("models");
PythonObject loadModelF = models.attr("load_model");
PythonObject model = loadModelF.call(s);
models.del();
loadModelF.del();
return model;
}
public Model(String path) throws PythonException{
pyModel = loadModel(path);
}
public INDArray[] predict(INDArray... inputs) throws PythonException{
PythonObject predictF = pyModel.attr("predict");
PythonObject inputList = new PythonObject(inputs);
PythonObject pyOut = predictF.call(inputList);
INDArray[] out;
if (Python.isinstance(pyOut, Python.listType())){
out = new INDArray[Python.len(pyOut).toInt()];
for(int i = 0; i < out.length; i++){
out[i] = pyOut.get(i).toNumpy().getNd4jArray();
}
}
else{
out = new INDArray[]{
pyOut.toNumpy().getNd4jArray()};
}
predictF.del();
inputList.del();
pyOut.del();
return out;
}
public int numInputs(){
PythonObject inputs = pyModel.attr("inputs");
PythonObject pyNumInputs = Python.len(inputs);
int ret = pyNumInputs.toInt();
inputs.del();
pyNumInputs.del();
return ret;
}
public int numOutputs(){
PythonObject outputs = pyModel.attr("outputs");
PythonObject pyNumOutputs = Python.len(outputs);
int ret = pyNumOutputs.toInt();
outputs.del();
pyNumOutputs.del();
return ret;
}
public long[][] inputShapes(){
long[][] ret = new long[numInputs()][];
for (int i = 0; i < ret.length; i++){
ret[i] = inputShapeAt(i);
}
return ret;
}
public long[][] outputShapes(){
long[][] ret = new long[numOutputs()][];
for (int i = 0; i < ret.length; i++){
ret[i] = outputShapeAt(i);
}
return ret;
}
public long[] inputShapeAt(int input){
PythonObject inputs = pyModel.attr("inputs");
PythonObject tensor = inputs.get(input);
PythonObject tensorShape = tensor.attr("shape");
PythonObject shapeList = Python.list(tensorShape);
PythonObject pyNdim = Python.len(shapeList);
int ndim = pyNdim.toInt();
long[] shape = new long[ndim];
for(int i = 0; i < shape.length; i++){
PythonObject pyDim = shapeList.get(i);
if (pyDim == null || !Python.isinstance(pyDim, Python.intType())){
shape[i] = -1;
}
else{
shape[i] = pyDim.toLong();
}
}
pyNdim.del();
shapeList.del();
tensorShape.del();
tensor.del();
inputs.del();
return shape;
}
public long[] outputShapeAt(int output){
PythonObject inputs = pyModel.attr("outputs");
PythonObject tensor = inputs.get(output);
PythonObject tensorShape = tensor.attr("shape");
PythonObject shapeList = Python.list(tensorShape);
PythonObject pyNdim = Python.len(shapeList);
int ndim = pyNdim.toInt();
long[] shape = new long[ndim];
for(int i = 0; i < shape.length; i++){
PythonObject pyDim = shapeList.get(i);
if (pyDim == null || !Python.isinstance(pyDim, Python.intType())){
shape[i] = -1;
}
else{
shape[i] = pyDim.toLong();
}
}
pyNdim.del();
shapeList.del();
tensorShape.del();
tensor.del();
inputs.del();
return shape;
}
}

View File

@ -20,6 +20,7 @@ import org.apache.commons.compress.utils.IOUtils;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
@ -154,10 +155,21 @@ public class TestUtils {
}
public static INDArray randomOneHotTimeSeries(int minibatch, int outSize, int tsLength, Random rng) {
INDArray out = Nd4j.create(new int[]{minibatch, outSize, tsLength}, 'f');
return randomOneHotTimeSeries(RNNFormat.NCW, minibatch, outSize, tsLength, rng);
}
public static INDArray randomOneHotTimeSeries(RNNFormat format, int minibatch, int outSize, int tsLength, Random rng){
boolean ncw = format == RNNFormat.NCW;
long[] shape = ncw ? new long[]{minibatch, outSize, tsLength} : new long[]{minibatch, tsLength, outSize};
char order = ncw ? 'f' : 'c';
INDArray out = Nd4j.create(DataType.FLOAT, shape, order);
for( int i=0; i<minibatch; i++ ){
for( int j=0; j<tsLength; j++ ){
if(ncw){
out.putScalar(i, rng.nextInt(outSize), j, 1.0);
} else {
out.putScalar(i, j, rng.nextInt(outSize), 1.0);
}
}
}
return out;

View File

@ -24,6 +24,7 @@ import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
import org.deeplearning4j.nn.conf.layers.*;
@ -560,6 +561,7 @@ public class GradientCheckTests extends BaseDL4JTest {
public void testEmbeddingSequenceLayer(){
Nd4j.getRandom().setSeed(12345);
for(RNNFormat seqOutputFormat : RNNFormat.values()) {
for (boolean maskArray : new boolean[]{false, true}) {
for (int inputRank : new int[]{2, 3}) {
@ -572,16 +574,20 @@ public class GradientCheckTests extends BaseDL4JTest {
.layer(new EmbeddingSequenceLayer.Builder()
.nIn(8)
.nOut(4)
.outputDataFormat(seqOutputFormat)
.build())
.layer(new RnnOutputLayer.Builder().nIn(4).nOut(3).activation(Activation.TANH)
.dataFormat(seqOutputFormat)
.lossFunction(LossFunction.MSE).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
boolean ncw = seqOutputFormat == RNNFormat.NCW;
INDArray in = Transforms.floor(Nd4j.rand(3, 6).muli(8)); //Integers 0 to 7 inclusive
INDArray label = Nd4j.rand(new int[]{3, 3, 6});
INDArray label = Nd4j.rand(DataType.FLOAT, ncw ? new int[]{3, 3, 6} : new int[]{3,6,3});
if (inputRank == 3) {
//Reshape from [3,6] to [3,1,6]
@ -633,6 +639,7 @@ public class GradientCheckTests extends BaseDL4JTest {
}
}
}
}
@Test

View File

@ -21,9 +21,7 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.distribution.GaussianDistribution;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
@ -341,6 +339,10 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
@Test
public void testCnnDepthMerge() {
for(CNN2DFormat format : CNN2DFormat.values()) {
String msg = "testCnnDepthMerge - " + format;
Nd4j.getRandom().setSeed(12345);
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
.dataType(DataType.DOUBLE)
@ -358,20 +360,20 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
.build(),
"merge")
.setOutputs("outputLayer")
.inputPreProcessor("outputLayer", new CnnToFeedForwardPreProcessor(5, 5, 4))
.setInputTypes(InputType.convolutional(6, 6, 2, format))
.build();
ComputationGraph graph = new ComputationGraph(conf);
graph.init();
Random r = new Random(12345);
INDArray input = Nd4j.rand(new int[] {5, 2, 6, 6}); //Order: examples, channels, height, width
INDArray input = Nd4j.rand(DataType.DOUBLE, format == CNN2DFormat.NCHW ? new long[]{5,2,6,6} : new long[]{5,6,6,2});
INDArray labels = Nd4j.zeros(5, 3);
for (int i = 0; i < 5; i++)
labels.putScalar(new int[]{i, r.nextInt(3)}, 1.0);
if (PRINT_RESULTS) {
System.out.println("testCnnDepthMerge()");
System.out.println(msg);
// for (int j = 0; j < graph.getNumLayers(); j++)
// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
}
@ -379,14 +381,18 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
.labels(new INDArray[]{labels}));
String msg = "testCnnDepthMerge()";
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph);
}
}
@Test
public void testRNNWithMerging() {
for(RNNFormat format : RNNFormat.values()) {
String msg = "testLSTMWithMerging - " + format;
Nd4j.getRandom().setSeed(12345);
ComputationGraphConfiguration conf =
new NeuralNetConfiguration.Builder().seed(12345)
@ -416,19 +422,18 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
.activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build(),
"merge")
.inputPreProcessor("dense1", new RnnToFeedForwardPreProcessor())
.inputPreProcessor("lstm3", new FeedForwardToRnnPreProcessor())
.setInputTypes(InputType.recurrent(4, format))
.build();
ComputationGraph graph = new ComputationGraph(conf);
graph.init();
Random r = new Random(12345);
INDArray input = Nd4j.rand(new int[] {2, 3, 4});
INDArray labels = TestUtils.randomOneHotTimeSeries(2, 3, 4);
INDArray input = Nd4j.rand(DataType.DOUBLE, format == RNNFormat.NCW ? new long[]{2, 3, 4} : new long[]{2,4,3});
INDArray labels = TestUtils.randomOneHotTimeSeries(format, 2, 3, 4, new Random(12345));
if (PRINT_RESULTS) {
System.out.println("testLSTMWithMerging()");
System.out.println(msg);
// for (int j = 0; j < graph.getNumLayers(); j++)
// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
}
@ -436,9 +441,10 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
.labels(new INDArray[]{labels}));
String msg = "testLSTMWithMerging()";
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph);
}
}
@Test

View File

@ -17,7 +17,9 @@
package org.deeplearning4j.nn.dtypes;
import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
import org.deeplearning4j.nn.conf.preprocessor.*;
import org.deeplearning4j.nn.modelimport.keras.layers.TFOpLayer;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor;
import org.nd4j.shade.guava.collect.ImmutableSet;
import org.nd4j.shade.guava.reflect.ClassPath;
import lombok.extern.slf4j.Slf4j;
@ -51,16 +53,11 @@ import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
import org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer;
import org.deeplearning4j.nn.conf.preprocessor.CnnToRnnPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.ComposableInputPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToCnn3DPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.RnnToCnnPreProcessor;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.util.IdentityLayer;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.WeightInitDistribution;
@ -97,7 +94,8 @@ public class DTypeTests extends BaseDL4JTest {
Pooling2D.class, //Alias for SubsamplingLayer
Convolution2D.class, //Alias for ConvolutionLayer
Pooling1D.class, //Alias for Subsampling1D
Convolution1D.class //Alias for Convolution1DLayer
Convolution1D.class, //Alias for Convolution1DLayer
TensorFlowCnnToFeedForwardPreProcessor.class //Deprecated
));
@Override
@ -1078,7 +1076,7 @@ public class DTypeTests extends BaseDL4JTest {
.addLayer("l", new DenseLayer.Builder().nOut(16).build(), "in")
.addVertex("preproc", new PreprocessorVertex(new FeedForwardToCnn3DPreProcessor(2, 2, 2, 2, true)), "l")
.addVertex("preproc2", new PreprocessorVertex(new PermutePreprocessor(0, 2, 3, 4, 1)), "preproc")
.addVertex("preproc3", new PreprocessorVertex(new ReshapePreprocessor(new long[]{2, 2, 2, 2}, new long[]{16})), "preproc2")
.addVertex("preproc3", new PreprocessorVertex(new ReshapePreprocessor(new long[]{2, 2, 2, 2}, new long[]{16}, false)), "preproc2")
.addLayer("out", new OutputLayer.Builder().nIn(16).nOut(10).build(), "preproc3")
.setInputTypes(InputType.feedForward(5))
.setOutputs("out");
@ -1150,7 +1148,7 @@ public class DTypeTests extends BaseDL4JTest {
case 7:
b.addInputs("in")
.addLayer("1", new ConvolutionLayer.Builder().kernelSize(2, 2).nOut(5).convolutionMode(ConvolutionMode.Same).build(), "in")
.addVertex("2", new PreprocessorVertex(new TensorFlowCnnToFeedForwardPreProcessor(28, 28, 5)), "1")
.addVertex("2", new PreprocessorVertex(new CnnToFeedForwardPreProcessor(28, 28, 5)), "1")
.addLayer("out", new OutputLayer.Builder().nOut(10).build(), "2")
.setOutputs("out")
.setInputTypes(InputType.convolutional(28, 28, 1));

View File

@ -60,7 +60,7 @@ public class TestGraphNodes extends BaseDL4JTest {
@Test
public void testMergeNode() {
Nd4j.getRandom().setSeed(12345);
GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType());
GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType(), 1);
INDArray first = Nd4j.linspace(0, 11, 12, Nd4j.dataType()).reshape(3, 4);
INDArray second = Nd4j.linspace(0, 17, 18, Nd4j.dataType()).reshape(3, 6).addi(100);
@ -82,7 +82,7 @@ public class TestGraphNodes extends BaseDL4JTest {
public void testMergeNodeRNN() {
Nd4j.getRandom().setSeed(12345);
GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType());
GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType(), 1);
INDArray first = Nd4j.linspace(0, 59, 60, Nd4j.dataType()).reshape(3, 4, 5);
INDArray second = Nd4j.linspace(0, 89, 90, Nd4j.dataType()).reshape(3, 6, 5).addi(100);
@ -103,7 +103,7 @@ public class TestGraphNodes extends BaseDL4JTest {
@Test
public void testCnnDepthMerge() {
Nd4j.getRandom().setSeed(12345);
GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType());
GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType(), 1);
INDArray first = Nd4j.linspace(0, 3, 4, Nd4j.dataType()).reshape(1, 1, 2, 2);
INDArray second = Nd4j.linspace(0, 3, 4, Nd4j.dataType()).reshape(1, 1, 2, 2).addi(10);

View File

@ -15,14 +15,13 @@
******************************************************************************/
package org.deeplearning4j.nn.layers.convolution;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.*;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
@ -30,8 +29,12 @@ import org.deeplearning4j.nn.conf.layers.CnnLossLayer;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.ZeroPaddingLayer;
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D;
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.ComposableInputPreProcessor;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@ -516,6 +519,40 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
}
@Test
public void testGlobalPooling() {
try {
for (boolean helpers : new boolean[]{false, true}) {
for (PoolingType pt : PoolingType.values()) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = helpers ? "With helpers (" + pt + ")" : "No helpers (" + pt + ")";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getGlobalPoolingNet(CNN2DFormat.NCHW, pt, true))
.net2(getGlobalPoolingNet(CNN2DFormat.NCHW, pt, false))
.net3(getGlobalPoolingNet(CNN2DFormat.NHWC, pt, true))
.net4(getGlobalPoolingNet(CNN2DFormat.NHWC, pt, false))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
.testLayerIdx(1)
.build();
testHelper(tc);
}
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
private MultiLayerNetwork getConv2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) {
return getNetWithLayer(new ConvolutionLayer.Builder()
@ -735,11 +772,28 @@ public class ConvDataFormatTests extends BaseDL4JTest {
.layer(new OutputLayer.Builder().activation(Activation.SOFTMAX).nOut(10).build())
.setInputType(inputType != null ? inputType : InputType.convolutional(12, 12, 3, format));
if(format == CNN2DFormat.NHWC && !(layer instanceof GlobalPoolingLayer)){
//Add a preprocessor due to the differences in how NHWC and NCHW activations are flattened
//DL4J's flattening behaviour matches Keras (hence TF) for import compatibility
builder.inputPreProcessor(2, new ComposableInputPreProcessor(new NHWCToNCHWPreprocessor(), new CnnToFeedForwardPreProcessor()));
}
MultiLayerNetwork net = new MultiLayerNetwork(builder.build());
net.init();
return net;
}
private MultiLayerNetwork getGlobalPoolingNet(CNN2DFormat format, PoolingType pt, boolean setOnLayerAlso) {
if (setOnLayerAlso) {
return getNetWithLayer(new GlobalPoolingLayer.Builder(pt)
.poolingDimensions(format == CNN2DFormat.NCHW ? new int[]{2,3} : new int[]{1,2})
.build(), format, ConvolutionMode.Same, null);
} else {
return getNetWithLayer(new GlobalPoolingLayer.Builder(pt)
.build(), format, ConvolutionMode.Same, null);
}
}
private MultiLayerNetwork getCnnLossNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm){
NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder()
.seed(12345)
@ -799,8 +853,13 @@ public class ConvDataFormatTests extends BaseDL4JTest {
INDArray l0_4 = tc.net4.feedForward(inNHWC).get(tc.testLayerIdx + 1);
assertEquals(tc.msg, l0_1, l0_2);
if(l0_1.rank() == 4) {
assertEquals(tc.msg, l0_1, l0_3.permute(0, 3, 1, 2));
assertEquals(tc.msg, l0_1, l0_4.permute(0, 3, 1, 2));
} else {
assertEquals(tc.msg, l0_1, l0_3);
assertEquals(tc.msg, l0_1, l0_4);
}
INDArray out1 = tc.net1.output(inNCHW);
@ -880,4 +939,36 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
return differs;
}
//Converts NHWC to NCHW activations
@EqualsAndHashCode
private static class NHWCToNCHWPreprocessor implements InputPreProcessor {
@Override
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.permute(0,3,1,2));
}
@Override
public INDArray backprop(INDArray output, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, output.permute(0,2,3,1));
}
@Override
public InputPreProcessor clone() {
return this;
}
@Override
public InputType getOutputType(InputType inputType) {
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
return InputType.convolutional(c.getHeight(), c.getWidth(), c.getChannels(), CNN2DFormat.NCHW);
}
@Override
public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) {
return null;
}
}
}

View File

@ -44,6 +44,7 @@ import org.nd4j.linalg.primitives.Pair;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
@ -217,7 +218,7 @@ public class TestRnnLayers extends BaseDL4JTest {
NeuralNetConfiguration.ListBuilder lb = new NeuralNetConfiguration.Builder()
.list()
.layer(new SimpleRnn.Builder().nIn(5).nOut(5).build());
.layer(new SimpleRnn.Builder().nIn(5).nOut(5).dataFormat(rnnDataFormat).build());
switch (i){
case 0:
@ -235,10 +236,7 @@ public class TestRnnLayers extends BaseDL4JTest {
net.init();
INDArray in = Nd4j.rand(DataType.FLOAT, 3, 5, 5);
INDArray l = TestUtils.randomOneHotTimeSeries(3, 5, 10);
if (rnnDataFormat == RNNFormat.NWC){
l = l.permute(0, 2, 1);
}
INDArray l = TestUtils.randomOneHotTimeSeries(rnnDataFormat, 3, 5, 10, new Random(12345));
try{
net.fit(in,l);
} catch (Throwable t){

View File

@ -61,15 +61,13 @@ public class TestSimpleRnn extends BaseDL4JTest {
int tsLength = 7;
INDArray in;
if (rnnDataFormat == RNNFormat.NCW){
in = Nd4j.rand(DataType.FLOAT, new int[]{m, nIn, tsLength});
in = Nd4j.rand(DataType.FLOAT, m, nIn, tsLength);
}
else{
in = Nd4j.rand(DataType.FLOAT, new int[]{m, tsLength, nIn});
in = Nd4j.rand(DataType.FLOAT, m, tsLength, nIn);
}
// in.get(all(), all(), interval(1,tsLength)).assign(0);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.updater(new NoOp())
.weightInit(WeightInit.XAVIER)

View File

@ -7,10 +7,14 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Test;
import org.junit.runner.RunWith;
@ -106,4 +110,73 @@ public class TestTimeDistributed extends BaseDL4JTest {
}
}
}
@Test
public void testTimeDistributedDense(){
for( int rnnType=0; rnnType<3; rnnType++ ) {
for( int ffType=0; ffType<3; ffType++ ) {
Layer l0, l2;
switch (rnnType) {
case 0:
l0 = new LSTM.Builder().nOut(5).build();
l2 = new LSTM.Builder().nOut(5).build();
break;
case 1:
l0 = new SimpleRnn.Builder().nOut(5).build();
l2 = new SimpleRnn.Builder().nOut(5).build();
break;
case 2:
l0 = new Bidirectional(new LSTM.Builder().nOut(5).build());
l2 = new Bidirectional(new LSTM.Builder().nOut(5).build());
break;
default:
throw new RuntimeException("Not implemented: " + rnnType);
}
Layer l1;
switch (ffType){
case 0:
l1 = new DenseLayer.Builder().nOut(5).build();
break;
case 1:
l1 = new VariationalAutoencoder.Builder().nOut(5).encoderLayerSizes(5).decoderLayerSizes(5).build();
break;
case 2:
l1 = new AutoEncoder.Builder().nOut(5).build();
break;
default:
throw new RuntimeException("Not implemented: " + ffType);
}
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.activation(Activation.TANH)
.list()
.layer(l0)
.layer(l1)
.layer(l2)
.setInputType(InputType.recurrent(5, 9, rnnDataFormat))
.build();
BaseRecurrentLayer l0a;
BaseRecurrentLayer l2a;
if (rnnType < 2) {
l0a = (BaseRecurrentLayer) l0;
l2a = (BaseRecurrentLayer) l2;
} else {
l0a = (BaseRecurrentLayer) ((Bidirectional) l0).getFwd();
l2a = (BaseRecurrentLayer) ((Bidirectional) l2).getFwd();
}
assertEquals(rnnDataFormat, l0a.getRnnDataFormat());
assertEquals(rnnDataFormat, l2a.getRnnDataFormat());
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
INDArray in = Nd4j.rand(DataType.FLOAT, rnnDataFormat == RNNFormat.NCW ? new long[]{2, 5, 9} : new long[]{2, 9, 5} );
net.output(in);
}
}
}
}

View File

@ -15,21 +15,24 @@
******************************************************************************/
package org.deeplearning4j.convolution;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.*;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.CuDNNTestUtils;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D;
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.ComposableInputPreProcessor;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@ -816,6 +819,12 @@ public class ConvDataFormatTests extends BaseDL4JTest {
.layer(new OutputLayer.Builder().activation(Activation.SOFTMAX).nOut(10).build())
.setInputType(inputType != null ? inputType : InputType.convolutional(12, 12, 3, format));
if(format == CNN2DFormat.NHWC && !(layer instanceof GlobalPoolingLayer)){
//Add a preprocessor due to the differences in how NHWC and NCHW activations are flattened
//DL4J's flattening behaviour matches Keras (hence TF) for import compatibility
builder.inputPreProcessor(2, new ComposableInputPreProcessor(new NHWCToNCHWPreprocessor(), new CnnToFeedForwardPreProcessor()));
}
MultiLayerNetwork net = new MultiLayerNetwork(builder.build());
net.init();
return net;
@ -964,4 +973,35 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
return differs;
}
//Converts NHWC to NCHW activations
@EqualsAndHashCode
private static class NHWCToNCHWPreprocessor implements InputPreProcessor {
@Override
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.permute(0,3,1,2));
}
@Override
public INDArray backprop(INDArray output, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, output.permute(0,2,3,1));
}
@Override
public InputPreProcessor clone() {
return this;
}
@Override
public InputType getOutputType(InputType inputType) {
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
return InputType.convolutional(c.getHeight(), c.getWidth(), c.getChannels(), CNN2DFormat.NCHW);
}
@Override
public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) {
return null;
}
}
}

View File

@ -212,7 +212,7 @@ public class TestConvolution extends BaseDL4JTest {
ComputationGraph model = KerasModelImport.importKerasModelAndWeights( fExtracted.getAbsolutePath(), new int[]{inSize, inSize, 3}, false);
model = model.convertDataType(DataType.DOUBLE);
INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{1, 3, inSize, inSize});
INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{1, inSize, inSize, 3}); //Keras import model -> NHWC
CuDNNTestUtils.assertHelpersPresent(model.getLayers());
Map<String,INDArray> withCudnn = model.feedForward(in, false);

View File

@ -113,6 +113,12 @@
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-python</artifactId>
<version>${datavec.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<profiles>

View File

@ -19,6 +19,9 @@ package org.deeplearning4j.nn.modelimport.keras.layers;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Convolution3D;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
@ -121,27 +124,29 @@ public class KerasInput extends KerasLayer {
InputType myInputType;
switch (this.inputShape.length) {
case 1:
myInputType = new InputType.InputTypeFeedForward(this.inputShape[0]);
myInputType = new InputType.InputTypeFeedForward(this.inputShape[0], null);
break;
case 2:
if(this.dimOrder != null) {
System.out.println("Dim order: " + this.dimOrder);
System.out.println("Input shape: " + ArrayUtils.toString(this.inputShape));
switch (this.dimOrder) {
case TENSORFLOW: //NWC == channels_last
myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0]);
myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0], RNNFormat.NWC);
break;
case THEANO: //NCW == channels_first
myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]);
myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1], RNNFormat.NCW);
break;
case NONE:
//Assume RNN in [mb, seqLen, size] format
myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]);
myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0], RNNFormat.NWC);
break;
default:
throw new IllegalStateException("Unknown/not supported dimension ordering: " + this.dimOrder);
}
} else {
//Assume RNN in [mb, seqLen, size] format
myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]);
myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0], RNNFormat.NWC);
}
break;
@ -150,17 +155,17 @@ public class KerasInput extends KerasLayer {
case TENSORFLOW:
/* TensorFlow convolutional input: # rows, # cols, # channels */
myInputType = new InputType.InputTypeConvolutional(this.inputShape[0], this.inputShape[1],
this.inputShape[2]);
this.inputShape[2], CNN2DFormat.NHWC);
break;
case THEANO:
/* Theano convolutional input: # channels, # rows, # cols */
myInputType = new InputType.InputTypeConvolutional(this.inputShape[1], this.inputShape[2],
this.inputShape[0]);
this.inputShape[0], CNN2DFormat.NCHW);
break;
default:
this.dimOrder = DimOrder.THEANO;
myInputType = new InputType.InputTypeConvolutional(this.inputShape[1], this.inputShape[2],
this.inputShape[0]);
this.inputShape[0], CNN2DFormat.NCHW);
log.warn("Couldn't determine dim ordering / data format from model file. Older Keras " +
"versions may come without specified backend, in which case we assume the model was " +
"built with theano." );

View File

@ -20,6 +20,7 @@ import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
@ -65,6 +66,9 @@ public class TFOpLayer extends Layer {
long[] shape = inputType.getShape(true);
TFOpLayerImpl tempLayer = new TFOpLayerImpl(nodeDef, constants, null, null);
long[] outputShape = tempLayer.getOutputShape(shape);
if (outputShape.length == 3){
return InputType.recurrent(outputShape[2], outputShape[1], RNNFormat.NWC);
}
return InputType.inferInputType(Nd4j.create(outputShape));
}

View File

@ -125,17 +125,9 @@ public class TFOpLayerImpl extends AbstractLayer<TFOpLayer> {
}
private INDArray runGraph(INDArray input){
if (input.rank() == 3){
// TODO make this a preprocessor
input = input.permute(0, 2, 1);
}
Map<String, INDArray> inputMap = new HashMap<>();
inputMap.put(inputNames.get(0), input);
INDArray out = graphRunnerService.run(inputMap).values().toArray(new INDArray[0])[0];
if (out.rank() == 3){
out = out.permute(0, 2, 1); // TODO post-processing?
}
return out;
}

View File

@ -95,7 +95,6 @@ public class KerasConvolution1D extends KerasConvolution {
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
enforceTrainingConfig, conf, kerasMajorVersion);
Convolution1DLayer.Builder builder = new Convolution1DLayer.Builder().name(this.layerName)
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
.activation(getIActivationFromConfig(layerConfig, conf))
@ -104,7 +103,7 @@ public class KerasConvolution1D extends KerasConvolution {
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
.kernelSize(getKernelSizeFromConfig(layerConfig, 1, conf, kerasMajorVersion)[0])
.hasBias(hasBias)
.stride(getStrideFromConfig(layerConfig, 1, conf)[0]);
.stride(getStrideFromConfig(layerConfig, 1, conf)[0]).rnnDataFormat(dimOrder == DimOrder.TENSORFLOW? RNNFormat.NWC: RNNFormat.NCW);
int[] padding = getPaddingFromBorderModeConfig(layerConfig, 1, conf, kerasMajorVersion);
if (hasBias)
builder.biasInit(0.0);

View File

@ -20,6 +20,7 @@ import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
@ -27,6 +28,7 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurat
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.weights.IWeightInit;
import oshi.jna.platform.windows.PowrProf;
import java.util.Map;
@ -93,6 +95,7 @@ public class KerasConvolution2D extends KerasConvolution {
LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig(
layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion);
System.out.println("----" + dimOrder);
ConvolutionLayer.Builder builder = new ConvolutionLayer.Builder().name(this.layerName)
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
.activation(getIActivationFromConfig(layerConfig, conf))
@ -101,7 +104,8 @@ public class KerasConvolution2D extends KerasConvolution {
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
.kernelSize(getKernelSizeFromConfig(layerConfig, 2, conf, kerasMajorVersion))
.hasBias(hasBias)
.stride(getStrideFromConfig(layerConfig, 2, conf));
.stride(getStrideFromConfig(layerConfig, 2, conf))
.dataFormat((dimOrder==DimOrder.TENSORFLOW)? CNN2DFormat.NHWC:CNN2DFormat.NCHW);
int[] padding = getPaddingFromBorderModeConfig(layerConfig, 2, conf, kerasMajorVersion);
if (hasBias)
builder.biasInit(0.0);

View File

@ -360,8 +360,19 @@ public class KerasConvolutionUtils {
}
} else if (dimension == 1) {
Object paddingObj = innerConfig.get(layerField);
if (paddingObj instanceof List){
List<Integer> paddingList = (List)paddingObj;
padding = new int[]{
paddingList.get(0),
paddingList.get(1)
};
}
else{
int paddingInt = (int) innerConfig.get(layerField);
padding = new int[]{paddingInt, paddingInt};
}
} else {
throw new UnsupportedKerasConfigurationException(
"Keras padding layer not supported");

View File

@ -18,6 +18,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.core;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InputType.InputTypeConvolutional;
@ -27,7 +28,6 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurat
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor;
import java.util.Map;
@ -93,11 +93,10 @@ public class KerasFlatten extends KerasLayer {
switch (this.getDimOrder()) {
case NONE:
case THEANO:
preprocessor = new CnnToFeedForwardPreProcessor(it.getHeight(), it.getWidth(), it.getChannels());
preprocessor = new CnnToFeedForwardPreProcessor(it.getHeight(), it.getWidth(), it.getChannels(), CNN2DFormat.NCHW);
break;
case TENSORFLOW:
preprocessor = new TensorFlowCnnToFeedForwardPreProcessor(it.getHeight(), it.getWidth(),
it.getChannels());
preprocessor = new CnnToFeedForwardPreProcessor(it.getHeight(), it.getWidth(), it.getChannels(), CNN2DFormat.NHWC);
break;
default:
throw new InvalidKerasConfigurationException("Unknown Keras backend " + this.getDimOrder());
@ -111,7 +110,7 @@ public class KerasFlatten extends KerasLayer {
// to RNN type. Otherwise we add this trivial preprocessor (since there's nothing to flatten).
InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0];
val inputShape = new long[]{it.getSize()};
preprocessor = new ReshapePreprocessor(inputShape, inputShape, false);
preprocessor = new ReshapePreprocessor(inputShape, inputShape, false, null);
}
return preprocessor;
}

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.nn.modelimport.keras.layers.core;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.misc.RepeatVector;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
@ -60,6 +61,7 @@ public class KerasRepeatVector extends KerasLayer {
super(layerConfig, enforceTrainingConfig);
this.layer = new RepeatVector.Builder().repetitionFactor(getRepeatMultiplier(layerConfig, conf))
.dataFormat(RNNFormat.NWC)
.name(this.layerName).build();
}

View File

@ -18,6 +18,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.core;
import lombok.val;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
@ -111,11 +112,9 @@ public class KerasReshape extends KerasLayer {
} else {
targetShape = new long[]{targetShape[1], targetShape[0], targetShape[2]};
}
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false, CNN2DFormat.NCHW);
} else { // (dimOrder == DimOrder.TENSORFLOW || dimOrder == DimOrder.NONE && kerasMajorVersion == 2)
if (inputShape[0] != targetShape[0])
targetShape = new long[]{targetShape[2], targetShape[0], targetShape[1]};
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false, CNN2DFormat.NHWC);
}
} else if (inputType[0] instanceof InputType.InputTypeConvolutional3D) {
@ -128,23 +127,23 @@ public class KerasReshape extends KerasLayer {
} else {
targetShape = new long[] { targetShape[2], targetShape[1], targetShape[0], targetShape[3] };
}
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false, null);
} else {
if (inputShape[0] != targetShape[0])
targetShape = new long[] { targetShape[3], targetShape[0], targetShape[1], targetShape[2] };
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false, null);
}
} else if (inputType[0] instanceof InputType.InputTypeRecurrent) {
InputType.InputTypeRecurrent it = (InputType.InputTypeRecurrent) inputType[0];
val inputShape = new long[]{it.getSize(), it.getTimeSeriesLength()};
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false);
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false, null);
} else if (inputType[0] instanceof InputType.InputTypeFeedForward) {
InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0];
val inputShape = new long[]{it.getSize()};
if (targetShape.length == 3) {
targetShape = targetShapeForDimOrder(inputShape, targetShape);
}
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false);
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false, null);
}
return preprocessor;
}

View File

@ -21,6 +21,7 @@ import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
@ -121,6 +122,7 @@ public class KerasEmbedding extends KerasLayer {
.biasInit(0.0)
.l1(this.weightL1Regularization)
.l2(this.weightL2Regularization)
.outputDataFormat(RNNFormat.NWC)
.hasBias(false);
if (embeddingConstraint != null)
builder.constrainWeights(embeddingConstraint);

View File

@ -186,7 +186,7 @@ public class KerasLSTM extends KerasLayer {
.weightInitRecurrent(recurrentInit)
.biasInit(0.0) // TODO: this is incorrect
.l1(this.weightL1Regularization)
.l2(this.weightL2Regularization);
.l2(this.weightL2Regularization).dataFormat(RNNFormat.NWC);
Integer nIn = KerasLayerUtils.getNInFromInputDim(layerConfig, conf);
if(nIn != null)
builder.setNIn(nIn);

View File

@ -158,7 +158,7 @@ public class KerasSimpleRnn extends KerasLayer {
.weightInitRecurrent(recurrentInit)
.biasInit(0.0)
.l1(this.weightL1Regularization)
.l2(this.weightL2Regularization);
.l2(this.weightL2Regularization).dataFormat(RNNFormat.NWC);
Integer nIn = KerasLayerUtils.getNInFromInputDim(layerConfig, conf);
if(nIn != null)
builder.setNIn(nIn);

View File

@ -147,7 +147,7 @@ public class KerasBidirectional extends KerasLayer {
break;
case "SimpleRNN":
kerasRnnlayer = new KerasSimpleRnn(innerRnnConfig, enforceTrainingConfig, previousLayers);
SimpleRnn rnnLayer = (SimpleRnn) ((KerasSimpleRnn) kerasRnnlayer).getSimpleRnnLayer();
Layer rnnLayer = ((KerasSimpleRnn) kerasRnnlayer).getSimpleRnnLayer();
this.layer = new Bidirectional(mode, rnnLayer);
layer.setLayerName(layerName);
break;

View File

@ -21,6 +21,9 @@ import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.DataFormat;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor;
@ -54,25 +57,30 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
private final long[] inputShape;
private final long[] targetShape;
private boolean hasMiniBatchDimension;
/**
* @deprecated Use constructor {@link #ReshapePreprocessor(long[], long[], boolean)}
*/
@Deprecated
public ReshapePreprocessor(long[] inputShape, long[] targetShape) {
this(inputShape, targetShape, false);
}
private DataFormat format;
/**
* @param inputShape Input shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
* @param targetShape Target shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
* @param hasMiniBatchDimension If true: shapes should be of the form [minibatch, x, y, ...]; if false: shapes should be of form [x, y, ...]
*/
public ReshapePreprocessor(long[] inputShape, long[] targetShape, boolean hasMiniBatchDimension) {
this(inputShape, targetShape, hasMiniBatchDimension, null);
}
/**
* @param inputShape Input shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
* @param targetShape Target shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
* @param hasMiniBatchDimension If true: shapes should be of the form [minibatch, x, y, ...]; if false: shapes should be of form [x, y, ...]
* @param dataFormat May be null. If non-null:
*/
public ReshapePreprocessor(@JsonProperty("inputShape") long[] inputShape, @JsonProperty("targetShape") long[] targetShape,
@JsonProperty("hasMiniBatchDimension") boolean hasMiniBatchDimension) {
@JsonProperty("hasMiniBatchDimension") boolean hasMiniBatchDimension,
@JsonProperty("dataFormat") DataFormat dataFormat) {
this.inputShape = inputShape;
this.targetShape = targetShape;
this.hasMiniBatchDimension = hasMiniBatchDimension;
this.format = dataFormat;
}
private long[] getShape(long[] originalShape, long minibatch) {
@ -140,13 +148,26 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
ret = InputType.feedForward(shape[1]);
break;
case 3:
ret = InputType.recurrent(shape[2], shape[1]);
RNNFormat format = RNNFormat.NCW;
if(this.format != null && this.format instanceof RNNFormat)
format = (RNNFormat)this.format;
ret = InputType.recurrent(shape[2], shape[1], format);
break;
case 4:
if (inputShape.length == 1 || inputType.getType() == InputType.Type.RNN) {
ret = InputType.convolutional(shape[1], shape[2], shape[3]);
} else {
ret = InputType.convolutional(shape[2], shape[3], shape[1]);
CNN2DFormat cnnFormat = CNN2DFormat.NCHW;
if (this.format != null && this.format instanceof CNN2DFormat)
cnnFormat = (CNN2DFormat) this.format;
if (cnnFormat == CNN2DFormat.NCHW) {
ret = InputType.convolutional(shape[2], shape[3], shape[1], cnnFormat);
} else {
ret = InputType.convolutional(shape[1], shape[2], shape[3], cnnFormat);
}
}
break;
default:

View File

@ -27,26 +27,25 @@ import org.nd4j.shade.jackson.annotation.JsonCreator;
import org.nd4j.shade.jackson.annotation.JsonProperty;
/**
* Specialized CnnToFeedForwardInputPreProcessor for use with
* Convolutional layers imported from Keras using the TensorFlow
* backend.
*
* @author dave@skymind.io
* @deprecated Exists only for backward compatibility of older pretrained models. Should not be used.
* Use {@link CnnToFeedForwardPreProcessor} for all new models instead.
*/
@Slf4j
@Slf4j @Deprecated
public class TensorFlowCnnToFeedForwardPreProcessor extends CnnToFeedForwardPreProcessor {
@JsonCreator
@JsonCreator @Deprecated
public TensorFlowCnnToFeedForwardPreProcessor(@JsonProperty("inputHeight") long inputHeight,
@JsonProperty("inputWidth") long inputWidth,
@JsonProperty("numChannels") long numChannels) {
super(inputHeight, inputWidth, numChannels);
}
@Deprecated
public TensorFlowCnnToFeedForwardPreProcessor(long inputHeight, long inputWidth) {
super(inputHeight, inputWidth);
}
@Deprecated
public TensorFlowCnnToFeedForwardPreProcessor() {
super();
}

View File

@ -31,6 +31,7 @@ import org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations.*;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.*;
import org.deeplearning4j.nn.modelimport.keras.layers.core.*;
import org.deeplearning4j.nn.modelimport.keras.layers.embeddings.KerasEmbedding;
import org.deeplearning4j.nn.modelimport.keras.layers.local.KerasLocallyConnected1D;
import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasAlphaDropout;
import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasGaussianDropout;
import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasGaussianNoise;
@ -319,6 +320,8 @@ public class KerasLayerUtils {
layer = new KerasELU(layerConfig, enforceTrainingConfig);
} else if(layerClassName.equals(conf.getLAYER_CLASS_NAME_SOFTMAX())){
layer = new KerasSoftmax(layerConfig, enforceTrainingConfig);
} else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_LOCALLY_CONNECTED_1D())){
layer = new KerasLocallyConnected1D(layerConfig, enforceTrainingConfig);
} else if (conf instanceof Keras2LayerConfiguration){
Keras2LayerConfiguration k2conf = (Keras2LayerConfiguration)conf;
if (layerClassName.equals(k2conf.getTENSORFLOW_OP_LAYER())){

View File

@ -1,50 +0,0 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.nn.modelimport.keras;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.junit.Assert;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.resources.Resources;
import java.io.File;
import java.util.Arrays;
public class TFKerasTests extends BaseDL4JTest{
@Test
public void testModelWithTFOp1() throws Exception{
File f = Resources.asFile("modelimport/keras/tfkeras/reshape.h5");
ComputationGraph graph = KerasModelImport.importKerasModelAndWeights(f.getAbsolutePath());
INDArray out = graph.outputSingle(Nd4j.zeros(12, 2, 3));
Assert.assertArrayEquals(new long[]{12, 3}, out.shape());
}
@Test
public void testModelWithTFOp2() throws Exception{
File f = Resources.asFile("modelimport/keras/tfkeras/permute.h5");
ComputationGraph graph = KerasModelImport.importKerasModelAndWeights(f.getAbsolutePath());
INDArray out = graph.outputSingle(Nd4j.zeros(12, 2, 3));
// dl4j's feedforward doesn't support 3D output, so batch and time axes gets squashed
long[] expectedShape = new long[]{12 * 2, 5};
Assert.assertArrayEquals(expectedShape, out.shape());
}
}

View File

@ -0,0 +1,147 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.nn.modelimport.keras;
import org.apache.commons.io.FileUtils;
import org.datavec.python.keras.Model;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.common.tests.ResourceUtils;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.resources.Resources;
import java.io.File;
import java.util.List;
@RunWith(Parameterized.class)
public class TestTFKerasModelImport extends BaseDL4JTest{
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
private String modelFile;
@Override
public long getTimeoutMilliseconds(){
return 300000;
} // installing TF will take a while
@Parameterized.Parameters(name = "file={0}")
public static Object[] params() throws Exception {
List<String> paths = ResourceUtils.listClassPathFiles("modelimport/keras/tfkeras", true, false);
return paths.toArray(new String[0]);
}
public TestTFKerasModelImport(String modelFile){
this.modelFile = modelFile;
}
@Test
public void testModelImport() throws Exception{
testModelImportWithData(modelFile);
}
private void testModelImportWithData(String path) throws Exception{
System.out.println(path);
// TODO multi input/output
INDArray inputArray;
INDArray expectedOutputArray;
File f = Resources.asFile(path); //May in in JAR that HDF5 can't read from
File modelFile = new File(testDir.getRoot(), f.getName());
FileUtils.copyFile(f, modelFile);
synchronized (Hdf5Archive.LOCK_OBJECT){
Hdf5Archive hdf5Archive = new Hdf5Archive(modelFile.getAbsolutePath());
List<String> rootGroups = hdf5Archive.getGroups();
if (rootGroups.contains("data")){
String inputName = hdf5Archive.readAttributeAsString("input_names", "data");
String outputName = hdf5Archive.readAttributeAsString("output_names", "data");
inputArray = hdf5Archive.readDataSet(inputName, "data");
expectedOutputArray = hdf5Archive.readDataSet(outputName, "data");
}
else{
hdf5Archive.close();
return;
}
hdf5Archive.close();
}
INDArray outputArray;
ComputationGraph dl4jModel = KerasModelImport.importKerasModelAndWeights(path);
outputArray = dl4jModel.outputSingle(inputArray);
expectedOutputArray = expectedOutputArray.castTo(DataType.FLOAT);
outputArray = outputArray.castTo(DataType.FLOAT);
if (path.contains("misc_")){
//shape relaxation
expectedOutputArray = expectedOutputArray.reshape( -1);
outputArray = outputArray.reshape(-1);
}
System.out.println(outputArray.toString());
System.out.println(expectedOutputArray.toString());
Assert.assertArrayEquals(expectedOutputArray.shape(), outputArray.shape());
Assert.assertTrue(expectedOutputArray.equalsWithEps(outputArray, 1e-3));
}
private void testModelImportWithKeras(String path) throws Exception{
Model kerasModel = new Model(path);
ComputationGraph dl4jModel = KerasModelImport.importKerasModelAndWeights(path);
Assert.assertEquals(kerasModel.numInputs(), dl4jModel.getNumInputArrays());
Assert.assertEquals(kerasModel.numOutputs(), dl4jModel.getNumOutputArrays());
INDArray[] kerasInputArrays = new INDArray[kerasModel.numInputs()];
INDArray[] dl4jInputArrays = new INDArray[kerasModel.numInputs()];
for (int i = 0; i < kerasInputArrays.length; i ++) {
long[] shape = kerasModel.inputShapeAt(i);
for (int j = 0; j < shape.length; j++) {
if (shape[j] < 0) {
shape[j] = 1;
}
}
kerasInputArrays[i] = Nd4j.rand(shape);
}
INDArray[] kerasOut = kerasModel.predict(kerasInputArrays);
INDArray[] dl4jOut = dl4jModel.output(dl4jInputArrays);
Assert.assertEquals(kerasOut.length, dl4jOut.length);
for (int i = 0; i < kerasOut.length; i++){
INDArray kerasOutArr = kerasOut[i];
kerasOutArr = kerasOutArr.reshape(1, -1);// bit of relaxation on shape
kerasOutArr= kerasOutArr.castTo(DataType.DOUBLE);
Nd4j.getAffinityManager().ensureLocation(dl4jOut[i], AffinityManager.Location.HOST);
INDArray dl4jOutArr = dl4jOut[i].reshape(1, -1);
System.out.println(kerasOutArr.shapeInfoToString());
System.out.println(dl4jOutArr.shapeInfoToString());
Assert.assertEquals(kerasOutArr, dl4jOutArr);
}
}
}

View File

@ -22,7 +22,6 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
@ -34,8 +33,7 @@ public class JsonTest extends BaseDL4JTest {
InputPreProcessor[] pp = new InputPreProcessor[] {
new KerasFlattenRnnPreprocessor(10, 5),
new PermutePreprocessor(new int[]{0,1,2}),
new ReshapePreprocessor(new long[]{10,10}, new long[]{100,1}),
new TensorFlowCnnToFeedForwardPreProcessor()
new ReshapePreprocessor(new long[]{10,10}, new long[]{100,1}, true, null)
};
for(InputPreProcessor p : pp ){

View File

@ -29,6 +29,7 @@ import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasSpaceTo
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.resources.Resources;
@ -250,7 +251,7 @@ public class Keras2ModelConfigurationTest extends BaseDL4JTest {
.enforceTrainingConfig(false).buildSequential().getMultiLayerConfiguration();
MultiLayerNetwork model = new MultiLayerNetwork(config);
model.init();
INDArray input = Nd4j.create(50, 500, 1500);
INDArray input = Nd4j.create(DataType.FLOAT, 50, 1500, 500); //NWC format - [Minibatch, seqLength, channels]
INDArray out = model.output(input);
assertTrue(Arrays.equals(out.shape(), new long[]{50, 64}));
}

View File

@ -87,15 +87,6 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
@Rule
public final TemporaryFolder testDir = new TemporaryFolder();
public static final BiFunction<String,INDArray,INDArray> nwc2ncwExpected = new BiFunction<String, INDArray, INDArray>() {
@Override
public INDArray apply(String s, INDArray array) {
if(array.rank() == 3)
return array.permute(0, 2, 1); //NWC to NCW
return array;
}
};
@Override
public long getTimeoutMilliseconds() {
return 180000L; //Most benchmarks should run very quickly; large timeout is to avoid issues with unusually slow download of test resources
@ -169,28 +160,28 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
public void importImdbLstmTfKeras1() throws Exception {
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_model.h5";
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected);
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, null);
}
@Test
public void importImdbLstmThKeras1() throws Exception {
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_model.h5";
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected);
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, null);
}
@Test
public void importImdbLstmTfKeras2() throws Exception {
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_model.h5";
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected);
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, null);
}
@Test
public void importImdbLstmThKeras2() throws Exception {
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_model.h5";
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, false, true, false, false, true, null, nwc2ncwExpected);
importEndModelTest(modelPath, inputsOutputPath, false, true, false, false, true, null, null);
}
/**
@ -262,7 +253,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
String modelPath = "modelimport/keras/examples/simple_flatten_rnn/simple_flatten_rnn_tf_keras_2_model.h5";
String inputsOutputPath = "modelimport/keras/examples/simple_flatten_rnn/" +
"simple_flatten_rnn_tf_keras_2_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected);
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, null);
}
/**
@ -316,7 +307,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
@Test
public void importAcganDiscriminator() throws Exception {
ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/acgan/acgan_discriminator_1_epochs.h5");
INDArray input = Nd4j.create(10, 1, 28, 28);
INDArray input = Nd4j.create(10, 28, 28, 1); //NHWC
INDArray[] output = model.output(input);
}
@ -403,7 +394,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
// Make predictions
int miniBatch = 32;
INDArray input = Nd4j.ones(miniBatch, 4, 10);
INDArray input = Nd4j.ones(miniBatch, 10, 4); //NWC format - with nIn=4, seqLength = 10
INDArray[] out = graph.output(input);
// Fit model
@ -450,7 +441,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
@Test
public void importMobileNet() throws Exception {
ComputationGraph graph = importFunctionalModelH5Test("modelimport/keras/examples/mobilenet/alternative.hdf5");
INDArray input = Nd4j.ones(10, 3, 299, 299);
INDArray input = Nd4j.ones(10, 299, 299, 3);
graph.output(input);
}
@ -462,7 +453,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
int[] inputShape = new int[]{299, 299, 3};
ComputationGraph graph = importFunctionalModelH5Test(
"modelimport/keras/examples/inception/inception_tf_keras_2.h5", inputShape, false);
INDArray input = Nd4j.ones(10, 3, 299, 299);
INDArray input = Nd4j.ones(10, 299, 299, 3); //TF = channels last = NHWC
graph.output(input);
System.out.println(graph.summary());
}
@ -476,7 +467,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
public void importInception() throws Exception {
ComputationGraph graph = importFunctionalModelH5Test(
"modelimport/keras/examples/inception/inception_v3_complete.h5");
INDArray input = Nd4j.ones(10, 3, 299, 299);
INDArray input = Nd4j.ones(10, 299, 299, 3); //TF = channels last = NHWC
graph.output(input);
System.out.println(graph.summary());
}
@ -533,14 +524,14 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
* - Separate (policy and value) residual architecture
* - Separate (policy and value) convolutional architecture
*/
@Test
@Test @Ignore //AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last
public void importSepConvPolicy() throws Exception {
ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/sep_conv_policy.h5");
INDArray input = Nd4j.create(32, 19, 19, 10);
model.output(input);
}
@Test
@Test @Ignore //AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last
public void importSepResPolicy() throws Exception {
ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/sep_res_policy.h5");
INDArray input = Nd4j.create(32, 19, 19, 10);
@ -548,28 +539,28 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
}
@Test
@Test @Ignore //AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last
public void importSepConvValue() throws Exception {
ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/sep_conv_value.h5");
INDArray input = Nd4j.create(32, 19, 19, 10);
model.output(input);
}
@Test
@Test @Ignore //AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last
public void importSepResValue() throws Exception {
ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/sep_res_value.h5");
INDArray input = Nd4j.create(32, 19, 19, 10);
model.output(input);
}
@Test
@Test @Ignore //AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last
public void importDualRes() throws Exception {
ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/dual_res.h5");
INDArray input = Nd4j.create(32, 19, 19, 10);
model.output(input);
}
@Test
@Test @Ignore //AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last
public void importDualConv() throws Exception {
ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/dual_conv.h5");
INDArray input = Nd4j.create(32, 19, 19, 10);
@ -634,16 +625,9 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
System.out.println("Starting test: " + name);
String modelPath = "modelimport/keras/examples/causal_conv1d/" + name;
String inputsOutputPath = "modelimport/keras/examples/causal_conv1d/" + (name.substring(0,name.length()-"model.h5".length()) + "inputs_and_outputs.h5");
Function<INDArray,INDArray> f = new Function<INDArray, INDArray>() {
@Override
public INDArray apply(INDArray i) {
//NWC to NCW
return i.permute(0, 2, 1);
}
};
MultiLayerNetwork net = importEndModelTest(modelPath, inputsOutputPath, true, true,
true, true, false, f, nwc2ncwExpected);
true, true, false, null, null);
Layer l = net.getLayer(0);
Convolution1DLayer c1d = (Convolution1DLayer) l.getConfig();
assertEquals(ConvolutionMode.Causal, c1d.getConvolutionMode());
@ -707,25 +691,9 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
System.out.println("Starting test: " + name);
String modelPath = "modelimport/keras/examples/conv1d/" + name;
String inputsOutputPath = "modelimport/keras/examples/conv1d/" + (name.substring(0,name.length()-"model.h5".length()) + "inputs_and_outputs.h5");
Function<INDArray,INDArray> f = name.contains("_cf_") ? null : new Function<INDArray, INDArray>() {
@Override
public INDArray apply(INDArray i) {
//NWC to NCW
return i.permute(0, 2, 1);
}
};
BiFunction<String,INDArray,INDArray> f2 = name.contains("_cf_") ? null : new BiFunction<String, INDArray, INDArray>() {
@Override
public INDArray apply(String s, INDArray array) {
// if("conv".equals(s)){
return array.permute(0, 2, 1);
// }
}
};
importEndModelTest(modelPath, inputsOutputPath, true, true,
true, true, false, f, f2);
true, true, false, null, null); //f, f2);
}
}
@ -882,8 +850,6 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
INDArray[] inputs = new INDArray[inputNames.size()];
for (int i = 0; i < inputNames.size(); i++) {
inputs[i] = archive.readDataSet(inputNames.get(i), GROUP_ATTR_INPUTS);
if (inputs[i].shape().length == 4 && tensorFlowImageDimOrdering)
inputs[i] = inputs[i].permute(0, 3, 1, 2);
}
return inputs;
}
@ -893,8 +859,6 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
Map<String, INDArray> activations = new HashMap<String, INDArray>();
for (String layerName : archive.getDataSets(GROUP_ACTIVATIONS)) {
INDArray activation = archive.readDataSet(layerName, GROUP_ACTIVATIONS);
if (activation.shape().length == 4 && tensorFlowImageDimOrdering)
activation = activation.permute(0, 3, 1, 2);
activations.put(layerName, activation);
}
return activations;
@ -907,8 +871,6 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
INDArray[] outputs = new INDArray[outputNames.size()];
for (int i = 0; i < outputNames.size(); i++) {
outputs[i] = archive.readDataSet(outputNames.get(i), GROUP_ATTR_OUTPUTS);
if (outputs[i].shape().length == 4 && tensorFlowImageDimOrdering)
outputs[i] = outputs[i].permute(0, 3, 1, 2);
}
return outputs;
}
@ -920,8 +882,6 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
INDArray[] predictions = new INDArray[outputNames.size()];
for (int i = 0; i < outputNames.size(); i++) {
predictions[i] = archive.readDataSet(outputNames.get(i), GROUP_PREDICTIONS);
if (predictions[i].shape().length == 4 && tensorFlowImageDimOrdering)
predictions[i] = predictions[i].permute(0, 3, 1, 2);
}
return predictions;
}
@ -941,6 +901,11 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
// skip too small absolute inputs
if (Math.abs(aAbsMax) > threshold && Math.abs(bAbsMax) > threshold) {
boolean eq = expected.equalsWithEps(actual.castTo(expected.dataType()), eps);
if(!eq){
System.out.println("Expected: " + Arrays.toString(expected.shape()) + ", actual: " + Arrays.toString(actual.shape()));
System.out.println("Expected:\n" + expected);
System.out.println("Actual: \n" + actual);
}
assertTrue("Output differs: " + label, eq);
}
}

View File

@ -176,10 +176,10 @@ public class KerasWeightSettingTests extends BaseDL4JTest {
INDArray bias = model.getLayer(0).getParam("b");
assertEquals(6, bias.length());
INDArray input = Nd4j.ones(1, 5, 3, 4);
INDArray input = Nd4j.ones(1, 3, 4, 5); //NHWC
INDArray output = model.output(input);
assertArrayEquals(new long[] {1, 6, 1, 2}, output.shape());
assertArrayEquals(new long[] {1, 1, 2, 6}, output.shape()); //NHWC
logSuccess(modelPath);
}
@ -224,7 +224,7 @@ public class KerasWeightSettingTests extends BaseDL4JTest {
INDArray input = Nd4j.zeros(mb, inputLength);
INDArray output = model.output(input);
assertArrayEquals(new long[]{mb, nOut, inputLength - kernel + 1}, output.shape());
assertArrayEquals(new long[]{mb, inputLength - kernel + 1, nOut}, output.shape()); //NWC
logSuccess(modelPath);
}
@ -238,9 +238,9 @@ public class KerasWeightSettingTests extends BaseDL4JTest {
KerasLayer.registerCustomLayer("Lambda", KerasSpaceToDepth.class);
MultiLayerNetwork model = loadMultiLayerNetwork(modelPath, false);
INDArray input = Nd4j.zeros(10, 4, 6, 6);
INDArray input = Nd4j.zeros(10, 6, 6, 4);
INDArray output = model.output(input);
assertArrayEquals(new long[]{10, 16, 3, 3}, output.shape());
assertArrayEquals(new long[]{10, 3, 3, 16}, output.shape());
logSuccess(modelPath);
}
@ -248,10 +248,11 @@ public class KerasWeightSettingTests extends BaseDL4JTest {
KerasLayer.registerCustomLayer("Lambda", KerasSpaceToDepth.class);
ComputationGraph model = loadComputationalGraph(modelPath, false);
INDArray input[] = new INDArray[]{Nd4j.zeros(10, 4, 6, 6), Nd4j.zeros(10, 16, 3, 3)};
// INDArray input[] = new INDArray[]{Nd4j.zeros(10, 4, 6, 6), Nd4j.zeros(10, 16, 3, 3)};
INDArray input[] = new INDArray[]{Nd4j.zeros(10, 6, 6, 4), Nd4j.zeros(10, 3, 3, 16)};
INDArray[] output = model.output(input);
log.info(Arrays.toString(output[0].shape()));
assertArrayEquals(new long[]{10, 32, 3, 3}, output[0].shape());
assertArrayEquals(new long[]{10, 3, 3, 32}, output[0].shape());
logSuccess(modelPath);
}
@ -278,7 +279,7 @@ public class KerasWeightSettingTests extends BaseDL4JTest {
INDArray inEmbedding = Nd4j.zeros(mb, inputLength);
INDArray output = model.output(inEmbedding);
assertArrayEquals(new long[]{mb, nOut, inputLength}, output.shape());
assertArrayEquals(new long[]{mb, inputLength, nOut}, output.shape()); //NWC format
logSuccess(modelPath);
}
@ -304,7 +305,7 @@ public class KerasWeightSettingTests extends BaseDL4JTest {
INDArray inEmbedding = Nd4j.zeros(mb, inputLength);
INDArray output = model.output(inEmbedding);
assertArrayEquals(new long[]{mb, nOut, inputLength - kernel + 1}, output.shape());
assertArrayEquals(new long[]{mb, inputLength - kernel + 1, nOut}, output.shape()); //NWC
logSuccess(modelPath);
}

View File

@ -9,7 +9,7 @@ package org.deeplearning4j.nn.conf;
*
* @author Alex Black
*/
public enum CNN2DFormat {
public enum CNN2DFormat implements DataFormat {
NCHW,
NHWC;

View File

@ -0,0 +1,26 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.nn.conf;
import org.deeplearning4j.nn.conf.serde.format.DataFormatDeserializer;
import org.deeplearning4j.nn.conf.serde.format.DataFormatSerializer;
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
@JsonSerialize(using = DataFormatSerializer.class)
@JsonDeserialize(using = DataFormatDeserializer.class)
public interface DataFormat {
}

View File

@ -663,7 +663,7 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
BaseRecurrentLayer brl = (BaseRecurrentLayer) firstLayer;
val nIn = brl.getNIn();
if (nIn > 0) {
inputType = InputType.recurrent(nIn);
inputType = InputType.recurrent(nIn, brl.getRnnDataFormat());
}
} else if (firstLayer instanceof DenseLayer || firstLayer instanceof EmbeddingLayer
|| firstLayer instanceof OutputLayer) {

View File

@ -23,7 +23,7 @@ package org.deeplearning4j.nn.conf;
* "width" corresponds to sequence length and "channels" corresponds to sequence item size.
*/
public enum RNNFormat {
public enum RNNFormat implements DataFormat {
NCW,
NWC
}

View File

@ -18,6 +18,8 @@ package org.deeplearning4j.nn.conf.graph;
import lombok.val;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
import org.deeplearning4j.nn.conf.layers.Convolution3D;
@ -38,6 +40,8 @@ import org.nd4j.linalg.api.ndarray.INDArray;
*/
public class MergeVertex extends GraphVertex {
protected int mergeAxis = 1; //default value for backward compatibility (deserialization of old version JSON) - NCHW and NCW format
@Override
public MergeVertex clone() {
return new MergeVertex();
@ -76,7 +80,7 @@ public class MergeVertex extends GraphVertex {
@Override
public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx,
INDArray paramsView, boolean initializeParams, DataType networkDatatype) {
return new org.deeplearning4j.nn.graph.vertex.impl.MergeVertex(graph, name, idx, networkDatatype);
return new org.deeplearning4j.nn.graph.vertex.impl.MergeVertex(graph, name, idx, networkDatatype, mergeAxis);
}
@Override
@ -126,6 +130,7 @@ public class MergeVertex extends GraphVertex {
//FF or RNN data inputs
int size = 0;
InputType.Type type = null;
RNNFormat format = null;
for (int i = 0; i < vertexInputs.length; i++) {
if (vertexInputs[i].getType() != first.getType()) {
throw new InvalidInputTypeException(
@ -142,6 +147,8 @@ public class MergeVertex extends GraphVertex {
break;
case RNN:
thisSize = ((InputType.InputTypeRecurrent) vertexInputs[i]).getSize();
format = ((InputType.InputTypeRecurrent) vertexInputs[i]).getFormat();
this.mergeAxis = format == RNNFormat.NCW ? 1 : 2;
type = InputType.Type.RNN;
break;
default:
@ -160,7 +167,7 @@ public class MergeVertex extends GraphVertex {
return InputType.feedForward(size);
} else {
val tsLength = ((InputType.InputTypeRecurrent) vertexInputs[0]).getTimeSeriesLength();
return InputType.recurrent(size, tsLength);
return InputType.recurrent(size, tsLength, format);
}
} else {
//size is unknown
@ -168,13 +175,14 @@ public class MergeVertex extends GraphVertex {
return InputType.feedForward(-1);
} else {
val tsLength = ((InputType.InputTypeRecurrent) vertexInputs[0]).getTimeSeriesLength();
return InputType.recurrent(-1, tsLength);
return InputType.recurrent(-1, tsLength, format);
}
}
} else {
//CNN inputs... also check that the channels, width and heights match:
InputType.InputTypeConvolutional firstConv = (InputType.InputTypeConvolutional) first;
CNN2DFormat format = firstConv.getFormat();
val fd = firstConv.getChannels();
val fw = firstConv.getWidth();
@ -206,7 +214,8 @@ public class MergeVertex extends GraphVertex {
depthSum += od;
}
return InputType.convolutional(fh, fw, depthSum);
this.mergeAxis = format == CNN2DFormat.NCHW ? 1 : 3;
return InputType.convolutional(fh, fw, depthSum, format);
}
}

View File

@ -20,6 +20,7 @@ import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.NoArgsConstructor;
import org.deeplearning4j.nn.conf.DataFormat;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.layers.Convolution3D;
@ -91,7 +92,11 @@ public abstract class InputType implements Serializable {
* @return InputTypeFeedForward
*/
public static InputType feedForward(long size) {
return new InputTypeFeedForward(size);
return new InputTypeFeedForward(size, null);
}
public static InputType feedForward(long size, DataFormat timeDistributedFormat) {
return new InputTypeFeedForward(size,timeDistributedFormat);
}
/**
@ -132,7 +137,6 @@ public abstract class InputType implements Serializable {
* @return InputTypeConvolutional
*/
public static InputType convolutional(long height, long width, long depth) {
// return new InputTypeConvolutional(height, width, depth);
return convolutional(height, width, depth, CNN2DFormat.NCHW);
}
@ -191,9 +195,11 @@ public abstract class InputType implements Serializable {
@EqualsAndHashCode(callSuper = false)
public static class InputTypeFeedForward extends InputType {
private long size;
private DataFormat timeDistributedFormat;
public InputTypeFeedForward(@JsonProperty("size") long size) {
public InputTypeFeedForward(@JsonProperty("size") long size, @JsonProperty("timeDistributedFormat") DataFormat timeDistributedFormat) {
this.size = size;
this.timeDistributedFormat = timeDistributedFormat;
}
@Override
@ -203,7 +209,7 @@ public abstract class InputType implements Serializable {
@Override
public String toString() {
return "InputTypeFeedForward(" + size + ")";
return "InputTypeFeedForward(" + size + (timeDistributedFormat != null ? "," + timeDistributedFormat : "") + ")";
}
@Override
@ -302,6 +308,7 @@ public abstract class InputType implements Serializable {
this.height = height;
this.width = width;
this.channels = channels;
if(format != null)
this.format = format;
}

View File

@ -64,11 +64,11 @@ public abstract class BaseRecurrentLayer extends FeedForwardLayer {
+ "\"): expect RNN input type with size > 0. Got: " + inputType);
}
if (nIn <= 0 || override) {
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
if (nIn <= 0 || override) {
this.nIn = r.getSize();
this.rnnDataFormat = r.getFormat();
}
this.rnnDataFormat = r.getFormat();
}
@Override

View File

@ -44,6 +44,7 @@ import java.util.Map;
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
public class Convolution1DLayer extends ConvolutionLayer {
private RNNFormat rnnDataFormat = RNNFormat.NCW;
/*
//TODO: We will eventually want to NOT subclass off of ConvolutionLayer.
//Currently, we just subclass off the ConvolutionLayer and hard code the "width" dimension to 1
@ -56,6 +57,7 @@ public class Convolution1DLayer extends ConvolutionLayer {
private Convolution1DLayer(Builder builder) {
super(builder);
initializeConstraints(builder);
this.rnnDataFormat = builder.rnnDataFormat;
}
@Override
@ -92,7 +94,8 @@ public class Convolution1DLayer extends ConvolutionLayer {
outLength = Convolution1DUtils.getOutputSize(inputTsLength, kernelSize[0], stride[0], padding[0],
convolutionMode, dilation[0]);
}
return InputType.recurrent(nOut, outLength);
return InputType.recurrent(nOut, outLength, rnnDataFormat);
}
@Override
@ -102,10 +105,11 @@ public class Convolution1DLayer extends ConvolutionLayer {
+ "\"): expect RNN input type with size > 0. Got: " + inputType);
}
if (nIn <= 0 || override) {
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
if (nIn <= 0 || override) {
this.nIn = r.getSize();
}
this.rnnDataFormat = r.getFormat();
}
@Override
@ -115,11 +119,13 @@ public class Convolution1DLayer extends ConvolutionLayer {
+ "\"): input is null");
}
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, RNNFormat.NCW,getLayerName());
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, rnnDataFormat,getLayerName());
}
public static class Builder extends ConvolutionLayer.BaseConvBuilder<Builder> {
private RNNFormat rnnDataFormat = RNNFormat.NCW;
public Builder() {
this(0, 1, 0);
this.setKernelSize((int[]) null);
@ -130,6 +136,11 @@ public class Convolution1DLayer extends ConvolutionLayer {
return true;
}
public Builder rnnDataFormat(RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
return this;
}
/**
* @param kernelSize Kernel size
* @param stride Stride

View File

@ -21,6 +21,7 @@ import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
@ -58,12 +59,14 @@ public class EmbeddingSequenceLayer extends FeedForwardLayer {
private int inputLength = 1; // By default only use one index to embed
private boolean hasBias = false;
private boolean inferInputLength = false; // use input length as provided by input data
private RNNFormat outputFormat = RNNFormat.NCW; //Default value for older deserialized models
private EmbeddingSequenceLayer(Builder builder) {
super(builder);
this.hasBias = builder.hasBias;
this.inputLength = builder.inputLength;
this.inferInputLength = builder.inferInputLength;
this.outputFormat = builder.outputFormat;
initializeConstraints(builder);
}
@ -87,7 +90,7 @@ public class EmbeddingSequenceLayer extends FeedForwardLayer {
throw new IllegalStateException("Invalid input for Embedding layer (layer index = " + layerIndex
+ ", layer name = \"" + getLayerName() + "\"): expect FF/RNN input type. Got: " + inputType);
}
return InputType.recurrent(nOut, inputLength);
return InputType.recurrent(nOut, inputLength, outputFormat);
}
@Override
@ -167,6 +170,13 @@ public class EmbeddingSequenceLayer extends FeedForwardLayer {
*/
private boolean inferInputLength = true;
private RNNFormat outputFormat = RNNFormat.NCW; //Default value for older deserialized models
public Builder outputDataFormat(RNNFormat format){
this.outputFormat = format;
return this;
}
/**
* If true: include bias parameters in the layer. False (default): no bias.
*

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.nn.conf.layers;
import lombok.*;
import org.deeplearning4j.nn.conf.DataFormat;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.preprocessor.Cnn3DToFeedForwardPreProcessor;
@ -35,6 +36,7 @@ public abstract class FeedForwardLayer extends BaseLayer {
protected long nIn;
protected long nOut;
protected DataFormat timeDistributedFormat;
public FeedForwardLayer(Builder builder) {
super(builder);
@ -51,7 +53,7 @@ public abstract class FeedForwardLayer extends BaseLayer {
+ getLayerName() + "\"): expected FeedForward input type. Got: " + inputType);
}
return InputType.feedForward(nOut);
return InputType.feedForward(nOut, timeDistributedFormat);
}
@Override
@ -71,6 +73,11 @@ public abstract class FeedForwardLayer extends BaseLayer {
this.nIn = f.getFlattenedSize();
}
}
if(inputType instanceof InputType.InputTypeFeedForward){
InputType.InputTypeFeedForward f = (InputType.InputTypeFeedForward) inputType;
this.timeDistributedFormat = f.getTimeDistributedFormat();
}
}
@Override

View File

@ -536,11 +536,17 @@ public class InputTypeUtil {
}
switch (inputType.getType()) {
case FF:
case CNNFlat:
//FF -> RNN or CNNFlat -> RNN
//In either case, input data format is a row vector per example
return new FeedForwardToRnnPreProcessor(rnnDataFormat);
case FF:
//If time distributed format is defined, use that. Otherwise use the layer-defined rnnDataFormat, which may be default
InputType.InputTypeFeedForward ff = (InputType.InputTypeFeedForward)inputType;
if(ff.getTimeDistributedFormat() != null && ff.getTimeDistributedFormat() instanceof RNNFormat){
return new FeedForwardToRnnPreProcessor((RNNFormat) ff.getTimeDistributedFormat());
}
return new FeedForwardToRnnPreProcessor(rnnDataFormat);
case RNN:
//RNN -> RNN: No preprocessor necessary
return null;

View File

@ -98,9 +98,9 @@ public class RnnOutputLayer extends BaseOutputLayer {
+ "\"): Expected RNN input, got " + inputType);
}
if (nIn <= 0 || override) {
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
this.rnnDataFormat = r.getFormat();
if (nIn <= 0 || override) {
this.nIn = r.getSize();
}
}

View File

@ -91,7 +91,7 @@ public class Subsampling1DLayer extends SubsamplingLayer {
outLength = Convolution1DUtils.getOutputSize(inputTsLength, kernelSize[0], stride[0], padding[0],
convolutionMode, dilation[0]);
}
return InputType.recurrent(r.getSize(), outLength);
return InputType.recurrent(r.getSize(), outLength, r.getFormat());
}
@Override

View File

@ -19,6 +19,7 @@ package org.deeplearning4j.nn.conf.layers.misc;
import lombok.*;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
@ -46,10 +47,12 @@ import java.util.Map;
public class RepeatVector extends FeedForwardLayer {
private int n = 1;
private RNNFormat dataFormat = RNNFormat.NCW;
protected RepeatVector(Builder builder) {
super(builder);
this.n = builder.n;
this.dataFormat = builder.dataFormat;
}
@Override
@ -83,7 +86,7 @@ public class RepeatVector extends FeedForwardLayer {
+ "\"): Expected FF input, got " + inputType);
}
InputType.InputTypeFeedForward ffInput = (InputType.InputTypeFeedForward) inputType;
return InputType.recurrent(ffInput.getSize(), n);
return InputType.recurrent(ffInput.getSize(), n, this.dataFormat);
}
@Override
@ -101,13 +104,14 @@ public class RepeatVector extends FeedForwardLayer {
}
@NoArgsConstructor
@Getter
@Setter
public static class Builder<T extends Builder<T>> extends FeedForwardLayer.Builder<T> {
private int n = 1; // no repetition by default
private RNNFormat dataFormat = RNNFormat.NCW;
/**
* Set repetition factor for RepeatVector layer
*/
@ -115,6 +119,15 @@ public class RepeatVector extends FeedForwardLayer {
return n;
}
public RNNFormat getDataFormat(){
return dataFormat;
}
public Builder dataFormat(RNNFormat dataFormat){
this.dataFormat = dataFormat;
return this;
}
/**
* Set repetition factor for RepeatVector layer
*

View File

@ -39,11 +39,13 @@ import java.util.Arrays;
* For example, CNN -> Denselayer <br>
* This does two things:<br>
* (b) Reshapes 4d activations out of CNN layer, with shape
* [numExamples, numChannels, inputHeight, inputWidth]) into 2d activations (with shape
* [numExamples, inputHeight*inputWidth*numChannels]) for use in feed forward layer
* [numExamples, numChannels, inputHeight, inputWidth]) (for {@link CNN2DFormat#NCHW} format activations) or shape
* [numExamples, inputHeight, inputWidth, numChannels] (for {@link CNN2DFormat#NHWC}) format activations) into 2d activations
* (with shape [numExamples, inputHeight*inputWidth*numChannels]) for use in feed forward layer.
* (a) Reshapes epsilons (weights*deltas) out of FeedFoward layer (which is 2D or 3D with shape
* [numExamples, inputHeight*inputWidth*numChannels]) into 4d epsilons (with shape
* [numExamples, numChannels, inputHeight, inputWidth]) suitable to feed into CNN layers.<br>
* [numExamples, numChannels, inputHeight, inputWidth] or [numExamples, inputHeight, inputWidth, numChannels]) suitable to
* feed into CNN layers.<br>
* Note: numChannels is equivalent to channels or featureMaps referenced in different literature
* @author Adam Gibson
* @see FeedForwardToCnnPreProcessor for opposite case (i.e., DenseLayer -> CNNetc)
@ -68,6 +70,7 @@ public class CnnToFeedForwardPreProcessor implements InputPreProcessor {
this.inputHeight = inputHeight;
this.inputWidth = inputWidth;
this.numChannels = numChannels;
if(format != null)
this.format = format;
}
@ -96,10 +99,17 @@ public class CnnToFeedForwardPreProcessor implements InputPreProcessor {
wDim = 2;
}
if(inputHeight == 0 && inputWidth == 0 && numChannels == 0){
this.inputHeight = input.size(hDim);
this.inputWidth = input.size(wDim);
this.numChannels = input.size(chDim);
}
if(input.size(chDim) != numChannels || input.size(hDim) != inputHeight || input.size(wDim) != inputWidth){
throw new IllegalStateException("Invalid input, does not match configuration: expected [minibatch, numChannels="
+ numChannels + ", inputHeight=" + inputHeight + ", inputWidth=" + inputWidth + "] but got input array of" +
"shape " + Arrays.toString(input.shape()));
throw new IllegalStateException("Invalid input, does not match configuration: expected " +
(format == CNN2DFormat.NCHW ? "[minibatch, numChannels=" + numChannels + ", inputHeight=" + inputHeight + ", inputWidth=" + inputWidth + "] " :
"[minibatch, inputHeight=" + inputHeight + ", inputWidth=" + inputWidth + ", numChannels=" + numChannels + "]") +
" but got input array of shape " + Arrays.toString(input.shape()));
}
//Check input: nchw format
@ -110,15 +120,13 @@ public class CnnToFeedForwardPreProcessor implements InputPreProcessor {
+ Arrays.toString(input.shape()));
}
if(format == CNN2DFormat.NHWC) {
input = input.permute(0, 3, 1, 2); //NHWC to NCHW
}
//Assume input is standard rank 4 activations out of CNN layer
//First: we require input to be in c order. But c order (as declared in array order) isn't enough; also need strides to be correct
if (input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input))
input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c');
//Note that to match Tensorflow/Keras, we do a simple "c order reshape" for both NCHW and NHWC
val inShape = input.shape(); //[miniBatch,depthOut,outH,outW]
val outShape = new long[]{inShape[0], inShape[1] * inShape[2] * inShape[3]};
@ -139,11 +147,13 @@ public class CnnToFeedForwardPreProcessor implements InputPreProcessor {
+ inputHeight + " x columns " + inputWidth + " x channels " + numChannels + " but was instead "
+ Arrays.toString(epsilons.shape()));
INDArray ret = epsilons.reshape('c', epsilons.size(0), numChannels, inputHeight, inputWidth);
if(format == CNN2DFormat.NHWC){
ret = ret.permute(0,2,3,1); //NCHW to NHWC
INDArray ret;
if(format == CNN2DFormat.NCHW){
ret = epsilons.reshape('c', epsilons.size(0), numChannels, inputHeight, inputWidth);
} else {
ret = epsilons.reshape('c', epsilons.size(0), inputHeight, inputWidth, numChannels);
}
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, ret); //Move if required to specified workspace
}

View File

@ -52,6 +52,7 @@ public class FeedForwardToRnnPreProcessor implements InputPreProcessor {
private RNNFormat rnnDataFormat = RNNFormat.NCW;
public FeedForwardToRnnPreProcessor(@JsonProperty("rnnDataFormat") RNNFormat rnnDataFormat){
if(rnnDataFormat != null)
this.rnnDataFormat = rnnDataFormat;
}
@Override

View File

@ -57,6 +57,7 @@ public class RnnToFeedForwardPreProcessor implements InputPreProcessor {
private RNNFormat rnnDataFormat = RNNFormat.NCW;
public RnnToFeedForwardPreProcessor(@JsonProperty("rnnDataFormat") RNNFormat rnnDataFormat){
if(rnnDataFormat != null)
this.rnnDataFormat = rnnDataFormat;
}
@Override
@ -116,7 +117,7 @@ public class RnnToFeedForwardPreProcessor implements InputPreProcessor {
}
InputType.InputTypeRecurrent rnn = (InputType.InputTypeRecurrent) inputType;
return InputType.feedForward(rnn.getSize());
return InputType.feedForward(rnn.getSize(), rnn.getFormat());
}
@Override

View File

@ -0,0 +1,52 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.nn.conf.serde.format;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.DataFormat;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.nd4j.shade.jackson.core.JsonParser;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.nd4j.shade.jackson.databind.DeserializationContext;
import org.nd4j.shade.jackson.databind.JsonDeserializer;
import org.nd4j.shade.jackson.databind.JsonNode;
import java.io.IOException;
/**
* Simple JSON deserializer for {@link DataFormat} instances - {@link CNN2DFormat} and {@link RNNFormat}
*
* @author Alex Black
*/
public class DataFormatDeserializer extends JsonDeserializer<DataFormat> {
@Override
public DataFormat deserialize(JsonParser jp, DeserializationContext deserializationContext) throws IOException, JsonProcessingException {
JsonNode node = jp.getCodec().readTree(jp);
String text = node.textValue();
switch (text){
case "NCHW":
return CNN2DFormat.NCHW;
case "NHWC":
return CNN2DFormat.NHWC;
case "NCW":
return RNNFormat.NCW;
case "NWC":
return RNNFormat.NWC;
default:
return null;
}
}
}

View File

@ -0,0 +1,37 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.nn.conf.serde.format;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.DataFormat;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.nd4j.shade.jackson.core.JsonGenerator;
import org.nd4j.shade.jackson.databind.JsonSerializer;
import org.nd4j.shade.jackson.databind.SerializerProvider;
import java.io.IOException;
/**
* Simple JSON deserializer for {@link DataFormat} instances - {@link CNN2DFormat} and {@link RNNFormat}
*
* @author Alex Black
*/
public class DataFormatSerializer extends JsonSerializer<DataFormat> {
@Override
public void serialize(DataFormat dataFormat, JsonGenerator jsonGenerator, SerializerProvider serializerProvider) throws IOException {
jsonGenerator.writeString(dataFormat.toString());
}
}

View File

@ -28,6 +28,7 @@ import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.primitives.Pair;
import org.deeplearning4j.nn.workspace.ArrayType;
@ -48,14 +49,16 @@ public class MergeVertex extends BaseGraphVertex {
private long[][] forwardPassShapes;
private int fwdPassRank;
private int mergeAxis;
public MergeVertex(ComputationGraph graph, String name, int vertexIndex, DataType dataType) {
this(graph, name, vertexIndex, null, null, dataType);
public MergeVertex(ComputationGraph graph, String name, int vertexIndex, DataType dataType, int mergeAxis) {
this(graph, name, vertexIndex, null, null, dataType, mergeAxis);
}
public MergeVertex(ComputationGraph graph, String name, int vertexIndex, VertexIndices[] inputVertices,
VertexIndices[] outputVertices, DataType dataType) {
VertexIndices[] outputVertices, DataType dataType, int mergeAxis) {
super(graph, name, vertexIndex, inputVertices, outputVertices, dataType);
this.mergeAxis = mergeAxis;
}
@Override
@ -92,7 +95,6 @@ public class MergeVertex extends BaseGraphVertex {
forwardPassShapes = new long[in.length][0];
val nExamples = in[0].size(0);
int nOut = 0;
fwdPassRank = in[0].rank();
for (int i = 0; i < in.length; i++) {
val currShape = in[i].shape();
@ -109,12 +111,11 @@ public class MergeVertex extends BaseGraphVertex {
+ Arrays.toString(in[0].shape()) + ", activations[" + i
+ "] shape: " + Arrays.toString(in[i].shape()));
}
nOut += currShape[1]; //Same dimension for all of CNNs, FF, RNNs
}
try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATIONS)){
return Nd4j.concat(1, in);
INDArray out = Nd4j.concat(mergeAxis, in);
return out;
}
}
@ -145,20 +146,16 @@ public class MergeVertex extends BaseGraphVertex {
break;
case 3:
for (int i = 0; i < forwardPassShapes.length; i++) {
out[i].assign(epsilon.get(NDArrayIndex.all(), //All rows
NDArrayIndex.interval(cumulative, cumulative + forwardPassShapes[i][1]), //subset of columns
NDArrayIndex.all())); //All time steps
out[i].assign(epsilon.get(indices(3, mergeAxis, cumulative, cumulative + forwardPassShapes[i][mergeAxis]))); //All time steps
cumulative += forwardPassShapes[i][1];
cumulative += forwardPassShapes[i][mergeAxis];
}
break;
case 4:
for (int i = 0; i < forwardPassShapes.length; i++) {
out[i].assign(epsilon.get(NDArrayIndex.all(),
NDArrayIndex.interval(cumulative, cumulative + forwardPassShapes[i][1]), //Subset of depth
NDArrayIndex.all(), //Width
NDArrayIndex.all())); //height
cumulative += forwardPassShapes[i][1];
out[i].assign(epsilon.get(indices(4, mergeAxis, cumulative, cumulative + forwardPassShapes[i][mergeAxis]))); //height
cumulative += forwardPassShapes[i][mergeAxis];
}
break;
default:
@ -168,6 +165,19 @@ public class MergeVertex extends BaseGraphVertex {
return new Pair<>(null, out);
}
private INDArrayIndex[] indices(int num, int axis, long from, long to){
INDArrayIndex[] out = new INDArrayIndex[num];
for( int i=0; i<num; i++ ){
if(i == axis){
out[i] = NDArrayIndex.interval(from, to);
} else {
out[i] = NDArrayIndex.all();
}
}
return out;
}
@Override
public void setBackpropGradientsViewArray(INDArray backpropGradientsViewArray) {
if (backpropGradientsViewArray != null)

View File

@ -79,7 +79,8 @@ public abstract class BaseOutputLayer<LayerConfT extends org.deeplearning4j.nn.c
ILossFunction lossFunction = layerConf().getLossFn();
double score = lossFunction.computeScore(getLabels2d(workspaceMgr, ArrayType.FF_WORKING_MEM), preOut,
INDArray labels2d = getLabels2d(workspaceMgr, ArrayType.FF_WORKING_MEM);
double score = lossFunction.computeScore(labels2d, preOut,
layerConf().getActivationFn(), maskArray,false);
if(conf().isMiniBatch())

View File

@ -21,6 +21,7 @@ import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.convolution.upsampling.Upsampling2D;
@ -71,7 +72,11 @@ public class RepeatVector extends AbstractLayer<org.deeplearning4j.nn.conf.layer
INDArray outEpsilon;
try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATION_GRAD)){
if (layerConf().getDataFormat() == RNNFormat.NCW) {
outEpsilon = epsilon.sum(2);
}else{
outEpsilon = epsilon.sum(1);
}
}
Gradient gradient = new DefaultGradient();
@ -99,11 +104,23 @@ public class RepeatVector extends AbstractLayer<org.deeplearning4j.nn.conf.layer
long miniBatch = input.size(0);
long size = input.size(1);
if (getDataFormat() == RNNFormat.NCW) {
INDArray output = input.reshape(miniBatch, size, 1).castTo(dataType);
try (MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATIONS)) {
return output.repeat(2, (long) getN());
}
}
else{
INDArray output = input.reshape(miniBatch, 1, size).castTo(dataType);
try (MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATIONS)) {
return output.repeat(1, (long) getN());
}
}
}
public RNNFormat getDataFormat(){
return layerConf().getDataFormat();
}
@Override
public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) {

View File

@ -20,6 +20,8 @@ import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.layers.Convolution1D;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.params.ConvolutionParamInitializer;
@ -74,6 +76,10 @@ public class Convolution1DLayer extends ConvolutionLayer {
+ Arrays.toString(epsilon.shape())
+ ". Expected rank 3 array with shape [minibatchSize, features, length]. " + layerId());
if (getRnnDataFormat() == RNNFormat.NWC){
epsilon = epsilon.permute(0, 2, 1);
this.input = input.permute(0, 2, 1);
}
if(maskArray != null){
INDArray maskOut = feedForwardMaskArray(maskArray, MaskState.Active, (int)epsilon.size(0)).getFirst();
Preconditions.checkState(epsilon.size(0) == maskOut.size(0) && epsilon.size(2) == maskOut.size(1),
@ -125,6 +131,10 @@ public class Convolution1DLayer extends ConvolutionLayer {
retGradient.setGradientFor(ConvolutionParamInitializer.BIAS_KEY, gradientViews.get(ConvolutionParamInitializer.BIAS_KEY));
}
retGradient.setGradientFor(ConvolutionParamInitializer.WEIGHT_KEY, gradientViews.get(ConvolutionParamInitializer.WEIGHT_KEY), 'c');
if (getRnnDataFormat() == RNNFormat.NWC){
epsOut = epsOut.permute(0, 2, 1);
this.input = input.permute(0, 2, 1);
}
return new Pair<>(retGradient, epsOut);
}
@ -140,7 +150,10 @@ public class Convolution1DLayer extends ConvolutionLayer {
// remove singleton fourth dimension from input and current epsilon
epsNext = epsNext.reshape(epsNext.size(0), epsNext.size(1), epsNext.size(2));
input = origInput;
if (getRnnDataFormat() == RNNFormat.NWC){
epsNext = epsNext.permute(0, 2, 1);
this.input = input.permute(0, 2, 1);
}
return new Pair<>(gradientEpsNext.getFirst(), epsNext);
}
@ -185,7 +198,8 @@ public class Convolution1DLayer extends ConvolutionLayer {
.s(c.getStride()[0])
.d(c.getDilation()[0])
.p(c.getPadding()[0])
.dataFormat(Conv1DConfig.NCW)
.dataFormat((((org.deeplearning4j.nn.conf.layers.Convolution1DLayer)
layerConf()).getRnnDataFormat()== RNNFormat.NCW)?Conv1DConfig.NCW: Conv1DConfig.NCW)
.paddingMode(PaddingMode.CAUSAL)
.build();
INDArray w = getParam(ConvolutionParamInitializer.WEIGHT_KEY);
@ -209,6 +223,9 @@ public class Convolution1DLayer extends ConvolutionLayer {
@Override
public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr){
if (getRnnDataFormat() == RNNFormat.NWC){
this.input = input.permute(0, 2, 1);
}
INDArray act4d = super.activate(training, workspaceMgr);
INDArray act3d = act4d.reshape(act4d.size(0), act4d.size(1), act4d.size(2));
@ -219,6 +236,10 @@ public class Convolution1DLayer extends ConvolutionLayer {
act3d.shape(), maskOut.shape());
Broadcast.mul(act3d, maskOut, act3d, 0, 2);
}
if (getRnnDataFormat() == RNNFormat.NWC){
this.input = input.permute(0, 2, 1);
act3d = act3d.permute(0, 2, 1);
}
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, act3d); //Should be zero copy most of the time
}
@ -231,4 +252,8 @@ public class Convolution1DLayer extends ConvolutionLayer {
layerConf().getConvolutionMode());
return new Pair<>(reduced, currentMaskState);
}
private RNNFormat getRnnDataFormat(){
return ((org.deeplearning4j.nn.conf.layers.Convolution1DLayer) layerConf()).getRnnDataFormat();
}
}

View File

@ -160,7 +160,8 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
Pair<INDArray, INDArray> p = preOutput4d(true, true, workspaceMgr);
INDArray z = p.getFirst();
if(layerConf().getCnn2dDataFormat() != CNN2DFormat.NCHW){
CNN2DFormat f = layerConf().getCnn2dDataFormat();
if(f != CNN2DFormat.NCHW){
z = z.permute(0,3,1,2); //NHWC to NCHW
}
delta = afn.backprop(z, epsilon).getFirst(); //TODO handle activation function params

View File

@ -20,6 +20,7 @@ import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer;
@ -64,8 +65,14 @@ public class EmbeddingSequenceLayer extends BaseLayer<org.deeplearning4j.nn.conf
INDArray z = preOutput(true, workspaceMgr);
INDArray delta = layerConf().getActivationFn().backprop(z, epsilon).getFirst(); //Shape: [mb, vector, seqLength]
boolean ncw = layerConf().getOutputFormat() == RNNFormat.NCW;
if (maskArray != null) {
if(ncw){
delta = Broadcast.mul(delta, maskArray, delta, 0, 2);
} else {
delta = Broadcast.mul(delta, maskArray, delta, 0, 1);
}
}
int inputLength = layerConf().getInputLength();
@ -76,7 +83,10 @@ public class EmbeddingSequenceLayer extends BaseLayer<org.deeplearning4j.nn.conf
delta = delta.dup('c');
}
if(ncw){
delta = delta.permute(0, 2, 1); //From [minibatch, nOut, length] to [minibatch, length, nOut]
}
delta = delta.reshape('c',inputLength * numSamples, nOut);
INDArray weightGradients = gradientViews.get(DefaultParamInitializer.WEIGHT_KEY);
@ -159,7 +169,10 @@ public class EmbeddingSequenceLayer extends BaseLayer<org.deeplearning4j.nn.conf
}
val shape = new long[]{minibatch, inputLength, nOut};
INDArray ret = rows.reshape('c', shape).permute(0, 2, 1);
INDArray ret = rows.reshape('c', shape);
if(layerConf().getOutputFormat() == RNNFormat.NCW){
ret = ret.permute(0, 2, 1); //[minibatch, seqLen, nOut] -> [minibatch, nOut, seqLen] i.e., NWC -> NCW
}
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, ret);
}
@ -177,8 +190,14 @@ public class EmbeddingSequenceLayer extends BaseLayer<org.deeplearning4j.nn.conf
" 2 (when input is rank 3, shape [mb,1,tsLength]). Input shape: " + Arrays.toString(input.shape()) +
", mask shape: " + Arrays.toString(maskArray.shape()));
}
boolean ncw = layerConf().getOutputFormat() == RNNFormat.NCW;
if(ncw){
//Returned array: rank 3, shape [mb, vector, seqLength]. mask shape: [mb, seqLength]
Broadcast.mul(ret, maskArray.castTo(ret.dataType()), ret, 0, 2);
} else {
//Returned array: rank 3, shape [mb, seqLength, vector]. mask shape: [mb, seqLength]
Broadcast.mul(ret, maskArray.castTo(ret.dataType()), ret, 0, 1);
}
}
return ret;
}

View File

@ -20,11 +20,13 @@ import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.LayerHelper;
import org.deeplearning4j.nn.layers.mkldnn.BaseMKLDNNHelper;
import org.deeplearning4j.nn.layers.mkldnn.MKLDNNLSTMHelper;
import org.deeplearning4j.nn.params.LSTMParamInitializer;
import org.deeplearning4j.util.TimeSeriesUtils;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -152,6 +154,14 @@ public class LSTM extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.layers.L
assertInputSet(false);
Preconditions.checkState(input.rank() == 3,
"3D input expected to RNN layer expected, got " + input.rank());
boolean nwc = TimeSeriesUtils.getFormatFromRnnLayer(layerConf()) == RNNFormat.NWC;
INDArray origInput = input;
if(nwc){
input = permuteIfNWC(input);
}
applyDropOutIfNecessary(training, workspaceMgr);
//TODO LSTM cache mode is disabled for now - not passing all tests
@ -166,7 +176,6 @@ public class LSTM extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.layers.L
final INDArray recurrentWeights = getParamWithNoise(LSTMParamInitializer.RECURRENT_WEIGHT_KEY, training, workspaceMgr); //Shape: [hiddenLayerSize,4*hiddenLayerSize+3]; order: [wI,wF,wO,wG,wFF,wOO,wGG]
final INDArray inputWeights = getParamWithNoise(LSTMParamInitializer.INPUT_WEIGHT_KEY, training, workspaceMgr); //Shape: [n^(L-1),4*hiddenLayerSize]; order: [wi,wf,wo,wg]
final INDArray biases = getParamWithNoise(LSTMParamInitializer.BIAS_KEY, training, workspaceMgr); //by row: IFOG //Shape: [4,hiddenLayerSize]; order: [bi,bf,bo,bg]^T
INDArray input = permuteIfNWC(this.input);
FwdPassReturn fwd = LSTMHelpers.activateHelper(this, this.conf, this.layerConf().getGateActivationFn(),
input, recurrentWeights, inputWeights, biases, training, prevOutputActivations,
prevMemCellState, (training && cacheMode != CacheMode.NONE) || forBackprop, true,
@ -178,6 +187,11 @@ public class LSTM extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.layers.L
if (training && cacheMode != CacheMode.NONE) {
cachedFwdPass = fwd;
}
if(nwc){
input = origInput;
}
return fwd;
}

View File

@ -61,11 +61,8 @@ public class LastTimeStepLayer extends BaseWrapperLayer {
@Override
public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
long[] newEpsShape = origOutputShape;
boolean nwc = (underlying instanceof BaseRecurrentLayer &&
((BaseRecurrentLayer) underlying).getDataFormat() == RNNFormat.NWC)||
(underlying instanceof MaskZeroLayer && ((MaskZeroLayer)underlying).getUnderlying() instanceof
BaseRecurrentLayer && ((BaseRecurrentLayer)((MaskZeroLayer)underlying).getUnderlying()).getDataFormat()
== RNNFormat.NWC);
boolean nwc = TimeSeriesUtils.getFormatFromRnnLayer(underlying.conf().getLayer()) == RNNFormat.NWC;
INDArray newEps = Nd4j.create(epsilon.dataType(), newEpsShape, 'f');
if(lastTimeStepIdxs == null){
//no mask case

View File

@ -58,7 +58,8 @@ public class RnnOutputLayer extends BaseOutputLayer<org.deeplearning4j.nn.conf.l
"Input is not rank 3. RnnOutputLayer expects rank 3 input with shape [minibatch, layerInSize, sequenceLength]." +
" Got input with rank " + input.rank() + " and shape " + Arrays.toString(input.shape()) + " - " + layerId());
}
int td = (layerConf().getRnnDataFormat()==RNNFormat.NCW)? 2: 1;
RNNFormat format = layerConf().getRnnDataFormat();
int td = (format == RNNFormat.NCW) ? 2 : 1;
Preconditions.checkState(labels.rank() == 3, "Expected rank 3 labels array, got label array with shape %ndShape", labels);
Preconditions.checkState(input.size(td) == labels.size(td), "Sequence lengths do not match for RnnOutputLayer input and labels:" +
"Arrays should be rank 3 with shape [minibatch, size, sequenceLength] - mismatch on dimension 2 (sequence length) - input=%ndShape vs. label=%ndShape", input, labels);

View File

@ -118,6 +118,7 @@ public class ConvolutionParamInitializer implements ParamInitializer {
params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams));
conf.addVariable(WEIGHT_KEY);
conf.addVariable(BIAS_KEY);
conf.addVariable(BIAS_KEY);
} else {
INDArray weightView = paramsView;
params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams));

View File

@ -34,6 +34,7 @@ import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.stats.StatsListener;
import org.deeplearning4j.ui.storage.InMemoryStatsStorage;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
@ -53,7 +54,7 @@ import static org.junit.Assert.*;
/**
* @author Tamas Fenyvesi
*/
@Slf4j
@Slf4j @Ignore //https://github.com/eclipse/deeplearning4j/issues/8891
public class TestVertxUIMultiSession extends BaseDL4JTest {
@Before