tf.keras import test and fixes (#347)

* merge conf

* merge conf

* tfkeras tests

* parameterized tests

* rename

* cuda versions

* jccp versions

* 'updates'

* updates

* rnn+mlp passing

* repeat

* updates

* tests

* Update pom.xml

* Update pom.xml

* rem print

* cnn1d model conversion fixed

* cnn1d activate fixed

* cnn1d outptut shape fix

* cnn1d bprop fix

* cnn1d stack fix

* KerasModelEndToEndTest - Remove permutes for NWC and NHWC format tests

Signed-off-by: Alex Black <blacka101@gmail.com>

* Fixes and update test - input shapes (NCHW -> NHWC input)

Signed-off-by: Alex Black <blacka101@gmail.com>

* Ignore for known bad tests

Signed-off-by: Alex Black <blacka101@gmail.com>

* Multiple fixes - MergeVertex, CNN1D layers, etc

Signed-off-by: Alex Black <blacka101@gmail.com>

* Fix issue with RNN/FF preprocessors, time distributed etc with NWC format

Signed-off-by: Alex Black <blacka101@gmail.com>

* LSTM NWC dropout fix

Signed-off-by: Alex Black <blacka101@gmail.com>

* Add sequence embedding layer NWC support (configurable output format)

Signed-off-by: Alex Black <blacka101@gmail.com>

* Fix expected shape in a couple of tests - NWC expected

Signed-off-by: Alex Black <blacka101@gmail.com>

* Fix EmbeddingSequenceLayer backprop for NWC output case + add gradient checks

Signed-off-by: Alex Black <blacka101@gmail.com>

* CnnToFeedForwardPreprocessor: align with Keras/TF; fix Keras reshape/flatten

Signed-off-by: Alex Black <blacka101@gmail.com>

* Update ConvDataFormatTests to match new reshape behaviour

Signed-off-by: Alex Black <blacka101@gmail.com>

* Switch hard-coded path to ResourceUtils.listClassPathfiles for TestTFKerasModelImport

Signed-off-by: Alex Black <blacka101@gmail.com>

* TestUtils fix

Signed-off-by: Alex Black <blacka101@gmail.com>

* Fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* Fix JSON serde issue with data formats

Signed-off-by: Alex Black <blacka101@gmail.com>

* Fix for input dtype inference; fix 2 tests

Signed-off-by: Alex Black <blacka101@gmail.com>

* Test fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* #8891 Ignore for TestVertxUIMultiSession until fixed

Signed-off-by: Alex Black <blacka101@gmail.com>

* Restore but deprecate TensorFlowCnnToFeedForwardPreProcessor for older zoo models

Signed-off-by: Alex Black <blacka101@gmail.com>

* Ignore for deprecated preprocessor in DTypeTests

Signed-off-by: Alex Black <blacka101@gmail.com>

* Remove debug printlns

Signed-off-by: Alex Black <blacka101@gmail.com>

Co-authored-by: Alex Black <blacka101@gmail.com>
master
Fariz Rahman 2020-04-28 14:31:09 +04:00 committed by GitHub
parent b9d5f1645b
commit 4cb87a94e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
67 changed files with 1336 additions and 438 deletions

View File

@ -20,6 +20,7 @@ package org.datavec.python;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.bytedeco.cpython.global.python;
import org.bytedeco.numpy.global.numpy;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -343,6 +344,19 @@ public class PythonExecutioner {
if (path == null) {
log.info("Setting python default path");
File[] packages = numpy.cachePackages();
//// TODO: fix in javacpp
File sitePackagesWindows = new File(python.cachePackage(), "site-packages");
File[] packages2 = new File[packages.length + 1];
for (int i = 0;i < packages.length; i++){
//System.out.println(packages[i].getAbsolutePath());
packages2[i] = packages[i];
}
packages2[packages.length] = sitePackagesWindows;
//System.out.println(sitePackagesWindows.getAbsolutePath());
packages = packages2;
//////////
Py_SetPath(packages);
} else {
log.info("Setting python path " + path);

View File

@ -0,0 +1,132 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.python;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.bytedeco.javacpp.Loader;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
@Slf4j
public class PythonProcess {
private static String pythonExecutable = Loader.load(org.bytedeco.cpython.python.class);
public static String runAndReturn(String... arguments)throws IOException, InterruptedException{
String[] allArgs = new String[arguments.length + 1];
for (int i = 0; i < arguments.length; i++){
allArgs[i + 1] = arguments[i];
}
allArgs[0] = pythonExecutable;
log.info("Executing command: " + Arrays.toString(allArgs));
ProcessBuilder pb = new ProcessBuilder(allArgs);
Process process = pb.start();
String out = IOUtils.toString(process.getInputStream(), StandardCharsets.UTF_8);
process.waitFor();
return out;
}
public static void run(String... arguments)throws IOException, InterruptedException{
String[] allArgs = new String[arguments.length + 1];
for (int i = 0; i < arguments.length; i++){
allArgs[i + 1] = arguments[i];
}
allArgs[0] = pythonExecutable;
log.info("Executing command: " + Arrays.toString(allArgs));
ProcessBuilder pb = new ProcessBuilder(allArgs);
pb.inheritIO().start().waitFor();
}
public static void pipInstall(String packageName) throws PythonException{
try{
run("-m", "pip", "install", packageName);
}catch(Exception e){
throw new PythonException("Error installing package " + packageName, e);
}
}
public static void pipInstall(String packageName, String version) throws PythonException{
pipInstall(packageName + "==" + version);
}
public static void pipUninstall(String packageName) throws PythonException{
try{
run("-m", "pip", "uninstall", packageName);
}catch(Exception e){
throw new PythonException("Error uninstalling package " + packageName, e);
}
}
public static void pipInstallFromGit(String gitRepoUrl) throws PythonException{
if (!gitRepoUrl.contains("://")){
gitRepoUrl = "git://" + gitRepoUrl;
}
try{
run("-m", "pip", "install", "git+", gitRepoUrl);
}catch(Exception e){
throw new PythonException("Error installing package from " + gitRepoUrl, e);
}
}
public static String getPackageVersion(String packageName) throws PythonException{
String out;
try{
out = runAndReturn("-m", "pip", "show", packageName);
} catch (Exception e){
throw new PythonException("Error finding version for package " + packageName, e);
}
if (!out.contains("Version: ")){
throw new PythonException("Can't find package " + packageName);
}
String pkgVersion = out.split("Version: ")[1].split(System.lineSeparator())[0];
return pkgVersion;
}
public static boolean isPackageInstalled(String packageName)throws PythonException{
try{
String out = runAndReturn("-m", "pip", "show", packageName);
return !out.isEmpty();
}catch (Exception e){
throw new PythonException("Error checking if package is installed: " +packageName, e);
}
}
public static void pipInstallFromRequirementsTxt(String path) throws PythonException{
try{
run("-m", "pip", "install","-r", path);
}catch (Exception e){
throw new PythonException("Error installing packages from " + path, e);
}
}
public static void pipInstallFromSetupScript(String path, boolean inplace) throws PythonException{
try{
run(path, inplace?"develop":"install");
}catch (Exception e){
throw new PythonException("Error installing package from " + path, e);
}
}
}

View File

@ -0,0 +1,144 @@
package org.datavec.python.keras;
import org.datavec.python.Python;
import org.datavec.python.PythonException;
import org.datavec.python.PythonObject;
import org.datavec.python.PythonProcess;
import org.nd4j.linalg.api.ndarray.INDArray;
public class Model {
private PythonObject pyModel;
private static PythonObject installAndImportTF() throws PythonException{
if (!PythonProcess.isPackageInstalled("tensorflow")){
PythonProcess.pipInstall("tensorflow");
}
return Python.importModule("tensorflow");
}
private static PythonObject getKerasModule() throws PythonException{
PythonObject tf = installAndImportTF();
PythonObject keras = tf.attr("keras");
tf.del();
return keras;
}
private static PythonObject loadModel(String s) throws PythonException{
PythonObject models = getKerasModule().attr("models");
PythonObject loadModelF = models.attr("load_model");
PythonObject model = loadModelF.call(s);
models.del();
loadModelF.del();
return model;
}
public Model(String path) throws PythonException{
pyModel = loadModel(path);
}
public INDArray[] predict(INDArray... inputs) throws PythonException{
PythonObject predictF = pyModel.attr("predict");
PythonObject inputList = new PythonObject(inputs);
PythonObject pyOut = predictF.call(inputList);
INDArray[] out;
if (Python.isinstance(pyOut, Python.listType())){
out = new INDArray[Python.len(pyOut).toInt()];
for(int i = 0; i < out.length; i++){
out[i] = pyOut.get(i).toNumpy().getNd4jArray();
}
}
else{
out = new INDArray[]{
pyOut.toNumpy().getNd4jArray()};
}
predictF.del();
inputList.del();
pyOut.del();
return out;
}
public int numInputs(){
PythonObject inputs = pyModel.attr("inputs");
PythonObject pyNumInputs = Python.len(inputs);
int ret = pyNumInputs.toInt();
inputs.del();
pyNumInputs.del();
return ret;
}
public int numOutputs(){
PythonObject outputs = pyModel.attr("outputs");
PythonObject pyNumOutputs = Python.len(outputs);
int ret = pyNumOutputs.toInt();
outputs.del();
pyNumOutputs.del();
return ret;
}
public long[][] inputShapes(){
long[][] ret = new long[numInputs()][];
for (int i = 0; i < ret.length; i++){
ret[i] = inputShapeAt(i);
}
return ret;
}
public long[][] outputShapes(){
long[][] ret = new long[numOutputs()][];
for (int i = 0; i < ret.length; i++){
ret[i] = outputShapeAt(i);
}
return ret;
}
public long[] inputShapeAt(int input){
PythonObject inputs = pyModel.attr("inputs");
PythonObject tensor = inputs.get(input);
PythonObject tensorShape = tensor.attr("shape");
PythonObject shapeList = Python.list(tensorShape);
PythonObject pyNdim = Python.len(shapeList);
int ndim = pyNdim.toInt();
long[] shape = new long[ndim];
for(int i = 0; i < shape.length; i++){
PythonObject pyDim = shapeList.get(i);
if (pyDim == null || !Python.isinstance(pyDim, Python.intType())){
shape[i] = -1;
}
else{
shape[i] = pyDim.toLong();
}
}
pyNdim.del();
shapeList.del();
tensorShape.del();
tensor.del();
inputs.del();
return shape;
}
public long[] outputShapeAt(int output){
PythonObject inputs = pyModel.attr("outputs");
PythonObject tensor = inputs.get(output);
PythonObject tensorShape = tensor.attr("shape");
PythonObject shapeList = Python.list(tensorShape);
PythonObject pyNdim = Python.len(shapeList);
int ndim = pyNdim.toInt();
long[] shape = new long[ndim];
for(int i = 0; i < shape.length; i++){
PythonObject pyDim = shapeList.get(i);
if (pyDim == null || !Python.isinstance(pyDim, Python.intType())){
shape[i] = -1;
}
else{
shape[i] = pyDim.toLong();
}
}
pyNdim.del();
shapeList.del();
tensorShape.del();
tensor.del();
inputs.del();
return shape;
}
}

View File

@ -20,6 +20,7 @@ import org.apache.commons.compress.utils.IOUtils;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
@ -153,11 +154,22 @@ public class TestUtils {
return randomOneHotTimeSeries(minibatch, outSize, tsLength, new Random(rngSeed));
}
public static INDArray randomOneHotTimeSeries(int minibatch, int outSize, int tsLength, Random rng){
INDArray out = Nd4j.create(new int[]{minibatch, outSize, tsLength}, 'f');
public static INDArray randomOneHotTimeSeries(int minibatch, int outSize, int tsLength, Random rng) {
return randomOneHotTimeSeries(RNNFormat.NCW, minibatch, outSize, tsLength, rng);
}
public static INDArray randomOneHotTimeSeries(RNNFormat format, int minibatch, int outSize, int tsLength, Random rng){
boolean ncw = format == RNNFormat.NCW;
long[] shape = ncw ? new long[]{minibatch, outSize, tsLength} : new long[]{minibatch, tsLength, outSize};
char order = ncw ? 'f' : 'c';
INDArray out = Nd4j.create(DataType.FLOAT, shape, order);
for( int i=0; i<minibatch; i++ ){
for( int j=0; j<tsLength; j++ ){
out.putScalar(i, rng.nextInt(outSize), j, 1.0);
if(ncw){
out.putScalar(i, rng.nextInt(outSize), j, 1.0);
} else {
out.putScalar(i, j, rng.nextInt(outSize), 1.0);
}
}
}
return out;

View File

@ -24,6 +24,7 @@ import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
import org.deeplearning4j.nn.conf.layers.*;
@ -560,75 +561,81 @@ public class GradientCheckTests extends BaseDL4JTest {
public void testEmbeddingSequenceLayer(){
Nd4j.getRandom().setSeed(12345);
for(boolean maskArray : new boolean[]{false, true}){
for(int inputRank : new int[]{2,3}) {
for(RNNFormat seqOutputFormat : RNNFormat.values()) {
for (boolean maskArray : new boolean[]{false, true}) {
for (int inputRank : new int[]{2, 3}) {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.seed(12345)
.updater(new NoOp())
.weightInit(new NormalDistribution(0, 1))
.list()
.layer(new EmbeddingSequenceLayer.Builder()
.nIn(8)
.nOut(4)
.build())
.layer(new RnnOutputLayer.Builder().nIn(4).nOut(3).activation(Activation.TANH)
.lossFunction(LossFunction.MSE).build())
.build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE)
.seed(12345)
.updater(new NoOp())
.weightInit(new NormalDistribution(0, 1))
.list()
.layer(new EmbeddingSequenceLayer.Builder()
.nIn(8)
.nOut(4)
.outputDataFormat(seqOutputFormat)
.build())
.layer(new RnnOutputLayer.Builder().nIn(4).nOut(3).activation(Activation.TANH)
.dataFormat(seqOutputFormat)
.lossFunction(LossFunction.MSE).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
INDArray in = Transforms.floor(Nd4j.rand(3, 6).muli(8)); //Integers 0 to 7 inclusive
INDArray label = Nd4j.rand(new int[]{3, 3, 6});
boolean ncw = seqOutputFormat == RNNFormat.NCW;
if(inputRank == 3){
//Reshape from [3,6] to [3,1,6]
in = in.reshape('c', 3, 1, 6);
}
INDArray in = Transforms.floor(Nd4j.rand(3, 6).muli(8)); //Integers 0 to 7 inclusive
INDArray label = Nd4j.rand(DataType.FLOAT, ncw ? new int[]{3, 3, 6} : new int[]{3,6,3});
INDArray fMask = null;
if (maskArray) {
fMask = Nd4j.create(new double[][]{{1, 1, 1, 1, 1, 1},
{1, 1, 0, 0, 0, 0},
{1, 0, 0, 0, 0, 0}});
}
String msg = "mask=" + maskArray + ", inputRank=" + inputRank;
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in)
.labels(label).inputMask(fMask));
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(net);
//Also: if mask is present, double check that the masked steps don't impact score
if (maskArray) {
DataSet ds = new DataSet(in, label, fMask, null);
double score = net.score(ds);
if(inputRank == 2){
in.putScalar(1, 2, 0);
in.putScalar(2, 1, 0);
in.putScalar(2, 2, 0);
} else {
in.putScalar(1, 0, 2, 0);
in.putScalar(2, 0, 1, 0);
in.putScalar(2, 0, 2, 0);
if (inputRank == 3) {
//Reshape from [3,6] to [3,1,6]
in = in.reshape('c', 3, 1, 6);
}
double score2 = net.score(ds);
assertEquals(score, score2, 1e-6);
if(inputRank == 2){
in.putScalar(1, 2, 1);
in.putScalar(2, 1, 1);
in.putScalar(2, 2, 1);
} else {
in.putScalar(1, 0, 2, 1);
in.putScalar(2, 0, 1, 1);
in.putScalar(2, 0, 2, 1);
INDArray fMask = null;
if (maskArray) {
fMask = Nd4j.create(new double[][]{{1, 1, 1, 1, 1, 1},
{1, 1, 0, 0, 0, 0},
{1, 0, 0, 0, 0, 0}});
}
String msg = "mask=" + maskArray + ", inputRank=" + inputRank;
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in)
.labels(label).inputMask(fMask));
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(net);
//Also: if mask is present, double check that the masked steps don't impact score
if (maskArray) {
DataSet ds = new DataSet(in, label, fMask, null);
double score = net.score(ds);
if (inputRank == 2) {
in.putScalar(1, 2, 0);
in.putScalar(2, 1, 0);
in.putScalar(2, 2, 0);
} else {
in.putScalar(1, 0, 2, 0);
in.putScalar(2, 0, 1, 0);
in.putScalar(2, 0, 2, 0);
}
double score2 = net.score(ds);
assertEquals(score, score2, 1e-6);
if (inputRank == 2) {
in.putScalar(1, 2, 1);
in.putScalar(2, 1, 1);
in.putScalar(2, 2, 1);
} else {
in.putScalar(1, 0, 2, 1);
in.putScalar(2, 0, 1, 1);
in.putScalar(2, 0, 2, 1);
}
double score3 = net.score(ds);
assertEquals(score, score3, 1e-6);
}
double score3 = net.score(ds);
assertEquals(score, score3, 1e-6);
}
}
}

View File

@ -21,9 +21,7 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.distribution.GaussianDistribution;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
@ -341,104 +339,112 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
@Test
public void testCnnDepthMerge() {
Nd4j.getRandom().setSeed(12345);
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
.dataType(DataType.DOUBLE)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.dist(new NormalDistribution(0, 0.1))
.updater(new NoOp()).graphBuilder().addInputs("input")
.addLayer("l1", new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).padding(0, 0)
.nIn(2).nOut(2).activation(Activation.TANH).build(), "input")
.addLayer("l2", new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).padding(0, 0)
.nIn(2).nOut(2).activation(Activation.TANH).build(), "input")
.addVertex("merge", new MergeVertex(), "l1", "l2")
.addLayer("outputLayer",
new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(5 * 5 * (2 + 2)).nOut(3)
.build(),
"merge")
.setOutputs("outputLayer")
.inputPreProcessor("outputLayer", new CnnToFeedForwardPreProcessor(5, 5, 4))
.build();
for(CNN2DFormat format : CNN2DFormat.values()) {
ComputationGraph graph = new ComputationGraph(conf);
graph.init();
String msg = "testCnnDepthMerge - " + format;
Random r = new Random(12345);
INDArray input = Nd4j.rand(new int[] {5, 2, 6, 6}); //Order: examples, channels, height, width
INDArray labels = Nd4j.zeros(5, 3);
for (int i = 0; i < 5; i++)
labels.putScalar(new int[] {i, r.nextInt(3)}, 1.0);
Nd4j.getRandom().setSeed(12345);
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
.dataType(DataType.DOUBLE)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.dist(new NormalDistribution(0, 0.1))
.updater(new NoOp()).graphBuilder().addInputs("input")
.addLayer("l1", new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).padding(0, 0)
.nIn(2).nOut(2).activation(Activation.TANH).build(), "input")
.addLayer("l2", new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).padding(0, 0)
.nIn(2).nOut(2).activation(Activation.TANH).build(), "input")
.addVertex("merge", new MergeVertex(), "l1", "l2")
.addLayer("outputLayer",
new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(5 * 5 * (2 + 2)).nOut(3)
.build(),
"merge")
.setOutputs("outputLayer")
.setInputTypes(InputType.convolutional(6, 6, 2, format))
.build();
if (PRINT_RESULTS) {
System.out.println("testCnnDepthMerge()");
ComputationGraph graph = new ComputationGraph(conf);
graph.init();
Random r = new Random(12345);
INDArray input = Nd4j.rand(DataType.DOUBLE, format == CNN2DFormat.NCHW ? new long[]{5,2,6,6} : new long[]{5,6,6,2});
INDArray labels = Nd4j.zeros(5, 3);
for (int i = 0; i < 5; i++)
labels.putScalar(new int[]{i, r.nextInt(3)}, 1.0);
if (PRINT_RESULTS) {
System.out.println(msg);
// for (int j = 0; j < graph.getNumLayers(); j++)
// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
}
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
.labels(new INDArray[]{labels}));
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph);
}
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
.labels(new INDArray[]{labels}));
String msg = "testCnnDepthMerge()";
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph);
}
@Test
public void testRNNWithMerging() {
Nd4j.getRandom().setSeed(12345);
ComputationGraphConfiguration conf =
new NeuralNetConfiguration.Builder().seed(12345)
.dataType(DataType.DOUBLE)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.dist(new UniformDistribution(0.2, 0.6))
.updater(new NoOp()).graphBuilder().addInputs("input")
.setOutputs("out")
.addLayer("lstm1",
new SimpleRnn.Builder().nIn(3).nOut(3)
.activation(Activation.TANH).build(),
"input")
.addLayer("lstm2",
new SimpleRnn.Builder().nIn(3).nOut(3)
.activation(Activation.TANH).build(),
"lstm1")
.addLayer("dense1",
new DenseLayer.Builder().nIn(3).nOut(3)
.activation(Activation.SIGMOID).build(),
"lstm1")
.addLayer("lstm3",
new SimpleRnn.Builder().nIn(3).nOut(3)
.activation(Activation.TANH).build(),
"dense1")
.addVertex("merge", new MergeVertex(), "lstm2", "lstm3")
.addLayer("out", new RnnOutputLayer.Builder().nIn(6).nOut(3)
.activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build(),
"merge")
.inputPreProcessor("dense1", new RnnToFeedForwardPreProcessor())
.inputPreProcessor("lstm3", new FeedForwardToRnnPreProcessor())
.build();
for(RNNFormat format : RNNFormat.values()) {
ComputationGraph graph = new ComputationGraph(conf);
graph.init();
String msg = "testLSTMWithMerging - " + format;
Random r = new Random(12345);
INDArray input = Nd4j.rand(new int[] {2, 3, 4});
INDArray labels = TestUtils.randomOneHotTimeSeries(2, 3, 4);
Nd4j.getRandom().setSeed(12345);
ComputationGraphConfiguration conf =
new NeuralNetConfiguration.Builder().seed(12345)
.dataType(DataType.DOUBLE)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.dist(new UniformDistribution(0.2, 0.6))
.updater(new NoOp()).graphBuilder().addInputs("input")
.setOutputs("out")
.addLayer("lstm1",
new SimpleRnn.Builder().nIn(3).nOut(3)
.activation(Activation.TANH).build(),
"input")
.addLayer("lstm2",
new SimpleRnn.Builder().nIn(3).nOut(3)
.activation(Activation.TANH).build(),
"lstm1")
.addLayer("dense1",
new DenseLayer.Builder().nIn(3).nOut(3)
.activation(Activation.SIGMOID).build(),
"lstm1")
.addLayer("lstm3",
new SimpleRnn.Builder().nIn(3).nOut(3)
.activation(Activation.TANH).build(),
"dense1")
.addVertex("merge", new MergeVertex(), "lstm2", "lstm3")
.addLayer("out", new RnnOutputLayer.Builder().nIn(6).nOut(3)
.activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build(),
"merge")
.setInputTypes(InputType.recurrent(4, format))
.build();
if (PRINT_RESULTS) {
System.out.println("testLSTMWithMerging()");
ComputationGraph graph = new ComputationGraph(conf);
graph.init();
Random r = new Random(12345);
INDArray input = Nd4j.rand(DataType.DOUBLE, format == RNNFormat.NCW ? new long[]{2, 3, 4} : new long[]{2,4,3});
INDArray labels = TestUtils.randomOneHotTimeSeries(format, 2, 3, 4, new Random(12345));
if (PRINT_RESULTS) {
System.out.println(msg);
// for (int j = 0; j < graph.getNumLayers(); j++)
// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
}
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
.labels(new INDArray[]{labels}));
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph);
}
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
.labels(new INDArray[]{labels}));
String msg = "testLSTMWithMerging()";
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(graph);
}
@Test

View File

@ -17,7 +17,9 @@
package org.deeplearning4j.nn.dtypes;
import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
import org.deeplearning4j.nn.conf.preprocessor.*;
import org.deeplearning4j.nn.modelimport.keras.layers.TFOpLayer;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor;
import org.nd4j.shade.guava.collect.ImmutableSet;
import org.nd4j.shade.guava.reflect.ClassPath;
import lombok.extern.slf4j.Slf4j;
@ -51,16 +53,11 @@ import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
import org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer;
import org.deeplearning4j.nn.conf.preprocessor.CnnToRnnPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.ComposableInputPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToCnn3DPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.RnnToCnnPreProcessor;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.util.IdentityLayer;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.WeightInitDistribution;
@ -97,7 +94,8 @@ public class DTypeTests extends BaseDL4JTest {
Pooling2D.class, //Alias for SubsamplingLayer
Convolution2D.class, //Alias for ConvolutionLayer
Pooling1D.class, //Alias for Subsampling1D
Convolution1D.class //Alias for Convolution1DLayer
Convolution1D.class, //Alias for Convolution1DLayer
TensorFlowCnnToFeedForwardPreProcessor.class //Deprecated
));
@Override
@ -1078,7 +1076,7 @@ public class DTypeTests extends BaseDL4JTest {
.addLayer("l", new DenseLayer.Builder().nOut(16).build(), "in")
.addVertex("preproc", new PreprocessorVertex(new FeedForwardToCnn3DPreProcessor(2, 2, 2, 2, true)), "l")
.addVertex("preproc2", new PreprocessorVertex(new PermutePreprocessor(0, 2, 3, 4, 1)), "preproc")
.addVertex("preproc3", new PreprocessorVertex(new ReshapePreprocessor(new long[]{2, 2, 2, 2}, new long[]{16})), "preproc2")
.addVertex("preproc3", new PreprocessorVertex(new ReshapePreprocessor(new long[]{2, 2, 2, 2}, new long[]{16}, false)), "preproc2")
.addLayer("out", new OutputLayer.Builder().nIn(16).nOut(10).build(), "preproc3")
.setInputTypes(InputType.feedForward(5))
.setOutputs("out");
@ -1150,7 +1148,7 @@ public class DTypeTests extends BaseDL4JTest {
case 7:
b.addInputs("in")
.addLayer("1", new ConvolutionLayer.Builder().kernelSize(2, 2).nOut(5).convolutionMode(ConvolutionMode.Same).build(), "in")
.addVertex("2", new PreprocessorVertex(new TensorFlowCnnToFeedForwardPreProcessor(28, 28, 5)), "1")
.addVertex("2", new PreprocessorVertex(new CnnToFeedForwardPreProcessor(28, 28, 5)), "1")
.addLayer("out", new OutputLayer.Builder().nOut(10).build(), "2")
.setOutputs("out")
.setInputTypes(InputType.convolutional(28, 28, 1));

View File

@ -60,7 +60,7 @@ public class TestGraphNodes extends BaseDL4JTest {
@Test
public void testMergeNode() {
Nd4j.getRandom().setSeed(12345);
GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType());
GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType(), 1);
INDArray first = Nd4j.linspace(0, 11, 12, Nd4j.dataType()).reshape(3, 4);
INDArray second = Nd4j.linspace(0, 17, 18, Nd4j.dataType()).reshape(3, 6).addi(100);
@ -82,7 +82,7 @@ public class TestGraphNodes extends BaseDL4JTest {
public void testMergeNodeRNN() {
Nd4j.getRandom().setSeed(12345);
GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType());
GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType(), 1);
INDArray first = Nd4j.linspace(0, 59, 60, Nd4j.dataType()).reshape(3, 4, 5);
INDArray second = Nd4j.linspace(0, 89, 90, Nd4j.dataType()).reshape(3, 6, 5).addi(100);
@ -103,7 +103,7 @@ public class TestGraphNodes extends BaseDL4JTest {
@Test
public void testCnnDepthMerge() {
Nd4j.getRandom().setSeed(12345);
GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType());
GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType(), 1);
INDArray first = Nd4j.linspace(0, 3, 4, Nd4j.dataType()).reshape(1, 1, 2, 2);
INDArray second = Nd4j.linspace(0, 3, 4, Nd4j.dataType()).reshape(1, 1, 2, 2).addi(10);

View File

@ -15,14 +15,13 @@
******************************************************************************/
package org.deeplearning4j.nn.layers.convolution;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.*;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
@ -30,8 +29,12 @@ import org.deeplearning4j.nn.conf.layers.CnnLossLayer;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.ZeroPaddingLayer;
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D;
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.ComposableInputPreProcessor;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@ -516,6 +519,40 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
}
@Test
public void testGlobalPooling() {
try {
for (boolean helpers : new boolean[]{false, true}) {
for (PoolingType pt : PoolingType.values()) {
Nd4j.getRandom().setSeed(12345);
Nd4j.getEnvironment().allowHelpers(helpers);
String msg = helpers ? "With helpers (" + pt + ")" : "No helpers (" + pt + ")";
System.out.println(" --- " + msg + " ---");
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
INDArray labels = TestUtils.randomOneHot(2, 10);
TestCase tc = TestCase.builder()
.msg(msg)
.net1(getGlobalPoolingNet(CNN2DFormat.NCHW, pt, true))
.net2(getGlobalPoolingNet(CNN2DFormat.NCHW, pt, false))
.net3(getGlobalPoolingNet(CNN2DFormat.NHWC, pt, true))
.net4(getGlobalPoolingNet(CNN2DFormat.NHWC, pt, false))
.inNCHW(inNCHW)
.labelsNCHW(labels)
.labelsNHWC(labels)
.testLayerIdx(1)
.build();
testHelper(tc);
}
}
} finally {
Nd4j.getEnvironment().allowHelpers(true);
}
}
private MultiLayerNetwork getConv2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
if (setOnLayerAlso) {
return getNetWithLayer(new ConvolutionLayer.Builder()
@ -735,11 +772,28 @@ public class ConvDataFormatTests extends BaseDL4JTest {
.layer(new OutputLayer.Builder().activation(Activation.SOFTMAX).nOut(10).build())
.setInputType(inputType != null ? inputType : InputType.convolutional(12, 12, 3, format));
if(format == CNN2DFormat.NHWC && !(layer instanceof GlobalPoolingLayer)){
//Add a preprocessor due to the differences in how NHWC and NCHW activations are flattened
//DL4J's flattening behaviour matches Keras (hence TF) for import compatibility
builder.inputPreProcessor(2, new ComposableInputPreProcessor(new NHWCToNCHWPreprocessor(), new CnnToFeedForwardPreProcessor()));
}
MultiLayerNetwork net = new MultiLayerNetwork(builder.build());
net.init();
return net;
}
private MultiLayerNetwork getGlobalPoolingNet(CNN2DFormat format, PoolingType pt, boolean setOnLayerAlso) {
if (setOnLayerAlso) {
return getNetWithLayer(new GlobalPoolingLayer.Builder(pt)
.poolingDimensions(format == CNN2DFormat.NCHW ? new int[]{2,3} : new int[]{1,2})
.build(), format, ConvolutionMode.Same, null);
} else {
return getNetWithLayer(new GlobalPoolingLayer.Builder(pt)
.build(), format, ConvolutionMode.Same, null);
}
}
private MultiLayerNetwork getCnnLossNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm){
NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder()
.seed(12345)
@ -799,8 +853,13 @@ public class ConvDataFormatTests extends BaseDL4JTest {
INDArray l0_4 = tc.net4.feedForward(inNHWC).get(tc.testLayerIdx + 1);
assertEquals(tc.msg, l0_1, l0_2);
assertEquals(tc.msg, l0_1, l0_3.permute(0, 3, 1, 2));
assertEquals(tc.msg, l0_1, l0_4.permute(0, 3, 1, 2));
if(l0_1.rank() == 4) {
assertEquals(tc.msg, l0_1, l0_3.permute(0, 3, 1, 2));
assertEquals(tc.msg, l0_1, l0_4.permute(0, 3, 1, 2));
} else {
assertEquals(tc.msg, l0_1, l0_3);
assertEquals(tc.msg, l0_1, l0_4);
}
INDArray out1 = tc.net1.output(inNCHW);
@ -880,4 +939,36 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
return differs;
}
//Converts NHWC to NCHW activations
@EqualsAndHashCode
private static class NHWCToNCHWPreprocessor implements InputPreProcessor {
@Override
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.permute(0,3,1,2));
}
@Override
public INDArray backprop(INDArray output, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, output.permute(0,2,3,1));
}
@Override
public InputPreProcessor clone() {
return this;
}
@Override
public InputType getOutputType(InputType inputType) {
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
return InputType.convolutional(c.getHeight(), c.getWidth(), c.getChannels(), CNN2DFormat.NCHW);
}
@Override
public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) {
return null;
}
}
}

View File

@ -44,6 +44,7 @@ import org.nd4j.linalg.primitives.Pair;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
@ -217,7 +218,7 @@ public class TestRnnLayers extends BaseDL4JTest {
NeuralNetConfiguration.ListBuilder lb = new NeuralNetConfiguration.Builder()
.list()
.layer(new SimpleRnn.Builder().nIn(5).nOut(5).build());
.layer(new SimpleRnn.Builder().nIn(5).nOut(5).dataFormat(rnnDataFormat).build());
switch (i){
case 0:
@ -235,10 +236,7 @@ public class TestRnnLayers extends BaseDL4JTest {
net.init();
INDArray in = Nd4j.rand(DataType.FLOAT, 3, 5, 5);
INDArray l = TestUtils.randomOneHotTimeSeries(3, 5, 10);
if (rnnDataFormat == RNNFormat.NWC){
l = l.permute(0, 2, 1);
}
INDArray l = TestUtils.randomOneHotTimeSeries(rnnDataFormat, 3, 5, 10, new Random(12345));
try{
net.fit(in,l);
} catch (Throwable t){

View File

@ -61,15 +61,13 @@ public class TestSimpleRnn extends BaseDL4JTest {
int tsLength = 7;
INDArray in;
if (rnnDataFormat == RNNFormat.NCW){
in = Nd4j.rand(DataType.FLOAT, new int[]{m, nIn, tsLength});
in = Nd4j.rand(DataType.FLOAT, m, nIn, tsLength);
}
else{
in = Nd4j.rand(DataType.FLOAT, new int[]{m, tsLength, nIn});
in = Nd4j.rand(DataType.FLOAT, m, tsLength, nIn);
}
// in.get(all(), all(), interval(1,tsLength)).assign(0);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.updater(new NoOp())
.weightInit(WeightInit.XAVIER)

View File

@ -7,10 +7,14 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Test;
import org.junit.runner.RunWith;
@ -106,4 +110,73 @@ public class TestTimeDistributed extends BaseDL4JTest {
}
}
}
@Test
public void testTimeDistributedDense(){
for( int rnnType=0; rnnType<3; rnnType++ ) {
for( int ffType=0; ffType<3; ffType++ ) {
Layer l0, l2;
switch (rnnType) {
case 0:
l0 = new LSTM.Builder().nOut(5).build();
l2 = new LSTM.Builder().nOut(5).build();
break;
case 1:
l0 = new SimpleRnn.Builder().nOut(5).build();
l2 = new SimpleRnn.Builder().nOut(5).build();
break;
case 2:
l0 = new Bidirectional(new LSTM.Builder().nOut(5).build());
l2 = new Bidirectional(new LSTM.Builder().nOut(5).build());
break;
default:
throw new RuntimeException("Not implemented: " + rnnType);
}
Layer l1;
switch (ffType){
case 0:
l1 = new DenseLayer.Builder().nOut(5).build();
break;
case 1:
l1 = new VariationalAutoencoder.Builder().nOut(5).encoderLayerSizes(5).decoderLayerSizes(5).build();
break;
case 2:
l1 = new AutoEncoder.Builder().nOut(5).build();
break;
default:
throw new RuntimeException("Not implemented: " + ffType);
}
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.activation(Activation.TANH)
.list()
.layer(l0)
.layer(l1)
.layer(l2)
.setInputType(InputType.recurrent(5, 9, rnnDataFormat))
.build();
BaseRecurrentLayer l0a;
BaseRecurrentLayer l2a;
if (rnnType < 2) {
l0a = (BaseRecurrentLayer) l0;
l2a = (BaseRecurrentLayer) l2;
} else {
l0a = (BaseRecurrentLayer) ((Bidirectional) l0).getFwd();
l2a = (BaseRecurrentLayer) ((Bidirectional) l2).getFwd();
}
assertEquals(rnnDataFormat, l0a.getRnnDataFormat());
assertEquals(rnnDataFormat, l2a.getRnnDataFormat());
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
INDArray in = Nd4j.rand(DataType.FLOAT, rnnDataFormat == RNNFormat.NCW ? new long[]{2, 5, 9} : new long[]{2, 9, 5} );
net.output(in);
}
}
}
}

View File

@ -15,21 +15,24 @@
******************************************************************************/
package org.deeplearning4j.convolution;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.*;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.CuDNNTestUtils;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D;
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.ComposableInputPreProcessor;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@ -816,6 +819,12 @@ public class ConvDataFormatTests extends BaseDL4JTest {
.layer(new OutputLayer.Builder().activation(Activation.SOFTMAX).nOut(10).build())
.setInputType(inputType != null ? inputType : InputType.convolutional(12, 12, 3, format));
if(format == CNN2DFormat.NHWC && !(layer instanceof GlobalPoolingLayer)){
//Add a preprocessor due to the differences in how NHWC and NCHW activations are flattened
//DL4J's flattening behaviour matches Keras (hence TF) for import compatibility
builder.inputPreProcessor(2, new ComposableInputPreProcessor(new NHWCToNCHWPreprocessor(), new CnnToFeedForwardPreProcessor()));
}
MultiLayerNetwork net = new MultiLayerNetwork(builder.build());
net.init();
return net;
@ -964,4 +973,35 @@ public class ConvDataFormatTests extends BaseDL4JTest {
}
return differs;
}
//Converts NHWC to NCHW activations
@EqualsAndHashCode
private static class NHWCToNCHWPreprocessor implements InputPreProcessor {
@Override
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.permute(0,3,1,2));
}
@Override
public INDArray backprop(INDArray output, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, output.permute(0,2,3,1));
}
@Override
public InputPreProcessor clone() {
return this;
}
@Override
public InputType getOutputType(InputType inputType) {
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
return InputType.convolutional(c.getHeight(), c.getWidth(), c.getChannels(), CNN2DFormat.NCHW);
}
@Override
public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) {
return null;
}
}
}

View File

@ -212,7 +212,7 @@ public class TestConvolution extends BaseDL4JTest {
ComputationGraph model = KerasModelImport.importKerasModelAndWeights( fExtracted.getAbsolutePath(), new int[]{inSize, inSize, 3}, false);
model = model.convertDataType(DataType.DOUBLE);
INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{1, 3, inSize, inSize});
INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{1, inSize, inSize, 3}); //Keras import model -> NHWC
CuDNNTestUtils.assertHelpersPresent(model.getLayers());
Map<String,INDArray> withCudnn = model.feedForward(in, false);

View File

@ -113,6 +113,12 @@
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-python</artifactId>
<version>${datavec.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<profiles>

View File

@ -19,6 +19,9 @@ package org.deeplearning4j.nn.modelimport.keras.layers;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Convolution3D;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
@ -121,27 +124,29 @@ public class KerasInput extends KerasLayer {
InputType myInputType;
switch (this.inputShape.length) {
case 1:
myInputType = new InputType.InputTypeFeedForward(this.inputShape[0]);
myInputType = new InputType.InputTypeFeedForward(this.inputShape[0], null);
break;
case 2:
if(this.dimOrder != null) {
System.out.println("Dim order: " + this.dimOrder);
System.out.println("Input shape: " + ArrayUtils.toString(this.inputShape));
switch (this.dimOrder) {
case TENSORFLOW: //NWC == channels_last
myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0]);
myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0], RNNFormat.NWC);
break;
case THEANO: //NCW == channels_first
myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]);
myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1], RNNFormat.NCW);
break;
case NONE:
//Assume RNN in [mb, seqLen, size] format
myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]);
myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0], RNNFormat.NWC);
break;
default:
throw new IllegalStateException("Unknown/not supported dimension ordering: " + this.dimOrder);
}
} else {
//Assume RNN in [mb, seqLen, size] format
myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]);
myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0], RNNFormat.NWC);
}
break;
@ -150,17 +155,17 @@ public class KerasInput extends KerasLayer {
case TENSORFLOW:
/* TensorFlow convolutional input: # rows, # cols, # channels */
myInputType = new InputType.InputTypeConvolutional(this.inputShape[0], this.inputShape[1],
this.inputShape[2]);
this.inputShape[2], CNN2DFormat.NHWC);
break;
case THEANO:
/* Theano convolutional input: # channels, # rows, # cols */
myInputType = new InputType.InputTypeConvolutional(this.inputShape[1], this.inputShape[2],
this.inputShape[0]);
this.inputShape[0], CNN2DFormat.NCHW);
break;
default:
this.dimOrder = DimOrder.THEANO;
myInputType = new InputType.InputTypeConvolutional(this.inputShape[1], this.inputShape[2],
this.inputShape[0]);
this.inputShape[0], CNN2DFormat.NCHW);
log.warn("Couldn't determine dim ordering / data format from model file. Older Keras " +
"versions may come without specified backend, in which case we assume the model was " +
"built with theano." );

View File

@ -20,6 +20,7 @@ import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
@ -65,6 +66,9 @@ public class TFOpLayer extends Layer {
long[] shape = inputType.getShape(true);
TFOpLayerImpl tempLayer = new TFOpLayerImpl(nodeDef, constants, null, null);
long[] outputShape = tempLayer.getOutputShape(shape);
if (outputShape.length == 3){
return InputType.recurrent(outputShape[2], outputShape[1], RNNFormat.NWC);
}
return InputType.inferInputType(Nd4j.create(outputShape));
}

View File

@ -125,17 +125,9 @@ public class TFOpLayerImpl extends AbstractLayer<TFOpLayer> {
}
private INDArray runGraph(INDArray input){
if (input.rank() == 3){
// TODO make this a preprocessor
input = input.permute(0, 2, 1);
}
Map<String, INDArray> inputMap = new HashMap<>();
inputMap.put(inputNames.get(0), input);
INDArray out = graphRunnerService.run(inputMap).values().toArray(new INDArray[0])[0];
if (out.rank() == 3){
out = out.permute(0, 2, 1); // TODO post-processing?
}
return out;
}

View File

@ -95,7 +95,6 @@ public class KerasConvolution1D extends KerasConvolution {
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
enforceTrainingConfig, conf, kerasMajorVersion);
Convolution1DLayer.Builder builder = new Convolution1DLayer.Builder().name(this.layerName)
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
.activation(getIActivationFromConfig(layerConfig, conf))
@ -104,7 +103,7 @@ public class KerasConvolution1D extends KerasConvolution {
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
.kernelSize(getKernelSizeFromConfig(layerConfig, 1, conf, kerasMajorVersion)[0])
.hasBias(hasBias)
.stride(getStrideFromConfig(layerConfig, 1, conf)[0]);
.stride(getStrideFromConfig(layerConfig, 1, conf)[0]).rnnDataFormat(dimOrder == DimOrder.TENSORFLOW? RNNFormat.NWC: RNNFormat.NCW);
int[] padding = getPaddingFromBorderModeConfig(layerConfig, 1, conf, kerasMajorVersion);
if (hasBias)
builder.biasInit(0.0);

View File

@ -20,6 +20,7 @@ import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
@ -27,6 +28,7 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurat
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.weights.IWeightInit;
import oshi.jna.platform.windows.PowrProf;
import java.util.Map;
@ -93,6 +95,7 @@ public class KerasConvolution2D extends KerasConvolution {
LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig(
layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion);
System.out.println("----" + dimOrder);
ConvolutionLayer.Builder builder = new ConvolutionLayer.Builder().name(this.layerName)
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
.activation(getIActivationFromConfig(layerConfig, conf))
@ -101,7 +104,8 @@ public class KerasConvolution2D extends KerasConvolution {
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
.kernelSize(getKernelSizeFromConfig(layerConfig, 2, conf, kerasMajorVersion))
.hasBias(hasBias)
.stride(getStrideFromConfig(layerConfig, 2, conf));
.stride(getStrideFromConfig(layerConfig, 2, conf))
.dataFormat((dimOrder==DimOrder.TENSORFLOW)? CNN2DFormat.NHWC:CNN2DFormat.NCHW);
int[] padding = getPaddingFromBorderModeConfig(layerConfig, 2, conf, kerasMajorVersion);
if (hasBias)
builder.biasInit(0.0);

View File

@ -360,8 +360,19 @@ public class KerasConvolutionUtils {
}
} else if (dimension == 1) {
int paddingInt = (int) innerConfig.get(layerField);
padding = new int[]{paddingInt, paddingInt};
Object paddingObj = innerConfig.get(layerField);
if (paddingObj instanceof List){
List<Integer> paddingList = (List)paddingObj;
padding = new int[]{
paddingList.get(0),
paddingList.get(1)
};
}
else{
int paddingInt = (int) innerConfig.get(layerField);
padding = new int[]{paddingInt, paddingInt};
}
} else {
throw new UnsupportedKerasConfigurationException(
"Keras padding layer not supported");

View File

@ -18,6 +18,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.core;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InputType.InputTypeConvolutional;
@ -27,7 +28,6 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurat
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor;
import java.util.Map;
@ -93,11 +93,10 @@ public class KerasFlatten extends KerasLayer {
switch (this.getDimOrder()) {
case NONE:
case THEANO:
preprocessor = new CnnToFeedForwardPreProcessor(it.getHeight(), it.getWidth(), it.getChannels());
preprocessor = new CnnToFeedForwardPreProcessor(it.getHeight(), it.getWidth(), it.getChannels(), CNN2DFormat.NCHW);
break;
case TENSORFLOW:
preprocessor = new TensorFlowCnnToFeedForwardPreProcessor(it.getHeight(), it.getWidth(),
it.getChannels());
preprocessor = new CnnToFeedForwardPreProcessor(it.getHeight(), it.getWidth(), it.getChannels(), CNN2DFormat.NHWC);
break;
default:
throw new InvalidKerasConfigurationException("Unknown Keras backend " + this.getDimOrder());
@ -111,7 +110,7 @@ public class KerasFlatten extends KerasLayer {
// to RNN type. Otherwise we add this trivial preprocessor (since there's nothing to flatten).
InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0];
val inputShape = new long[]{it.getSize()};
preprocessor = new ReshapePreprocessor(inputShape, inputShape, false);
preprocessor = new ReshapePreprocessor(inputShape, inputShape, false, null);
}
return preprocessor;
}

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.nn.modelimport.keras.layers.core;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.misc.RepeatVector;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
@ -60,6 +61,7 @@ public class KerasRepeatVector extends KerasLayer {
super(layerConfig, enforceTrainingConfig);
this.layer = new RepeatVector.Builder().repetitionFactor(getRepeatMultiplier(layerConfig, conf))
.dataFormat(RNNFormat.NWC)
.name(this.layerName).build();
}

View File

@ -18,6 +18,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.core;
import lombok.val;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
@ -111,11 +112,9 @@ public class KerasReshape extends KerasLayer {
} else {
targetShape = new long[]{targetShape[1], targetShape[0], targetShape[2]};
}
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false, CNN2DFormat.NCHW);
} else { // (dimOrder == DimOrder.TENSORFLOW || dimOrder == DimOrder.NONE && kerasMajorVersion == 2)
if (inputShape[0] != targetShape[0])
targetShape = new long[]{targetShape[2], targetShape[0], targetShape[1]};
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false, CNN2DFormat.NHWC);
}
} else if (inputType[0] instanceof InputType.InputTypeConvolutional3D) {
@ -128,23 +127,23 @@ public class KerasReshape extends KerasLayer {
} else {
targetShape = new long[] { targetShape[2], targetShape[1], targetShape[0], targetShape[3] };
}
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false, null);
} else {
if (inputShape[0] != targetShape[0])
targetShape = new long[] { targetShape[3], targetShape[0], targetShape[1], targetShape[2] };
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false, null);
}
} else if (inputType[0] instanceof InputType.InputTypeRecurrent) {
InputType.InputTypeRecurrent it = (InputType.InputTypeRecurrent) inputType[0];
val inputShape = new long[]{it.getSize(), it.getTimeSeriesLength()};
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false);
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false, null);
} else if (inputType[0] instanceof InputType.InputTypeFeedForward) {
InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0];
val inputShape = new long[]{it.getSize()};
if (targetShape.length == 3) {
targetShape = targetShapeForDimOrder(inputShape, targetShape);
}
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false);
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false, null);
}
return preprocessor;
}

View File

@ -21,6 +21,7 @@ import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
@ -121,6 +122,7 @@ public class KerasEmbedding extends KerasLayer {
.biasInit(0.0)
.l1(this.weightL1Regularization)
.l2(this.weightL2Regularization)
.outputDataFormat(RNNFormat.NWC)
.hasBias(false);
if (embeddingConstraint != null)
builder.constrainWeights(embeddingConstraint);

View File

@ -186,7 +186,7 @@ public class KerasLSTM extends KerasLayer {
.weightInitRecurrent(recurrentInit)
.biasInit(0.0) // TODO: this is incorrect
.l1(this.weightL1Regularization)
.l2(this.weightL2Regularization);
.l2(this.weightL2Regularization).dataFormat(RNNFormat.NWC);
Integer nIn = KerasLayerUtils.getNInFromInputDim(layerConfig, conf);
if(nIn != null)
builder.setNIn(nIn);

View File

@ -158,7 +158,7 @@ public class KerasSimpleRnn extends KerasLayer {
.weightInitRecurrent(recurrentInit)
.biasInit(0.0)
.l1(this.weightL1Regularization)
.l2(this.weightL2Regularization);
.l2(this.weightL2Regularization).dataFormat(RNNFormat.NWC);
Integer nIn = KerasLayerUtils.getNInFromInputDim(layerConfig, conf);
if(nIn != null)
builder.setNIn(nIn);

View File

@ -147,7 +147,7 @@ public class KerasBidirectional extends KerasLayer {
break;
case "SimpleRNN":
kerasRnnlayer = new KerasSimpleRnn(innerRnnConfig, enforceTrainingConfig, previousLayers);
SimpleRnn rnnLayer = (SimpleRnn) ((KerasSimpleRnn) kerasRnnlayer).getSimpleRnnLayer();
Layer rnnLayer = ((KerasSimpleRnn) kerasRnnlayer).getSimpleRnnLayer();
this.layer = new Bidirectional(mode, rnnLayer);
layer.setLayerName(layerName);
break;

View File

@ -21,6 +21,9 @@ import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.DataFormat;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor;
@ -54,25 +57,30 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
private final long[] inputShape;
private final long[] targetShape;
private boolean hasMiniBatchDimension;
/**
* @deprecated Use constructor {@link #ReshapePreprocessor(long[], long[], boolean)}
*/
@Deprecated
public ReshapePreprocessor(long[] inputShape, long[] targetShape) {
this(inputShape, targetShape, false);
}
private DataFormat format;
/**
* @param inputShape Input shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
* @param targetShape Target shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
* @param hasMiniBatchDimension If true: shapes should be of the form [minibatch, x, y, ...]; if false: shapes should be of form [x, y, ...]
*/
public ReshapePreprocessor(long[] inputShape, long[] targetShape, boolean hasMiniBatchDimension) {
this(inputShape, targetShape, hasMiniBatchDimension, null);
}
/**
* @param inputShape Input shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
* @param targetShape Target shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
* @param hasMiniBatchDimension If true: shapes should be of the form [minibatch, x, y, ...]; if false: shapes should be of form [x, y, ...]
* @param dataFormat May be null. If non-null:
*/
public ReshapePreprocessor(@JsonProperty("inputShape") long[] inputShape, @JsonProperty("targetShape") long[] targetShape,
@JsonProperty("hasMiniBatchDimension") boolean hasMiniBatchDimension) {
@JsonProperty("hasMiniBatchDimension") boolean hasMiniBatchDimension,
@JsonProperty("dataFormat") DataFormat dataFormat) {
this.inputShape = inputShape;
this.targetShape = targetShape;
this.hasMiniBatchDimension = hasMiniBatchDimension;
this.format = dataFormat;
}
private long[] getShape(long[] originalShape, long minibatch) {
@ -140,13 +148,26 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
ret = InputType.feedForward(shape[1]);
break;
case 3:
ret = InputType.recurrent(shape[2], shape[1]);
RNNFormat format = RNNFormat.NCW;
if(this.format != null && this.format instanceof RNNFormat)
format = (RNNFormat)this.format;
ret = InputType.recurrent(shape[2], shape[1], format);
break;
case 4:
if (inputShape.length == 1 || inputType.getType() == InputType.Type.RNN) {
ret = InputType.convolutional(shape[1], shape[2], shape[3]);
} else {
ret = InputType.convolutional(shape[2], shape[3], shape[1]);
CNN2DFormat cnnFormat = CNN2DFormat.NCHW;
if (this.format != null && this.format instanceof CNN2DFormat)
cnnFormat = (CNN2DFormat) this.format;
if (cnnFormat == CNN2DFormat.NCHW) {
ret = InputType.convolutional(shape[2], shape[3], shape[1], cnnFormat);
} else {
ret = InputType.convolutional(shape[1], shape[2], shape[3], cnnFormat);
}
}
break;
default:

View File

@ -27,26 +27,25 @@ import org.nd4j.shade.jackson.annotation.JsonCreator;
import org.nd4j.shade.jackson.annotation.JsonProperty;
/**
* Specialized CnnToFeedForwardInputPreProcessor for use with
* Convolutional layers imported from Keras using the TensorFlow
* backend.
*
* @author dave@skymind.io
* @deprecated Exists only for backward compatibility of older pretrained models. Should not be used.
* Use {@link CnnToFeedForwardPreProcessor} for all new models instead.
*/
@Slf4j
@Slf4j @Deprecated
public class TensorFlowCnnToFeedForwardPreProcessor extends CnnToFeedForwardPreProcessor {
@JsonCreator
@JsonCreator @Deprecated
public TensorFlowCnnToFeedForwardPreProcessor(@JsonProperty("inputHeight") long inputHeight,
@JsonProperty("inputWidth") long inputWidth,
@JsonProperty("numChannels") long numChannels) {
super(inputHeight, inputWidth, numChannels);
}
@Deprecated
public TensorFlowCnnToFeedForwardPreProcessor(long inputHeight, long inputWidth) {
super(inputHeight, inputWidth);
}
@Deprecated
public TensorFlowCnnToFeedForwardPreProcessor() {
super();
}

View File

@ -31,6 +31,7 @@ import org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations.*;
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.*;
import org.deeplearning4j.nn.modelimport.keras.layers.core.*;
import org.deeplearning4j.nn.modelimport.keras.layers.embeddings.KerasEmbedding;
import org.deeplearning4j.nn.modelimport.keras.layers.local.KerasLocallyConnected1D;
import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasAlphaDropout;
import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasGaussianDropout;
import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasGaussianNoise;
@ -319,6 +320,8 @@ public class KerasLayerUtils {
layer = new KerasELU(layerConfig, enforceTrainingConfig);
} else if(layerClassName.equals(conf.getLAYER_CLASS_NAME_SOFTMAX())){
layer = new KerasSoftmax(layerConfig, enforceTrainingConfig);
} else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_LOCALLY_CONNECTED_1D())){
layer = new KerasLocallyConnected1D(layerConfig, enforceTrainingConfig);
} else if (conf instanceof Keras2LayerConfiguration){
Keras2LayerConfiguration k2conf = (Keras2LayerConfiguration)conf;
if (layerClassName.equals(k2conf.getTENSORFLOW_OP_LAYER())){

View File

@ -1,50 +0,0 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.nn.modelimport.keras;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.junit.Assert;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.resources.Resources;
import java.io.File;
import java.util.Arrays;
public class TFKerasTests extends BaseDL4JTest{
@Test
public void testModelWithTFOp1() throws Exception{
File f = Resources.asFile("modelimport/keras/tfkeras/reshape.h5");
ComputationGraph graph = KerasModelImport.importKerasModelAndWeights(f.getAbsolutePath());
INDArray out = graph.outputSingle(Nd4j.zeros(12, 2, 3));
Assert.assertArrayEquals(new long[]{12, 3}, out.shape());
}
@Test
public void testModelWithTFOp2() throws Exception{
File f = Resources.asFile("modelimport/keras/tfkeras/permute.h5");
ComputationGraph graph = KerasModelImport.importKerasModelAndWeights(f.getAbsolutePath());
INDArray out = graph.outputSingle(Nd4j.zeros(12, 2, 3));
// dl4j's feedforward doesn't support 3D output, so batch and time axes gets squashed
long[] expectedShape = new long[]{12 * 2, 5};
Assert.assertArrayEquals(expectedShape, out.shape());
}
}

View File

@ -0,0 +1,147 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.nn.modelimport.keras;
import org.apache.commons.io.FileUtils;
import org.datavec.python.keras.Model;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.common.tests.ResourceUtils;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.resources.Resources;
import java.io.File;
import java.util.List;
@RunWith(Parameterized.class)
public class TestTFKerasModelImport extends BaseDL4JTest{
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
private String modelFile;
@Override
public long getTimeoutMilliseconds(){
return 300000;
} // installing TF will take a while
@Parameterized.Parameters(name = "file={0}")
public static Object[] params() throws Exception {
List<String> paths = ResourceUtils.listClassPathFiles("modelimport/keras/tfkeras", true, false);
return paths.toArray(new String[0]);
}
public TestTFKerasModelImport(String modelFile){
this.modelFile = modelFile;
}
@Test
public void testModelImport() throws Exception{
testModelImportWithData(modelFile);
}
private void testModelImportWithData(String path) throws Exception{
System.out.println(path);
// TODO multi input/output
INDArray inputArray;
INDArray expectedOutputArray;
File f = Resources.asFile(path); //May in in JAR that HDF5 can't read from
File modelFile = new File(testDir.getRoot(), f.getName());
FileUtils.copyFile(f, modelFile);
synchronized (Hdf5Archive.LOCK_OBJECT){
Hdf5Archive hdf5Archive = new Hdf5Archive(modelFile.getAbsolutePath());
List<String> rootGroups = hdf5Archive.getGroups();
if (rootGroups.contains("data")){
String inputName = hdf5Archive.readAttributeAsString("input_names", "data");
String outputName = hdf5Archive.readAttributeAsString("output_names", "data");
inputArray = hdf5Archive.readDataSet(inputName, "data");
expectedOutputArray = hdf5Archive.readDataSet(outputName, "data");
}
else{
hdf5Archive.close();
return;
}
hdf5Archive.close();
}
INDArray outputArray;
ComputationGraph dl4jModel = KerasModelImport.importKerasModelAndWeights(path);
outputArray = dl4jModel.outputSingle(inputArray);
expectedOutputArray = expectedOutputArray.castTo(DataType.FLOAT);
outputArray = outputArray.castTo(DataType.FLOAT);
if (path.contains("misc_")){
//shape relaxation
expectedOutputArray = expectedOutputArray.reshape( -1);
outputArray = outputArray.reshape(-1);
}
System.out.println(outputArray.toString());
System.out.println(expectedOutputArray.toString());
Assert.assertArrayEquals(expectedOutputArray.shape(), outputArray.shape());
Assert.assertTrue(expectedOutputArray.equalsWithEps(outputArray, 1e-3));
}
private void testModelImportWithKeras(String path) throws Exception{
Model kerasModel = new Model(path);
ComputationGraph dl4jModel = KerasModelImport.importKerasModelAndWeights(path);
Assert.assertEquals(kerasModel.numInputs(), dl4jModel.getNumInputArrays());
Assert.assertEquals(kerasModel.numOutputs(), dl4jModel.getNumOutputArrays());
INDArray[] kerasInputArrays = new INDArray[kerasModel.numInputs()];
INDArray[] dl4jInputArrays = new INDArray[kerasModel.numInputs()];
for (int i = 0; i < kerasInputArrays.length; i ++) {
long[] shape = kerasModel.inputShapeAt(i);
for (int j = 0; j < shape.length; j++) {
if (shape[j] < 0) {
shape[j] = 1;
}
}
kerasInputArrays[i] = Nd4j.rand(shape);
}
INDArray[] kerasOut = kerasModel.predict(kerasInputArrays);
INDArray[] dl4jOut = dl4jModel.output(dl4jInputArrays);
Assert.assertEquals(kerasOut.length, dl4jOut.length);
for (int i = 0; i < kerasOut.length; i++){
INDArray kerasOutArr = kerasOut[i];
kerasOutArr = kerasOutArr.reshape(1, -1);// bit of relaxation on shape
kerasOutArr= kerasOutArr.castTo(DataType.DOUBLE);
Nd4j.getAffinityManager().ensureLocation(dl4jOut[i], AffinityManager.Location.HOST);
INDArray dl4jOutArr = dl4jOut[i].reshape(1, -1);
System.out.println(kerasOutArr.shapeInfoToString());
System.out.println(dl4jOutArr.shapeInfoToString());
Assert.assertEquals(kerasOutArr, dl4jOutArr);
}
}
}

View File

@ -22,7 +22,6 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
@ -34,8 +33,7 @@ public class JsonTest extends BaseDL4JTest {
InputPreProcessor[] pp = new InputPreProcessor[] {
new KerasFlattenRnnPreprocessor(10, 5),
new PermutePreprocessor(new int[]{0,1,2}),
new ReshapePreprocessor(new long[]{10,10}, new long[]{100,1}),
new TensorFlowCnnToFeedForwardPreProcessor()
new ReshapePreprocessor(new long[]{10,10}, new long[]{100,1}, true, null)
};
for(InputPreProcessor p : pp ){

View File

@ -29,6 +29,7 @@ import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasSpaceTo
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.resources.Resources;
@ -250,7 +251,7 @@ public class Keras2ModelConfigurationTest extends BaseDL4JTest {
.enforceTrainingConfig(false).buildSequential().getMultiLayerConfiguration();
MultiLayerNetwork model = new MultiLayerNetwork(config);
model.init();
INDArray input = Nd4j.create(50, 500, 1500);
INDArray input = Nd4j.create(DataType.FLOAT, 50, 1500, 500); //NWC format - [Minibatch, seqLength, channels]
INDArray out = model.output(input);
assertTrue(Arrays.equals(out.shape(), new long[]{50, 64}));
}

View File

@ -87,15 +87,6 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
@Rule
public final TemporaryFolder testDir = new TemporaryFolder();
public static final BiFunction<String,INDArray,INDArray> nwc2ncwExpected = new BiFunction<String, INDArray, INDArray>() {
@Override
public INDArray apply(String s, INDArray array) {
if(array.rank() == 3)
return array.permute(0, 2, 1); //NWC to NCW
return array;
}
};
@Override
public long getTimeoutMilliseconds() {
return 180000L; //Most benchmarks should run very quickly; large timeout is to avoid issues with unusually slow download of test resources
@ -169,28 +160,28 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
public void importImdbLstmTfKeras1() throws Exception {
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_model.h5";
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected);
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, null);
}
@Test
public void importImdbLstmThKeras1() throws Exception {
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_model.h5";
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected);
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, null);
}
@Test
public void importImdbLstmTfKeras2() throws Exception {
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_model.h5";
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected);
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, null);
}
@Test
public void importImdbLstmThKeras2() throws Exception {
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_model.h5";
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, false, true, false, false, true, null, nwc2ncwExpected);
importEndModelTest(modelPath, inputsOutputPath, false, true, false, false, true, null, null);
}
/**
@ -262,7 +253,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
String modelPath = "modelimport/keras/examples/simple_flatten_rnn/simple_flatten_rnn_tf_keras_2_model.h5";
String inputsOutputPath = "modelimport/keras/examples/simple_flatten_rnn/" +
"simple_flatten_rnn_tf_keras_2_inputs_and_outputs.h5";
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected);
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, null);
}
/**
@ -316,7 +307,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
@Test
public void importAcganDiscriminator() throws Exception {
ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/acgan/acgan_discriminator_1_epochs.h5");
INDArray input = Nd4j.create(10, 1, 28, 28);
INDArray input = Nd4j.create(10, 28, 28, 1); //NHWC
INDArray[] output = model.output(input);
}
@ -403,7 +394,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
// Make predictions
int miniBatch = 32;
INDArray input = Nd4j.ones(miniBatch, 4, 10);
INDArray input = Nd4j.ones(miniBatch, 10, 4); //NWC format - with nIn=4, seqLength = 10
INDArray[] out = graph.output(input);
// Fit model
@ -450,7 +441,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
@Test
public void importMobileNet() throws Exception {
ComputationGraph graph = importFunctionalModelH5Test("modelimport/keras/examples/mobilenet/alternative.hdf5");
INDArray input = Nd4j.ones(10, 3, 299, 299);
INDArray input = Nd4j.ones(10, 299, 299, 3);
graph.output(input);
}
@ -462,7 +453,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
int[] inputShape = new int[]{299, 299, 3};
ComputationGraph graph = importFunctionalModelH5Test(
"modelimport/keras/examples/inception/inception_tf_keras_2.h5", inputShape, false);
INDArray input = Nd4j.ones(10, 3, 299, 299);
INDArray input = Nd4j.ones(10, 299, 299, 3); //TF = channels last = NHWC
graph.output(input);
System.out.println(graph.summary());
}
@ -476,7 +467,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
public void importInception() throws Exception {
ComputationGraph graph = importFunctionalModelH5Test(
"modelimport/keras/examples/inception/inception_v3_complete.h5");
INDArray input = Nd4j.ones(10, 3, 299, 299);
INDArray input = Nd4j.ones(10, 299, 299, 3); //TF = channels last = NHWC
graph.output(input);
System.out.println(graph.summary());
}
@ -533,14 +524,14 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
* - Separate (policy and value) residual architecture
* - Separate (policy and value) convolutional architecture
*/
@Test
@Test @Ignore //AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last
public void importSepConvPolicy() throws Exception {
ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/sep_conv_policy.h5");
INDArray input = Nd4j.create(32, 19, 19, 10);
model.output(input);
}
@Test
@Test @Ignore //AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last
public void importSepResPolicy() throws Exception {
ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/sep_res_policy.h5");
INDArray input = Nd4j.create(32, 19, 19, 10);
@ -548,28 +539,28 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
}
@Test
@Test @Ignore //AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last
public void importSepConvValue() throws Exception {
ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/sep_conv_value.h5");
INDArray input = Nd4j.create(32, 19, 19, 10);
model.output(input);
}
@Test
@Test @Ignore //AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last
public void importSepResValue() throws Exception {
ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/sep_res_value.h5");
INDArray input = Nd4j.create(32, 19, 19, 10);
model.output(input);
}
@Test
@Test @Ignore //AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last
public void importDualRes() throws Exception {
ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/dual_res.h5");
INDArray input = Nd4j.create(32, 19, 19, 10);
model.output(input);
}
@Test
@Test @Ignore //AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last
public void importDualConv() throws Exception {
ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/dual_conv.h5");
INDArray input = Nd4j.create(32, 19, 19, 10);
@ -634,16 +625,9 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
System.out.println("Starting test: " + name);
String modelPath = "modelimport/keras/examples/causal_conv1d/" + name;
String inputsOutputPath = "modelimport/keras/examples/causal_conv1d/" + (name.substring(0,name.length()-"model.h5".length()) + "inputs_and_outputs.h5");
Function<INDArray,INDArray> f = new Function<INDArray, INDArray>() {
@Override
public INDArray apply(INDArray i) {
//NWC to NCW
return i.permute(0, 2, 1);
}
};
MultiLayerNetwork net = importEndModelTest(modelPath, inputsOutputPath, true, true,
true, true, false, f, nwc2ncwExpected);
true, true, false, null, null);
Layer l = net.getLayer(0);
Convolution1DLayer c1d = (Convolution1DLayer) l.getConfig();
assertEquals(ConvolutionMode.Causal, c1d.getConvolutionMode());
@ -707,25 +691,9 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
System.out.println("Starting test: " + name);
String modelPath = "modelimport/keras/examples/conv1d/" + name;
String inputsOutputPath = "modelimport/keras/examples/conv1d/" + (name.substring(0,name.length()-"model.h5".length()) + "inputs_and_outputs.h5");
Function<INDArray,INDArray> f = name.contains("_cf_") ? null : new Function<INDArray, INDArray>() {
@Override
public INDArray apply(INDArray i) {
//NWC to NCW
return i.permute(0, 2, 1);
}
};
BiFunction<String,INDArray,INDArray> f2 = name.contains("_cf_") ? null : new BiFunction<String, INDArray, INDArray>() {
@Override
public INDArray apply(String s, INDArray array) {
// if("conv".equals(s)){
return array.permute(0, 2, 1);
// }
}
};
importEndModelTest(modelPath, inputsOutputPath, true, true,
true, true, false, f, f2);
true, true, false, null, null); //f, f2);
}
}
@ -882,8 +850,6 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
INDArray[] inputs = new INDArray[inputNames.size()];
for (int i = 0; i < inputNames.size(); i++) {
inputs[i] = archive.readDataSet(inputNames.get(i), GROUP_ATTR_INPUTS);
if (inputs[i].shape().length == 4 && tensorFlowImageDimOrdering)
inputs[i] = inputs[i].permute(0, 3, 1, 2);
}
return inputs;
}
@ -893,8 +859,6 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
Map<String, INDArray> activations = new HashMap<String, INDArray>();
for (String layerName : archive.getDataSets(GROUP_ACTIVATIONS)) {
INDArray activation = archive.readDataSet(layerName, GROUP_ACTIVATIONS);
if (activation.shape().length == 4 && tensorFlowImageDimOrdering)
activation = activation.permute(0, 3, 1, 2);
activations.put(layerName, activation);
}
return activations;
@ -907,8 +871,6 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
INDArray[] outputs = new INDArray[outputNames.size()];
for (int i = 0; i < outputNames.size(); i++) {
outputs[i] = archive.readDataSet(outputNames.get(i), GROUP_ATTR_OUTPUTS);
if (outputs[i].shape().length == 4 && tensorFlowImageDimOrdering)
outputs[i] = outputs[i].permute(0, 3, 1, 2);
}
return outputs;
}
@ -920,8 +882,6 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
INDArray[] predictions = new INDArray[outputNames.size()];
for (int i = 0; i < outputNames.size(); i++) {
predictions[i] = archive.readDataSet(outputNames.get(i), GROUP_PREDICTIONS);
if (predictions[i].shape().length == 4 && tensorFlowImageDimOrdering)
predictions[i] = predictions[i].permute(0, 3, 1, 2);
}
return predictions;
}
@ -941,6 +901,11 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
// skip too small absolute inputs
if (Math.abs(aAbsMax) > threshold && Math.abs(bAbsMax) > threshold) {
boolean eq = expected.equalsWithEps(actual.castTo(expected.dataType()), eps);
if(!eq){
System.out.println("Expected: " + Arrays.toString(expected.shape()) + ", actual: " + Arrays.toString(actual.shape()));
System.out.println("Expected:\n" + expected);
System.out.println("Actual: \n" + actual);
}
assertTrue("Output differs: " + label, eq);
}
}

View File

@ -176,10 +176,10 @@ public class KerasWeightSettingTests extends BaseDL4JTest {
INDArray bias = model.getLayer(0).getParam("b");
assertEquals(6, bias.length());
INDArray input = Nd4j.ones(1, 5, 3, 4);
INDArray input = Nd4j.ones(1, 3, 4, 5); //NHWC
INDArray output = model.output(input);
assertArrayEquals(new long[] {1, 6, 1, 2}, output.shape());
assertArrayEquals(new long[] {1, 1, 2, 6}, output.shape()); //NHWC
logSuccess(modelPath);
}
@ -224,7 +224,7 @@ public class KerasWeightSettingTests extends BaseDL4JTest {
INDArray input = Nd4j.zeros(mb, inputLength);
INDArray output = model.output(input);
assertArrayEquals(new long[]{mb, nOut, inputLength - kernel + 1}, output.shape());
assertArrayEquals(new long[]{mb, inputLength - kernel + 1, nOut}, output.shape()); //NWC
logSuccess(modelPath);
}
@ -238,9 +238,9 @@ public class KerasWeightSettingTests extends BaseDL4JTest {
KerasLayer.registerCustomLayer("Lambda", KerasSpaceToDepth.class);
MultiLayerNetwork model = loadMultiLayerNetwork(modelPath, false);
INDArray input = Nd4j.zeros(10, 4, 6, 6);
INDArray input = Nd4j.zeros(10, 6, 6, 4);
INDArray output = model.output(input);
assertArrayEquals(new long[]{10, 16, 3, 3}, output.shape());
assertArrayEquals(new long[]{10, 3, 3, 16}, output.shape());
logSuccess(modelPath);
}
@ -248,10 +248,11 @@ public class KerasWeightSettingTests extends BaseDL4JTest {
KerasLayer.registerCustomLayer("Lambda", KerasSpaceToDepth.class);
ComputationGraph model = loadComputationalGraph(modelPath, false);
INDArray input[] = new INDArray[]{Nd4j.zeros(10, 4, 6, 6), Nd4j.zeros(10, 16, 3, 3)};
// INDArray input[] = new INDArray[]{Nd4j.zeros(10, 4, 6, 6), Nd4j.zeros(10, 16, 3, 3)};
INDArray input[] = new INDArray[]{Nd4j.zeros(10, 6, 6, 4), Nd4j.zeros(10, 3, 3, 16)};
INDArray[] output = model.output(input);
log.info(Arrays.toString(output[0].shape()));
assertArrayEquals(new long[]{10, 32, 3, 3}, output[0].shape());
assertArrayEquals(new long[]{10, 3, 3, 32}, output[0].shape());
logSuccess(modelPath);
}
@ -278,7 +279,7 @@ public class KerasWeightSettingTests extends BaseDL4JTest {
INDArray inEmbedding = Nd4j.zeros(mb, inputLength);
INDArray output = model.output(inEmbedding);
assertArrayEquals(new long[]{mb, nOut, inputLength}, output.shape());
assertArrayEquals(new long[]{mb, inputLength, nOut}, output.shape()); //NWC format
logSuccess(modelPath);
}
@ -304,7 +305,7 @@ public class KerasWeightSettingTests extends BaseDL4JTest {
INDArray inEmbedding = Nd4j.zeros(mb, inputLength);
INDArray output = model.output(inEmbedding);
assertArrayEquals(new long[]{mb, nOut, inputLength - kernel + 1}, output.shape());
assertArrayEquals(new long[]{mb, inputLength - kernel + 1, nOut}, output.shape()); //NWC
logSuccess(modelPath);
}

View File

@ -9,7 +9,7 @@ package org.deeplearning4j.nn.conf;
*
* @author Alex Black
*/
public enum CNN2DFormat {
public enum CNN2DFormat implements DataFormat {
NCHW,
NHWC;

View File

@ -0,0 +1,26 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.nn.conf;
import org.deeplearning4j.nn.conf.serde.format.DataFormatDeserializer;
import org.deeplearning4j.nn.conf.serde.format.DataFormatSerializer;
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
@JsonSerialize(using = DataFormatSerializer.class)
@JsonDeserialize(using = DataFormatDeserializer.class)
public interface DataFormat {
}

View File

@ -663,7 +663,7 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
BaseRecurrentLayer brl = (BaseRecurrentLayer) firstLayer;
val nIn = brl.getNIn();
if (nIn > 0) {
inputType = InputType.recurrent(nIn);
inputType = InputType.recurrent(nIn, brl.getRnnDataFormat());
}
} else if (firstLayer instanceof DenseLayer || firstLayer instanceof EmbeddingLayer
|| firstLayer instanceof OutputLayer) {

View File

@ -23,7 +23,7 @@ package org.deeplearning4j.nn.conf;
* "width" corresponds to sequence length and "channels" corresponds to sequence item size.
*/
public enum RNNFormat {
public enum RNNFormat implements DataFormat {
NCW,
NWC
}

View File

@ -18,6 +18,8 @@ package org.deeplearning4j.nn.conf.graph;
import lombok.val;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
import org.deeplearning4j.nn.conf.layers.Convolution3D;
@ -38,6 +40,8 @@ import org.nd4j.linalg.api.ndarray.INDArray;
*/
public class MergeVertex extends GraphVertex {
protected int mergeAxis = 1; //default value for backward compatibility (deserialization of old version JSON) - NCHW and NCW format
@Override
public MergeVertex clone() {
return new MergeVertex();
@ -76,7 +80,7 @@ public class MergeVertex extends GraphVertex {
@Override
public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx,
INDArray paramsView, boolean initializeParams, DataType networkDatatype) {
return new org.deeplearning4j.nn.graph.vertex.impl.MergeVertex(graph, name, idx, networkDatatype);
return new org.deeplearning4j.nn.graph.vertex.impl.MergeVertex(graph, name, idx, networkDatatype, mergeAxis);
}
@Override
@ -126,6 +130,7 @@ public class MergeVertex extends GraphVertex {
//FF or RNN data inputs
int size = 0;
InputType.Type type = null;
RNNFormat format = null;
for (int i = 0; i < vertexInputs.length; i++) {
if (vertexInputs[i].getType() != first.getType()) {
throw new InvalidInputTypeException(
@ -142,6 +147,8 @@ public class MergeVertex extends GraphVertex {
break;
case RNN:
thisSize = ((InputType.InputTypeRecurrent) vertexInputs[i]).getSize();
format = ((InputType.InputTypeRecurrent) vertexInputs[i]).getFormat();
this.mergeAxis = format == RNNFormat.NCW ? 1 : 2;
type = InputType.Type.RNN;
break;
default:
@ -160,7 +167,7 @@ public class MergeVertex extends GraphVertex {
return InputType.feedForward(size);
} else {
val tsLength = ((InputType.InputTypeRecurrent) vertexInputs[0]).getTimeSeriesLength();
return InputType.recurrent(size, tsLength);
return InputType.recurrent(size, tsLength, format);
}
} else {
//size is unknown
@ -168,13 +175,14 @@ public class MergeVertex extends GraphVertex {
return InputType.feedForward(-1);
} else {
val tsLength = ((InputType.InputTypeRecurrent) vertexInputs[0]).getTimeSeriesLength();
return InputType.recurrent(-1, tsLength);
return InputType.recurrent(-1, tsLength, format);
}
}
} else {
//CNN inputs... also check that the channels, width and heights match:
InputType.InputTypeConvolutional firstConv = (InputType.InputTypeConvolutional) first;
CNN2DFormat format = firstConv.getFormat();
val fd = firstConv.getChannels();
val fw = firstConv.getWidth();
@ -206,7 +214,8 @@ public class MergeVertex extends GraphVertex {
depthSum += od;
}
return InputType.convolutional(fh, fw, depthSum);
this.mergeAxis = format == CNN2DFormat.NCHW ? 1 : 3;
return InputType.convolutional(fh, fw, depthSum, format);
}
}

View File

@ -20,6 +20,7 @@ import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.NoArgsConstructor;
import org.deeplearning4j.nn.conf.DataFormat;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.layers.Convolution3D;
@ -91,7 +92,11 @@ public abstract class InputType implements Serializable {
* @return InputTypeFeedForward
*/
public static InputType feedForward(long size) {
return new InputTypeFeedForward(size);
return new InputTypeFeedForward(size, null);
}
public static InputType feedForward(long size, DataFormat timeDistributedFormat) {
return new InputTypeFeedForward(size,timeDistributedFormat);
}
/**
@ -132,7 +137,6 @@ public abstract class InputType implements Serializable {
* @return InputTypeConvolutional
*/
public static InputType convolutional(long height, long width, long depth) {
// return new InputTypeConvolutional(height, width, depth);
return convolutional(height, width, depth, CNN2DFormat.NCHW);
}
@ -191,9 +195,11 @@ public abstract class InputType implements Serializable {
@EqualsAndHashCode(callSuper = false)
public static class InputTypeFeedForward extends InputType {
private long size;
private DataFormat timeDistributedFormat;
public InputTypeFeedForward(@JsonProperty("size") long size) {
public InputTypeFeedForward(@JsonProperty("size") long size, @JsonProperty("timeDistributedFormat") DataFormat timeDistributedFormat) {
this.size = size;
this.timeDistributedFormat = timeDistributedFormat;
}
@Override
@ -203,7 +209,7 @@ public abstract class InputType implements Serializable {
@Override
public String toString() {
return "InputTypeFeedForward(" + size + ")";
return "InputTypeFeedForward(" + size + (timeDistributedFormat != null ? "," + timeDistributedFormat : "") + ")";
}
@Override
@ -302,7 +308,8 @@ public abstract class InputType implements Serializable {
this.height = height;
this.width = width;
this.channels = channels;
this.format = format;
if(format != null)
this.format = format;
}
public InputTypeConvolutional(long height, long width, long channels) {

View File

@ -64,11 +64,11 @@ public abstract class BaseRecurrentLayer extends FeedForwardLayer {
+ "\"): expect RNN input type with size > 0. Got: " + inputType);
}
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
if (nIn <= 0 || override) {
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
this.nIn = r.getSize();
this.rnnDataFormat = r.getFormat();
}
this.rnnDataFormat = r.getFormat();
}
@Override

View File

@ -44,6 +44,7 @@ import java.util.Map;
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
public class Convolution1DLayer extends ConvolutionLayer {
private RNNFormat rnnDataFormat = RNNFormat.NCW;
/*
//TODO: We will eventually want to NOT subclass off of ConvolutionLayer.
//Currently, we just subclass off the ConvolutionLayer and hard code the "width" dimension to 1
@ -56,6 +57,7 @@ public class Convolution1DLayer extends ConvolutionLayer {
private Convolution1DLayer(Builder builder) {
super(builder);
initializeConstraints(builder);
this.rnnDataFormat = builder.rnnDataFormat;
}
@Override
@ -92,7 +94,8 @@ public class Convolution1DLayer extends ConvolutionLayer {
outLength = Convolution1DUtils.getOutputSize(inputTsLength, kernelSize[0], stride[0], padding[0],
convolutionMode, dilation[0]);
}
return InputType.recurrent(nOut, outLength);
return InputType.recurrent(nOut, outLength, rnnDataFormat);
}
@Override
@ -102,10 +105,11 @@ public class Convolution1DLayer extends ConvolutionLayer {
+ "\"): expect RNN input type with size > 0. Got: " + inputType);
}
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
if (nIn <= 0 || override) {
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
this.nIn = r.getSize();
}
this.rnnDataFormat = r.getFormat();
}
@Override
@ -115,11 +119,13 @@ public class Convolution1DLayer extends ConvolutionLayer {
+ "\"): input is null");
}
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, RNNFormat.NCW,getLayerName());
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, rnnDataFormat,getLayerName());
}
public static class Builder extends ConvolutionLayer.BaseConvBuilder<Builder> {
private RNNFormat rnnDataFormat = RNNFormat.NCW;
public Builder() {
this(0, 1, 0);
this.setKernelSize((int[]) null);
@ -130,6 +136,11 @@ public class Convolution1DLayer extends ConvolutionLayer {
return true;
}
public Builder rnnDataFormat(RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
return this;
}
/**
* @param kernelSize Kernel size
* @param stride Stride

View File

@ -21,6 +21,7 @@ import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
@ -58,12 +59,14 @@ public class EmbeddingSequenceLayer extends FeedForwardLayer {
private int inputLength = 1; // By default only use one index to embed
private boolean hasBias = false;
private boolean inferInputLength = false; // use input length as provided by input data
private RNNFormat outputFormat = RNNFormat.NCW; //Default value for older deserialized models
private EmbeddingSequenceLayer(Builder builder) {
super(builder);
this.hasBias = builder.hasBias;
this.inputLength = builder.inputLength;
this.inferInputLength = builder.inferInputLength;
this.outputFormat = builder.outputFormat;
initializeConstraints(builder);
}
@ -87,7 +90,7 @@ public class EmbeddingSequenceLayer extends FeedForwardLayer {
throw new IllegalStateException("Invalid input for Embedding layer (layer index = " + layerIndex
+ ", layer name = \"" + getLayerName() + "\"): expect FF/RNN input type. Got: " + inputType);
}
return InputType.recurrent(nOut, inputLength);
return InputType.recurrent(nOut, inputLength, outputFormat);
}
@Override
@ -167,6 +170,13 @@ public class EmbeddingSequenceLayer extends FeedForwardLayer {
*/
private boolean inferInputLength = true;
private RNNFormat outputFormat = RNNFormat.NCW; //Default value for older deserialized models
public Builder outputDataFormat(RNNFormat format){
this.outputFormat = format;
return this;
}
/**
* If true: include bias parameters in the layer. False (default): no bias.
*

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.nn.conf.layers;
import lombok.*;
import org.deeplearning4j.nn.conf.DataFormat;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.preprocessor.Cnn3DToFeedForwardPreProcessor;
@ -35,6 +36,7 @@ public abstract class FeedForwardLayer extends BaseLayer {
protected long nIn;
protected long nOut;
protected DataFormat timeDistributedFormat;
public FeedForwardLayer(Builder builder) {
super(builder);
@ -51,7 +53,7 @@ public abstract class FeedForwardLayer extends BaseLayer {
+ getLayerName() + "\"): expected FeedForward input type. Got: " + inputType);
}
return InputType.feedForward(nOut);
return InputType.feedForward(nOut, timeDistributedFormat);
}
@Override
@ -71,6 +73,11 @@ public abstract class FeedForwardLayer extends BaseLayer {
this.nIn = f.getFlattenedSize();
}
}
if(inputType instanceof InputType.InputTypeFeedForward){
InputType.InputTypeFeedForward f = (InputType.InputTypeFeedForward) inputType;
this.timeDistributedFormat = f.getTimeDistributedFormat();
}
}
@Override

View File

@ -536,11 +536,17 @@ public class InputTypeUtil {
}
switch (inputType.getType()) {
case FF:
case CNNFlat:
//FF -> RNN or CNNFlat -> RNN
//In either case, input data format is a row vector per example
return new FeedForwardToRnnPreProcessor(rnnDataFormat);
case FF:
//If time distributed format is defined, use that. Otherwise use the layer-defined rnnDataFormat, which may be default
InputType.InputTypeFeedForward ff = (InputType.InputTypeFeedForward)inputType;
if(ff.getTimeDistributedFormat() != null && ff.getTimeDistributedFormat() instanceof RNNFormat){
return new FeedForwardToRnnPreProcessor((RNNFormat) ff.getTimeDistributedFormat());
}
return new FeedForwardToRnnPreProcessor(rnnDataFormat);
case RNN:
//RNN -> RNN: No preprocessor necessary
return null;

View File

@ -98,9 +98,9 @@ public class RnnOutputLayer extends BaseOutputLayer {
+ "\"): Expected RNN input, got " + inputType);
}
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
this.rnnDataFormat = r.getFormat();
if (nIn <= 0 || override) {
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
this.rnnDataFormat = r.getFormat();
this.nIn = r.getSize();
}
}

View File

@ -91,7 +91,7 @@ public class Subsampling1DLayer extends SubsamplingLayer {
outLength = Convolution1DUtils.getOutputSize(inputTsLength, kernelSize[0], stride[0], padding[0],
convolutionMode, dilation[0]);
}
return InputType.recurrent(r.getSize(), outLength);
return InputType.recurrent(r.getSize(), outLength, r.getFormat());
}
@Override

View File

@ -19,6 +19,7 @@ package org.deeplearning4j.nn.conf.layers.misc;
import lombok.*;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
@ -46,10 +47,12 @@ import java.util.Map;
public class RepeatVector extends FeedForwardLayer {
private int n = 1;
private RNNFormat dataFormat = RNNFormat.NCW;
protected RepeatVector(Builder builder) {
super(builder);
this.n = builder.n;
this.dataFormat = builder.dataFormat;
}
@Override
@ -83,7 +86,7 @@ public class RepeatVector extends FeedForwardLayer {
+ "\"): Expected FF input, got " + inputType);
}
InputType.InputTypeFeedForward ffInput = (InputType.InputTypeFeedForward) inputType;
return InputType.recurrent(ffInput.getSize(), n);
return InputType.recurrent(ffInput.getSize(), n, this.dataFormat);
}
@Override
@ -101,13 +104,14 @@ public class RepeatVector extends FeedForwardLayer {
}
@NoArgsConstructor
@Getter
@Setter
public static class Builder<T extends Builder<T>> extends FeedForwardLayer.Builder<T> {
private int n = 1; // no repetition by default
private RNNFormat dataFormat = RNNFormat.NCW;
/**
* Set repetition factor for RepeatVector layer
*/
@ -115,6 +119,15 @@ public class RepeatVector extends FeedForwardLayer {
return n;
}
public RNNFormat getDataFormat(){
return dataFormat;
}
public Builder dataFormat(RNNFormat dataFormat){
this.dataFormat = dataFormat;
return this;
}
/**
* Set repetition factor for RepeatVector layer
*

View File

@ -39,11 +39,13 @@ import java.util.Arrays;
* For example, CNN -> Denselayer <br>
* This does two things:<br>
* (b) Reshapes 4d activations out of CNN layer, with shape
* [numExamples, numChannels, inputHeight, inputWidth]) into 2d activations (with shape
* [numExamples, inputHeight*inputWidth*numChannels]) for use in feed forward layer
* [numExamples, numChannels, inputHeight, inputWidth]) (for {@link CNN2DFormat#NCHW} format activations) or shape
* [numExamples, inputHeight, inputWidth, numChannels] (for {@link CNN2DFormat#NHWC}) format activations) into 2d activations
* (with shape [numExamples, inputHeight*inputWidth*numChannels]) for use in feed forward layer.
* (a) Reshapes epsilons (weights*deltas) out of FeedFoward layer (which is 2D or 3D with shape
* [numExamples, inputHeight*inputWidth*numChannels]) into 4d epsilons (with shape
* [numExamples, numChannels, inputHeight, inputWidth]) suitable to feed into CNN layers.<br>
* [numExamples, numChannels, inputHeight, inputWidth] or [numExamples, inputHeight, inputWidth, numChannels]) suitable to
* feed into CNN layers.<br>
* Note: numChannels is equivalent to channels or featureMaps referenced in different literature
* @author Adam Gibson
* @see FeedForwardToCnnPreProcessor for opposite case (i.e., DenseLayer -> CNNetc)
@ -68,7 +70,8 @@ public class CnnToFeedForwardPreProcessor implements InputPreProcessor {
this.inputHeight = inputHeight;
this.inputWidth = inputWidth;
this.numChannels = numChannels;
this.format = format;
if(format != null)
this.format = format;
}
public CnnToFeedForwardPreProcessor(long inputHeight, long inputWidth) {
@ -96,10 +99,17 @@ public class CnnToFeedForwardPreProcessor implements InputPreProcessor {
wDim = 2;
}
if(inputHeight == 0 && inputWidth == 0 && numChannels == 0){
this.inputHeight = input.size(hDim);
this.inputWidth = input.size(wDim);
this.numChannels = input.size(chDim);
}
if(input.size(chDim) != numChannels || input.size(hDim) != inputHeight || input.size(wDim) != inputWidth){
throw new IllegalStateException("Invalid input, does not match configuration: expected [minibatch, numChannels="
+ numChannels + ", inputHeight=" + inputHeight + ", inputWidth=" + inputWidth + "] but got input array of" +
"shape " + Arrays.toString(input.shape()));
throw new IllegalStateException("Invalid input, does not match configuration: expected " +
(format == CNN2DFormat.NCHW ? "[minibatch, numChannels=" + numChannels + ", inputHeight=" + inputHeight + ", inputWidth=" + inputWidth + "] " :
"[minibatch, inputHeight=" + inputHeight + ", inputWidth=" + inputWidth + ", numChannels=" + numChannels + "]") +
" but got input array of shape " + Arrays.toString(input.shape()));
}
//Check input: nchw format
@ -110,15 +120,13 @@ public class CnnToFeedForwardPreProcessor implements InputPreProcessor {
+ Arrays.toString(input.shape()));
}
if(format == CNN2DFormat.NHWC) {
input = input.permute(0, 3, 1, 2); //NHWC to NCHW
}
//Assume input is standard rank 4 activations out of CNN layer
//First: we require input to be in c order. But c order (as declared in array order) isn't enough; also need strides to be correct
if (input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input))
input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c');
//Note that to match Tensorflow/Keras, we do a simple "c order reshape" for both NCHW and NHWC
val inShape = input.shape(); //[miniBatch,depthOut,outH,outW]
val outShape = new long[]{inShape[0], inShape[1] * inShape[2] * inShape[3]};
@ -139,11 +147,13 @@ public class CnnToFeedForwardPreProcessor implements InputPreProcessor {
+ inputHeight + " x columns " + inputWidth + " x channels " + numChannels + " but was instead "
+ Arrays.toString(epsilons.shape()));
INDArray ret = epsilons.reshape('c', epsilons.size(0), numChannels, inputHeight, inputWidth);
if(format == CNN2DFormat.NHWC){
ret = ret.permute(0,2,3,1); //NCHW to NHWC
INDArray ret;
if(format == CNN2DFormat.NCHW){
ret = epsilons.reshape('c', epsilons.size(0), numChannels, inputHeight, inputWidth);
} else {
ret = epsilons.reshape('c', epsilons.size(0), inputHeight, inputWidth, numChannels);
}
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, ret); //Move if required to specified workspace
}

View File

@ -52,7 +52,8 @@ public class FeedForwardToRnnPreProcessor implements InputPreProcessor {
private RNNFormat rnnDataFormat = RNNFormat.NCW;
public FeedForwardToRnnPreProcessor(@JsonProperty("rnnDataFormat") RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
if(rnnDataFormat != null)
this.rnnDataFormat = rnnDataFormat;
}
@Override
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {

View File

@ -57,7 +57,8 @@ public class RnnToFeedForwardPreProcessor implements InputPreProcessor {
private RNNFormat rnnDataFormat = RNNFormat.NCW;
public RnnToFeedForwardPreProcessor(@JsonProperty("rnnDataFormat") RNNFormat rnnDataFormat){
this.rnnDataFormat = rnnDataFormat;
if(rnnDataFormat != null)
this.rnnDataFormat = rnnDataFormat;
}
@Override
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
@ -116,7 +117,7 @@ public class RnnToFeedForwardPreProcessor implements InputPreProcessor {
}
InputType.InputTypeRecurrent rnn = (InputType.InputTypeRecurrent) inputType;
return InputType.feedForward(rnn.getSize());
return InputType.feedForward(rnn.getSize(), rnn.getFormat());
}
@Override

View File

@ -0,0 +1,52 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.nn.conf.serde.format;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.DataFormat;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.nd4j.shade.jackson.core.JsonParser;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.nd4j.shade.jackson.databind.DeserializationContext;
import org.nd4j.shade.jackson.databind.JsonDeserializer;
import org.nd4j.shade.jackson.databind.JsonNode;
import java.io.IOException;
/**
* Simple JSON deserializer for {@link DataFormat} instances - {@link CNN2DFormat} and {@link RNNFormat}
*
* @author Alex Black
*/
public class DataFormatDeserializer extends JsonDeserializer<DataFormat> {
@Override
public DataFormat deserialize(JsonParser jp, DeserializationContext deserializationContext) throws IOException, JsonProcessingException {
JsonNode node = jp.getCodec().readTree(jp);
String text = node.textValue();
switch (text){
case "NCHW":
return CNN2DFormat.NCHW;
case "NHWC":
return CNN2DFormat.NHWC;
case "NCW":
return RNNFormat.NCW;
case "NWC":
return RNNFormat.NWC;
default:
return null;
}
}
}

View File

@ -0,0 +1,37 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.nn.conf.serde.format;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.DataFormat;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.nd4j.shade.jackson.core.JsonGenerator;
import org.nd4j.shade.jackson.databind.JsonSerializer;
import org.nd4j.shade.jackson.databind.SerializerProvider;
import java.io.IOException;
/**
* Simple JSON deserializer for {@link DataFormat} instances - {@link CNN2DFormat} and {@link RNNFormat}
*
* @author Alex Black
*/
public class DataFormatSerializer extends JsonSerializer<DataFormat> {
@Override
public void serialize(DataFormat dataFormat, JsonGenerator jsonGenerator, SerializerProvider serializerProvider) throws IOException {
jsonGenerator.writeString(dataFormat.toString());
}
}

View File

@ -28,6 +28,7 @@ import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.primitives.Pair;
import org.deeplearning4j.nn.workspace.ArrayType;
@ -48,14 +49,16 @@ public class MergeVertex extends BaseGraphVertex {
private long[][] forwardPassShapes;
private int fwdPassRank;
private int mergeAxis;
public MergeVertex(ComputationGraph graph, String name, int vertexIndex, DataType dataType) {
this(graph, name, vertexIndex, null, null, dataType);
public MergeVertex(ComputationGraph graph, String name, int vertexIndex, DataType dataType, int mergeAxis) {
this(graph, name, vertexIndex, null, null, dataType, mergeAxis);
}
public MergeVertex(ComputationGraph graph, String name, int vertexIndex, VertexIndices[] inputVertices,
VertexIndices[] outputVertices, DataType dataType) {
VertexIndices[] outputVertices, DataType dataType, int mergeAxis) {
super(graph, name, vertexIndex, inputVertices, outputVertices, dataType);
this.mergeAxis = mergeAxis;
}
@Override
@ -92,7 +95,6 @@ public class MergeVertex extends BaseGraphVertex {
forwardPassShapes = new long[in.length][0];
val nExamples = in[0].size(0);
int nOut = 0;
fwdPassRank = in[0].rank();
for (int i = 0; i < in.length; i++) {
val currShape = in[i].shape();
@ -109,12 +111,11 @@ public class MergeVertex extends BaseGraphVertex {
+ Arrays.toString(in[0].shape()) + ", activations[" + i
+ "] shape: " + Arrays.toString(in[i].shape()));
}
nOut += currShape[1]; //Same dimension for all of CNNs, FF, RNNs
}
try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATIONS)){
return Nd4j.concat(1, in);
INDArray out = Nd4j.concat(mergeAxis, in);
return out;
}
}
@ -145,20 +146,16 @@ public class MergeVertex extends BaseGraphVertex {
break;
case 3:
for (int i = 0; i < forwardPassShapes.length; i++) {
out[i].assign(epsilon.get(NDArrayIndex.all(), //All rows
NDArrayIndex.interval(cumulative, cumulative + forwardPassShapes[i][1]), //subset of columns
NDArrayIndex.all())); //All time steps
out[i].assign(epsilon.get(indices(3, mergeAxis, cumulative, cumulative + forwardPassShapes[i][mergeAxis]))); //All time steps
cumulative += forwardPassShapes[i][1];
cumulative += forwardPassShapes[i][mergeAxis];
}
break;
case 4:
for (int i = 0; i < forwardPassShapes.length; i++) {
out[i].assign(epsilon.get(NDArrayIndex.all(),
NDArrayIndex.interval(cumulative, cumulative + forwardPassShapes[i][1]), //Subset of depth
NDArrayIndex.all(), //Width
NDArrayIndex.all())); //height
cumulative += forwardPassShapes[i][1];
out[i].assign(epsilon.get(indices(4, mergeAxis, cumulative, cumulative + forwardPassShapes[i][mergeAxis]))); //height
cumulative += forwardPassShapes[i][mergeAxis];
}
break;
default:
@ -168,6 +165,19 @@ public class MergeVertex extends BaseGraphVertex {
return new Pair<>(null, out);
}
private INDArrayIndex[] indices(int num, int axis, long from, long to){
INDArrayIndex[] out = new INDArrayIndex[num];
for( int i=0; i<num; i++ ){
if(i == axis){
out[i] = NDArrayIndex.interval(from, to);
} else {
out[i] = NDArrayIndex.all();
}
}
return out;
}
@Override
public void setBackpropGradientsViewArray(INDArray backpropGradientsViewArray) {
if (backpropGradientsViewArray != null)

View File

@ -79,7 +79,8 @@ public abstract class BaseOutputLayer<LayerConfT extends org.deeplearning4j.nn.c
ILossFunction lossFunction = layerConf().getLossFn();
double score = lossFunction.computeScore(getLabels2d(workspaceMgr, ArrayType.FF_WORKING_MEM), preOut,
INDArray labels2d = getLabels2d(workspaceMgr, ArrayType.FF_WORKING_MEM);
double score = lossFunction.computeScore(labels2d, preOut,
layerConf().getActivationFn(), maskArray,false);
if(conf().isMiniBatch())

View File

@ -21,6 +21,7 @@ import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.convolution.upsampling.Upsampling2D;
@ -71,7 +72,11 @@ public class RepeatVector extends AbstractLayer<org.deeplearning4j.nn.conf.layer
INDArray outEpsilon;
try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATION_GRAD)){
outEpsilon = epsilon.sum(2);
if (layerConf().getDataFormat() == RNNFormat.NCW) {
outEpsilon = epsilon.sum(2);
}else{
outEpsilon = epsilon.sum(1);
}
}
Gradient gradient = new DefaultGradient();
@ -99,10 +104,22 @@ public class RepeatVector extends AbstractLayer<org.deeplearning4j.nn.conf.layer
long miniBatch = input.size(0);
long size = input.size(1);
INDArray output = input.reshape(miniBatch, size, 1).castTo(dataType);
try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATIONS)) {
return output.repeat(2, (long) getN());
if (getDataFormat() == RNNFormat.NCW) {
INDArray output = input.reshape(miniBatch, size, 1).castTo(dataType);
try (MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATIONS)) {
return output.repeat(2, (long) getN());
}
}
else{
INDArray output = input.reshape(miniBatch, 1, size).castTo(dataType);
try (MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATIONS)) {
return output.repeat(1, (long) getN());
}
}
}
public RNNFormat getDataFormat(){
return layerConf().getDataFormat();
}
@Override

View File

@ -20,6 +20,8 @@ import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.layers.Convolution1D;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.params.ConvolutionParamInitializer;
@ -74,6 +76,10 @@ public class Convolution1DLayer extends ConvolutionLayer {
+ Arrays.toString(epsilon.shape())
+ ". Expected rank 3 array with shape [minibatchSize, features, length]. " + layerId());
if (getRnnDataFormat() == RNNFormat.NWC){
epsilon = epsilon.permute(0, 2, 1);
this.input = input.permute(0, 2, 1);
}
if(maskArray != null){
INDArray maskOut = feedForwardMaskArray(maskArray, MaskState.Active, (int)epsilon.size(0)).getFirst();
Preconditions.checkState(epsilon.size(0) == maskOut.size(0) && epsilon.size(2) == maskOut.size(1),
@ -125,6 +131,10 @@ public class Convolution1DLayer extends ConvolutionLayer {
retGradient.setGradientFor(ConvolutionParamInitializer.BIAS_KEY, gradientViews.get(ConvolutionParamInitializer.BIAS_KEY));
}
retGradient.setGradientFor(ConvolutionParamInitializer.WEIGHT_KEY, gradientViews.get(ConvolutionParamInitializer.WEIGHT_KEY), 'c');
if (getRnnDataFormat() == RNNFormat.NWC){
epsOut = epsOut.permute(0, 2, 1);
this.input = input.permute(0, 2, 1);
}
return new Pair<>(retGradient, epsOut);
}
@ -140,7 +150,10 @@ public class Convolution1DLayer extends ConvolutionLayer {
// remove singleton fourth dimension from input and current epsilon
epsNext = epsNext.reshape(epsNext.size(0), epsNext.size(1), epsNext.size(2));
input = origInput;
if (getRnnDataFormat() == RNNFormat.NWC){
epsNext = epsNext.permute(0, 2, 1);
this.input = input.permute(0, 2, 1);
}
return new Pair<>(gradientEpsNext.getFirst(), epsNext);
}
@ -185,7 +198,8 @@ public class Convolution1DLayer extends ConvolutionLayer {
.s(c.getStride()[0])
.d(c.getDilation()[0])
.p(c.getPadding()[0])
.dataFormat(Conv1DConfig.NCW)
.dataFormat((((org.deeplearning4j.nn.conf.layers.Convolution1DLayer)
layerConf()).getRnnDataFormat()== RNNFormat.NCW)?Conv1DConfig.NCW: Conv1DConfig.NCW)
.paddingMode(PaddingMode.CAUSAL)
.build();
INDArray w = getParam(ConvolutionParamInitializer.WEIGHT_KEY);
@ -209,6 +223,9 @@ public class Convolution1DLayer extends ConvolutionLayer {
@Override
public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr){
if (getRnnDataFormat() == RNNFormat.NWC){
this.input = input.permute(0, 2, 1);
}
INDArray act4d = super.activate(training, workspaceMgr);
INDArray act3d = act4d.reshape(act4d.size(0), act4d.size(1), act4d.size(2));
@ -219,6 +236,10 @@ public class Convolution1DLayer extends ConvolutionLayer {
act3d.shape(), maskOut.shape());
Broadcast.mul(act3d, maskOut, act3d, 0, 2);
}
if (getRnnDataFormat() == RNNFormat.NWC){
this.input = input.permute(0, 2, 1);
act3d = act3d.permute(0, 2, 1);
}
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, act3d); //Should be zero copy most of the time
}
@ -231,4 +252,8 @@ public class Convolution1DLayer extends ConvolutionLayer {
layerConf().getConvolutionMode());
return new Pair<>(reduced, currentMaskState);
}
private RNNFormat getRnnDataFormat(){
return ((org.deeplearning4j.nn.conf.layers.Convolution1DLayer) layerConf()).getRnnDataFormat();
}
}

View File

@ -160,7 +160,8 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
Pair<INDArray, INDArray> p = preOutput4d(true, true, workspaceMgr);
INDArray z = p.getFirst();
if(layerConf().getCnn2dDataFormat() != CNN2DFormat.NCHW){
CNN2DFormat f = layerConf().getCnn2dDataFormat();
if(f != CNN2DFormat.NCHW){
z = z.permute(0,3,1,2); //NHWC to NCHW
}
delta = afn.backprop(z, epsilon).getFirst(); //TODO handle activation function params

View File

@ -20,6 +20,7 @@ import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer;
@ -64,8 +65,14 @@ public class EmbeddingSequenceLayer extends BaseLayer<org.deeplearning4j.nn.conf
INDArray z = preOutput(true, workspaceMgr);
INDArray delta = layerConf().getActivationFn().backprop(z, epsilon).getFirst(); //Shape: [mb, vector, seqLength]
boolean ncw = layerConf().getOutputFormat() == RNNFormat.NCW;
if (maskArray != null) {
delta = Broadcast.mul(delta, maskArray, delta, 0, 2);
if(ncw){
delta = Broadcast.mul(delta, maskArray, delta, 0, 2);
} else {
delta = Broadcast.mul(delta, maskArray, delta, 0, 1);
}
}
int inputLength = layerConf().getInputLength();
@ -76,7 +83,10 @@ public class EmbeddingSequenceLayer extends BaseLayer<org.deeplearning4j.nn.conf
delta = delta.dup('c');
}
delta = delta.permute(0, 2, 1); //From [minibatch, nOut, length] to [minibatch, length, nOut]
if(ncw){
delta = delta.permute(0, 2, 1); //From [minibatch, nOut, length] to [minibatch, length, nOut]
}
delta = delta.reshape('c',inputLength * numSamples, nOut);
INDArray weightGradients = gradientViews.get(DefaultParamInitializer.WEIGHT_KEY);
@ -159,7 +169,10 @@ public class EmbeddingSequenceLayer extends BaseLayer<org.deeplearning4j.nn.conf
}
val shape = new long[]{minibatch, inputLength, nOut};
INDArray ret = rows.reshape('c', shape).permute(0, 2, 1);
INDArray ret = rows.reshape('c', shape);
if(layerConf().getOutputFormat() == RNNFormat.NCW){
ret = ret.permute(0, 2, 1); //[minibatch, seqLen, nOut] -> [minibatch, nOut, seqLen] i.e., NWC -> NCW
}
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, ret);
}
@ -177,8 +190,14 @@ public class EmbeddingSequenceLayer extends BaseLayer<org.deeplearning4j.nn.conf
" 2 (when input is rank 3, shape [mb,1,tsLength]). Input shape: " + Arrays.toString(input.shape()) +
", mask shape: " + Arrays.toString(maskArray.shape()));
}
//Returned array: rank 3, shape [mb, vector, seqLength]. mask shape: [mb, seqLength]
Broadcast.mul(ret, maskArray.castTo(ret.dataType()), ret, 0, 2);
boolean ncw = layerConf().getOutputFormat() == RNNFormat.NCW;
if(ncw){
//Returned array: rank 3, shape [mb, vector, seqLength]. mask shape: [mb, seqLength]
Broadcast.mul(ret, maskArray.castTo(ret.dataType()), ret, 0, 2);
} else {
//Returned array: rank 3, shape [mb, seqLength, vector]. mask shape: [mb, seqLength]
Broadcast.mul(ret, maskArray.castTo(ret.dataType()), ret, 0, 1);
}
}
return ret;
}

View File

@ -20,11 +20,13 @@ import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.LayerHelper;
import org.deeplearning4j.nn.layers.mkldnn.BaseMKLDNNHelper;
import org.deeplearning4j.nn.layers.mkldnn.MKLDNNLSTMHelper;
import org.deeplearning4j.nn.params.LSTMParamInitializer;
import org.deeplearning4j.util.TimeSeriesUtils;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -152,6 +154,14 @@ public class LSTM extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.layers.L
assertInputSet(false);
Preconditions.checkState(input.rank() == 3,
"3D input expected to RNN layer expected, got " + input.rank());
boolean nwc = TimeSeriesUtils.getFormatFromRnnLayer(layerConf()) == RNNFormat.NWC;
INDArray origInput = input;
if(nwc){
input = permuteIfNWC(input);
}
applyDropOutIfNecessary(training, workspaceMgr);
//TODO LSTM cache mode is disabled for now - not passing all tests
@ -166,7 +176,6 @@ public class LSTM extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.layers.L
final INDArray recurrentWeights = getParamWithNoise(LSTMParamInitializer.RECURRENT_WEIGHT_KEY, training, workspaceMgr); //Shape: [hiddenLayerSize,4*hiddenLayerSize+3]; order: [wI,wF,wO,wG,wFF,wOO,wGG]
final INDArray inputWeights = getParamWithNoise(LSTMParamInitializer.INPUT_WEIGHT_KEY, training, workspaceMgr); //Shape: [n^(L-1),4*hiddenLayerSize]; order: [wi,wf,wo,wg]
final INDArray biases = getParamWithNoise(LSTMParamInitializer.BIAS_KEY, training, workspaceMgr); //by row: IFOG //Shape: [4,hiddenLayerSize]; order: [bi,bf,bo,bg]^T
INDArray input = permuteIfNWC(this.input);
FwdPassReturn fwd = LSTMHelpers.activateHelper(this, this.conf, this.layerConf().getGateActivationFn(),
input, recurrentWeights, inputWeights, biases, training, prevOutputActivations,
prevMemCellState, (training && cacheMode != CacheMode.NONE) || forBackprop, true,
@ -178,6 +187,11 @@ public class LSTM extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.layers.L
if (training && cacheMode != CacheMode.NONE) {
cachedFwdPass = fwd;
}
if(nwc){
input = origInput;
}
return fwd;
}

View File

@ -61,11 +61,8 @@ public class LastTimeStepLayer extends BaseWrapperLayer {
@Override
public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
long[] newEpsShape = origOutputShape;
boolean nwc = (underlying instanceof BaseRecurrentLayer &&
((BaseRecurrentLayer) underlying).getDataFormat() == RNNFormat.NWC)||
(underlying instanceof MaskZeroLayer && ((MaskZeroLayer)underlying).getUnderlying() instanceof
BaseRecurrentLayer && ((BaseRecurrentLayer)((MaskZeroLayer)underlying).getUnderlying()).getDataFormat()
== RNNFormat.NWC);
boolean nwc = TimeSeriesUtils.getFormatFromRnnLayer(underlying.conf().getLayer()) == RNNFormat.NWC;
INDArray newEps = Nd4j.create(epsilon.dataType(), newEpsShape, 'f');
if(lastTimeStepIdxs == null){
//no mask case

View File

@ -58,7 +58,8 @@ public class RnnOutputLayer extends BaseOutputLayer<org.deeplearning4j.nn.conf.l
"Input is not rank 3. RnnOutputLayer expects rank 3 input with shape [minibatch, layerInSize, sequenceLength]." +
" Got input with rank " + input.rank() + " and shape " + Arrays.toString(input.shape()) + " - " + layerId());
}
int td = (layerConf().getRnnDataFormat()==RNNFormat.NCW)? 2: 1;
RNNFormat format = layerConf().getRnnDataFormat();
int td = (format == RNNFormat.NCW) ? 2 : 1;
Preconditions.checkState(labels.rank() == 3, "Expected rank 3 labels array, got label array with shape %ndShape", labels);
Preconditions.checkState(input.size(td) == labels.size(td), "Sequence lengths do not match for RnnOutputLayer input and labels:" +
"Arrays should be rank 3 with shape [minibatch, size, sequenceLength] - mismatch on dimension 2 (sequence length) - input=%ndShape vs. label=%ndShape", input, labels);

View File

@ -118,6 +118,7 @@ public class ConvolutionParamInitializer implements ParamInitializer {
params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams));
conf.addVariable(WEIGHT_KEY);
conf.addVariable(BIAS_KEY);
conf.addVariable(BIAS_KEY);
} else {
INDArray weightView = paramsView;
params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams));

View File

@ -34,6 +34,7 @@ import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.stats.StatsListener;
import org.deeplearning4j.ui.storage.InMemoryStatsStorage;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
@ -53,7 +54,7 @@ import static org.junit.Assert.*;
/**
* @author Tamas Fenyvesi
*/
@Slf4j
@Slf4j @Ignore //https://github.com/eclipse/deeplearning4j/issues/8891
public class TestVertxUIMultiSession extends BaseDL4JTest {
@Before