tf.keras import test and fixes (#347)

* merge conf

* merge conf

* tfkeras tests

* parameterized tests

* rename

* cuda versions

* jccp versions

* 'updates'

* updates

* rnn+mlp passing

* repeat

* updates

* tests

* Update pom.xml

* Update pom.xml

* rem print

* cnn1d model conversion fixed

* cnn1d activate fixed

* cnn1d outptut shape fix

* cnn1d bprop fix

* cnn1d stack fix

* KerasModelEndToEndTest - Remove permutes for NWC and NHWC format tests

Signed-off-by: Alex Black <blacka101@gmail.com>

* Fixes and update test - input shapes (NCHW -> NHWC input)

Signed-off-by: Alex Black <blacka101@gmail.com>

* Ignore for known bad tests

Signed-off-by: Alex Black <blacka101@gmail.com>

* Multiple fixes - MergeVertex, CNN1D layers, etc

Signed-off-by: Alex Black <blacka101@gmail.com>

* Fix issue with RNN/FF preprocessors, time distributed etc with NWC format

Signed-off-by: Alex Black <blacka101@gmail.com>

* LSTM NWC dropout fix

Signed-off-by: Alex Black <blacka101@gmail.com>

* Add sequence embedding layer NWC support (configurable output format)

Signed-off-by: Alex Black <blacka101@gmail.com>

* Fix expected shape in a couple of tests - NWC expected

Signed-off-by: Alex Black <blacka101@gmail.com>

* Fix EmbeddingSequenceLayer backprop for NWC output case + add gradient checks

Signed-off-by: Alex Black <blacka101@gmail.com>

* CnnToFeedForwardPreprocessor: align with Keras/TF; fix Keras reshape/flatten

Signed-off-by: Alex Black <blacka101@gmail.com>

* Update ConvDataFormatTests to match new reshape behaviour

Signed-off-by: Alex Black <blacka101@gmail.com>

* Switch hard-coded path to ResourceUtils.listClassPathfiles for TestTFKerasModelImport

Signed-off-by: Alex Black <blacka101@gmail.com>

* TestUtils fix

Signed-off-by: Alex Black <blacka101@gmail.com>

* Fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* Fix JSON serde issue with data formats

Signed-off-by: Alex Black <blacka101@gmail.com>

* Fix for input dtype inference; fix 2 tests

Signed-off-by: Alex Black <blacka101@gmail.com>

* Test fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* #8891 Ignore for TestVertxUIMultiSession until fixed

Signed-off-by: Alex Black <blacka101@gmail.com>

* Restore but deprecate TensorFlowCnnToFeedForwardPreProcessor for older zoo models

Signed-off-by: Alex Black <blacka101@gmail.com>

* Ignore for deprecated preprocessor in DTypeTests

Signed-off-by: Alex Black <blacka101@gmail.com>

* Remove debug printlns

Signed-off-by: Alex Black <blacka101@gmail.com>

Co-authored-by: Alex Black <blacka101@gmail.com>
master
Fariz Rahman 2020-04-28 14:31:09 +04:00 committed by GitHub
parent b9d5f1645b
commit 4cb87a94e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
67 changed files with 1336 additions and 438 deletions

View File

@ -20,6 +20,7 @@ package org.datavec.python;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils; import org.apache.commons.io.IOUtils;
import org.bytedeco.cpython.global.python;
import org.bytedeco.numpy.global.numpy; import org.bytedeco.numpy.global.numpy;
import org.nd4j.linalg.api.concurrency.AffinityManager; import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -343,6 +344,19 @@ public class PythonExecutioner {
if (path == null) { if (path == null) {
log.info("Setting python default path"); log.info("Setting python default path");
File[] packages = numpy.cachePackages(); File[] packages = numpy.cachePackages();
//// TODO: fix in javacpp
File sitePackagesWindows = new File(python.cachePackage(), "site-packages");
File[] packages2 = new File[packages.length + 1];
for (int i = 0;i < packages.length; i++){
//System.out.println(packages[i].getAbsolutePath());
packages2[i] = packages[i];
}
packages2[packages.length] = sitePackagesWindows;
//System.out.println(sitePackagesWindows.getAbsolutePath());
packages = packages2;
//////////
Py_SetPath(packages); Py_SetPath(packages);
} else { } else {
log.info("Setting python path " + path); log.info("Setting python path " + path);

View File

@ -0,0 +1,132 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.python;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.bytedeco.javacpp.Loader;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
@Slf4j
public class PythonProcess {
private static String pythonExecutable = Loader.load(org.bytedeco.cpython.python.class);
public static String runAndReturn(String... arguments)throws IOException, InterruptedException{
String[] allArgs = new String[arguments.length + 1];
for (int i = 0; i < arguments.length; i++){
allArgs[i + 1] = arguments[i];
}
allArgs[0] = pythonExecutable;
log.info("Executing command: " + Arrays.toString(allArgs));
ProcessBuilder pb = new ProcessBuilder(allArgs);
Process process = pb.start();
String out = IOUtils.toString(process.getInputStream(), StandardCharsets.UTF_8);
process.waitFor();
return out;
}
public static void run(String... arguments)throws IOException, InterruptedException{
String[] allArgs = new String[arguments.length + 1];
for (int i = 0; i < arguments.length; i++){
allArgs[i + 1] = arguments[i];
}
allArgs[0] = pythonExecutable;
log.info("Executing command: " + Arrays.toString(allArgs));
ProcessBuilder pb = new ProcessBuilder(allArgs);
pb.inheritIO().start().waitFor();
}
public static void pipInstall(String packageName) throws PythonException{
try{
run("-m", "pip", "install", packageName);
}catch(Exception e){
throw new PythonException("Error installing package " + packageName, e);
}
}
public static void pipInstall(String packageName, String version) throws PythonException{
pipInstall(packageName + "==" + version);
}
public static void pipUninstall(String packageName) throws PythonException{
try{
run("-m", "pip", "uninstall", packageName);
}catch(Exception e){
throw new PythonException("Error uninstalling package " + packageName, e);
}
}
public static void pipInstallFromGit(String gitRepoUrl) throws PythonException{
if (!gitRepoUrl.contains("://")){
gitRepoUrl = "git://" + gitRepoUrl;
}
try{
run("-m", "pip", "install", "git+", gitRepoUrl);
}catch(Exception e){
throw new PythonException("Error installing package from " + gitRepoUrl, e);
}
}
public static String getPackageVersion(String packageName) throws PythonException{
String out;
try{
out = runAndReturn("-m", "pip", "show", packageName);
} catch (Exception e){
throw new PythonException("Error finding version for package " + packageName, e);
}
if (!out.contains("Version: ")){
throw new PythonException("Can't find package " + packageName);
}
String pkgVersion = out.split("Version: ")[1].split(System.lineSeparator())[0];
return pkgVersion;
}
public static boolean isPackageInstalled(String packageName)throws PythonException{
try{
String out = runAndReturn("-m", "pip", "show", packageName);
return !out.isEmpty();
}catch (Exception e){
throw new PythonException("Error checking if package is installed: " +packageName, e);
}
}
public static void pipInstallFromRequirementsTxt(String path) throws PythonException{
try{
run("-m", "pip", "install","-r", path);
}catch (Exception e){
throw new PythonException("Error installing packages from " + path, e);
}
}
public static void pipInstallFromSetupScript(String path, boolean inplace) throws PythonException{
try{
run(path, inplace?"develop":"install");
}catch (Exception e){
throw new PythonException("Error installing package from " + path, e);
}
}
}

View File

@ -0,0 +1,144 @@
package org.datavec.python.keras;
import org.datavec.python.Python;
import org.datavec.python.PythonException;
import org.datavec.python.PythonObject;
import org.datavec.python.PythonProcess;
import org.nd4j.linalg.api.ndarray.INDArray;
public class Model {
private PythonObject pyModel;
private static PythonObject installAndImportTF() throws PythonException{
if (!PythonProcess.isPackageInstalled("tensorflow")){
PythonProcess.pipInstall("tensorflow");
}
return Python.importModule("tensorflow");
}
private static PythonObject getKerasModule() throws PythonException{
PythonObject tf = installAndImportTF();
PythonObject keras = tf.attr("keras");
tf.del();
return keras;
}
private static PythonObject loadModel(String s) throws PythonException{
PythonObject models = getKerasModule().attr("models");
PythonObject loadModelF = models.attr("load_model");
PythonObject model = loadModelF.call(s);
models.del();
loadModelF.del();
return model;
}
public Model(String path) throws PythonException{
pyModel = loadModel(path);
}
public INDArray[] predict(INDArray... inputs) throws PythonException{
PythonObject predictF = pyModel.attr("predict");
PythonObject inputList = new PythonObject(inputs);
PythonObject pyOut = predictF.call(inputList);
INDArray[] out;
if (Python.isinstance(pyOut, Python.listType())){
out = new INDArray[Python.len(pyOut).toInt()];
for(int i = 0; i < out.length; i++){
out[i] = pyOut.get(i).toNumpy().getNd4jArray();
}
}
else{
out = new INDArray[]{
pyOut.toNumpy().getNd4jArray()};
}
predictF.del();
inputList.del();
pyOut.del();
return out;
}
public int numInputs(){
PythonObject inputs = pyModel.attr("inputs");
PythonObject pyNumInputs = Python.len(inputs);
int ret = pyNumInputs.toInt();
inputs.del();
pyNumInputs.del();
return ret;
}
public int numOutputs(){
PythonObject outputs = pyModel.attr("outputs");
PythonObject pyNumOutputs = Python.len(outputs);
int ret = pyNumOutputs.toInt();
outputs.del();
pyNumOutputs.del();
return ret;
}
public long[][] inputShapes(){
long[][] ret = new long[numInputs()][];
for (int i = 0; i < ret.length; i++){
ret[i] = inputShapeAt(i);
}
return ret;
}
public long[][] outputShapes(){
long[][] ret = new long[numOutputs()][];
for (int i = 0; i < ret.length; i++){
ret[i] = outputShapeAt(i);
}
return ret;
}
public long[] inputShapeAt(int input){
PythonObject inputs = pyModel.attr("inputs");
PythonObject tensor = inputs.get(input);
PythonObject tensorShape = tensor.attr("shape");
PythonObject shapeList = Python.list(tensorShape);
PythonObject pyNdim = Python.len(shapeList);
int ndim = pyNdim.toInt();
long[] shape = new long[ndim];
for(int i = 0; i < shape.length; i++){
PythonObject pyDim = shapeList.get(i);
if (pyDim == null || !Python.isinstance(pyDim, Python.intType())){
shape[i] = -1;
}
else{
shape[i] = pyDim.toLong();
}
}
pyNdim.del();
shapeList.del();
tensorShape.del();
tensor.del();
inputs.del();
return shape;
}
public long[] outputShapeAt(int output){
PythonObject inputs = pyModel.attr("outputs");
PythonObject tensor = inputs.get(output);
PythonObject tensorShape = tensor.attr("shape");
PythonObject shapeList = Python.list(tensorShape);
PythonObject pyNdim = Python.len(shapeList);
int ndim = pyNdim.toInt();
long[] shape = new long[ndim];
for(int i = 0; i < shape.length; i++){
PythonObject pyDim = shapeList.get(i);
if (pyDim == null || !Python.isinstance(pyDim, Python.intType())){
shape[i] = -1;
}
else{
shape[i] = pyDim.toLong();
}
}
pyNdim.del();
shapeList.del();
tensorShape.del();
tensor.del();
inputs.del();
return shape;
}
}

View File

@ -20,6 +20,7 @@ import org.apache.commons.compress.utils.IOUtils;
import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.layers.BaseLayer; import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer; import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
@ -153,11 +154,22 @@ public class TestUtils {
return randomOneHotTimeSeries(minibatch, outSize, tsLength, new Random(rngSeed)); return randomOneHotTimeSeries(minibatch, outSize, tsLength, new Random(rngSeed));
} }
public static INDArray randomOneHotTimeSeries(int minibatch, int outSize, int tsLength, Random rng){ public static INDArray randomOneHotTimeSeries(int minibatch, int outSize, int tsLength, Random rng) {
INDArray out = Nd4j.create(new int[]{minibatch, outSize, tsLength}, 'f'); return randomOneHotTimeSeries(RNNFormat.NCW, minibatch, outSize, tsLength, rng);
}
public static INDArray randomOneHotTimeSeries(RNNFormat format, int minibatch, int outSize, int tsLength, Random rng){
boolean ncw = format == RNNFormat.NCW;
long[] shape = ncw ? new long[]{minibatch, outSize, tsLength} : new long[]{minibatch, tsLength, outSize};
char order = ncw ? 'f' : 'c';
INDArray out = Nd4j.create(DataType.FLOAT, shape, order);
for( int i=0; i<minibatch; i++ ){ for( int i=0; i<minibatch; i++ ){
for( int j=0; j<tsLength; j++ ){ for( int j=0; j<tsLength; j++ ){
out.putScalar(i, rng.nextInt(outSize), j, 1.0); if(ncw){
out.putScalar(i, rng.nextInt(outSize), j, 1.0);
} else {
out.putScalar(i, j, rng.nextInt(outSize), 1.0);
}
} }
} }
return out; return out;

View File

@ -24,6 +24,7 @@ import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.distribution.UniformDistribution; import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.*;
@ -560,75 +561,81 @@ public class GradientCheckTests extends BaseDL4JTest {
public void testEmbeddingSequenceLayer(){ public void testEmbeddingSequenceLayer(){
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
for(boolean maskArray : new boolean[]{false, true}){ for(RNNFormat seqOutputFormat : RNNFormat.values()) {
for(int inputRank : new int[]{2,3}) { for (boolean maskArray : new boolean[]{false, true}) {
for (int inputRank : new int[]{2, 3}) {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE) .dataType(DataType.DOUBLE)
.seed(12345) .seed(12345)
.updater(new NoOp()) .updater(new NoOp())
.weightInit(new NormalDistribution(0, 1)) .weightInit(new NormalDistribution(0, 1))
.list() .list()
.layer(new EmbeddingSequenceLayer.Builder() .layer(new EmbeddingSequenceLayer.Builder()
.nIn(8) .nIn(8)
.nOut(4) .nOut(4)
.build()) .outputDataFormat(seqOutputFormat)
.layer(new RnnOutputLayer.Builder().nIn(4).nOut(3).activation(Activation.TANH) .build())
.lossFunction(LossFunction.MSE).build()) .layer(new RnnOutputLayer.Builder().nIn(4).nOut(3).activation(Activation.TANH)
.build(); .dataFormat(seqOutputFormat)
.lossFunction(LossFunction.MSE).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
INDArray in = Transforms.floor(Nd4j.rand(3, 6).muli(8)); //Integers 0 to 7 inclusive boolean ncw = seqOutputFormat == RNNFormat.NCW;
INDArray label = Nd4j.rand(new int[]{3, 3, 6});
if(inputRank == 3){ INDArray in = Transforms.floor(Nd4j.rand(3, 6).muli(8)); //Integers 0 to 7 inclusive
//Reshape from [3,6] to [3,1,6] INDArray label = Nd4j.rand(DataType.FLOAT, ncw ? new int[]{3, 3, 6} : new int[]{3,6,3});
in = in.reshape('c', 3, 1, 6);
}
INDArray fMask = null; if (inputRank == 3) {
if (maskArray) { //Reshape from [3,6] to [3,1,6]
fMask = Nd4j.create(new double[][]{{1, 1, 1, 1, 1, 1}, in = in.reshape('c', 3, 1, 6);
{1, 1, 0, 0, 0, 0},
{1, 0, 0, 0, 0, 0}});
}
String msg = "mask=" + maskArray + ", inputRank=" + inputRank;
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in)
.labels(label).inputMask(fMask));
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(net);
//Also: if mask is present, double check that the masked steps don't impact score
if (maskArray) {
DataSet ds = new DataSet(in, label, fMask, null);
double score = net.score(ds);
if(inputRank == 2){
in.putScalar(1, 2, 0);
in.putScalar(2, 1, 0);
in.putScalar(2, 2, 0);
} else {
in.putScalar(1, 0, 2, 0);
in.putScalar(2, 0, 1, 0);
in.putScalar(2, 0, 2, 0);
} }
double score2 = net.score(ds);
assertEquals(score, score2, 1e-6); INDArray fMask = null;
if(inputRank == 2){ if (maskArray) {
in.putScalar(1, 2, 1); fMask = Nd4j.create(new double[][]{{1, 1, 1, 1, 1, 1},
in.putScalar(2, 1, 1); {1, 1, 0, 0, 0, 0},
in.putScalar(2, 2, 1); {1, 0, 0, 0, 0, 0}});
} else {
in.putScalar(1, 0, 2, 1); }
in.putScalar(2, 0, 1, 1);
in.putScalar(2, 0, 2, 1); String msg = "mask=" + maskArray + ", inputRank=" + inputRank;
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in)
.labels(label).inputMask(fMask));
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(net);
//Also: if mask is present, double check that the masked steps don't impact score
if (maskArray) {
DataSet ds = new DataSet(in, label, fMask, null);
double score = net.score(ds);
if (inputRank == 2) {
in.putScalar(1, 2, 0);
in.putScalar(2, 1, 0);
in.putScalar(2, 2, 0);
} else {
in.putScalar(1, 0, 2, 0);
in.putScalar(2, 0, 1, 0);
in.putScalar(2, 0, 2, 0);
}
double score2 = net.score(ds);
assertEquals(score, score2, 1e-6);
if (inputRank == 2) {
in.putScalar(1, 2, 1);
in.putScalar(2, 1, 1);
in.putScalar(2, 2, 1);
} else {
in.putScalar(1, 0, 2, 1);
in.putScalar(2, 0, 1, 1);
in.putScalar(2, 0, 2, 1);
}
double score3 = net.score(ds);
assertEquals(score, score3, 1e-6);
} }
double score3 = net.score(ds);
assertEquals(score, score3, 1e-6);
} }
} }
} }

View File

@ -21,9 +21,7 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils; import org.deeplearning4j.TestUtils;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.GaussianDistribution; import org.deeplearning4j.nn.conf.distribution.GaussianDistribution;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.distribution.UniformDistribution; import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
@ -341,104 +339,112 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
@Test @Test
public void testCnnDepthMerge() { public void testCnnDepthMerge() {
Nd4j.getRandom().setSeed(12345); for(CNN2DFormat format : CNN2DFormat.values()) {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
.dataType(DataType.DOUBLE)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.dist(new NormalDistribution(0, 0.1))
.updater(new NoOp()).graphBuilder().addInputs("input")
.addLayer("l1", new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).padding(0, 0)
.nIn(2).nOut(2).activation(Activation.TANH).build(), "input")
.addLayer("l2", new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).padding(0, 0)
.nIn(2).nOut(2).activation(Activation.TANH).build(), "input")
.addVertex("merge", new MergeVertex(), "l1", "l2")
.addLayer("outputLayer",
new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(5 * 5 * (2 + 2)).nOut(3)
.build(),
"merge")
.setOutputs("outputLayer")
.inputPreProcessor("outputLayer", new CnnToFeedForwardPreProcessor(5, 5, 4))
.build();
ComputationGraph graph = new ComputationGraph(conf); String msg = "testCnnDepthMerge - " + format;
graph.init();
Random r = new Random(12345); Nd4j.getRandom().setSeed(12345);
INDArray input = Nd4j.rand(new int[] {5, 2, 6, 6}); //Order: examples, channels, height, width ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
INDArray labels = Nd4j.zeros(5, 3); .dataType(DataType.DOUBLE)
for (int i = 0; i < 5; i++) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
labels.putScalar(new int[] {i, r.nextInt(3)}, 1.0); .dist(new NormalDistribution(0, 0.1))
.updater(new NoOp()).graphBuilder().addInputs("input")
.addLayer("l1", new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).padding(0, 0)
.nIn(2).nOut(2).activation(Activation.TANH).build(), "input")
.addLayer("l2", new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).padding(0, 0)
.nIn(2).nOut(2).activation(Activation.TANH).build(), "input")
.addVertex("merge", new MergeVertex(), "l1", "l2")
.addLayer("outputLayer",
new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(5 * 5 * (2 + 2)).nOut(3)
.build(),
"merge")
.setOutputs("outputLayer")
.setInputTypes(InputType.convolutional(6, 6, 2, format))
.build();
if (PRINT_RESULTS) { ComputationGraph graph = new ComputationGraph(conf);
System.out.println("testCnnDepthMerge()"); graph.init();
Random r = new Random(12345);
INDArray input = Nd4j.rand(DataType.DOUBLE, format == CNN2DFormat.NCHW ? new long[]{5,2,6,6} : new long[]{5,6,6,2});
INDArray labels = Nd4j.zeros(5, 3);
for (int i = 0; i < 5; i++)
labels.putScalar(new int[]{i, r.nextInt(3)}, 1.0);
if (PRINT_RESULTS) {
System.out.println(msg);
// for (int j = 0; j < graph.getNumLayers(); j++) // for (int j = 0; j < graph.getNumLayers(); j++)
// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
}
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
.labels(new INDArray[]{labels}));
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph);
} }
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
.labels(new INDArray[]{labels}));
String msg = "testCnnDepthMerge()";
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph);
} }
@Test @Test
public void testRNNWithMerging() { public void testRNNWithMerging() {
Nd4j.getRandom().setSeed(12345); for(RNNFormat format : RNNFormat.values()) {
ComputationGraphConfiguration conf =
new NeuralNetConfiguration.Builder().seed(12345)
.dataType(DataType.DOUBLE)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.dist(new UniformDistribution(0.2, 0.6))
.updater(new NoOp()).graphBuilder().addInputs("input")
.setOutputs("out")
.addLayer("lstm1",
new SimpleRnn.Builder().nIn(3).nOut(3)
.activation(Activation.TANH).build(),
"input")
.addLayer("lstm2",
new SimpleRnn.Builder().nIn(3).nOut(3)
.activation(Activation.TANH).build(),
"lstm1")
.addLayer("dense1",
new DenseLayer.Builder().nIn(3).nOut(3)
.activation(Activation.SIGMOID).build(),
"lstm1")
.addLayer("lstm3",
new SimpleRnn.Builder().nIn(3).nOut(3)
.activation(Activation.TANH).build(),
"dense1")
.addVertex("merge", new MergeVertex(), "lstm2", "lstm3")
.addLayer("out", new RnnOutputLayer.Builder().nIn(6).nOut(3)
.activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build(),
"merge")
.inputPreProcessor("dense1", new RnnToFeedForwardPreProcessor())
.inputPreProcessor("lstm3", new FeedForwardToRnnPreProcessor())
.build();
ComputationGraph graph = new ComputationGraph(conf); String msg = "testLSTMWithMerging - " + format;
graph.init();
Random r = new Random(12345); Nd4j.getRandom().setSeed(12345);
INDArray input = Nd4j.rand(new int[] {2, 3, 4}); ComputationGraphConfiguration conf =
INDArray labels = TestUtils.randomOneHotTimeSeries(2, 3, 4); new NeuralNetConfiguration.Builder().seed(12345)
.dataType(DataType.DOUBLE)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.dist(new UniformDistribution(0.2, 0.6))
.updater(new NoOp()).graphBuilder().addInputs("input")
.setOutputs("out")
.addLayer("lstm1",
new SimpleRnn.Builder().nIn(3).nOut(3)
.activation(Activation.TANH).build(),
"input")
.addLayer("lstm2",
new SimpleRnn.Builder().nIn(3).nOut(3)
.activation(Activation.TANH).build(),
"lstm1")
.addLayer("dense1",
new DenseLayer.Builder().nIn(3).nOut(3)
.activation(Activation.SIGMOID).build(),
"lstm1")
.addLayer("lstm3",
new SimpleRnn.Builder().nIn(3).nOut(3)
.activation(Activation.TANH).build(),
"dense1")
.addVertex("merge", new MergeVertex(), "lstm2", "lstm3")
.addLayer("out", new RnnOutputLayer.Builder().nIn(6).nOut(3)
.activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build(),
"merge")
.setInputTypes(InputType.recurrent(4, format))
.build();
if (PRINT_RESULTS) { ComputationGraph graph = new ComputationGraph(conf);
System.out.println("testLSTMWithMerging()"); graph.init();
Random r = new Random(12345);
INDArray input = Nd4j.rand(DataType.DOUBLE, format == RNNFormat.NCW ? new long[]{2, 3, 4} : new long[]{2,4,3});
INDArray labels = TestUtils.randomOneHotTimeSeries(format, 2, 3, 4, new Random(12345));
if (PRINT_RESULTS) {
System.out.println(msg);
// for (int j = 0; j < graph.getNumLayers(); j++) // for (int j = 0; j < graph.getNumLayers(); j++)
// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); // System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
}
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
.labels(new INDArray[]{labels}));
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph);
} }
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
.labels(new INDArray[]{labels}));
String msg = "testLSTMWithMerging()";
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph);
} }
@Test @Test

View File

@ -17,7 +17,9 @@
package org.deeplearning4j.nn.dtypes; package org.deeplearning4j.nn.dtypes;
import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed; import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
import org.deeplearning4j.nn.conf.preprocessor.*;
import org.deeplearning4j.nn.modelimport.keras.layers.TFOpLayer; import org.deeplearning4j.nn.modelimport.keras.layers.TFOpLayer;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor;
import org.nd4j.shade.guava.collect.ImmutableSet; import org.nd4j.shade.guava.collect.ImmutableSet;
import org.nd4j.shade.guava.reflect.ClassPath; import org.nd4j.shade.guava.reflect.ClassPath;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@ -51,16 +53,11 @@ import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder; import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
import org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer; import org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer;
import org.deeplearning4j.nn.conf.preprocessor.CnnToRnnPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.ComposableInputPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToCnn3DPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.RnnToCnnPreProcessor;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.util.IdentityLayer; import org.deeplearning4j.nn.layers.util.IdentityLayer;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor; import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor; import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor; import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.WeightInitDistribution; import org.deeplearning4j.nn.weights.WeightInitDistribution;
@ -97,7 +94,8 @@ public class DTypeTests extends BaseDL4JTest {
Pooling2D.class, //Alias for SubsamplingLayer Pooling2D.class, //Alias for SubsamplingLayer
Convolution2D.class, //Alias for ConvolutionLayer Convolution2D.class, //Alias for ConvolutionLayer
Pooling1D.class, //Alias for Subsampling1D Pooling1D.class, //Alias for Subsampling1D
Convolution1D.class //Alias for Convolution1DLayer Convolution1D.class, //Alias for Convolution1DLayer
TensorFlowCnnToFeedForwardPreProcessor.class //Deprecated
)); ));
@Override @Override
@ -1078,7 +1076,7 @@ public class DTypeTests extends BaseDL4JTest {
.addLayer("l", new DenseLayer.Builder().nOut(16).build(), "in") .addLayer("l", new DenseLayer.Builder().nOut(16).build(), "in")
.addVertex("preproc", new PreprocessorVertex(new FeedForwardToCnn3DPreProcessor(2, 2, 2, 2, true)), "l") .addVertex("preproc", new PreprocessorVertex(new FeedForwardToCnn3DPreProcessor(2, 2, 2, 2, true)), "l")
.addVertex("preproc2", new PreprocessorVertex(new PermutePreprocessor(0, 2, 3, 4, 1)), "preproc") .addVertex("preproc2", new PreprocessorVertex(new PermutePreprocessor(0, 2, 3, 4, 1)), "preproc")
.addVertex("preproc3", new PreprocessorVertex(new ReshapePreprocessor(new long[]{2, 2, 2, 2}, new long[]{16})), "preproc2") .addVertex("preproc3", new PreprocessorVertex(new ReshapePreprocessor(new long[]{2, 2, 2, 2}, new long[]{16}, false)), "preproc2")
.addLayer("out", new OutputLayer.Builder().nIn(16).nOut(10).build(), "preproc3") .addLayer("out", new OutputLayer.Builder().nIn(16).nOut(10).build(), "preproc3")
.setInputTypes(InputType.feedForward(5)) .setInputTypes(InputType.feedForward(5))
.setOutputs("out"); .setOutputs("out");
@ -1150,7 +1148,7 @@ public class DTypeTests extends BaseDL4JTest {
case 7: case 7:
b.addInputs("in") b.addInputs("in")
.addLayer("1", new ConvolutionLayer.Builder().kernelSize(2, 2).nOut(5).convolutionMode(ConvolutionMode.Same).build(), "in") .addLayer("1", new ConvolutionLayer.Builder().kernelSize(2, 2).nOut(5).convolutionMode(ConvolutionMode.Same).build(), "in")
.addVertex("2", new PreprocessorVertex(new TensorFlowCnnToFeedForwardPreProcessor(28, 28, 5)), "1") .addVertex("2", new PreprocessorVertex(new CnnToFeedForwardPreProcessor(28, 28, 5)), "1")
.addLayer("out", new OutputLayer.Builder().nOut(10).build(), "2") .addLayer("out", new OutputLayer.Builder().nOut(10).build(), "2")
.setOutputs("out") .setOutputs("out")
.setInputTypes(InputType.convolutional(28, 28, 1)); .setInputTypes(InputType.convolutional(28, 28, 1));

View File

@ -60,7 +60,7 @@ public class TestGraphNodes extends BaseDL4JTest {
@Test @Test
public void testMergeNode() { public void testMergeNode() {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType()); GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType(), 1);
INDArray first = Nd4j.linspace(0, 11, 12, Nd4j.dataType()).reshape(3, 4); INDArray first = Nd4j.linspace(0, 11, 12, Nd4j.dataType()).reshape(3, 4);
INDArray second = Nd4j.linspace(0, 17, 18, Nd4j.dataType()).reshape(3, 6).addi(100); INDArray second = Nd4j.linspace(0, 17, 18, Nd4j.dataType()).reshape(3, 6).addi(100);
@ -82,7 +82,7 @@ public class TestGraphNodes extends BaseDL4JTest {
public void testMergeNodeRNN() { public void testMergeNodeRNN() {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType()); GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType(), 1);
INDArray first = Nd4j.linspace(0, 59, 60, Nd4j.dataType()).reshape(3, 4, 5); INDArray first = Nd4j.linspace(0, 59, 60, Nd4j.dataType()).reshape(3, 4, 5);
INDArray second = Nd4j.linspace(0, 89, 90, Nd4j.dataType()).reshape(3, 6, 5).addi(100); INDArray second = Nd4j.linspace(0, 89, 90, Nd4j.dataType()).reshape(3, 6, 5).addi(100);
@ -103,7 +103,7 @@ public class TestGraphNodes extends BaseDL4JTest {
@Test @Test
public void testCnnDepthMerge() { public void testCnnDepthMerge() {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType()); GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType(), 1);
INDArray first = Nd4j.linspace(0, 3, 4, Nd4j.dataType()).reshape(1, 1, 2, 2); INDArray first = Nd4j.linspace(0, 3, 4, Nd4j.dataType()).reshape(1, 1, 2, 2);
INDArray second = Nd4j.linspace(0, 3, 4, Nd4j.dataType()).reshape(1, 1, 2, 2).addi(10); INDArray second = Nd4j.linspace(0, 3, 4, Nd4j.dataType()).reshape(1, 1, 2, 2).addi(10);

View File

@ -15,14 +15,13 @@
******************************************************************************/ ******************************************************************************/
package org.deeplearning4j.nn.layers.convolution; package org.deeplearning4j.nn.layers.convolution;
import lombok.AllArgsConstructor; import lombok.*;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils; import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.*;
@ -30,8 +29,12 @@ import org.deeplearning4j.nn.conf.layers.CnnLossLayer;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.ZeroPaddingLayer; import org.deeplearning4j.nn.conf.layers.ZeroPaddingLayer;
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D; import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D;
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.ComposableInputPreProcessor;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.junit.runners.Parameterized; import org.junit.runners.Parameterized;
@ -516,6 +519,40 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
} }
@Test
public void testGlobalPooling() {
try {
for (boolean helpers : new boolean[]{false, true}) {
for (PoolingType pt : PoolingType.values()) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = helpers ? "With helpers (" + pt + ")" : "No helpers (" + pt + ")";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getGlobalPoolingNet(CNN2DFormat.NCHW, pt, true))
.net2(getGlobalPoolingNet(CNN2DFormat.NCHW, pt, false))
.net3(getGlobalPoolingNet(CNN2DFormat.NHWC, pt, true))
.net4(getGlobalPoolingNet(CNN2DFormat.NHWC, pt, false))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
.testLayerIdx(1)
.build();
testHelper(tc);
}
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
private MultiLayerNetwork getConv2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { private MultiLayerNetwork getConv2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) { if (setOnLayerAlso) {
return getNetWithLayer(new ConvolutionLayer.Builder() return getNetWithLayer(new ConvolutionLayer.Builder()
@ -735,11 +772,28 @@ public class ConvDataFormatTests extends BaseDL4JTest {
.layer(new OutputLayer.Builder().activation(Activation.SOFTMAX).nOut(10).build()) .layer(new OutputLayer.Builder().activation(Activation.SOFTMAX).nOut(10).build())
.setInputType(inputType != null ? inputType : InputType.convolutional(12, 12, 3, format)); .setInputType(inputType != null ? inputType : InputType.convolutional(12, 12, 3, format));
if(format == CNN2DFormat.NHWC && !(layer instanceof GlobalPoolingLayer)){
//Add a preprocessor due to the differences in how NHWC and NCHW activations are flattened
//DL4J's flattening behaviour matches Keras (hence TF) for import compatibility
builder.inputPreProcessor(2, new ComposableInputPreProcessor(new NHWCToNCHWPreprocessor(), new CnnToFeedForwardPreProcessor()));
}
MultiLayerNetwork net = new MultiLayerNetwork(builder.build()); MultiLayerNetwork net = new MultiLayerNetwork(builder.build());
net.init(); net.init();
return net; return net;
} }
private MultiLayerNetwork getGlobalPoolingNet(CNN2DFormat format, PoolingType pt, boolean setOnLayerAlso) {
if (setOnLayerAlso) {
return getNetWithLayer(new GlobalPoolingLayer.Builder(pt)
.poolingDimensions(format == CNN2DFormat.NCHW ? new int[]{2,3} : new int[]{1,2})
.build(), format, ConvolutionMode.Same, null);
} else {
return getNetWithLayer(new GlobalPoolingLayer.Builder(pt)
.build(), format, ConvolutionMode.Same, null);
}
}
private MultiLayerNetwork getCnnLossNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm){ private MultiLayerNetwork getCnnLossNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm){
NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder() NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder()
.seed(12345) .seed(12345)
@ -799,8 +853,13 @@ public class ConvDataFormatTests extends BaseDL4JTest {
INDArray l0_4 = tc.net4.feedForward(inNHWC).get(tc.testLayerIdx + 1); INDArray l0_4 = tc.net4.feedForward(inNHWC).get(tc.testLayerIdx + 1);
assertEquals(tc.msg, l0_1, l0_2); assertEquals(tc.msg, l0_1, l0_2);
assertEquals(tc.msg, l0_1, l0_3.permute(0, 3, 1, 2)); if(l0_1.rank() == 4) {
assertEquals(tc.msg, l0_1, l0_4.permute(0, 3, 1, 2)); assertEquals(tc.msg, l0_1, l0_3.permute(0, 3, 1, 2));
assertEquals(tc.msg, l0_1, l0_4.permute(0, 3, 1, 2));
} else {
assertEquals(tc.msg, l0_1, l0_3);
assertEquals(tc.msg, l0_1, l0_4);
}
INDArray out1 = tc.net1.output(inNCHW); INDArray out1 = tc.net1.output(inNCHW);
@ -880,4 +939,36 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
return differs; return differs;
} }
//Converts NHWC to NCHW activations
@EqualsAndHashCode
private static class NHWCToNCHWPreprocessor implements InputPreProcessor {
@Override
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.permute(0,3,1,2));
}
@Override
public INDArray backprop(INDArray output, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, output.permute(0,2,3,1));
}
@Override
public InputPreProcessor clone() {
return this;
}
@Override
public InputType getOutputType(InputType inputType) {
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
return InputType.convolutional(c.getHeight(), c.getWidth(), c.getChannels(), CNN2DFormat.NCHW);
}
@Override
public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) {
return null;
}
}
} }

View File

@ -44,6 +44,7 @@ import org.nd4j.linalg.primitives.Pair;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.Random;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNotEquals;
@ -217,7 +218,7 @@ public class TestRnnLayers extends BaseDL4JTest {
NeuralNetConfiguration.ListBuilder lb = new NeuralNetConfiguration.Builder() NeuralNetConfiguration.ListBuilder lb = new NeuralNetConfiguration.Builder()
.list() .list()
.layer(new SimpleRnn.Builder().nIn(5).nOut(5).build()); .layer(new SimpleRnn.Builder().nIn(5).nOut(5).dataFormat(rnnDataFormat).build());
switch (i){ switch (i){
case 0: case 0:
@ -235,10 +236,7 @@ public class TestRnnLayers extends BaseDL4JTest {
net.init(); net.init();
INDArray in = Nd4j.rand(DataType.FLOAT, 3, 5, 5); INDArray in = Nd4j.rand(DataType.FLOAT, 3, 5, 5);
INDArray l = TestUtils.randomOneHotTimeSeries(3, 5, 10); INDArray l = TestUtils.randomOneHotTimeSeries(rnnDataFormat, 3, 5, 10, new Random(12345));
if (rnnDataFormat == RNNFormat.NWC){
l = l.permute(0, 2, 1);
}
try{ try{
net.fit(in,l); net.fit(in,l);
} catch (Throwable t){ } catch (Throwable t){

View File

@ -61,15 +61,13 @@ public class TestSimpleRnn extends BaseDL4JTest {
int tsLength = 7; int tsLength = 7;
INDArray in; INDArray in;
if (rnnDataFormat == RNNFormat.NCW){ if (rnnDataFormat == RNNFormat.NCW){
in = Nd4j.rand(DataType.FLOAT, new int[]{m, nIn, tsLength}); in = Nd4j.rand(DataType.FLOAT, m, nIn, tsLength);
} }
else{ else{
in = Nd4j.rand(DataType.FLOAT, new int[]{m, tsLength, nIn}); in = Nd4j.rand(DataType.FLOAT, m, tsLength, nIn);
} }
// in.get(all(), all(), interval(1,tsLength)).assign(0);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.updater(new NoOp()) .updater(new NoOp())
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)

View File

@ -7,10 +7,14 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.WorkspaceMode; import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer;
import org.deeplearning4j.nn.conf.layers.LSTM; import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed; import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
@ -106,4 +110,73 @@ public class TestTimeDistributed extends BaseDL4JTest {
} }
} }
} }
@Test
public void testTimeDistributedDense(){
for( int rnnType=0; rnnType<3; rnnType++ ) {
for( int ffType=0; ffType<3; ffType++ ) {
Layer l0, l2;
switch (rnnType) {
case 0:
l0 = new LSTM.Builder().nOut(5).build();
l2 = new LSTM.Builder().nOut(5).build();
break;
case 1:
l0 = new SimpleRnn.Builder().nOut(5).build();
l2 = new SimpleRnn.Builder().nOut(5).build();
break;
case 2:
l0 = new Bidirectional(new LSTM.Builder().nOut(5).build());
l2 = new Bidirectional(new LSTM.Builder().nOut(5).build());
break;
default:
throw new RuntimeException("Not implemented: " + rnnType);
}
Layer l1;
switch (ffType){
case 0:
l1 = new DenseLayer.Builder().nOut(5).build();
break;
case 1:
l1 = new VariationalAutoencoder.Builder().nOut(5).encoderLayerSizes(5).decoderLayerSizes(5).build();
break;
case 2:
l1 = new AutoEncoder.Builder().nOut(5).build();
break;
default:
throw new RuntimeException("Not implemented: " + ffType);
}
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.activation(Activation.TANH)
.list()
.layer(l0)
.layer(l1)
.layer(l2)
.setInputType(InputType.recurrent(5, 9, rnnDataFormat))
.build();
BaseRecurrentLayer l0a;
BaseRecurrentLayer l2a;
if (rnnType < 2) {
l0a = (BaseRecurrentLayer) l0;
l2a = (BaseRecurrentLayer) l2;
} else {
l0a = (BaseRecurrentLayer) ((Bidirectional) l0).getFwd();
l2a = (BaseRecurrentLayer) ((Bidirectional) l2).getFwd();
}
assertEquals(rnnDataFormat, l0a.getRnnDataFormat());
assertEquals(rnnDataFormat, l2a.getRnnDataFormat());
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
INDArray in = Nd4j.rand(DataType.FLOAT, rnnDataFormat == RNNFormat.NCW ? new long[]{2, 5, 9} : new long[]{2, 9, 5} );
net.output(in);
}
}
}
} }

View File

@ -15,21 +15,24 @@
******************************************************************************/ ******************************************************************************/
package org.deeplearning4j.convolution; package org.deeplearning4j.convolution;
import lombok.AllArgsConstructor; import lombok.*;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.CuDNNTestUtils; import org.deeplearning4j.CuDNNTestUtils;
import org.deeplearning4j.TestUtils; import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D; import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D;
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.ComposableInputPreProcessor;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.junit.runners.Parameterized; import org.junit.runners.Parameterized;
@ -816,6 +819,12 @@ public class ConvDataFormatTests extends BaseDL4JTest {
.layer(new OutputLayer.Builder().activation(Activation.SOFTMAX).nOut(10).build()) .layer(new OutputLayer.Builder().activation(Activation.SOFTMAX).nOut(10).build())
.setInputType(inputType != null ? inputType : InputType.convolutional(12, 12, 3, format)); .setInputType(inputType != null ? inputType : InputType.convolutional(12, 12, 3, format));
if(format == CNN2DFormat.NHWC && !(layer instanceof GlobalPoolingLayer)){
//Add a preprocessor due to the differences in how NHWC and NCHW activations are flattened
//DL4J's flattening behaviour matches Keras (hence TF) for import compatibility
builder.inputPreProcessor(2, new ComposableInputPreProcessor(new NHWCToNCHWPreprocessor(), new CnnToFeedForwardPreProcessor()));
}
MultiLayerNetwork net = new MultiLayerNetwork(builder.build()); MultiLayerNetwork net = new MultiLayerNetwork(builder.build());
net.init(); net.init();
return net; return net;
@ -964,4 +973,35 @@ public class ConvDataFormatTests extends BaseDL4JTest {
} }
return differs; return differs;
} }
//Converts NHWC to NCHW activations
@EqualsAndHashCode
private static class NHWCToNCHWPreprocessor implements InputPreProcessor {
@Override
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.permute(0,3,1,2));
}
@Override
public INDArray backprop(INDArray output, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, output.permute(0,2,3,1));
}
@Override
public InputPreProcessor clone() {
return this;
}
@Override
public InputType getOutputType(InputType inputType) {
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
return InputType.convolutional(c.getHeight(), c.getWidth(), c.getChannels(), CNN2DFormat.NCHW);
}
@Override
public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) {
return null;
}
}
} }

View File

@ -212,7 +212,7 @@ public class TestConvolution extends BaseDL4JTest {
ComputationGraph model = KerasModelImport.importKerasModelAndWeights( fExtracted.getAbsolutePath(), new int[]{inSize, inSize, 3}, false); ComputationGraph model = KerasModelImport.importKerasModelAndWeights( fExtracted.getAbsolutePath(), new int[]{inSize, inSize, 3}, false);
model = model.convertDataType(DataType.DOUBLE); model = model.convertDataType(DataType.DOUBLE);
INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{1, 3, inSize, inSize}); INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{1, inSize, inSize, 3}); //Keras import model -> NHWC
CuDNNTestUtils.assertHelpersPresent(model.getLayers()); CuDNNTestUtils.assertHelpersPresent(model.getLayers());
Map<String,INDArray> withCudnn = model.feedForward(in, false); Map<String,INDArray> withCudnn = model.feedForward(in, false);

View File

@ -113,6 +113,12 @@
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-python</artifactId>
<version>${datavec.version}</version>
<scope>test</scope>
</dependency>
</dependencies> </dependencies>
<profiles> <profiles>

View File

@ -19,6 +19,9 @@ package org.deeplearning4j.nn.modelimport.keras.layers;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Convolution3D; import org.deeplearning4j.nn.conf.layers.Convolution3D;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
@ -121,27 +124,29 @@ public class KerasInput extends KerasLayer {
InputType myInputType; InputType myInputType;
switch (this.inputShape.length) { switch (this.inputShape.length) {
case 1: case 1:
myInputType = new InputType.InputTypeFeedForward(this.inputShape[0]); myInputType = new InputType.InputTypeFeedForward(this.inputShape[0], null);
break; break;
case 2: case 2:
if(this.dimOrder != null) { if(this.dimOrder != null) {
System.out.println("Dim order: " + this.dimOrder);
System.out.println("Input shape: " + ArrayUtils.toString(this.inputShape));
switch (this.dimOrder) { switch (this.dimOrder) {
case TENSORFLOW: //NWC == channels_last case TENSORFLOW: //NWC == channels_last
myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0]); myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0], RNNFormat.NWC);
break; break;
case THEANO: //NCW == channels_first case THEANO: //NCW == channels_first
myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]); myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1], RNNFormat.NCW);
break; break;
case NONE: case NONE:
//Assume RNN in [mb, seqLen, size] format //Assume RNN in [mb, seqLen, size] format
myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]); myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0], RNNFormat.NWC);
break; break;
default: default:
throw new IllegalStateException("Unknown/not supported dimension ordering: " + this.dimOrder); throw new IllegalStateException("Unknown/not supported dimension ordering: " + this.dimOrder);
} }
} else { } else {
//Assume RNN in [mb, seqLen, size] format //Assume RNN in [mb, seqLen, size] format
myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]); myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0], RNNFormat.NWC);
} }
break; break;
@ -150,17 +155,17 @@ public class KerasInput extends KerasLayer {
case TENSORFLOW: case TENSORFLOW:
/* TensorFlow convolutional input: # rows, # cols, # channels */ /* TensorFlow convolutional input: # rows, # cols, # channels */
myInputType = new InputType.InputTypeConvolutional(this.inputShape[0], this.inputShape[1], myInputType = new InputType.InputTypeConvolutional(this.inputShape[0], this.inputShape[1],
this.inputShape[2]); this.inputShape[2], CNN2DFormat.NHWC);
break; break;
case THEANO: case THEANO:
/* Theano convolutional input: # channels, # rows, # cols */ /* Theano convolutional input: # channels, # rows, # cols */
myInputType = new InputType.InputTypeConvolutional(this.inputShape[1], this.inputShape[2], myInputType = new InputType.InputTypeConvolutional(this.inputShape[1], this.inputShape[2],
this.inputShape[0]); this.inputShape[0], CNN2DFormat.NCHW);
break; break;
default: default:
this.dimOrder = DimOrder.THEANO; this.dimOrder = DimOrder.THEANO;
myInputType = new InputType.InputTypeConvolutional(this.inputShape[1], this.inputShape[2], myInputType = new InputType.InputTypeConvolutional(this.inputShape[1], this.inputShape[2],
this.inputShape[0]); this.inputShape[0], CNN2DFormat.NCHW);
log.warn("Couldn't determine dim ordering / data format from model file. Older Keras " + log.warn("Couldn't determine dim ordering / data format from model file. Older Keras " +
"versions may come without specified backend, in which case we assume the model was " + "versions may come without specified backend, in which case we assume the model was " +
"built with theano." ); "built with theano." );

View File

@ -20,6 +20,7 @@ import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Layer; import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
@ -65,6 +66,9 @@ public class TFOpLayer extends Layer {
long[] shape = inputType.getShape(true); long[] shape = inputType.getShape(true);
TFOpLayerImpl tempLayer = new TFOpLayerImpl(nodeDef, constants, null, null); TFOpLayerImpl tempLayer = new TFOpLayerImpl(nodeDef, constants, null, null);
long[] outputShape = tempLayer.getOutputShape(shape); long[] outputShape = tempLayer.getOutputShape(shape);
if (outputShape.length == 3){
return InputType.recurrent(outputShape[2], outputShape[1], RNNFormat.NWC);
}
return InputType.inferInputType(Nd4j.create(outputShape)); return InputType.inferInputType(Nd4j.create(outputShape));
} }

View File

@ -125,17 +125,9 @@ public class TFOpLayerImpl extends AbstractLayer<TFOpLayer> {
} }
private INDArray runGraph(INDArray input){ private INDArray runGraph(INDArray input){
if (input.rank() == 3){
// TODO make this a preprocessor
input = input.permute(0, 2, 1);
}
Map<String, INDArray> inputMap = new HashMap<>(); Map<String, INDArray> inputMap = new HashMap<>();
inputMap.put(inputNames.get(0), input); inputMap.put(inputNames.get(0), input);
INDArray out = graphRunnerService.run(inputMap).values().toArray(new INDArray[0])[0]; INDArray out = graphRunnerService.run(inputMap).values().toArray(new INDArray[0])[0];
if (out.rank() == 3){
out = out.permute(0, 2, 1); // TODO post-processing?
}
return out; return out;
} }

View File

@ -95,7 +95,6 @@ public class KerasConvolution1D extends KerasConvolution {
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
enforceTrainingConfig, conf, kerasMajorVersion); enforceTrainingConfig, conf, kerasMajorVersion);
Convolution1DLayer.Builder builder = new Convolution1DLayer.Builder().name(this.layerName) Convolution1DLayer.Builder builder = new Convolution1DLayer.Builder().name(this.layerName)
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout) .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
.activation(getIActivationFromConfig(layerConfig, conf)) .activation(getIActivationFromConfig(layerConfig, conf))
@ -104,7 +103,7 @@ public class KerasConvolution1D extends KerasConvolution {
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf)) .convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
.kernelSize(getKernelSizeFromConfig(layerConfig, 1, conf, kerasMajorVersion)[0]) .kernelSize(getKernelSizeFromConfig(layerConfig, 1, conf, kerasMajorVersion)[0])
.hasBias(hasBias) .hasBias(hasBias)
.stride(getStrideFromConfig(layerConfig, 1, conf)[0]); .stride(getStrideFromConfig(layerConfig, 1, conf)[0]).rnnDataFormat(dimOrder == DimOrder.TENSORFLOW? RNNFormat.NWC: RNNFormat.NCW);
int[] padding = getPaddingFromBorderModeConfig(layerConfig, 1, conf, kerasMajorVersion); int[] padding = getPaddingFromBorderModeConfig(layerConfig, 1, conf, kerasMajorVersion);
if (hasBias) if (hasBias)
builder.biasInit(0.0); builder.biasInit(0.0);

View File

@ -20,6 +20,7 @@ import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
@ -27,6 +28,7 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurat
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.IWeightInit;
import oshi.jna.platform.windows.PowrProf;
import java.util.Map; import java.util.Map;
@ -93,6 +95,7 @@ public class KerasConvolution2D extends KerasConvolution {
LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig( LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig(
layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion); layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion);
System.out.println("----" + dimOrder);
ConvolutionLayer.Builder builder = new ConvolutionLayer.Builder().name(this.layerName) ConvolutionLayer.Builder builder = new ConvolutionLayer.Builder().name(this.layerName)
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout) .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
.activation(getIActivationFromConfig(layerConfig, conf)) .activation(getIActivationFromConfig(layerConfig, conf))
@ -101,7 +104,8 @@ public class KerasConvolution2D extends KerasConvolution {
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf)) .convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
.kernelSize(getKernelSizeFromConfig(layerConfig, 2, conf, kerasMajorVersion)) .kernelSize(getKernelSizeFromConfig(layerConfig, 2, conf, kerasMajorVersion))
.hasBias(hasBias) .hasBias(hasBias)
.stride(getStrideFromConfig(layerConfig, 2, conf)); .stride(getStrideFromConfig(layerConfig, 2, conf))
.dataFormat((dimOrder==DimOrder.TENSORFLOW)? CNN2DFormat.NHWC:CNN2DFormat.NCHW);
int[] padding = getPaddingFromBorderModeConfig(layerConfig, 2, conf, kerasMajorVersion); int[] padding = getPaddingFromBorderModeConfig(layerConfig, 2, conf, kerasMajorVersion);
if (hasBias) if (hasBias)
builder.biasInit(0.0); builder.biasInit(0.0);

View File

@ -360,8 +360,19 @@ public class KerasConvolutionUtils {
} }
} else if (dimension == 1) { } else if (dimension == 1) {
int paddingInt = (int) innerConfig.get(layerField); Object paddingObj = innerConfig.get(layerField);
padding = new int[]{paddingInt, paddingInt}; if (paddingObj instanceof List){
List<Integer> paddingList = (List)paddingObj;
padding = new int[]{
paddingList.get(0),
paddingList.get(1)
};
}
else{
int paddingInt = (int) innerConfig.get(layerField);
padding = new int[]{paddingInt, paddingInt};
}
} else { } else {
throw new UnsupportedKerasConfigurationException( throw new UnsupportedKerasConfigurationException(
"Keras padding layer not supported"); "Keras padding layer not supported");

View File

@ -18,6 +18,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.core;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InputType.InputTypeConvolutional; import org.deeplearning4j.nn.conf.inputs.InputType.InputTypeConvolutional;
@ -27,7 +28,6 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurat
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor; import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor; import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor;
import java.util.Map; import java.util.Map;
@ -93,11 +93,10 @@ public class KerasFlatten extends KerasLayer {
switch (this.getDimOrder()) { switch (this.getDimOrder()) {
case NONE: case NONE:
case THEANO: case THEANO:
preprocessor = new CnnToFeedForwardPreProcessor(it.getHeight(), it.getWidth(), it.getChannels()); preprocessor = new CnnToFeedForwardPreProcessor(it.getHeight(), it.getWidth(), it.getChannels(), CNN2DFormat.NCHW);
break; break;
case TENSORFLOW: case TENSORFLOW:
preprocessor = new TensorFlowCnnToFeedForwardPreProcessor(it.getHeight(), it.getWidth(), preprocessor = new CnnToFeedForwardPreProcessor(it.getHeight(), it.getWidth(), it.getChannels(), CNN2DFormat.NHWC);
it.getChannels());
break; break;
default: default:
throw new InvalidKerasConfigurationException("Unknown Keras backend " + this.getDimOrder()); throw new InvalidKerasConfigurationException("Unknown Keras backend " + this.getDimOrder());
@ -111,7 +110,7 @@ public class KerasFlatten extends KerasLayer {
// to RNN type. Otherwise we add this trivial preprocessor (since there's nothing to flatten). // to RNN type. Otherwise we add this trivial preprocessor (since there's nothing to flatten).
InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0]; InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0];
val inputShape = new long[]{it.getSize()}; val inputShape = new long[]{it.getSize()};
preprocessor = new ReshapePreprocessor(inputShape, inputShape, false); preprocessor = new ReshapePreprocessor(inputShape, inputShape, false, null);
} }
return preprocessor; return preprocessor;
} }

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.nn.modelimport.keras.layers.core; package org.deeplearning4j.nn.modelimport.keras.layers.core;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.misc.RepeatVector; import org.deeplearning4j.nn.conf.layers.misc.RepeatVector;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
@ -60,6 +61,7 @@ public class KerasRepeatVector extends KerasLayer {
super(layerConfig, enforceTrainingConfig); super(layerConfig, enforceTrainingConfig);
this.layer = new RepeatVector.Builder().repetitionFactor(getRepeatMultiplier(layerConfig, conf)) this.layer = new RepeatVector.Builder().repetitionFactor(getRepeatMultiplier(layerConfig, conf))
.dataFormat(RNNFormat.NWC)
.name(this.layerName).build(); .name(this.layerName).build();
} }

View File

@ -18,6 +18,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.core;
import lombok.val; import lombok.val;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
@ -111,11 +112,9 @@ public class KerasReshape extends KerasLayer {
} else { } else {
targetShape = new long[]{targetShape[1], targetShape[0], targetShape[2]}; targetShape = new long[]{targetShape[1], targetShape[0], targetShape[2]};
} }
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false); preprocessor = new ReshapePreprocessor(inputShape, targetShape, false, CNN2DFormat.NCHW);
} else { // (dimOrder == DimOrder.TENSORFLOW || dimOrder == DimOrder.NONE && kerasMajorVersion == 2) } else { // (dimOrder == DimOrder.TENSORFLOW || dimOrder == DimOrder.NONE && kerasMajorVersion == 2)
if (inputShape[0] != targetShape[0]) preprocessor = new ReshapePreprocessor(inputShape, targetShape, false, CNN2DFormat.NHWC);
targetShape = new long[]{targetShape[2], targetShape[0], targetShape[1]};
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
} }
} else if (inputType[0] instanceof InputType.InputTypeConvolutional3D) { } else if (inputType[0] instanceof InputType.InputTypeConvolutional3D) {
@ -128,23 +127,23 @@ public class KerasReshape extends KerasLayer {
} else { } else {
targetShape = new long[] { targetShape[2], targetShape[1], targetShape[0], targetShape[3] }; targetShape = new long[] { targetShape[2], targetShape[1], targetShape[0], targetShape[3] };
} }
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false); preprocessor = new ReshapePreprocessor(inputShape, targetShape, false, null);
} else { } else {
if (inputShape[0] != targetShape[0]) if (inputShape[0] != targetShape[0])
targetShape = new long[] { targetShape[3], targetShape[0], targetShape[1], targetShape[2] }; targetShape = new long[] { targetShape[3], targetShape[0], targetShape[1], targetShape[2] };
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false); preprocessor = new ReshapePreprocessor(inputShape, targetShape, false, null);
} }
} else if (inputType[0] instanceof InputType.InputTypeRecurrent) { } else if (inputType[0] instanceof InputType.InputTypeRecurrent) {
InputType.InputTypeRecurrent it = (InputType.InputTypeRecurrent) inputType[0]; InputType.InputTypeRecurrent it = (InputType.InputTypeRecurrent) inputType[0];
val inputShape = new long[]{it.getSize(), it.getTimeSeriesLength()}; val inputShape = new long[]{it.getSize(), it.getTimeSeriesLength()};
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false); preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false, null);
} else if (inputType[0] instanceof InputType.InputTypeFeedForward) { } else if (inputType[0] instanceof InputType.InputTypeFeedForward) {
InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0]; InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0];
val inputShape = new long[]{it.getSize()}; val inputShape = new long[]{it.getSize()};
if (targetShape.length == 3) { if (targetShape.length == 3) {
targetShape = targetShapeForDimOrder(inputShape, targetShape); targetShape = targetShapeForDimOrder(inputShape, targetShape);
} }
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false); preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false, null);
} }
return preprocessor; return preprocessor;
} }

View File

@ -21,6 +21,7 @@ import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer; import org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
@ -121,6 +122,7 @@ public class KerasEmbedding extends KerasLayer {
.biasInit(0.0) .biasInit(0.0)
.l1(this.weightL1Regularization) .l1(this.weightL1Regularization)
.l2(this.weightL2Regularization) .l2(this.weightL2Regularization)
.outputDataFormat(RNNFormat.NWC)
.hasBias(false); .hasBias(false);
if (embeddingConstraint != null) if (embeddingConstraint != null)
builder.constrainWeights(embeddingConstraint); builder.constrainWeights(embeddingConstraint);

View File

@ -186,7 +186,7 @@ public class KerasLSTM extends KerasLayer {
.weightInitRecurrent(recurrentInit) .weightInitRecurrent(recurrentInit)
.biasInit(0.0) // TODO: this is incorrect .biasInit(0.0) // TODO: this is incorrect
.l1(this.weightL1Regularization) .l1(this.weightL1Regularization)
.l2(this.weightL2Regularization); .l2(this.weightL2Regularization).dataFormat(RNNFormat.NWC);
Integer nIn = KerasLayerUtils.getNInFromInputDim(layerConfig, conf); Integer nIn = KerasLayerUtils.getNInFromInputDim(layerConfig, conf);
if(nIn != null) if(nIn != null)
builder.setNIn(nIn); builder.setNIn(nIn);

View File

@ -158,7 +158,7 @@ public class KerasSimpleRnn extends KerasLayer {
.weightInitRecurrent(recurrentInit) .weightInitRecurrent(recurrentInit)
.biasInit(0.0) .biasInit(0.0)
.l1(this.weightL1Regularization) .l1(this.weightL1Regularization)
.l2(this.weightL2Regularization); .l2(this.weightL2Regularization).dataFormat(RNNFormat.NWC);
Integer nIn = KerasLayerUtils.getNInFromInputDim(layerConfig, conf); Integer nIn = KerasLayerUtils.getNInFromInputDim(layerConfig, conf);
if(nIn != null) if(nIn != null)
builder.setNIn(nIn); builder.setNIn(nIn);

View File

@ -147,7 +147,7 @@ public class KerasBidirectional extends KerasLayer {
break; break;
case "SimpleRNN": case "SimpleRNN":
kerasRnnlayer = new KerasSimpleRnn(innerRnnConfig, enforceTrainingConfig, previousLayers); kerasRnnlayer = new KerasSimpleRnn(innerRnnConfig, enforceTrainingConfig, previousLayers);
SimpleRnn rnnLayer = (SimpleRnn) ((KerasSimpleRnn) kerasRnnlayer).getSimpleRnnLayer(); Layer rnnLayer = ((KerasSimpleRnn) kerasRnnlayer).getSimpleRnnLayer();
this.layer = new Bidirectional(mode, rnnLayer); this.layer = new Bidirectional(mode, rnnLayer);
layer.setLayerName(layerName); layer.setLayerName(layerName);
break; break;

View File

@ -21,6 +21,9 @@ import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.DataFormat;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException; import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor; import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor;
@ -54,25 +57,30 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
private final long[] inputShape; private final long[] inputShape;
private final long[] targetShape; private final long[] targetShape;
private boolean hasMiniBatchDimension; private boolean hasMiniBatchDimension;
private DataFormat format;
/**
* @deprecated Use constructor {@link #ReshapePreprocessor(long[], long[], boolean)}
*/
@Deprecated
public ReshapePreprocessor(long[] inputShape, long[] targetShape) {
this(inputShape, targetShape, false);
}
/** /**
* @param inputShape Input shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension * @param inputShape Input shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
* @param targetShape Target shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension * @param targetShape Target shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
* @param hasMiniBatchDimension If true: shapes should be of the form [minibatch, x, y, ...]; if false: shapes should be of form [x, y, ...] * @param hasMiniBatchDimension If true: shapes should be of the form [minibatch, x, y, ...]; if false: shapes should be of form [x, y, ...]
*/ */
public ReshapePreprocessor(long[] inputShape, long[] targetShape, boolean hasMiniBatchDimension) {
this(inputShape, targetShape, hasMiniBatchDimension, null);
}
/**
* @param inputShape Input shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
* @param targetShape Target shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
* @param hasMiniBatchDimension If true: shapes should be of the form [minibatch, x, y, ...]; if false: shapes should be of form [x, y, ...]
* @param dataFormat May be null. If non-null:
*/
public ReshapePreprocessor(@JsonProperty("inputShape") long[] inputShape, @JsonProperty("targetShape") long[] targetShape, public ReshapePreprocessor(@JsonProperty("inputShape") long[] inputShape, @JsonProperty("targetShape") long[] targetShape,
@JsonProperty("hasMiniBatchDimension") boolean hasMiniBatchDimension) { @JsonProperty("hasMiniBatchDimension") boolean hasMiniBatchDimension,
@JsonProperty("dataFormat") DataFormat dataFormat) {
this.inputShape = inputShape; this.inputShape = inputShape;
this.targetShape = targetShape; this.targetShape = targetShape;
this.hasMiniBatchDimension = hasMiniBatchDimension; this.hasMiniBatchDimension = hasMiniBatchDimension;
this.format = dataFormat;
} }
private long[] getShape(long[] originalShape, long minibatch) { private long[] getShape(long[] originalShape, long minibatch) {
@ -140,13 +148,26 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
ret = InputType.feedForward(shape[1]); ret = InputType.feedForward(shape[1]);
break; break;
case 3: case 3:
ret = InputType.recurrent(shape[2], shape[1]); RNNFormat format = RNNFormat.NCW;
if(this.format != null && this.format instanceof RNNFormat)
format = (RNNFormat)this.format;
ret = InputType.recurrent(shape[2], shape[1], format);
break; break;
case 4: case 4:
if (inputShape.length == 1 || inputType.getType() == InputType.Type.RNN) { if (inputShape.length == 1 || inputType.getType() == InputType.Type.RNN) {
ret = InputType.convolutional(shape[1], shape[2], shape[3]); ret = InputType.convolutional(shape[1], shape[2], shape[3]);
} else { } else {
ret = InputType.convolutional(shape[2], shape[3], shape[1]);
CNN2DFormat cnnFormat = CNN2DFormat.NCHW;
if (this.format != null && this.format instanceof CNN2DFormat)
cnnFormat = (CNN2DFormat) this.format;
if (cnnFormat == CNN2DFormat.NCHW) {
ret = InputType.convolutional(shape[2], shape[3], shape[1], cnnFormat);
} else {
ret = InputType.convolutional(shape[1], shape[2], shape[3], cnnFormat);
}
} }
break; break;
default: default:

View File

@ -27,26 +27,25 @@ import org.nd4j.shade.jackson.annotation.JsonCreator;
import org.nd4j.shade.jackson.annotation.JsonProperty; import org.nd4j.shade.jackson.annotation.JsonProperty;
/** /**
* Specialized CnnToFeedForwardInputPreProcessor for use with * @deprecated Exists only for backward compatibility of older pretrained models. Should not be used.
* Convolutional layers imported from Keras using the TensorFlow * Use {@link CnnToFeedForwardPreProcessor} for all new models instead.
* backend.
*
* @author dave@skymind.io
*/ */
@Slf4j @Slf4j @Deprecated
public class TensorFlowCnnToFeedForwardPreProcessor extends CnnToFeedForwardPreProcessor { public class TensorFlowCnnToFeedForwardPreProcessor extends CnnToFeedForwardPreProcessor {
@JsonCreator @JsonCreator @Deprecated
public TensorFlowCnnToFeedForwardPreProcessor(@JsonProperty("inputHeight") long inputHeight, public TensorFlowCnnToFeedForwardPreProcessor(@JsonProperty("inputHeight") long inputHeight,
@JsonProperty("inputWidth") long inputWidth, @JsonProperty("inputWidth") long inputWidth,
@JsonProperty("numChannels") long numChannels) { @JsonProperty("numChannels") long numChannels) {
super(inputHeight, inputWidth, numChannels); super(inputHeight, inputWidth, numChannels);
} }
@Deprecated
public TensorFlowCnnToFeedForwardPreProcessor(long inputHeight, long inputWidth) { public TensorFlowCnnToFeedForwardPreProcessor(long inputHeight, long inputWidth) {
super(inputHeight, inputWidth); super(inputHeight, inputWidth);
} }
@Deprecated
public TensorFlowCnnToFeedForwardPreProcessor() { public TensorFlowCnnToFeedForwardPreProcessor() {
super(); super();
} }
@ -81,4 +80,4 @@ public class TensorFlowCnnToFeedForwardPreProcessor extends CnnToFeedForwardPreP
public TensorFlowCnnToFeedForwardPreProcessor clone() { public TensorFlowCnnToFeedForwardPreProcessor clone() {
return (TensorFlowCnnToFeedForwardPreProcessor) super.clone(); return (TensorFlowCnnToFeedForwardPreProcessor) super.clone();
} }
} }

View File

@ -31,6 +31,7 @@ import org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations.*;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.*; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.*;
import org.deeplearning4j.nn.modelimport.keras.layers.core.*; import org.deeplearning4j.nn.modelimport.keras.layers.core.*;
import org.deeplearning4j.nn.modelimport.keras.layers.embeddings.KerasEmbedding; import org.deeplearning4j.nn.modelimport.keras.layers.embeddings.KerasEmbedding;
import org.deeplearning4j.nn.modelimport.keras.layers.local.KerasLocallyConnected1D;
import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasAlphaDropout; import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasAlphaDropout;
import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasGaussianDropout; import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasGaussianDropout;
import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasGaussianNoise; import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasGaussianNoise;
@ -319,6 +320,8 @@ public class KerasLayerUtils {
layer = new KerasELU(layerConfig, enforceTrainingConfig); layer = new KerasELU(layerConfig, enforceTrainingConfig);
} else if(layerClassName.equals(conf.getLAYER_CLASS_NAME_SOFTMAX())){ } else if(layerClassName.equals(conf.getLAYER_CLASS_NAME_SOFTMAX())){
layer = new KerasSoftmax(layerConfig, enforceTrainingConfig); layer = new KerasSoftmax(layerConfig, enforceTrainingConfig);
} else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_LOCALLY_CONNECTED_1D())){
layer = new KerasLocallyConnected1D(layerConfig, enforceTrainingConfig);
} else if (conf instanceof Keras2LayerConfiguration){ } else if (conf instanceof Keras2LayerConfiguration){
Keras2LayerConfiguration k2conf = (Keras2LayerConfiguration)conf; Keras2LayerConfiguration k2conf = (Keras2LayerConfiguration)conf;
if (layerClassName.equals(k2conf.getTENSORFLOW_OP_LAYER())){ if (layerClassName.equals(k2conf.getTENSORFLOW_OP_LAYER())){

View File

@ -1,50 +0,0 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.nn.modelimport.keras;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.junit.Assert;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.resources.Resources;
import java.io.File;
import java.util.Arrays;
public class TFKerasTests extends BaseDL4JTest{
@Test
public void testModelWithTFOp1() throws Exception{
File f = Resources.asFile("modelimport/keras/tfkeras/reshape.h5");
ComputationGraph graph = KerasModelImport.importKerasModelAndWeights(f.getAbsolutePath());
INDArray out = graph.outputSingle(Nd4j.zeros(12, 2, 3));
Assert.assertArrayEquals(new long[]{12, 3}, out.shape());
}
@Test
public void testModelWithTFOp2() throws Exception{
File f = Resources.asFile("modelimport/keras/tfkeras/permute.h5");
ComputationGraph graph = KerasModelImport.importKerasModelAndWeights(f.getAbsolutePath());
INDArray out = graph.outputSingle(Nd4j.zeros(12, 2, 3));
// dl4j's feedforward doesn't support 3D output, so batch and time axes gets squashed
long[] expectedShape = new long[]{12 * 2, 5};
Assert.assertArrayEquals(expectedShape, out.shape());
}
}

View File

@ -0,0 +1,147 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.nn.modelimport.keras;
import org.apache.commons.io.FileUtils;
import org.datavec.python.keras.Model;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.common.tests.ResourceUtils;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.resources.Resources;
import java.io.File;
import java.util.List;
@RunWith(Parameterized.class)
public class TestTFKerasModelImport extends BaseDL4JTest{
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
private String modelFile;
@Override
public long getTimeoutMilliseconds(){
return 300000;
} // installing TF will take a while
@Parameterized.Parameters(name = "file={0}")
public static Object[] params() throws Exception {
List<String> paths = ResourceUtils.listClassPathFiles("modelimport/keras/tfkeras", true, false);
return paths.toArray(new String[0]);
}
public TestTFKerasModelImport(String modelFile){
this.modelFile = modelFile;
}
@Test
public void testModelImport() throws Exception{
testModelImportWithData(modelFile);
}
private void testModelImportWithData(String path) throws Exception{
System.out.println(path);
// TODO multi input/output
INDArray inputArray;
INDArray expectedOutputArray;
File f = Resources.asFile(path); //May in in JAR that HDF5 can't read from
File modelFile = new File(testDir.getRoot(), f.getName());
FileUtils.copyFile(f, modelFile);
synchronized (Hdf5Archive.LOCK_OBJECT){
Hdf5Archive hdf5Archive = new Hdf5Archive(modelFile.getAbsolutePath());
List<String> rootGroups = hdf5Archive.getGroups();
if (rootGroups.contains("data")){
String inputName = hdf5Archive.readAttributeAsString("input_names", "data");
String outputName = hdf5Archive.readAttributeAsString("output_names", "data");
inputArray = hdf5Archive.readDataSet(inputName, "data");
expectedOutputArray = hdf5Archive.readDataSet(outputName, "data");
}
else{
hdf5Archive.close();
return;
}
hdf5Archive.close();
}
INDArray outputArray;
ComputationGraph dl4jModel = KerasModelImport.importKerasModelAndWeights(path);
outputArray = dl4jModel.outputSingle(inputArray);
expectedOutputArray = expectedOutputArray.castTo(DataType.FLOAT);
outputArray = outputArray.castTo(DataType.FLOAT);
if (path.contains("misc_")){
//shape relaxation
expectedOutputArray = expectedOutputArray.reshape( -1);
outputArray = outputArray.reshape(-1);
}
System.out.println(outputArray.toString());
System.out.println(expectedOutputArray.toString());
Assert.assertArrayEquals(expectedOutputArray.shape(), outputArray.shape());
Assert.assertTrue(expectedOutputArray.equalsWithEps(outputArray, 1e-3));
}
private void testModelImportWithKeras(String path) throws Exception{
Model kerasModel = new Model(path);
ComputationGraph dl4jModel = KerasModelImport.importKerasModelAndWeights(path);
Assert.assertEquals(kerasModel.numInputs(), dl4jModel.getNumInputArrays());
Assert.assertEquals(kerasModel.numOutputs(), dl4jModel.getNumOutputArrays());
INDArray[] kerasInputArrays = new INDArray[kerasModel.numInputs()];
INDArray[] dl4jInputArrays = new INDArray[kerasModel.numInputs()];
for (int i = 0; i < kerasInputArrays.length; i ++) {
long[] shape = kerasModel.inputShapeAt(i);
for (int j = 0; j < shape.length; j++) {
if (shape[j] < 0) {
shape[j] = 1;
}
}
kerasInputArrays[i] = Nd4j.rand(shape);
}
INDArray[] kerasOut = kerasModel.predict(kerasInputArrays);
INDArray[] dl4jOut = dl4jModel.output(dl4jInputArrays);
Assert.assertEquals(kerasOut.length, dl4jOut.length);
for (int i = 0; i < kerasOut.length; i++){
INDArray kerasOutArr = kerasOut[i];
kerasOutArr = kerasOutArr.reshape(1, -1);// bit of relaxation on shape
kerasOutArr= kerasOutArr.castTo(DataType.DOUBLE);
Nd4j.getAffinityManager().ensureLocation(dl4jOut[i], AffinityManager.Location.HOST);
INDArray dl4jOutArr = dl4jOut[i].reshape(1, -1);
System.out.println(kerasOutArr.shapeInfoToString());
System.out.println(dl4jOutArr.shapeInfoToString());
Assert.assertEquals(kerasOutArr, dl4jOutArr);
}
}
}

View File

@ -22,7 +22,6 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor; import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor; import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor; import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor;
import org.junit.Test; import org.junit.Test;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
@ -34,8 +33,7 @@ public class JsonTest extends BaseDL4JTest {
InputPreProcessor[] pp = new InputPreProcessor[] { InputPreProcessor[] pp = new InputPreProcessor[] {
new KerasFlattenRnnPreprocessor(10, 5), new KerasFlattenRnnPreprocessor(10, 5),
new PermutePreprocessor(new int[]{0,1,2}), new PermutePreprocessor(new int[]{0,1,2}),
new ReshapePreprocessor(new long[]{10,10}, new long[]{100,1}), new ReshapePreprocessor(new long[]{10,10}, new long[]{100,1}, true, null)
new TensorFlowCnnToFeedForwardPreProcessor()
}; };
for(InputPreProcessor p : pp ){ for(InputPreProcessor p : pp ){

View File

@ -29,6 +29,7 @@ import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasSpaceTo
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Ignore; import org.junit.Ignore;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.resources.Resources; import org.nd4j.resources.Resources;
@ -250,7 +251,7 @@ public class Keras2ModelConfigurationTest extends BaseDL4JTest {
.enforceTrainingConfig(false).buildSequential().getMultiLayerConfiguration(); .enforceTrainingConfig(false).buildSequential().getMultiLayerConfiguration();
MultiLayerNetwork model = new MultiLayerNetwork(config); MultiLayerNetwork model = new MultiLayerNetwork(config);
model.init(); model.init();
INDArray input = Nd4j.create(50, 500, 1500); INDArray input = Nd4j.create(DataType.FLOAT, 50, 1500, 500); //NWC format - [Minibatch, seqLength, channels]
INDArray out = model.output(input); INDArray out = model.output(input);
assertTrue(Arrays.equals(out.shape(), new long[]{50, 64})); assertTrue(Arrays.equals(out.shape(), new long[]{50, 64}));
} }

View File

@ -87,15 +87,6 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
@Rule @Rule
public final TemporaryFolder testDir = new TemporaryFolder(); public final TemporaryFolder testDir = new TemporaryFolder();
public static final BiFunction<String,INDArray,INDArray> nwc2ncwExpected = new BiFunction<String, INDArray, INDArray>() {
@Override
public INDArray apply(String s, INDArray array) {
if(array.rank() == 3)
return array.permute(0, 2, 1); //NWC to NCW
return array;
}
};
@Override @Override
public long getTimeoutMilliseconds() { public long getTimeoutMilliseconds() {
return 180000L; //Most benchmarks should run very quickly; large timeout is to avoid issues with unusually slow download of test resources return 180000L; //Most benchmarks should run very quickly; large timeout is to avoid issues with unusually slow download of test resources
@ -169,28 +160,28 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
public void importImdbLstmTfKeras1() throws Exception { public void importImdbLstmTfKeras1() throws Exception {
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_model.h5"; String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_model.h5";
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_inputs_and_outputs.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected); importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, null);
} }
@Test @Test
public void importImdbLstmThKeras1() throws Exception { public void importImdbLstmThKeras1() throws Exception {
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_model.h5"; String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_model.h5";
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_inputs_and_outputs.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected); importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, null);
} }
@Test @Test
public void importImdbLstmTfKeras2() throws Exception { public void importImdbLstmTfKeras2() throws Exception {
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_model.h5"; String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_model.h5";
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_inputs_and_outputs.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected); importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, null);
} }
@Test @Test
public void importImdbLstmThKeras2() throws Exception { public void importImdbLstmThKeras2() throws Exception {
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_model.h5"; String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_model.h5";
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_inputs_and_outputs.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, false, true, false, false, true, null, nwc2ncwExpected); importEndModelTest(modelPath, inputsOutputPath, false, true, false, false, true, null, null);
} }
/** /**
@ -262,7 +253,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
String modelPath = "modelimport/keras/examples/simple_flatten_rnn/simple_flatten_rnn_tf_keras_2_model.h5"; String modelPath = "modelimport/keras/examples/simple_flatten_rnn/simple_flatten_rnn_tf_keras_2_model.h5";
String inputsOutputPath = "modelimport/keras/examples/simple_flatten_rnn/" + String inputsOutputPath = "modelimport/keras/examples/simple_flatten_rnn/" +
"simple_flatten_rnn_tf_keras_2_inputs_and_outputs.h5"; "simple_flatten_rnn_tf_keras_2_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected); importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, null);
} }
/** /**
@ -316,7 +307,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
@Test @Test
public void importAcganDiscriminator() throws Exception { public void importAcganDiscriminator() throws Exception {
ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/acgan/acgan_discriminator_1_epochs.h5"); ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/acgan/acgan_discriminator_1_epochs.h5");
INDArray input = Nd4j.create(10, 1, 28, 28); INDArray input = Nd4j.create(10, 28, 28, 1); //NHWC
INDArray[] output = model.output(input); INDArray[] output = model.output(input);
} }
@ -403,7 +394,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
// Make predictions // Make predictions
int miniBatch = 32; int miniBatch = 32;
INDArray input = Nd4j.ones(miniBatch, 4, 10); INDArray input = Nd4j.ones(miniBatch, 10, 4); //NWC format - with nIn=4, seqLength = 10
INDArray[] out = graph.output(input); INDArray[] out = graph.output(input);
// Fit model // Fit model
@ -450,7 +441,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
@Test @Test
public void importMobileNet() throws Exception { public void importMobileNet() throws Exception {
ComputationGraph graph = importFunctionalModelH5Test("modelimport/keras/examples/mobilenet/alternative.hdf5"); ComputationGraph graph = importFunctionalModelH5Test("modelimport/keras/examples/mobilenet/alternative.hdf5");
INDArray input = Nd4j.ones(10, 3, 299, 299); INDArray input = Nd4j.ones(10, 299, 299, 3);
graph.output(input); graph.output(input);
} }
@ -462,7 +453,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
int[] inputShape = new int[]{299, 299, 3}; int[] inputShape = new int[]{299, 299, 3};
ComputationGraph graph = importFunctionalModelH5Test( ComputationGraph graph = importFunctionalModelH5Test(
"modelimport/keras/examples/inception/inception_tf_keras_2.h5", inputShape, false); "modelimport/keras/examples/inception/inception_tf_keras_2.h5", inputShape, false);
INDArray input = Nd4j.ones(10, 3, 299, 299); INDArray input = Nd4j.ones(10, 299, 299, 3); //TF = channels last = NHWC
graph.output(input); graph.output(input);
System.out.println(graph.summary()); System.out.println(graph.summary());
} }
@ -476,7 +467,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
public void importInception() throws Exception { public void importInception() throws Exception {
ComputationGraph graph = importFunctionalModelH5Test( ComputationGraph graph = importFunctionalModelH5Test(
"modelimport/keras/examples/inception/inception_v3_complete.h5"); "modelimport/keras/examples/inception/inception_v3_complete.h5");
INDArray input = Nd4j.ones(10, 3, 299, 299); INDArray input = Nd4j.ones(10, 299, 299, 3); //TF = channels last = NHWC
graph.output(input); graph.output(input);
System.out.println(graph.summary()); System.out.println(graph.summary());
} }
@ -533,14 +524,14 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
* - Separate (policy and value) residual architecture * - Separate (policy and value) residual architecture
* - Separate (policy and value) convolutional architecture * - Separate (policy and value) convolutional architecture
*/ */
@Test @Test @Ignore //AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last
public void importSepConvPolicy() throws Exception { public void importSepConvPolicy() throws Exception {
ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/sep_conv_policy.h5"); ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/sep_conv_policy.h5");
INDArray input = Nd4j.create(32, 19, 19, 10); INDArray input = Nd4j.create(32, 19, 19, 10);
model.output(input); model.output(input);
} }
@Test @Test @Ignore //AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last
public void importSepResPolicy() throws Exception { public void importSepResPolicy() throws Exception {
ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/sep_res_policy.h5"); ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/sep_res_policy.h5");
INDArray input = Nd4j.create(32, 19, 19, 10); INDArray input = Nd4j.create(32, 19, 19, 10);
@ -548,28 +539,28 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
} }
@Test @Test @Ignore //AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last
public void importSepConvValue() throws Exception { public void importSepConvValue() throws Exception {
ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/sep_conv_value.h5"); ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/sep_conv_value.h5");
INDArray input = Nd4j.create(32, 19, 19, 10); INDArray input = Nd4j.create(32, 19, 19, 10);
model.output(input); model.output(input);
} }
@Test @Test @Ignore //AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last
public void importSepResValue() throws Exception { public void importSepResValue() throws Exception {
ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/sep_res_value.h5"); ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/sep_res_value.h5");
INDArray input = Nd4j.create(32, 19, 19, 10); INDArray input = Nd4j.create(32, 19, 19, 10);
model.output(input); model.output(input);
} }
@Test @Test @Ignore //AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last
public void importDualRes() throws Exception { public void importDualRes() throws Exception {
ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/dual_res.h5"); ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/dual_res.h5");
INDArray input = Nd4j.create(32, 19, 19, 10); INDArray input = Nd4j.create(32, 19, 19, 10);
model.output(input); model.output(input);
} }
@Test @Test @Ignore //AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last
public void importDualConv() throws Exception { public void importDualConv() throws Exception {
ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/dual_conv.h5"); ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/dual_conv.h5");
INDArray input = Nd4j.create(32, 19, 19, 10); INDArray input = Nd4j.create(32, 19, 19, 10);
@ -634,16 +625,9 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
System.out.println("Starting test: " + name); System.out.println("Starting test: " + name);
String modelPath = "modelimport/keras/examples/causal_conv1d/" + name; String modelPath = "modelimport/keras/examples/causal_conv1d/" + name;
String inputsOutputPath = "modelimport/keras/examples/causal_conv1d/" + (name.substring(0,name.length()-"model.h5".length()) + "inputs_and_outputs.h5"); String inputsOutputPath = "modelimport/keras/examples/causal_conv1d/" + (name.substring(0,name.length()-"model.h5".length()) + "inputs_and_outputs.h5");
Function<INDArray,INDArray> f = new Function<INDArray, INDArray>() {
@Override
public INDArray apply(INDArray i) {
//NWC to NCW
return i.permute(0, 2, 1);
}
};
MultiLayerNetwork net = importEndModelTest(modelPath, inputsOutputPath, true, true, MultiLayerNetwork net = importEndModelTest(modelPath, inputsOutputPath, true, true,
true, true, false, f, nwc2ncwExpected); true, true, false, null, null);
Layer l = net.getLayer(0); Layer l = net.getLayer(0);
Convolution1DLayer c1d = (Convolution1DLayer) l.getConfig(); Convolution1DLayer c1d = (Convolution1DLayer) l.getConfig();
assertEquals(ConvolutionMode.Causal, c1d.getConvolutionMode()); assertEquals(ConvolutionMode.Causal, c1d.getConvolutionMode());
@ -707,25 +691,9 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
System.out.println("Starting test: " + name); System.out.println("Starting test: " + name);
String modelPath = "modelimport/keras/examples/conv1d/" + name; String modelPath = "modelimport/keras/examples/conv1d/" + name;
String inputsOutputPath = "modelimport/keras/examples/conv1d/" + (name.substring(0,name.length()-"model.h5".length()) + "inputs_and_outputs.h5"); String inputsOutputPath = "modelimport/keras/examples/conv1d/" + (name.substring(0,name.length()-"model.h5".length()) + "inputs_and_outputs.h5");
Function<INDArray,INDArray> f = name.contains("_cf_") ? null : new Function<INDArray, INDArray>() {
@Override
public INDArray apply(INDArray i) {
//NWC to NCW
return i.permute(0, 2, 1);
}
};
BiFunction<String,INDArray,INDArray> f2 = name.contains("_cf_") ? null : new BiFunction<String, INDArray, INDArray>() {
@Override
public INDArray apply(String s, INDArray array) {
// if("conv".equals(s)){
return array.permute(0, 2, 1);
// }
}
};
importEndModelTest(modelPath, inputsOutputPath, true, true, importEndModelTest(modelPath, inputsOutputPath, true, true,
true, true, false, f, f2); true, true, false, null, null); //f, f2);
} }
} }
@ -882,8 +850,6 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
INDArray[] inputs = new INDArray[inputNames.size()]; INDArray[] inputs = new INDArray[inputNames.size()];
for (int i = 0; i < inputNames.size(); i++) { for (int i = 0; i < inputNames.size(); i++) {
inputs[i] = archive.readDataSet(inputNames.get(i), GROUP_ATTR_INPUTS); inputs[i] = archive.readDataSet(inputNames.get(i), GROUP_ATTR_INPUTS);
if (inputs[i].shape().length == 4 && tensorFlowImageDimOrdering)
inputs[i] = inputs[i].permute(0, 3, 1, 2);
} }
return inputs; return inputs;
} }
@ -893,8 +859,6 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
Map<String, INDArray> activations = new HashMap<String, INDArray>(); Map<String, INDArray> activations = new HashMap<String, INDArray>();
for (String layerName : archive.getDataSets(GROUP_ACTIVATIONS)) { for (String layerName : archive.getDataSets(GROUP_ACTIVATIONS)) {
INDArray activation = archive.readDataSet(layerName, GROUP_ACTIVATIONS); INDArray activation = archive.readDataSet(layerName, GROUP_ACTIVATIONS);
if (activation.shape().length == 4 && tensorFlowImageDimOrdering)
activation = activation.permute(0, 3, 1, 2);
activations.put(layerName, activation); activations.put(layerName, activation);
} }
return activations; return activations;
@ -907,8 +871,6 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
INDArray[] outputs = new INDArray[outputNames.size()]; INDArray[] outputs = new INDArray[outputNames.size()];
for (int i = 0; i < outputNames.size(); i++) { for (int i = 0; i < outputNames.size(); i++) {
outputs[i] = archive.readDataSet(outputNames.get(i), GROUP_ATTR_OUTPUTS); outputs[i] = archive.readDataSet(outputNames.get(i), GROUP_ATTR_OUTPUTS);
if (outputs[i].shape().length == 4 && tensorFlowImageDimOrdering)
outputs[i] = outputs[i].permute(0, 3, 1, 2);
} }
return outputs; return outputs;
} }
@ -920,8 +882,6 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
INDArray[] predictions = new INDArray[outputNames.size()]; INDArray[] predictions = new INDArray[outputNames.size()];
for (int i = 0; i < outputNames.size(); i++) { for (int i = 0; i < outputNames.size(); i++) {
predictions[i] = archive.readDataSet(outputNames.get(i), GROUP_PREDICTIONS); predictions[i] = archive.readDataSet(outputNames.get(i), GROUP_PREDICTIONS);
if (predictions[i].shape().length == 4 && tensorFlowImageDimOrdering)
predictions[i] = predictions[i].permute(0, 3, 1, 2);
} }
return predictions; return predictions;
} }
@ -941,6 +901,11 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
// skip too small absolute inputs // skip too small absolute inputs
if (Math.abs(aAbsMax) > threshold && Math.abs(bAbsMax) > threshold) { if (Math.abs(aAbsMax) > threshold && Math.abs(bAbsMax) > threshold) {
boolean eq = expected.equalsWithEps(actual.castTo(expected.dataType()), eps); boolean eq = expected.equalsWithEps(actual.castTo(expected.dataType()), eps);
if(!eq){
System.out.println("Expected: " + Arrays.toString(expected.shape()) + ", actual: " + Arrays.toString(actual.shape()));
System.out.println("Expected:\n" + expected);
System.out.println("Actual: \n" + actual);
}
assertTrue("Output differs: " + label, eq); assertTrue("Output differs: " + label, eq);
} }
} }

View File

@ -176,10 +176,10 @@ public class KerasWeightSettingTests extends BaseDL4JTest {
INDArray bias = model.getLayer(0).getParam("b"); INDArray bias = model.getLayer(0).getParam("b");
assertEquals(6, bias.length()); assertEquals(6, bias.length());
INDArray input = Nd4j.ones(1, 5, 3, 4); INDArray input = Nd4j.ones(1, 3, 4, 5); //NHWC
INDArray output = model.output(input); INDArray output = model.output(input);
assertArrayEquals(new long[] {1, 6, 1, 2}, output.shape()); assertArrayEquals(new long[] {1, 1, 2, 6}, output.shape()); //NHWC
logSuccess(modelPath); logSuccess(modelPath);
} }
@ -224,7 +224,7 @@ public class KerasWeightSettingTests extends BaseDL4JTest {
INDArray input = Nd4j.zeros(mb, inputLength); INDArray input = Nd4j.zeros(mb, inputLength);
INDArray output = model.output(input); INDArray output = model.output(input);
assertArrayEquals(new long[]{mb, nOut, inputLength - kernel + 1}, output.shape()); assertArrayEquals(new long[]{mb, inputLength - kernel + 1, nOut}, output.shape()); //NWC
logSuccess(modelPath); logSuccess(modelPath);
} }
@ -238,9 +238,9 @@ public class KerasWeightSettingTests extends BaseDL4JTest {
KerasLayer.registerCustomLayer("Lambda", KerasSpaceToDepth.class); KerasLayer.registerCustomLayer("Lambda", KerasSpaceToDepth.class);
MultiLayerNetwork model = loadMultiLayerNetwork(modelPath, false); MultiLayerNetwork model = loadMultiLayerNetwork(modelPath, false);
INDArray input = Nd4j.zeros(10, 4, 6, 6); INDArray input = Nd4j.zeros(10, 6, 6, 4);
INDArray output = model.output(input); INDArray output = model.output(input);
assertArrayEquals(new long[]{10, 16, 3, 3}, output.shape()); assertArrayEquals(new long[]{10, 3, 3, 16}, output.shape());
logSuccess(modelPath); logSuccess(modelPath);
} }
@ -248,10 +248,11 @@ public class KerasWeightSettingTests extends BaseDL4JTest {
KerasLayer.registerCustomLayer("Lambda", KerasSpaceToDepth.class); KerasLayer.registerCustomLayer("Lambda", KerasSpaceToDepth.class);
ComputationGraph model = loadComputationalGraph(modelPath, false); ComputationGraph model = loadComputationalGraph(modelPath, false);
INDArray input[] = new INDArray[]{Nd4j.zeros(10, 4, 6, 6), Nd4j.zeros(10, 16, 3, 3)}; // INDArray input[] = new INDArray[]{Nd4j.zeros(10, 4, 6, 6), Nd4j.zeros(10, 16, 3, 3)};
INDArray input[] = new INDArray[]{Nd4j.zeros(10, 6, 6, 4), Nd4j.zeros(10, 3, 3, 16)};
INDArray[] output = model.output(input); INDArray[] output = model.output(input);
log.info(Arrays.toString(output[0].shape())); log.info(Arrays.toString(output[0].shape()));
assertArrayEquals(new long[]{10, 32, 3, 3}, output[0].shape()); assertArrayEquals(new long[]{10, 3, 3, 32}, output[0].shape());
logSuccess(modelPath); logSuccess(modelPath);
} }
@ -278,7 +279,7 @@ public class KerasWeightSettingTests extends BaseDL4JTest {
INDArray inEmbedding = Nd4j.zeros(mb, inputLength); INDArray inEmbedding = Nd4j.zeros(mb, inputLength);
INDArray output = model.output(inEmbedding); INDArray output = model.output(inEmbedding);
assertArrayEquals(new long[]{mb, nOut, inputLength}, output.shape()); assertArrayEquals(new long[]{mb, inputLength, nOut}, output.shape()); //NWC format
logSuccess(modelPath); logSuccess(modelPath);
} }
@ -304,7 +305,7 @@ public class KerasWeightSettingTests extends BaseDL4JTest {
INDArray inEmbedding = Nd4j.zeros(mb, inputLength); INDArray inEmbedding = Nd4j.zeros(mb, inputLength);
INDArray output = model.output(inEmbedding); INDArray output = model.output(inEmbedding);
assertArrayEquals(new long[]{mb, nOut, inputLength - kernel + 1}, output.shape()); assertArrayEquals(new long[]{mb, inputLength - kernel + 1, nOut}, output.shape()); //NWC
logSuccess(modelPath); logSuccess(modelPath);
} }

View File

@ -9,7 +9,7 @@ package org.deeplearning4j.nn.conf;
* *
* @author Alex Black * @author Alex Black
*/ */
public enum CNN2DFormat { public enum CNN2DFormat implements DataFormat {
NCHW, NCHW,
NHWC; NHWC;

View File

@ -0,0 +1,26 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.nn.conf;
import org.deeplearning4j.nn.conf.serde.format.DataFormatDeserializer;
import org.deeplearning4j.nn.conf.serde.format.DataFormatSerializer;
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
@JsonSerialize(using = DataFormatSerializer.class)
@JsonDeserialize(using = DataFormatDeserializer.class)
public interface DataFormat {
}

View File

@ -663,7 +663,7 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
BaseRecurrentLayer brl = (BaseRecurrentLayer) firstLayer; BaseRecurrentLayer brl = (BaseRecurrentLayer) firstLayer;
val nIn = brl.getNIn(); val nIn = brl.getNIn();
if (nIn > 0) { if (nIn > 0) {
inputType = InputType.recurrent(nIn); inputType = InputType.recurrent(nIn, brl.getRnnDataFormat());
} }
} else if (firstLayer instanceof DenseLayer || firstLayer instanceof EmbeddingLayer } else if (firstLayer instanceof DenseLayer || firstLayer instanceof EmbeddingLayer
|| firstLayer instanceof OutputLayer) { || firstLayer instanceof OutputLayer) {

View File

@ -23,7 +23,7 @@ package org.deeplearning4j.nn.conf;
* "width" corresponds to sequence length and "channels" corresponds to sequence item size. * "width" corresponds to sequence length and "channels" corresponds to sequence item size.
*/ */
public enum RNNFormat { public enum RNNFormat implements DataFormat {
NCW, NCW,
NWC NWC
} }

View File

@ -18,6 +18,8 @@ package org.deeplearning4j.nn.conf.graph;
import lombok.val; import lombok.val;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException; import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
import org.deeplearning4j.nn.conf.layers.Convolution3D; import org.deeplearning4j.nn.conf.layers.Convolution3D;
@ -38,6 +40,8 @@ import org.nd4j.linalg.api.ndarray.INDArray;
*/ */
public class MergeVertex extends GraphVertex { public class MergeVertex extends GraphVertex {
protected int mergeAxis = 1; //default value for backward compatibility (deserialization of old version JSON) - NCHW and NCW format
@Override @Override
public MergeVertex clone() { public MergeVertex clone() {
return new MergeVertex(); return new MergeVertex();
@ -76,7 +80,7 @@ public class MergeVertex extends GraphVertex {
@Override @Override
public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx, public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx,
INDArray paramsView, boolean initializeParams, DataType networkDatatype) { INDArray paramsView, boolean initializeParams, DataType networkDatatype) {
return new org.deeplearning4j.nn.graph.vertex.impl.MergeVertex(graph, name, idx, networkDatatype); return new org.deeplearning4j.nn.graph.vertex.impl.MergeVertex(graph, name, idx, networkDatatype, mergeAxis);
} }
@Override @Override
@ -126,6 +130,7 @@ public class MergeVertex extends GraphVertex {
//FF or RNN data inputs //FF or RNN data inputs
int size = 0; int size = 0;
InputType.Type type = null; InputType.Type type = null;
RNNFormat format = null;
for (int i = 0; i < vertexInputs.length; i++) { for (int i = 0; i < vertexInputs.length; i++) {
if (vertexInputs[i].getType() != first.getType()) { if (vertexInputs[i].getType() != first.getType()) {
throw new InvalidInputTypeException( throw new InvalidInputTypeException(
@ -142,6 +147,8 @@ public class MergeVertex extends GraphVertex {
break; break;
case RNN: case RNN:
thisSize = ((InputType.InputTypeRecurrent) vertexInputs[i]).getSize(); thisSize = ((InputType.InputTypeRecurrent) vertexInputs[i]).getSize();
format = ((InputType.InputTypeRecurrent) vertexInputs[i]).getFormat();
this.mergeAxis = format == RNNFormat.NCW ? 1 : 2;
type = InputType.Type.RNN; type = InputType.Type.RNN;
break; break;
default: default:
@ -160,7 +167,7 @@ public class MergeVertex extends GraphVertex {
return InputType.feedForward(size); return InputType.feedForward(size);
} else { } else {
val tsLength = ((InputType.InputTypeRecurrent) vertexInputs[0]).getTimeSeriesLength(); val tsLength = ((InputType.InputTypeRecurrent) vertexInputs[0]).getTimeSeriesLength();
return InputType.recurrent(size, tsLength); return InputType.recurrent(size, tsLength, format);
} }
} else { } else {
//size is unknown //size is unknown
@ -168,13 +175,14 @@ public class MergeVertex extends GraphVertex {
return InputType.feedForward(-1); return InputType.feedForward(-1);
} else { } else {
val tsLength = ((InputType.InputTypeRecurrent) vertexInputs[0]).getTimeSeriesLength(); val tsLength = ((InputType.InputTypeRecurrent) vertexInputs[0]).getTimeSeriesLength();
return InputType.recurrent(-1, tsLength); return InputType.recurrent(-1, tsLength, format);
} }
} }
} else { } else {
//CNN inputs... also check that the channels, width and heights match: //CNN inputs... also check that the channels, width and heights match:
InputType.InputTypeConvolutional firstConv = (InputType.InputTypeConvolutional) first; InputType.InputTypeConvolutional firstConv = (InputType.InputTypeConvolutional) first;
CNN2DFormat format = firstConv.getFormat();
val fd = firstConv.getChannels(); val fd = firstConv.getChannels();
val fw = firstConv.getWidth(); val fw = firstConv.getWidth();
@ -206,7 +214,8 @@ public class MergeVertex extends GraphVertex {
depthSum += od; depthSum += od;
} }
return InputType.convolutional(fh, fw, depthSum); this.mergeAxis = format == CNN2DFormat.NCHW ? 1 : 3;
return InputType.convolutional(fh, fw, depthSum, format);
} }
} }

View File

@ -20,6 +20,7 @@ import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.Getter; import lombok.Getter;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import org.deeplearning4j.nn.conf.DataFormat;
import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.layers.Convolution3D; import org.deeplearning4j.nn.conf.layers.Convolution3D;
@ -91,7 +92,11 @@ public abstract class InputType implements Serializable {
* @return InputTypeFeedForward * @return InputTypeFeedForward
*/ */
public static InputType feedForward(long size) { public static InputType feedForward(long size) {
return new InputTypeFeedForward(size); return new InputTypeFeedForward(size, null);
}
public static InputType feedForward(long size, DataFormat timeDistributedFormat) {
return new InputTypeFeedForward(size,timeDistributedFormat);
} }
/** /**
@ -132,7 +137,6 @@ public abstract class InputType implements Serializable {
* @return InputTypeConvolutional * @return InputTypeConvolutional
*/ */
public static InputType convolutional(long height, long width, long depth) { public static InputType convolutional(long height, long width, long depth) {
// return new InputTypeConvolutional(height, width, depth);
return convolutional(height, width, depth, CNN2DFormat.NCHW); return convolutional(height, width, depth, CNN2DFormat.NCHW);
} }
@ -191,9 +195,11 @@ public abstract class InputType implements Serializable {
@EqualsAndHashCode(callSuper = false) @EqualsAndHashCode(callSuper = false)
public static class InputTypeFeedForward extends InputType { public static class InputTypeFeedForward extends InputType {
private long size; private long size;
private DataFormat timeDistributedFormat;
public InputTypeFeedForward(@JsonProperty("size") long size) { public InputTypeFeedForward(@JsonProperty("size") long size, @JsonProperty("timeDistributedFormat") DataFormat timeDistributedFormat) {
this.size = size; this.size = size;
this.timeDistributedFormat = timeDistributedFormat;
} }
@Override @Override
@ -203,7 +209,7 @@ public abstract class InputType implements Serializable {
@Override @Override
public String toString() { public String toString() {
return "InputTypeFeedForward(" + size + ")"; return "InputTypeFeedForward(" + size + (timeDistributedFormat != null ? "," + timeDistributedFormat : "") + ")";
} }
@Override @Override
@ -302,7 +308,8 @@ public abstract class InputType implements Serializable {
this.height = height; this.height = height;
this.width = width; this.width = width;
this.channels = channels; this.channels = channels;
this.format = format; if(format != null)
this.format = format;
} }
public InputTypeConvolutional(long height, long width, long channels) { public InputTypeConvolutional(long height, long width, long channels) {

View File

@ -64,11 +64,11 @@ public abstract class BaseRecurrentLayer extends FeedForwardLayer {
+ "\"): expect RNN input type with size > 0. Got: " + inputType); + "\"): expect RNN input type with size > 0. Got: " + inputType);
} }
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
if (nIn <= 0 || override) { if (nIn <= 0 || override) {
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
this.nIn = r.getSize(); this.nIn = r.getSize();
this.rnnDataFormat = r.getFormat();
} }
this.rnnDataFormat = r.getFormat();
} }
@Override @Override

View File

@ -44,6 +44,7 @@ import java.util.Map;
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
public class Convolution1DLayer extends ConvolutionLayer { public class Convolution1DLayer extends ConvolutionLayer {
private RNNFormat rnnDataFormat = RNNFormat.NCW;
/* /*
//TODO: We will eventually want to NOT subclass off of ConvolutionLayer. //TODO: We will eventually want to NOT subclass off of ConvolutionLayer.
//Currently, we just subclass off the ConvolutionLayer and hard code the "width" dimension to 1 //Currently, we just subclass off the ConvolutionLayer and hard code the "width" dimension to 1
@ -56,6 +57,7 @@ public class Convolution1DLayer extends ConvolutionLayer {
private Convolution1DLayer(Builder builder) { private Convolution1DLayer(Builder builder) {
super(builder); super(builder);
initializeConstraints(builder); initializeConstraints(builder);
this.rnnDataFormat = builder.rnnDataFormat;
} }
@Override @Override
@ -92,7 +94,8 @@ public class Convolution1DLayer extends ConvolutionLayer {
outLength = Convolution1DUtils.getOutputSize(inputTsLength, kernelSize[0], stride[0], padding[0], outLength = Convolution1DUtils.getOutputSize(inputTsLength, kernelSize[0], stride[0], padding[0],
convolutionMode, dilation[0]); convolutionMode, dilation[0]);
} }
return InputType.recurrent(nOut, outLength);
return InputType.recurrent(nOut, outLength, rnnDataFormat);
} }
@Override @Override
@ -102,10 +105,11 @@ public class Convolution1DLayer extends ConvolutionLayer {
+ "\"): expect RNN input type with size > 0. Got: " + inputType); + "\"): expect RNN input type with size > 0. Got: " + inputType);
} }
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
if (nIn <= 0 || override) { if (nIn <= 0 || override) {
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
this.nIn = r.getSize(); this.nIn = r.getSize();
} }
this.rnnDataFormat = r.getFormat();
} }
@Override @Override
@ -115,11 +119,13 @@ public class Convolution1DLayer extends ConvolutionLayer {
+ "\"): input is null"); + "\"): input is null");
} }
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, RNNFormat.NCW,getLayerName()); return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, rnnDataFormat,getLayerName());
} }
public static class Builder extends ConvolutionLayer.BaseConvBuilder<Builder> { public static class Builder extends ConvolutionLayer.BaseConvBuilder<Builder> {
private RNNFormat rnnDataFormat = RNNFormat.NCW;
public Builder() { public Builder() {
this(0, 1, 0); this(0, 1, 0);
this.setKernelSize((int[]) null); this.setKernelSize((int[]) null);
@ -130,6 +136,11 @@ public class Convolution1DLayer extends ConvolutionLayer {
return true; return true;
} }
public Builder rnnDataFormat(RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
return this;
}
/** /**
* @param kernelSize Kernel size * @param kernelSize Kernel size
* @param stride Stride * @param stride Stride

View File

@ -21,6 +21,7 @@ import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport;
@ -58,12 +59,14 @@ public class EmbeddingSequenceLayer extends FeedForwardLayer {
private int inputLength = 1; // By default only use one index to embed private int inputLength = 1; // By default only use one index to embed
private boolean hasBias = false; private boolean hasBias = false;
private boolean inferInputLength = false; // use input length as provided by input data private boolean inferInputLength = false; // use input length as provided by input data
private RNNFormat outputFormat = RNNFormat.NCW; //Default value for older deserialized models
private EmbeddingSequenceLayer(Builder builder) { private EmbeddingSequenceLayer(Builder builder) {
super(builder); super(builder);
this.hasBias = builder.hasBias; this.hasBias = builder.hasBias;
this.inputLength = builder.inputLength; this.inputLength = builder.inputLength;
this.inferInputLength = builder.inferInputLength; this.inferInputLength = builder.inferInputLength;
this.outputFormat = builder.outputFormat;
initializeConstraints(builder); initializeConstraints(builder);
} }
@ -87,7 +90,7 @@ public class EmbeddingSequenceLayer extends FeedForwardLayer {
throw new IllegalStateException("Invalid input for Embedding layer (layer index = " + layerIndex throw new IllegalStateException("Invalid input for Embedding layer (layer index = " + layerIndex
+ ", layer name = \"" + getLayerName() + "\"): expect FF/RNN input type. Got: " + inputType); + ", layer name = \"" + getLayerName() + "\"): expect FF/RNN input type. Got: " + inputType);
} }
return InputType.recurrent(nOut, inputLength); return InputType.recurrent(nOut, inputLength, outputFormat);
} }
@Override @Override
@ -167,6 +170,13 @@ public class EmbeddingSequenceLayer extends FeedForwardLayer {
*/ */
private boolean inferInputLength = true; private boolean inferInputLength = true;
private RNNFormat outputFormat = RNNFormat.NCW; //Default value for older deserialized models
public Builder outputDataFormat(RNNFormat format){
this.outputFormat = format;
return this;
}
/** /**
* If true: include bias parameters in the layer. False (default): no bias. * If true: include bias parameters in the layer. False (default): no bias.
* *

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.nn.conf.layers; package org.deeplearning4j.nn.conf.layers;
import lombok.*; import lombok.*;
import org.deeplearning4j.nn.conf.DataFormat;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.preprocessor.Cnn3DToFeedForwardPreProcessor; import org.deeplearning4j.nn.conf.preprocessor.Cnn3DToFeedForwardPreProcessor;
@ -35,6 +36,7 @@ public abstract class FeedForwardLayer extends BaseLayer {
protected long nIn; protected long nIn;
protected long nOut; protected long nOut;
protected DataFormat timeDistributedFormat;
public FeedForwardLayer(Builder builder) { public FeedForwardLayer(Builder builder) {
super(builder); super(builder);
@ -51,7 +53,7 @@ public abstract class FeedForwardLayer extends BaseLayer {
+ getLayerName() + "\"): expected FeedForward input type. Got: " + inputType); + getLayerName() + "\"): expected FeedForward input type. Got: " + inputType);
} }
return InputType.feedForward(nOut); return InputType.feedForward(nOut, timeDistributedFormat);
} }
@Override @Override
@ -71,6 +73,11 @@ public abstract class FeedForwardLayer extends BaseLayer {
this.nIn = f.getFlattenedSize(); this.nIn = f.getFlattenedSize();
} }
} }
if(inputType instanceof InputType.InputTypeFeedForward){
InputType.InputTypeFeedForward f = (InputType.InputTypeFeedForward) inputType;
this.timeDistributedFormat = f.getTimeDistributedFormat();
}
} }
@Override @Override

View File

@ -536,11 +536,17 @@ public class InputTypeUtil {
} }
switch (inputType.getType()) { switch (inputType.getType()) {
case FF:
case CNNFlat: case CNNFlat:
//FF -> RNN or CNNFlat -> RNN //FF -> RNN or CNNFlat -> RNN
//In either case, input data format is a row vector per example //In either case, input data format is a row vector per example
return new FeedForwardToRnnPreProcessor(rnnDataFormat); return new FeedForwardToRnnPreProcessor(rnnDataFormat);
case FF:
//If time distributed format is defined, use that. Otherwise use the layer-defined rnnDataFormat, which may be default
InputType.InputTypeFeedForward ff = (InputType.InputTypeFeedForward)inputType;
if(ff.getTimeDistributedFormat() != null && ff.getTimeDistributedFormat() instanceof RNNFormat){
return new FeedForwardToRnnPreProcessor((RNNFormat) ff.getTimeDistributedFormat());
}
return new FeedForwardToRnnPreProcessor(rnnDataFormat);
case RNN: case RNN:
//RNN -> RNN: No preprocessor necessary //RNN -> RNN: No preprocessor necessary
return null; return null;

View File

@ -98,9 +98,9 @@ public class RnnOutputLayer extends BaseOutputLayer {
+ "\"): Expected RNN input, got " + inputType); + "\"): Expected RNN input, got " + inputType);
} }
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
this.rnnDataFormat = r.getFormat();
if (nIn <= 0 || override) { if (nIn <= 0 || override) {
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
this.rnnDataFormat = r.getFormat();
this.nIn = r.getSize(); this.nIn = r.getSize();
} }
} }

View File

@ -91,7 +91,7 @@ public class Subsampling1DLayer extends SubsamplingLayer {
outLength = Convolution1DUtils.getOutputSize(inputTsLength, kernelSize[0], stride[0], padding[0], outLength = Convolution1DUtils.getOutputSize(inputTsLength, kernelSize[0], stride[0], padding[0],
convolutionMode, dilation[0]); convolutionMode, dilation[0]);
} }
return InputType.recurrent(r.getSize(), outLength); return InputType.recurrent(r.getSize(), outLength, r.getFormat());
} }
@Override @Override

View File

@ -19,6 +19,7 @@ package org.deeplearning4j.nn.conf.layers.misc;
import lombok.*; import lombok.*;
import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
@ -46,10 +47,12 @@ import java.util.Map;
public class RepeatVector extends FeedForwardLayer { public class RepeatVector extends FeedForwardLayer {
private int n = 1; private int n = 1;
private RNNFormat dataFormat = RNNFormat.NCW;
protected RepeatVector(Builder builder) { protected RepeatVector(Builder builder) {
super(builder); super(builder);
this.n = builder.n; this.n = builder.n;
this.dataFormat = builder.dataFormat;
} }
@Override @Override
@ -83,7 +86,7 @@ public class RepeatVector extends FeedForwardLayer {
+ "\"): Expected FF input, got " + inputType); + "\"): Expected FF input, got " + inputType);
} }
InputType.InputTypeFeedForward ffInput = (InputType.InputTypeFeedForward) inputType; InputType.InputTypeFeedForward ffInput = (InputType.InputTypeFeedForward) inputType;
return InputType.recurrent(ffInput.getSize(), n); return InputType.recurrent(ffInput.getSize(), n, this.dataFormat);
} }
@Override @Override
@ -101,13 +104,14 @@ public class RepeatVector extends FeedForwardLayer {
} }
@NoArgsConstructor @NoArgsConstructor
@Getter @Getter
@Setter @Setter
public static class Builder<T extends Builder<T>> extends FeedForwardLayer.Builder<T> { public static class Builder<T extends Builder<T>> extends FeedForwardLayer.Builder<T> {
private int n = 1; // no repetition by default private int n = 1; // no repetition by default
private RNNFormat dataFormat = RNNFormat.NCW;
/** /**
* Set repetition factor for RepeatVector layer * Set repetition factor for RepeatVector layer
*/ */
@ -115,6 +119,15 @@ public class RepeatVector extends FeedForwardLayer {
return n; return n;
} }
public RNNFormat getDataFormat(){
return dataFormat;
}
public Builder dataFormat(RNNFormat dataFormat){
this.dataFormat = dataFormat;
return this;
}
/** /**
* Set repetition factor for RepeatVector layer * Set repetition factor for RepeatVector layer
* *

View File

@ -39,11 +39,13 @@ import java.util.Arrays;
* For example, CNN -> Denselayer <br> * For example, CNN -> Denselayer <br>
* This does two things:<br> * This does two things:<br>
* (b) Reshapes 4d activations out of CNN layer, with shape * (b) Reshapes 4d activations out of CNN layer, with shape
* [numExamples, numChannels, inputHeight, inputWidth]) into 2d activations (with shape * [numExamples, numChannels, inputHeight, inputWidth]) (for {@link CNN2DFormat#NCHW} format activations) or shape
* [numExamples, inputHeight*inputWidth*numChannels]) for use in feed forward layer * [numExamples, inputHeight, inputWidth, numChannels] (for {@link CNN2DFormat#NHWC}) format activations) into 2d activations
* (with shape [numExamples, inputHeight*inputWidth*numChannels]) for use in feed forward layer.
* (a) Reshapes epsilons (weights*deltas) out of FeedFoward layer (which is 2D or 3D with shape * (a) Reshapes epsilons (weights*deltas) out of FeedFoward layer (which is 2D or 3D with shape
* [numExamples, inputHeight*inputWidth*numChannels]) into 4d epsilons (with shape * [numExamples, inputHeight*inputWidth*numChannels]) into 4d epsilons (with shape
* [numExamples, numChannels, inputHeight, inputWidth]) suitable to feed into CNN layers.<br> * [numExamples, numChannels, inputHeight, inputWidth] or [numExamples, inputHeight, inputWidth, numChannels]) suitable to
* feed into CNN layers.<br>
* Note: numChannels is equivalent to channels or featureMaps referenced in different literature * Note: numChannels is equivalent to channels or featureMaps referenced in different literature
* @author Adam Gibson * @author Adam Gibson
* @see FeedForwardToCnnPreProcessor for opposite case (i.e., DenseLayer -> CNNetc) * @see FeedForwardToCnnPreProcessor for opposite case (i.e., DenseLayer -> CNNetc)
@ -68,7 +70,8 @@ public class CnnToFeedForwardPreProcessor implements InputPreProcessor {
this.inputHeight = inputHeight; this.inputHeight = inputHeight;
this.inputWidth = inputWidth; this.inputWidth = inputWidth;
this.numChannels = numChannels; this.numChannels = numChannels;
this.format = format; if(format != null)
this.format = format;
} }
public CnnToFeedForwardPreProcessor(long inputHeight, long inputWidth) { public CnnToFeedForwardPreProcessor(long inputHeight, long inputWidth) {
@ -96,10 +99,17 @@ public class CnnToFeedForwardPreProcessor implements InputPreProcessor {
wDim = 2; wDim = 2;
} }
if(inputHeight == 0 && inputWidth == 0 && numChannels == 0){
this.inputHeight = input.size(hDim);
this.inputWidth = input.size(wDim);
this.numChannels = input.size(chDim);
}
if(input.size(chDim) != numChannels || input.size(hDim) != inputHeight || input.size(wDim) != inputWidth){ if(input.size(chDim) != numChannels || input.size(hDim) != inputHeight || input.size(wDim) != inputWidth){
throw new IllegalStateException("Invalid input, does not match configuration: expected [minibatch, numChannels=" throw new IllegalStateException("Invalid input, does not match configuration: expected " +
+ numChannels + ", inputHeight=" + inputHeight + ", inputWidth=" + inputWidth + "] but got input array of" + (format == CNN2DFormat.NCHW ? "[minibatch, numChannels=" + numChannels + ", inputHeight=" + inputHeight + ", inputWidth=" + inputWidth + "] " :
"shape " + Arrays.toString(input.shape())); "[minibatch, inputHeight=" + inputHeight + ", inputWidth=" + inputWidth + ", numChannels=" + numChannels + "]") +
" but got input array of shape " + Arrays.toString(input.shape()));
} }
//Check input: nchw format //Check input: nchw format
@ -110,15 +120,13 @@ public class CnnToFeedForwardPreProcessor implements InputPreProcessor {
+ Arrays.toString(input.shape())); + Arrays.toString(input.shape()));
} }
if(format == CNN2DFormat.NHWC) {
input = input.permute(0, 3, 1, 2); //NHWC to NCHW
}
//Assume input is standard rank 4 activations out of CNN layer //Assume input is standard rank 4 activations out of CNN layer
//First: we require input to be in c order. But c order (as declared in array order) isn't enough; also need strides to be correct //First: we require input to be in c order. But c order (as declared in array order) isn't enough; also need strides to be correct
if (input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input)) if (input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input))
input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c'); input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c');
//Note that to match Tensorflow/Keras, we do a simple "c order reshape" for both NCHW and NHWC
val inShape = input.shape(); //[miniBatch,depthOut,outH,outW] val inShape = input.shape(); //[miniBatch,depthOut,outH,outW]
val outShape = new long[]{inShape[0], inShape[1] * inShape[2] * inShape[3]}; val outShape = new long[]{inShape[0], inShape[1] * inShape[2] * inShape[3]};
@ -139,11 +147,13 @@ public class CnnToFeedForwardPreProcessor implements InputPreProcessor {
+ inputHeight + " x columns " + inputWidth + " x channels " + numChannels + " but was instead " + inputHeight + " x columns " + inputWidth + " x channels " + numChannels + " but was instead "
+ Arrays.toString(epsilons.shape())); + Arrays.toString(epsilons.shape()));
INDArray ret = epsilons.reshape('c', epsilons.size(0), numChannels, inputHeight, inputWidth); INDArray ret;
if(format == CNN2DFormat.NCHW){
if(format == CNN2DFormat.NHWC){ ret = epsilons.reshape('c', epsilons.size(0), numChannels, inputHeight, inputWidth);
ret = ret.permute(0,2,3,1); //NCHW to NHWC } else {
ret = epsilons.reshape('c', epsilons.size(0), inputHeight, inputWidth, numChannels);
} }
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, ret); //Move if required to specified workspace return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, ret); //Move if required to specified workspace
} }

View File

@ -52,7 +52,8 @@ public class FeedForwardToRnnPreProcessor implements InputPreProcessor {
private RNNFormat rnnDataFormat = RNNFormat.NCW; private RNNFormat rnnDataFormat = RNNFormat.NCW;
public FeedForwardToRnnPreProcessor(@JsonProperty("rnnDataFormat") RNNFormat rnnDataFormat){ public FeedForwardToRnnPreProcessor(@JsonProperty("rnnDataFormat") RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat; if(rnnDataFormat != null)
this.rnnDataFormat = rnnDataFormat;
} }
@Override @Override
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) { public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {

View File

@ -57,7 +57,8 @@ public class RnnToFeedForwardPreProcessor implements InputPreProcessor {
private RNNFormat rnnDataFormat = RNNFormat.NCW; private RNNFormat rnnDataFormat = RNNFormat.NCW;
public RnnToFeedForwardPreProcessor(@JsonProperty("rnnDataFormat") RNNFormat rnnDataFormat){ public RnnToFeedForwardPreProcessor(@JsonProperty("rnnDataFormat") RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat; if(rnnDataFormat != null)
this.rnnDataFormat = rnnDataFormat;
} }
@Override @Override
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) { public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
@ -116,7 +117,7 @@ public class RnnToFeedForwardPreProcessor implements InputPreProcessor {
} }
InputType.InputTypeRecurrent rnn = (InputType.InputTypeRecurrent) inputType; InputType.InputTypeRecurrent rnn = (InputType.InputTypeRecurrent) inputType;
return InputType.feedForward(rnn.getSize()); return InputType.feedForward(rnn.getSize(), rnn.getFormat());
} }
@Override @Override

View File

@ -0,0 +1,52 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.nn.conf.serde.format;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.DataFormat;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.nd4j.shade.jackson.core.JsonParser;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.nd4j.shade.jackson.databind.DeserializationContext;
import org.nd4j.shade.jackson.databind.JsonDeserializer;
import org.nd4j.shade.jackson.databind.JsonNode;
import java.io.IOException;
/**
* Simple JSON deserializer for {@link DataFormat} instances - {@link CNN2DFormat} and {@link RNNFormat}
*
* @author Alex Black
*/
public class DataFormatDeserializer extends JsonDeserializer<DataFormat> {
@Override
public DataFormat deserialize(JsonParser jp, DeserializationContext deserializationContext) throws IOException, JsonProcessingException {
JsonNode node = jp.getCodec().readTree(jp);
String text = node.textValue();
switch (text){
case "NCHW":
return CNN2DFormat.NCHW;
case "NHWC":
return CNN2DFormat.NHWC;
case "NCW":
return RNNFormat.NCW;
case "NWC":
return RNNFormat.NWC;
default:
return null;
}
}
}

View File

@ -0,0 +1,37 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.nn.conf.serde.format;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.DataFormat;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.nd4j.shade.jackson.core.JsonGenerator;
import org.nd4j.shade.jackson.databind.JsonSerializer;
import org.nd4j.shade.jackson.databind.SerializerProvider;
import java.io.IOException;
/**
* Simple JSON deserializer for {@link DataFormat} instances - {@link CNN2DFormat} and {@link RNNFormat}
*
* @author Alex Black
*/
public class DataFormatSerializer extends JsonSerializer<DataFormat> {
@Override
public void serialize(DataFormat dataFormat, JsonGenerator jsonGenerator, SerializerProvider serializerProvider) throws IOException {
jsonGenerator.writeString(dataFormat.toString());
}
}

View File

@ -28,6 +28,7 @@ import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.ArrayType;
@ -48,14 +49,16 @@ public class MergeVertex extends BaseGraphVertex {
private long[][] forwardPassShapes; private long[][] forwardPassShapes;
private int fwdPassRank; private int fwdPassRank;
private int mergeAxis;
public MergeVertex(ComputationGraph graph, String name, int vertexIndex, DataType dataType) { public MergeVertex(ComputationGraph graph, String name, int vertexIndex, DataType dataType, int mergeAxis) {
this(graph, name, vertexIndex, null, null, dataType); this(graph, name, vertexIndex, null, null, dataType, mergeAxis);
} }
public MergeVertex(ComputationGraph graph, String name, int vertexIndex, VertexIndices[] inputVertices, public MergeVertex(ComputationGraph graph, String name, int vertexIndex, VertexIndices[] inputVertices,
VertexIndices[] outputVertices, DataType dataType) { VertexIndices[] outputVertices, DataType dataType, int mergeAxis) {
super(graph, name, vertexIndex, inputVertices, outputVertices, dataType); super(graph, name, vertexIndex, inputVertices, outputVertices, dataType);
this.mergeAxis = mergeAxis;
} }
@Override @Override
@ -92,7 +95,6 @@ public class MergeVertex extends BaseGraphVertex {
forwardPassShapes = new long[in.length][0]; forwardPassShapes = new long[in.length][0];
val nExamples = in[0].size(0); val nExamples = in[0].size(0);
int nOut = 0;
fwdPassRank = in[0].rank(); fwdPassRank = in[0].rank();
for (int i = 0; i < in.length; i++) { for (int i = 0; i < in.length; i++) {
val currShape = in[i].shape(); val currShape = in[i].shape();
@ -109,12 +111,11 @@ public class MergeVertex extends BaseGraphVertex {
+ Arrays.toString(in[0].shape()) + ", activations[" + i + Arrays.toString(in[0].shape()) + ", activations[" + i
+ "] shape: " + Arrays.toString(in[i].shape())); + "] shape: " + Arrays.toString(in[i].shape()));
} }
nOut += currShape[1]; //Same dimension for all of CNNs, FF, RNNs
} }
try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATIONS)){ try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATIONS)){
return Nd4j.concat(1, in); INDArray out = Nd4j.concat(mergeAxis, in);
return out;
} }
} }
@ -145,20 +146,16 @@ public class MergeVertex extends BaseGraphVertex {
break; break;
case 3: case 3:
for (int i = 0; i < forwardPassShapes.length; i++) { for (int i = 0; i < forwardPassShapes.length; i++) {
out[i].assign(epsilon.get(NDArrayIndex.all(), //All rows out[i].assign(epsilon.get(indices(3, mergeAxis, cumulative, cumulative + forwardPassShapes[i][mergeAxis]))); //All time steps
NDArrayIndex.interval(cumulative, cumulative + forwardPassShapes[i][1]), //subset of columns
NDArrayIndex.all())); //All time steps
cumulative += forwardPassShapes[i][1]; cumulative += forwardPassShapes[i][mergeAxis];
} }
break; break;
case 4: case 4:
for (int i = 0; i < forwardPassShapes.length; i++) { for (int i = 0; i < forwardPassShapes.length; i++) {
out[i].assign(epsilon.get(NDArrayIndex.all(), out[i].assign(epsilon.get(indices(4, mergeAxis, cumulative, cumulative + forwardPassShapes[i][mergeAxis]))); //height
NDArrayIndex.interval(cumulative, cumulative + forwardPassShapes[i][1]), //Subset of depth
NDArrayIndex.all(), //Width cumulative += forwardPassShapes[i][mergeAxis];
NDArrayIndex.all())); //height
cumulative += forwardPassShapes[i][1];
} }
break; break;
default: default:
@ -168,6 +165,19 @@ public class MergeVertex extends BaseGraphVertex {
return new Pair<>(null, out); return new Pair<>(null, out);
} }
private INDArrayIndex[] indices(int num, int axis, long from, long to){
INDArrayIndex[] out = new INDArrayIndex[num];
for( int i=0; i<num; i++ ){
if(i == axis){
out[i] = NDArrayIndex.interval(from, to);
} else {
out[i] = NDArrayIndex.all();
}
}
return out;
}
@Override @Override
public void setBackpropGradientsViewArray(INDArray backpropGradientsViewArray) { public void setBackpropGradientsViewArray(INDArray backpropGradientsViewArray) {
if (backpropGradientsViewArray != null) if (backpropGradientsViewArray != null)

View File

@ -79,7 +79,8 @@ public abstract class BaseOutputLayer<LayerConfT extends org.deeplearning4j.nn.c
ILossFunction lossFunction = layerConf().getLossFn(); ILossFunction lossFunction = layerConf().getLossFn();
double score = lossFunction.computeScore(getLabels2d(workspaceMgr, ArrayType.FF_WORKING_MEM), preOut, INDArray labels2d = getLabels2d(workspaceMgr, ArrayType.FF_WORKING_MEM);
double score = lossFunction.computeScore(labels2d, preOut,
layerConf().getActivationFn(), maskArray,false); layerConf().getActivationFn(), maskArray,false);
if(conf().isMiniBatch()) if(conf().isMiniBatch())

View File

@ -21,6 +21,7 @@ import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.CacheMode; import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.convolution.upsampling.Upsampling2D; import org.deeplearning4j.nn.layers.convolution.upsampling.Upsampling2D;
@ -71,7 +72,11 @@ public class RepeatVector extends AbstractLayer<org.deeplearning4j.nn.conf.layer
INDArray outEpsilon; INDArray outEpsilon;
try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATION_GRAD)){ try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATION_GRAD)){
outEpsilon = epsilon.sum(2); if (layerConf().getDataFormat() == RNNFormat.NCW) {
outEpsilon = epsilon.sum(2);
}else{
outEpsilon = epsilon.sum(1);
}
} }
Gradient gradient = new DefaultGradient(); Gradient gradient = new DefaultGradient();
@ -99,10 +104,22 @@ public class RepeatVector extends AbstractLayer<org.deeplearning4j.nn.conf.layer
long miniBatch = input.size(0); long miniBatch = input.size(0);
long size = input.size(1); long size = input.size(1);
INDArray output = input.reshape(miniBatch, size, 1).castTo(dataType); if (getDataFormat() == RNNFormat.NCW) {
try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATIONS)) { INDArray output = input.reshape(miniBatch, size, 1).castTo(dataType);
return output.repeat(2, (long) getN()); try (MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATIONS)) {
return output.repeat(2, (long) getN());
}
} }
else{
INDArray output = input.reshape(miniBatch, 1, size).castTo(dataType);
try (MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATIONS)) {
return output.repeat(1, (long) getN());
}
}
}
public RNNFormat getDataFormat(){
return layerConf().getDataFormat();
} }
@Override @Override

View File

@ -20,6 +20,8 @@ import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.layers.Convolution1D;
import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.params.ConvolutionParamInitializer; import org.deeplearning4j.nn.params.ConvolutionParamInitializer;
@ -74,6 +76,10 @@ public class Convolution1DLayer extends ConvolutionLayer {
+ Arrays.toString(epsilon.shape()) + Arrays.toString(epsilon.shape())
+ ". Expected rank 3 array with shape [minibatchSize, features, length]. " + layerId()); + ". Expected rank 3 array with shape [minibatchSize, features, length]. " + layerId());
if (getRnnDataFormat() == RNNFormat.NWC){
epsilon = epsilon.permute(0, 2, 1);
this.input = input.permute(0, 2, 1);
}
if(maskArray != null){ if(maskArray != null){
INDArray maskOut = feedForwardMaskArray(maskArray, MaskState.Active, (int)epsilon.size(0)).getFirst(); INDArray maskOut = feedForwardMaskArray(maskArray, MaskState.Active, (int)epsilon.size(0)).getFirst();
Preconditions.checkState(epsilon.size(0) == maskOut.size(0) && epsilon.size(2) == maskOut.size(1), Preconditions.checkState(epsilon.size(0) == maskOut.size(0) && epsilon.size(2) == maskOut.size(1),
@ -125,6 +131,10 @@ public class Convolution1DLayer extends ConvolutionLayer {
retGradient.setGradientFor(ConvolutionParamInitializer.BIAS_KEY, gradientViews.get(ConvolutionParamInitializer.BIAS_KEY)); retGradient.setGradientFor(ConvolutionParamInitializer.BIAS_KEY, gradientViews.get(ConvolutionParamInitializer.BIAS_KEY));
} }
retGradient.setGradientFor(ConvolutionParamInitializer.WEIGHT_KEY, gradientViews.get(ConvolutionParamInitializer.WEIGHT_KEY), 'c'); retGradient.setGradientFor(ConvolutionParamInitializer.WEIGHT_KEY, gradientViews.get(ConvolutionParamInitializer.WEIGHT_KEY), 'c');
if (getRnnDataFormat() == RNNFormat.NWC){
epsOut = epsOut.permute(0, 2, 1);
this.input = input.permute(0, 2, 1);
}
return new Pair<>(retGradient, epsOut); return new Pair<>(retGradient, epsOut);
} }
@ -140,7 +150,10 @@ public class Convolution1DLayer extends ConvolutionLayer {
// remove singleton fourth dimension from input and current epsilon // remove singleton fourth dimension from input and current epsilon
epsNext = epsNext.reshape(epsNext.size(0), epsNext.size(1), epsNext.size(2)); epsNext = epsNext.reshape(epsNext.size(0), epsNext.size(1), epsNext.size(2));
input = origInput; input = origInput;
if (getRnnDataFormat() == RNNFormat.NWC){
epsNext = epsNext.permute(0, 2, 1);
this.input = input.permute(0, 2, 1);
}
return new Pair<>(gradientEpsNext.getFirst(), epsNext); return new Pair<>(gradientEpsNext.getFirst(), epsNext);
} }
@ -185,7 +198,8 @@ public class Convolution1DLayer extends ConvolutionLayer {
.s(c.getStride()[0]) .s(c.getStride()[0])
.d(c.getDilation()[0]) .d(c.getDilation()[0])
.p(c.getPadding()[0]) .p(c.getPadding()[0])
.dataFormat(Conv1DConfig.NCW) .dataFormat((((org.deeplearning4j.nn.conf.layers.Convolution1DLayer)
layerConf()).getRnnDataFormat()== RNNFormat.NCW)?Conv1DConfig.NCW: Conv1DConfig.NCW)
.paddingMode(PaddingMode.CAUSAL) .paddingMode(PaddingMode.CAUSAL)
.build(); .build();
INDArray w = getParam(ConvolutionParamInitializer.WEIGHT_KEY); INDArray w = getParam(ConvolutionParamInitializer.WEIGHT_KEY);
@ -209,6 +223,9 @@ public class Convolution1DLayer extends ConvolutionLayer {
@Override @Override
public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr){ public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr){
if (getRnnDataFormat() == RNNFormat.NWC){
this.input = input.permute(0, 2, 1);
}
INDArray act4d = super.activate(training, workspaceMgr); INDArray act4d = super.activate(training, workspaceMgr);
INDArray act3d = act4d.reshape(act4d.size(0), act4d.size(1), act4d.size(2)); INDArray act3d = act4d.reshape(act4d.size(0), act4d.size(1), act4d.size(2));
@ -219,6 +236,10 @@ public class Convolution1DLayer extends ConvolutionLayer {
act3d.shape(), maskOut.shape()); act3d.shape(), maskOut.shape());
Broadcast.mul(act3d, maskOut, act3d, 0, 2); Broadcast.mul(act3d, maskOut, act3d, 0, 2);
} }
if (getRnnDataFormat() == RNNFormat.NWC){
this.input = input.permute(0, 2, 1);
act3d = act3d.permute(0, 2, 1);
}
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, act3d); //Should be zero copy most of the time return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, act3d); //Should be zero copy most of the time
} }
@ -231,4 +252,8 @@ public class Convolution1DLayer extends ConvolutionLayer {
layerConf().getConvolutionMode()); layerConf().getConvolutionMode());
return new Pair<>(reduced, currentMaskState); return new Pair<>(reduced, currentMaskState);
} }
private RNNFormat getRnnDataFormat(){
return ((org.deeplearning4j.nn.conf.layers.Convolution1DLayer) layerConf()).getRnnDataFormat();
}
} }

View File

@ -160,7 +160,8 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
Pair<INDArray, INDArray> p = preOutput4d(true, true, workspaceMgr); Pair<INDArray, INDArray> p = preOutput4d(true, true, workspaceMgr);
INDArray z = p.getFirst(); INDArray z = p.getFirst();
if(layerConf().getCnn2dDataFormat() != CNN2DFormat.NCHW){ CNN2DFormat f = layerConf().getCnn2dDataFormat();
if(f != CNN2DFormat.NCHW){
z = z.permute(0,3,1,2); //NHWC to NCHW z = z.permute(0,3,1,2); //NHWC to NCHW
} }
delta = afn.backprop(z, epsilon).getFirst(); //TODO handle activation function params delta = afn.backprop(z, epsilon).getFirst(); //TODO handle activation function params

View File

@ -20,6 +20,7 @@ import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer; import org.deeplearning4j.nn.layers.BaseLayer;
@ -64,8 +65,14 @@ public class EmbeddingSequenceLayer extends BaseLayer<org.deeplearning4j.nn.conf
INDArray z = preOutput(true, workspaceMgr); INDArray z = preOutput(true, workspaceMgr);
INDArray delta = layerConf().getActivationFn().backprop(z, epsilon).getFirst(); //Shape: [mb, vector, seqLength] INDArray delta = layerConf().getActivationFn().backprop(z, epsilon).getFirst(); //Shape: [mb, vector, seqLength]
boolean ncw = layerConf().getOutputFormat() == RNNFormat.NCW;
if (maskArray != null) { if (maskArray != null) {
delta = Broadcast.mul(delta, maskArray, delta, 0, 2); if(ncw){
delta = Broadcast.mul(delta, maskArray, delta, 0, 2);
} else {
delta = Broadcast.mul(delta, maskArray, delta, 0, 1);
}
} }
int inputLength = layerConf().getInputLength(); int inputLength = layerConf().getInputLength();
@ -76,7 +83,10 @@ public class EmbeddingSequenceLayer extends BaseLayer<org.deeplearning4j.nn.conf
delta = delta.dup('c'); delta = delta.dup('c');
} }
delta = delta.permute(0, 2, 1); //From [minibatch, nOut, length] to [minibatch, length, nOut] if(ncw){
delta = delta.permute(0, 2, 1); //From [minibatch, nOut, length] to [minibatch, length, nOut]
}
delta = delta.reshape('c',inputLength * numSamples, nOut); delta = delta.reshape('c',inputLength * numSamples, nOut);
INDArray weightGradients = gradientViews.get(DefaultParamInitializer.WEIGHT_KEY); INDArray weightGradients = gradientViews.get(DefaultParamInitializer.WEIGHT_KEY);
@ -159,7 +169,10 @@ public class EmbeddingSequenceLayer extends BaseLayer<org.deeplearning4j.nn.conf
} }
val shape = new long[]{minibatch, inputLength, nOut}; val shape = new long[]{minibatch, inputLength, nOut};
INDArray ret = rows.reshape('c', shape).permute(0, 2, 1); INDArray ret = rows.reshape('c', shape);
if(layerConf().getOutputFormat() == RNNFormat.NCW){
ret = ret.permute(0, 2, 1); //[minibatch, seqLen, nOut] -> [minibatch, nOut, seqLen] i.e., NWC -> NCW
}
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, ret); return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, ret);
} }
@ -177,8 +190,14 @@ public class EmbeddingSequenceLayer extends BaseLayer<org.deeplearning4j.nn.conf
" 2 (when input is rank 3, shape [mb,1,tsLength]). Input shape: " + Arrays.toString(input.shape()) + " 2 (when input is rank 3, shape [mb,1,tsLength]). Input shape: " + Arrays.toString(input.shape()) +
", mask shape: " + Arrays.toString(maskArray.shape())); ", mask shape: " + Arrays.toString(maskArray.shape()));
} }
//Returned array: rank 3, shape [mb, vector, seqLength]. mask shape: [mb, seqLength] boolean ncw = layerConf().getOutputFormat() == RNNFormat.NCW;
Broadcast.mul(ret, maskArray.castTo(ret.dataType()), ret, 0, 2); if(ncw){
//Returned array: rank 3, shape [mb, vector, seqLength]. mask shape: [mb, seqLength]
Broadcast.mul(ret, maskArray.castTo(ret.dataType()), ret, 0, 2);
} else {
//Returned array: rank 3, shape [mb, seqLength, vector]. mask shape: [mb, seqLength]
Broadcast.mul(ret, maskArray.castTo(ret.dataType()), ret, 0, 1);
}
} }
return ret; return ret;
} }

View File

@ -20,11 +20,13 @@ import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.CacheMode; import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.LayerHelper; import org.deeplearning4j.nn.layers.LayerHelper;
import org.deeplearning4j.nn.layers.mkldnn.BaseMKLDNNHelper; import org.deeplearning4j.nn.layers.mkldnn.BaseMKLDNNHelper;
import org.deeplearning4j.nn.layers.mkldnn.MKLDNNLSTMHelper; import org.deeplearning4j.nn.layers.mkldnn.MKLDNNLSTMHelper;
import org.deeplearning4j.nn.params.LSTMParamInitializer; import org.deeplearning4j.nn.params.LSTMParamInitializer;
import org.deeplearning4j.util.TimeSeriesUtils;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -152,6 +154,14 @@ public class LSTM extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.layers.L
assertInputSet(false); assertInputSet(false);
Preconditions.checkState(input.rank() == 3, Preconditions.checkState(input.rank() == 3,
"3D input expected to RNN layer expected, got " + input.rank()); "3D input expected to RNN layer expected, got " + input.rank());
boolean nwc = TimeSeriesUtils.getFormatFromRnnLayer(layerConf()) == RNNFormat.NWC;
INDArray origInput = input;
if(nwc){
input = permuteIfNWC(input);
}
applyDropOutIfNecessary(training, workspaceMgr); applyDropOutIfNecessary(training, workspaceMgr);
//TODO LSTM cache mode is disabled for now - not passing all tests //TODO LSTM cache mode is disabled for now - not passing all tests
@ -166,7 +176,6 @@ public class LSTM extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.layers.L
final INDArray recurrentWeights = getParamWithNoise(LSTMParamInitializer.RECURRENT_WEIGHT_KEY, training, workspaceMgr); //Shape: [hiddenLayerSize,4*hiddenLayerSize+3]; order: [wI,wF,wO,wG,wFF,wOO,wGG] final INDArray recurrentWeights = getParamWithNoise(LSTMParamInitializer.RECURRENT_WEIGHT_KEY, training, workspaceMgr); //Shape: [hiddenLayerSize,4*hiddenLayerSize+3]; order: [wI,wF,wO,wG,wFF,wOO,wGG]
final INDArray inputWeights = getParamWithNoise(LSTMParamInitializer.INPUT_WEIGHT_KEY, training, workspaceMgr); //Shape: [n^(L-1),4*hiddenLayerSize]; order: [wi,wf,wo,wg] final INDArray inputWeights = getParamWithNoise(LSTMParamInitializer.INPUT_WEIGHT_KEY, training, workspaceMgr); //Shape: [n^(L-1),4*hiddenLayerSize]; order: [wi,wf,wo,wg]
final INDArray biases = getParamWithNoise(LSTMParamInitializer.BIAS_KEY, training, workspaceMgr); //by row: IFOG //Shape: [4,hiddenLayerSize]; order: [bi,bf,bo,bg]^T final INDArray biases = getParamWithNoise(LSTMParamInitializer.BIAS_KEY, training, workspaceMgr); //by row: IFOG //Shape: [4,hiddenLayerSize]; order: [bi,bf,bo,bg]^T
INDArray input = permuteIfNWC(this.input);
FwdPassReturn fwd = LSTMHelpers.activateHelper(this, this.conf, this.layerConf().getGateActivationFn(), FwdPassReturn fwd = LSTMHelpers.activateHelper(this, this.conf, this.layerConf().getGateActivationFn(),
input, recurrentWeights, inputWeights, biases, training, prevOutputActivations, input, recurrentWeights, inputWeights, biases, training, prevOutputActivations,
prevMemCellState, (training && cacheMode != CacheMode.NONE) || forBackprop, true, prevMemCellState, (training && cacheMode != CacheMode.NONE) || forBackprop, true,
@ -178,6 +187,11 @@ public class LSTM extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.layers.L
if (training && cacheMode != CacheMode.NONE) { if (training && cacheMode != CacheMode.NONE) {
cachedFwdPass = fwd; cachedFwdPass = fwd;
} }
if(nwc){
input = origInput;
}
return fwd; return fwd;
} }

View File

@ -61,11 +61,8 @@ public class LastTimeStepLayer extends BaseWrapperLayer {
@Override @Override
public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
long[] newEpsShape = origOutputShape; long[] newEpsShape = origOutputShape;
boolean nwc = (underlying instanceof BaseRecurrentLayer &&
((BaseRecurrentLayer) underlying).getDataFormat() == RNNFormat.NWC)|| boolean nwc = TimeSeriesUtils.getFormatFromRnnLayer(underlying.conf().getLayer()) == RNNFormat.NWC;
(underlying instanceof MaskZeroLayer && ((MaskZeroLayer)underlying).getUnderlying() instanceof
BaseRecurrentLayer && ((BaseRecurrentLayer)((MaskZeroLayer)underlying).getUnderlying()).getDataFormat()
== RNNFormat.NWC);
INDArray newEps = Nd4j.create(epsilon.dataType(), newEpsShape, 'f'); INDArray newEps = Nd4j.create(epsilon.dataType(), newEpsShape, 'f');
if(lastTimeStepIdxs == null){ if(lastTimeStepIdxs == null){
//no mask case //no mask case

View File

@ -58,7 +58,8 @@ public class RnnOutputLayer extends BaseOutputLayer<org.deeplearning4j.nn.conf.l
"Input is not rank 3. RnnOutputLayer expects rank 3 input with shape [minibatch, layerInSize, sequenceLength]." + "Input is not rank 3. RnnOutputLayer expects rank 3 input with shape [minibatch, layerInSize, sequenceLength]." +
" Got input with rank " + input.rank() + " and shape " + Arrays.toString(input.shape()) + " - " + layerId()); " Got input with rank " + input.rank() + " and shape " + Arrays.toString(input.shape()) + " - " + layerId());
} }
int td = (layerConf().getRnnDataFormat()==RNNFormat.NCW)? 2: 1; RNNFormat format = layerConf().getRnnDataFormat();
int td = (format == RNNFormat.NCW) ? 2 : 1;
Preconditions.checkState(labels.rank() == 3, "Expected rank 3 labels array, got label array with shape %ndShape", labels); Preconditions.checkState(labels.rank() == 3, "Expected rank 3 labels array, got label array with shape %ndShape", labels);
Preconditions.checkState(input.size(td) == labels.size(td), "Sequence lengths do not match for RnnOutputLayer input and labels:" + Preconditions.checkState(input.size(td) == labels.size(td), "Sequence lengths do not match for RnnOutputLayer input and labels:" +
"Arrays should be rank 3 with shape [minibatch, size, sequenceLength] - mismatch on dimension 2 (sequence length) - input=%ndShape vs. label=%ndShape", input, labels); "Arrays should be rank 3 with shape [minibatch, size, sequenceLength] - mismatch on dimension 2 (sequence length) - input=%ndShape vs. label=%ndShape", input, labels);

View File

@ -118,6 +118,7 @@ public class ConvolutionParamInitializer implements ParamInitializer {
params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams)); params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams));
conf.addVariable(WEIGHT_KEY); conf.addVariable(WEIGHT_KEY);
conf.addVariable(BIAS_KEY); conf.addVariable(BIAS_KEY);
conf.addVariable(BIAS_KEY);
} else { } else {
INDArray weightView = paramsView; INDArray weightView = paramsView;
params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams)); params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams));

View File

@ -34,6 +34,7 @@ import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.stats.StatsListener; import org.deeplearning4j.ui.stats.StatsListener;
import org.deeplearning4j.ui.storage.InMemoryStatsStorage; import org.deeplearning4j.ui.storage.InMemoryStatsStorage;
import org.junit.Before; import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
@ -53,7 +54,7 @@ import static org.junit.Assert.*;
/** /**
* @author Tamas Fenyvesi * @author Tamas Fenyvesi
*/ */
@Slf4j @Slf4j @Ignore //https://github.com/eclipse/deeplearning4j/issues/8891
public class TestVertxUIMultiSession extends BaseDL4JTest { public class TestVertxUIMultiSession extends BaseDL4JTest {
@Before @Before