tf.keras import test and fixes (#347)
* merge conf * merge conf * tfkeras tests * parameterized tests * rename * cuda versions * jccp versions * 'updates' * updates * rnn+mlp passing * repeat * updates * tests * Update pom.xml * Update pom.xml * rem print * cnn1d model conversion fixed * cnn1d activate fixed * cnn1d outptut shape fix * cnn1d bprop fix * cnn1d stack fix * KerasModelEndToEndTest - Remove permutes for NWC and NHWC format tests Signed-off-by: Alex Black <blacka101@gmail.com> * Fixes and update test - input shapes (NCHW -> NHWC input) Signed-off-by: Alex Black <blacka101@gmail.com> * Ignore for known bad tests Signed-off-by: Alex Black <blacka101@gmail.com> * Multiple fixes - MergeVertex, CNN1D layers, etc Signed-off-by: Alex Black <blacka101@gmail.com> * Fix issue with RNN/FF preprocessors, time distributed etc with NWC format Signed-off-by: Alex Black <blacka101@gmail.com> * LSTM NWC dropout fix Signed-off-by: Alex Black <blacka101@gmail.com> * Add sequence embedding layer NWC support (configurable output format) Signed-off-by: Alex Black <blacka101@gmail.com> * Fix expected shape in a couple of tests - NWC expected Signed-off-by: Alex Black <blacka101@gmail.com> * Fix EmbeddingSequenceLayer backprop for NWC output case + add gradient checks Signed-off-by: Alex Black <blacka101@gmail.com> * CnnToFeedForwardPreprocessor: align with Keras/TF; fix Keras reshape/flatten Signed-off-by: Alex Black <blacka101@gmail.com> * Update ConvDataFormatTests to match new reshape behaviour Signed-off-by: Alex Black <blacka101@gmail.com> * Switch hard-coded path to ResourceUtils.listClassPathfiles for TestTFKerasModelImport Signed-off-by: Alex Black <blacka101@gmail.com> * TestUtils fix Signed-off-by: Alex Black <blacka101@gmail.com> * Fixes Signed-off-by: Alex Black <blacka101@gmail.com> * Fix JSON serde issue with data formats Signed-off-by: Alex Black <blacka101@gmail.com> * Fix for input dtype inference; fix 2 tests Signed-off-by: Alex Black <blacka101@gmail.com> * Test fixes Signed-off-by: Alex Black <blacka101@gmail.com> * #8891 Ignore for TestVertxUIMultiSession until fixed Signed-off-by: Alex Black <blacka101@gmail.com> * Restore but deprecate TensorFlowCnnToFeedForwardPreProcessor for older zoo models Signed-off-by: Alex Black <blacka101@gmail.com> * Ignore for deprecated preprocessor in DTypeTests Signed-off-by: Alex Black <blacka101@gmail.com> * Remove debug printlns Signed-off-by: Alex Black <blacka101@gmail.com> Co-authored-by: Alex Black <blacka101@gmail.com>master
parent
b9d5f1645b
commit
4cb87a94e8
|
@ -20,6 +20,7 @@ package org.datavec.python;
|
|||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.IOUtils;
|
||||
import org.bytedeco.cpython.global.python;
|
||||
import org.bytedeco.numpy.global.numpy;
|
||||
import org.nd4j.linalg.api.concurrency.AffinityManager;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -343,6 +344,19 @@ public class PythonExecutioner {
|
|||
if (path == null) {
|
||||
log.info("Setting python default path");
|
||||
File[] packages = numpy.cachePackages();
|
||||
|
||||
//// TODO: fix in javacpp
|
||||
File sitePackagesWindows = new File(python.cachePackage(), "site-packages");
|
||||
File[] packages2 = new File[packages.length + 1];
|
||||
for (int i = 0;i < packages.length; i++){
|
||||
//System.out.println(packages[i].getAbsolutePath());
|
||||
packages2[i] = packages[i];
|
||||
}
|
||||
packages2[packages.length] = sitePackagesWindows;
|
||||
//System.out.println(sitePackagesWindows.getAbsolutePath());
|
||||
packages = packages2;
|
||||
//////////
|
||||
|
||||
Py_SetPath(packages);
|
||||
} else {
|
||||
log.info("Setting python path " + path);
|
||||
|
|
|
@ -0,0 +1,132 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
|
||||
package org.datavec.python;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.IOUtils;
|
||||
import org.bytedeco.javacpp.Loader;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.Arrays;
|
||||
|
||||
@Slf4j
|
||||
public class PythonProcess {
|
||||
private static String pythonExecutable = Loader.load(org.bytedeco.cpython.python.class);
|
||||
public static String runAndReturn(String... arguments)throws IOException, InterruptedException{
|
||||
String[] allArgs = new String[arguments.length + 1];
|
||||
for (int i = 0; i < arguments.length; i++){
|
||||
allArgs[i + 1] = arguments[i];
|
||||
}
|
||||
allArgs[0] = pythonExecutable;
|
||||
log.info("Executing command: " + Arrays.toString(allArgs));
|
||||
ProcessBuilder pb = new ProcessBuilder(allArgs);
|
||||
Process process = pb.start();
|
||||
String out = IOUtils.toString(process.getInputStream(), StandardCharsets.UTF_8);
|
||||
process.waitFor();
|
||||
return out;
|
||||
|
||||
}
|
||||
|
||||
public static void run(String... arguments)throws IOException, InterruptedException{
|
||||
String[] allArgs = new String[arguments.length + 1];
|
||||
for (int i = 0; i < arguments.length; i++){
|
||||
allArgs[i + 1] = arguments[i];
|
||||
}
|
||||
allArgs[0] = pythonExecutable;
|
||||
log.info("Executing command: " + Arrays.toString(allArgs));
|
||||
ProcessBuilder pb = new ProcessBuilder(allArgs);
|
||||
pb.inheritIO().start().waitFor();
|
||||
}
|
||||
public static void pipInstall(String packageName) throws PythonException{
|
||||
try{
|
||||
run("-m", "pip", "install", packageName);
|
||||
}catch(Exception e){
|
||||
throw new PythonException("Error installing package " + packageName, e);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public static void pipInstall(String packageName, String version) throws PythonException{
|
||||
pipInstall(packageName + "==" + version);
|
||||
}
|
||||
|
||||
public static void pipUninstall(String packageName) throws PythonException{
|
||||
try{
|
||||
run("-m", "pip", "uninstall", packageName);
|
||||
}catch(Exception e){
|
||||
throw new PythonException("Error uninstalling package " + packageName, e);
|
||||
}
|
||||
|
||||
}
|
||||
public static void pipInstallFromGit(String gitRepoUrl) throws PythonException{
|
||||
if (!gitRepoUrl.contains("://")){
|
||||
gitRepoUrl = "git://" + gitRepoUrl;
|
||||
}
|
||||
try{
|
||||
run("-m", "pip", "install", "git+", gitRepoUrl);
|
||||
}catch(Exception e){
|
||||
throw new PythonException("Error installing package from " + gitRepoUrl, e);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public static String getPackageVersion(String packageName) throws PythonException{
|
||||
String out;
|
||||
try{
|
||||
out = runAndReturn("-m", "pip", "show", packageName);
|
||||
} catch (Exception e){
|
||||
throw new PythonException("Error finding version for package " + packageName, e);
|
||||
}
|
||||
|
||||
if (!out.contains("Version: ")){
|
||||
throw new PythonException("Can't find package " + packageName);
|
||||
}
|
||||
String pkgVersion = out.split("Version: ")[1].split(System.lineSeparator())[0];
|
||||
return pkgVersion;
|
||||
}
|
||||
|
||||
public static boolean isPackageInstalled(String packageName)throws PythonException{
|
||||
try{
|
||||
String out = runAndReturn("-m", "pip", "show", packageName);
|
||||
return !out.isEmpty();
|
||||
}catch (Exception e){
|
||||
throw new PythonException("Error checking if package is installed: " +packageName, e);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public static void pipInstallFromRequirementsTxt(String path) throws PythonException{
|
||||
try{
|
||||
run("-m", "pip", "install","-r", path);
|
||||
}catch (Exception e){
|
||||
throw new PythonException("Error installing packages from " + path, e);
|
||||
}
|
||||
}
|
||||
|
||||
public static void pipInstallFromSetupScript(String path, boolean inplace) throws PythonException{
|
||||
|
||||
try{
|
||||
run(path, inplace?"develop":"install");
|
||||
}catch (Exception e){
|
||||
throw new PythonException("Error installing package from " + path, e);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,144 @@
|
|||
package org.datavec.python.keras;
|
||||
|
||||
import org.datavec.python.Python;
|
||||
import org.datavec.python.PythonException;
|
||||
import org.datavec.python.PythonObject;
|
||||
import org.datavec.python.PythonProcess;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
public class Model {
|
||||
|
||||
private PythonObject pyModel;
|
||||
|
||||
|
||||
private static PythonObject installAndImportTF() throws PythonException{
|
||||
if (!PythonProcess.isPackageInstalled("tensorflow")){
|
||||
PythonProcess.pipInstall("tensorflow");
|
||||
}
|
||||
return Python.importModule("tensorflow");
|
||||
}
|
||||
private static PythonObject getKerasModule() throws PythonException{
|
||||
PythonObject tf = installAndImportTF();
|
||||
PythonObject keras = tf.attr("keras");
|
||||
tf.del();
|
||||
return keras;
|
||||
}
|
||||
|
||||
private static PythonObject loadModel(String s) throws PythonException{
|
||||
PythonObject models = getKerasModule().attr("models");
|
||||
PythonObject loadModelF = models.attr("load_model");
|
||||
PythonObject model = loadModelF.call(s);
|
||||
models.del();
|
||||
loadModelF.del();
|
||||
return model;
|
||||
}
|
||||
|
||||
public Model(String path) throws PythonException{
|
||||
pyModel = loadModel(path);
|
||||
}
|
||||
|
||||
public INDArray[] predict(INDArray... inputs) throws PythonException{
|
||||
PythonObject predictF = pyModel.attr("predict");
|
||||
PythonObject inputList = new PythonObject(inputs);
|
||||
PythonObject pyOut = predictF.call(inputList);
|
||||
INDArray[] out;
|
||||
if (Python.isinstance(pyOut, Python.listType())){
|
||||
out = new INDArray[Python.len(pyOut).toInt()];
|
||||
for(int i = 0; i < out.length; i++){
|
||||
out[i] = pyOut.get(i).toNumpy().getNd4jArray();
|
||||
}
|
||||
}
|
||||
else{
|
||||
out = new INDArray[]{
|
||||
pyOut.toNumpy().getNd4jArray()};
|
||||
}
|
||||
|
||||
predictF.del();
|
||||
inputList.del();
|
||||
pyOut.del();
|
||||
return out;
|
||||
}
|
||||
|
||||
public int numInputs(){
|
||||
PythonObject inputs = pyModel.attr("inputs");
|
||||
PythonObject pyNumInputs = Python.len(inputs);
|
||||
int ret = pyNumInputs.toInt();
|
||||
inputs.del();
|
||||
pyNumInputs.del();
|
||||
return ret;
|
||||
}
|
||||
public int numOutputs(){
|
||||
PythonObject outputs = pyModel.attr("outputs");
|
||||
PythonObject pyNumOutputs = Python.len(outputs);
|
||||
int ret = pyNumOutputs.toInt();
|
||||
outputs.del();
|
||||
pyNumOutputs.del();
|
||||
return ret;
|
||||
}
|
||||
|
||||
public long[][] inputShapes(){
|
||||
long[][] ret = new long[numInputs()][];
|
||||
for (int i = 0; i < ret.length; i++){
|
||||
ret[i] = inputShapeAt(i);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
public long[][] outputShapes(){
|
||||
long[][] ret = new long[numOutputs()][];
|
||||
for (int i = 0; i < ret.length; i++){
|
||||
ret[i] = outputShapeAt(i);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
public long[] inputShapeAt(int input){
|
||||
PythonObject inputs = pyModel.attr("inputs");
|
||||
PythonObject tensor = inputs.get(input);
|
||||
PythonObject tensorShape = tensor.attr("shape");
|
||||
PythonObject shapeList = Python.list(tensorShape);
|
||||
PythonObject pyNdim = Python.len(shapeList);
|
||||
int ndim = pyNdim.toInt();
|
||||
long[] shape = new long[ndim];
|
||||
for(int i = 0; i < shape.length; i++){
|
||||
PythonObject pyDim = shapeList.get(i);
|
||||
if (pyDim == null || !Python.isinstance(pyDim, Python.intType())){
|
||||
shape[i] = -1;
|
||||
}
|
||||
else{
|
||||
shape[i] = pyDim.toLong();
|
||||
}
|
||||
}
|
||||
pyNdim.del();
|
||||
shapeList.del();
|
||||
tensorShape.del();
|
||||
tensor.del();
|
||||
inputs.del();
|
||||
return shape;
|
||||
}
|
||||
|
||||
public long[] outputShapeAt(int output){
|
||||
PythonObject inputs = pyModel.attr("outputs");
|
||||
PythonObject tensor = inputs.get(output);
|
||||
PythonObject tensorShape = tensor.attr("shape");
|
||||
PythonObject shapeList = Python.list(tensorShape);
|
||||
PythonObject pyNdim = Python.len(shapeList);
|
||||
int ndim = pyNdim.toInt();
|
||||
long[] shape = new long[ndim];
|
||||
for(int i = 0; i < shape.length; i++){
|
||||
PythonObject pyDim = shapeList.get(i);
|
||||
if (pyDim == null || !Python.isinstance(pyDim, Python.intType())){
|
||||
shape[i] = -1;
|
||||
}
|
||||
else{
|
||||
shape[i] = pyDim.toLong();
|
||||
}
|
||||
}
|
||||
pyNdim.del();
|
||||
shapeList.del();
|
||||
tensorShape.del();
|
||||
tensor.del();
|
||||
inputs.del();
|
||||
return shape;
|
||||
}
|
||||
}
|
|
@ -20,6 +20,7 @@ import org.apache.commons.compress.utils.IOUtils;
|
|||
import org.deeplearning4j.nn.api.Layer;
|
||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.layers.BaseLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer;
|
||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||
|
@ -153,11 +154,22 @@ public class TestUtils {
|
|||
return randomOneHotTimeSeries(minibatch, outSize, tsLength, new Random(rngSeed));
|
||||
}
|
||||
|
||||
public static INDArray randomOneHotTimeSeries(int minibatch, int outSize, int tsLength, Random rng){
|
||||
INDArray out = Nd4j.create(new int[]{minibatch, outSize, tsLength}, 'f');
|
||||
public static INDArray randomOneHotTimeSeries(int minibatch, int outSize, int tsLength, Random rng) {
|
||||
return randomOneHotTimeSeries(RNNFormat.NCW, minibatch, outSize, tsLength, rng);
|
||||
}
|
||||
|
||||
public static INDArray randomOneHotTimeSeries(RNNFormat format, int minibatch, int outSize, int tsLength, Random rng){
|
||||
boolean ncw = format == RNNFormat.NCW;
|
||||
long[] shape = ncw ? new long[]{minibatch, outSize, tsLength} : new long[]{minibatch, tsLength, outSize};
|
||||
char order = ncw ? 'f' : 'c';
|
||||
INDArray out = Nd4j.create(DataType.FLOAT, shape, order);
|
||||
for( int i=0; i<minibatch; i++ ){
|
||||
for( int j=0; j<tsLength; j++ ){
|
||||
out.putScalar(i, rng.nextInt(outSize), j, 1.0);
|
||||
if(ncw){
|
||||
out.putScalar(i, rng.nextInt(outSize), j, 1.0);
|
||||
} else {
|
||||
out.putScalar(i, j, rng.nextInt(outSize), 1.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
return out;
|
||||
|
|
|
@ -24,6 +24,7 @@ import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
|||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
||||
import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
|
||||
import org.deeplearning4j.nn.conf.layers.*;
|
||||
|
@ -560,75 +561,81 @@ public class GradientCheckTests extends BaseDL4JTest {
|
|||
public void testEmbeddingSequenceLayer(){
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
for(boolean maskArray : new boolean[]{false, true}){
|
||||
for(int inputRank : new int[]{2,3}) {
|
||||
for(RNNFormat seqOutputFormat : RNNFormat.values()) {
|
||||
for (boolean maskArray : new boolean[]{false, true}) {
|
||||
for (int inputRank : new int[]{2, 3}) {
|
||||
|
||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
.dataType(DataType.DOUBLE)
|
||||
.seed(12345)
|
||||
.updater(new NoOp())
|
||||
.weightInit(new NormalDistribution(0, 1))
|
||||
.list()
|
||||
.layer(new EmbeddingSequenceLayer.Builder()
|
||||
.nIn(8)
|
||||
.nOut(4)
|
||||
.build())
|
||||
.layer(new RnnOutputLayer.Builder().nIn(4).nOut(3).activation(Activation.TANH)
|
||||
.lossFunction(LossFunction.MSE).build())
|
||||
.build();
|
||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
.dataType(DataType.DOUBLE)
|
||||
.seed(12345)
|
||||
.updater(new NoOp())
|
||||
.weightInit(new NormalDistribution(0, 1))
|
||||
.list()
|
||||
.layer(new EmbeddingSequenceLayer.Builder()
|
||||
.nIn(8)
|
||||
.nOut(4)
|
||||
.outputDataFormat(seqOutputFormat)
|
||||
.build())
|
||||
.layer(new RnnOutputLayer.Builder().nIn(4).nOut(3).activation(Activation.TANH)
|
||||
.dataFormat(seqOutputFormat)
|
||||
.lossFunction(LossFunction.MSE).build())
|
||||
.build();
|
||||
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.init();
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.init();
|
||||
|
||||
INDArray in = Transforms.floor(Nd4j.rand(3, 6).muli(8)); //Integers 0 to 7 inclusive
|
||||
INDArray label = Nd4j.rand(new int[]{3, 3, 6});
|
||||
boolean ncw = seqOutputFormat == RNNFormat.NCW;
|
||||
|
||||
if(inputRank == 3){
|
||||
//Reshape from [3,6] to [3,1,6]
|
||||
in = in.reshape('c', 3, 1, 6);
|
||||
}
|
||||
INDArray in = Transforms.floor(Nd4j.rand(3, 6).muli(8)); //Integers 0 to 7 inclusive
|
||||
INDArray label = Nd4j.rand(DataType.FLOAT, ncw ? new int[]{3, 3, 6} : new int[]{3,6,3});
|
||||
|
||||
INDArray fMask = null;
|
||||
if (maskArray) {
|
||||
fMask = Nd4j.create(new double[][]{{1, 1, 1, 1, 1, 1},
|
||||
{1, 1, 0, 0, 0, 0},
|
||||
{1, 0, 0, 0, 0, 0}});
|
||||
|
||||
}
|
||||
|
||||
String msg = "mask=" + maskArray + ", inputRank=" + inputRank;
|
||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in)
|
||||
.labels(label).inputMask(fMask));
|
||||
assertTrue(msg, gradOK);
|
||||
TestUtils.testModelSerialization(net);
|
||||
|
||||
|
||||
//Also: if mask is present, double check that the masked steps don't impact score
|
||||
if (maskArray) {
|
||||
DataSet ds = new DataSet(in, label, fMask, null);
|
||||
double score = net.score(ds);
|
||||
if(inputRank == 2){
|
||||
in.putScalar(1, 2, 0);
|
||||
in.putScalar(2, 1, 0);
|
||||
in.putScalar(2, 2, 0);
|
||||
} else {
|
||||
in.putScalar(1, 0, 2, 0);
|
||||
in.putScalar(2, 0, 1, 0);
|
||||
in.putScalar(2, 0, 2, 0);
|
||||
if (inputRank == 3) {
|
||||
//Reshape from [3,6] to [3,1,6]
|
||||
in = in.reshape('c', 3, 1, 6);
|
||||
}
|
||||
double score2 = net.score(ds);
|
||||
assertEquals(score, score2, 1e-6);
|
||||
if(inputRank == 2){
|
||||
in.putScalar(1, 2, 1);
|
||||
in.putScalar(2, 1, 1);
|
||||
in.putScalar(2, 2, 1);
|
||||
} else {
|
||||
in.putScalar(1, 0, 2, 1);
|
||||
in.putScalar(2, 0, 1, 1);
|
||||
in.putScalar(2, 0, 2, 1);
|
||||
|
||||
INDArray fMask = null;
|
||||
if (maskArray) {
|
||||
fMask = Nd4j.create(new double[][]{{1, 1, 1, 1, 1, 1},
|
||||
{1, 1, 0, 0, 0, 0},
|
||||
{1, 0, 0, 0, 0, 0}});
|
||||
|
||||
}
|
||||
|
||||
String msg = "mask=" + maskArray + ", inputRank=" + inputRank;
|
||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(in)
|
||||
.labels(label).inputMask(fMask));
|
||||
assertTrue(msg, gradOK);
|
||||
TestUtils.testModelSerialization(net);
|
||||
|
||||
|
||||
//Also: if mask is present, double check that the masked steps don't impact score
|
||||
if (maskArray) {
|
||||
DataSet ds = new DataSet(in, label, fMask, null);
|
||||
double score = net.score(ds);
|
||||
if (inputRank == 2) {
|
||||
in.putScalar(1, 2, 0);
|
||||
in.putScalar(2, 1, 0);
|
||||
in.putScalar(2, 2, 0);
|
||||
} else {
|
||||
in.putScalar(1, 0, 2, 0);
|
||||
in.putScalar(2, 0, 1, 0);
|
||||
in.putScalar(2, 0, 2, 0);
|
||||
}
|
||||
double score2 = net.score(ds);
|
||||
assertEquals(score, score2, 1e-6);
|
||||
if (inputRank == 2) {
|
||||
in.putScalar(1, 2, 1);
|
||||
in.putScalar(2, 1, 1);
|
||||
in.putScalar(2, 2, 1);
|
||||
} else {
|
||||
in.putScalar(1, 0, 2, 1);
|
||||
in.putScalar(2, 0, 1, 1);
|
||||
in.putScalar(2, 0, 2, 1);
|
||||
}
|
||||
double score3 = net.score(ds);
|
||||
assertEquals(score, score3, 1e-6);
|
||||
}
|
||||
double score3 = net.score(ds);
|
||||
assertEquals(score, score3, 1e-6);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,9 +21,7 @@ import org.deeplearning4j.BaseDL4JTest;
|
|||
import org.deeplearning4j.TestUtils;
|
||||
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
|
||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.*;
|
||||
import org.deeplearning4j.nn.conf.distribution.GaussianDistribution;
|
||||
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
||||
import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
|
||||
|
@ -341,104 +339,112 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
|
|||
@Test
|
||||
public void testCnnDepthMerge() {
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
|
||||
.dataType(DataType.DOUBLE)
|
||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||
.dist(new NormalDistribution(0, 0.1))
|
||||
.updater(new NoOp()).graphBuilder().addInputs("input")
|
||||
.addLayer("l1", new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).padding(0, 0)
|
||||
.nIn(2).nOut(2).activation(Activation.TANH).build(), "input")
|
||||
.addLayer("l2", new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).padding(0, 0)
|
||||
.nIn(2).nOut(2).activation(Activation.TANH).build(), "input")
|
||||
.addVertex("merge", new MergeVertex(), "l1", "l2")
|
||||
.addLayer("outputLayer",
|
||||
new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
|
||||
.activation(Activation.SOFTMAX).nIn(5 * 5 * (2 + 2)).nOut(3)
|
||||
.build(),
|
||||
"merge")
|
||||
.setOutputs("outputLayer")
|
||||
.inputPreProcessor("outputLayer", new CnnToFeedForwardPreProcessor(5, 5, 4))
|
||||
.build();
|
||||
for(CNN2DFormat format : CNN2DFormat.values()) {
|
||||
|
||||
ComputationGraph graph = new ComputationGraph(conf);
|
||||
graph.init();
|
||||
String msg = "testCnnDepthMerge - " + format;
|
||||
|
||||
Random r = new Random(12345);
|
||||
INDArray input = Nd4j.rand(new int[] {5, 2, 6, 6}); //Order: examples, channels, height, width
|
||||
INDArray labels = Nd4j.zeros(5, 3);
|
||||
for (int i = 0; i < 5; i++)
|
||||
labels.putScalar(new int[] {i, r.nextInt(3)}, 1.0);
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
|
||||
.dataType(DataType.DOUBLE)
|
||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||
.dist(new NormalDistribution(0, 0.1))
|
||||
.updater(new NoOp()).graphBuilder().addInputs("input")
|
||||
.addLayer("l1", new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).padding(0, 0)
|
||||
.nIn(2).nOut(2).activation(Activation.TANH).build(), "input")
|
||||
.addLayer("l2", new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).padding(0, 0)
|
||||
.nIn(2).nOut(2).activation(Activation.TANH).build(), "input")
|
||||
.addVertex("merge", new MergeVertex(), "l1", "l2")
|
||||
.addLayer("outputLayer",
|
||||
new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
|
||||
.activation(Activation.SOFTMAX).nIn(5 * 5 * (2 + 2)).nOut(3)
|
||||
.build(),
|
||||
"merge")
|
||||
.setOutputs("outputLayer")
|
||||
.setInputTypes(InputType.convolutional(6, 6, 2, format))
|
||||
.build();
|
||||
|
||||
if (PRINT_RESULTS) {
|
||||
System.out.println("testCnnDepthMerge()");
|
||||
ComputationGraph graph = new ComputationGraph(conf);
|
||||
graph.init();
|
||||
|
||||
Random r = new Random(12345);
|
||||
INDArray input = Nd4j.rand(DataType.DOUBLE, format == CNN2DFormat.NCHW ? new long[]{5,2,6,6} : new long[]{5,6,6,2});
|
||||
INDArray labels = Nd4j.zeros(5, 3);
|
||||
for (int i = 0; i < 5; i++)
|
||||
labels.putScalar(new int[]{i, r.nextInt(3)}, 1.0);
|
||||
|
||||
if (PRINT_RESULTS) {
|
||||
System.out.println(msg);
|
||||
// for (int j = 0; j < graph.getNumLayers(); j++)
|
||||
// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
|
||||
}
|
||||
|
||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
|
||||
.labels(new INDArray[]{labels}));
|
||||
|
||||
assertTrue(msg, gradOK);
|
||||
TestUtils.testModelSerialization(graph);
|
||||
}
|
||||
|
||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
|
||||
.labels(new INDArray[]{labels}));
|
||||
|
||||
String msg = "testCnnDepthMerge()";
|
||||
assertTrue(msg, gradOK);
|
||||
TestUtils.testModelSerialization(graph);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRNNWithMerging() {
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
ComputationGraphConfiguration conf =
|
||||
new NeuralNetConfiguration.Builder().seed(12345)
|
||||
.dataType(DataType.DOUBLE)
|
||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||
.dist(new UniformDistribution(0.2, 0.6))
|
||||
.updater(new NoOp()).graphBuilder().addInputs("input")
|
||||
.setOutputs("out")
|
||||
.addLayer("lstm1",
|
||||
new SimpleRnn.Builder().nIn(3).nOut(3)
|
||||
.activation(Activation.TANH).build(),
|
||||
"input")
|
||||
.addLayer("lstm2",
|
||||
new SimpleRnn.Builder().nIn(3).nOut(3)
|
||||
.activation(Activation.TANH).build(),
|
||||
"lstm1")
|
||||
.addLayer("dense1",
|
||||
new DenseLayer.Builder().nIn(3).nOut(3)
|
||||
.activation(Activation.SIGMOID).build(),
|
||||
"lstm1")
|
||||
.addLayer("lstm3",
|
||||
new SimpleRnn.Builder().nIn(3).nOut(3)
|
||||
.activation(Activation.TANH).build(),
|
||||
"dense1")
|
||||
.addVertex("merge", new MergeVertex(), "lstm2", "lstm3")
|
||||
.addLayer("out", new RnnOutputLayer.Builder().nIn(6).nOut(3)
|
||||
.activation(Activation.SOFTMAX)
|
||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build(),
|
||||
"merge")
|
||||
.inputPreProcessor("dense1", new RnnToFeedForwardPreProcessor())
|
||||
.inputPreProcessor("lstm3", new FeedForwardToRnnPreProcessor())
|
||||
.build();
|
||||
for(RNNFormat format : RNNFormat.values()) {
|
||||
|
||||
ComputationGraph graph = new ComputationGraph(conf);
|
||||
graph.init();
|
||||
String msg = "testLSTMWithMerging - " + format;
|
||||
|
||||
Random r = new Random(12345);
|
||||
INDArray input = Nd4j.rand(new int[] {2, 3, 4});
|
||||
INDArray labels = TestUtils.randomOneHotTimeSeries(2, 3, 4);
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
ComputationGraphConfiguration conf =
|
||||
new NeuralNetConfiguration.Builder().seed(12345)
|
||||
.dataType(DataType.DOUBLE)
|
||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||
.dist(new UniformDistribution(0.2, 0.6))
|
||||
.updater(new NoOp()).graphBuilder().addInputs("input")
|
||||
.setOutputs("out")
|
||||
.addLayer("lstm1",
|
||||
new SimpleRnn.Builder().nIn(3).nOut(3)
|
||||
.activation(Activation.TANH).build(),
|
||||
"input")
|
||||
.addLayer("lstm2",
|
||||
new SimpleRnn.Builder().nIn(3).nOut(3)
|
||||
.activation(Activation.TANH).build(),
|
||||
"lstm1")
|
||||
.addLayer("dense1",
|
||||
new DenseLayer.Builder().nIn(3).nOut(3)
|
||||
.activation(Activation.SIGMOID).build(),
|
||||
"lstm1")
|
||||
.addLayer("lstm3",
|
||||
new SimpleRnn.Builder().nIn(3).nOut(3)
|
||||
.activation(Activation.TANH).build(),
|
||||
"dense1")
|
||||
.addVertex("merge", new MergeVertex(), "lstm2", "lstm3")
|
||||
.addLayer("out", new RnnOutputLayer.Builder().nIn(6).nOut(3)
|
||||
.activation(Activation.SOFTMAX)
|
||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build(),
|
||||
"merge")
|
||||
.setInputTypes(InputType.recurrent(4, format))
|
||||
.build();
|
||||
|
||||
if (PRINT_RESULTS) {
|
||||
System.out.println("testLSTMWithMerging()");
|
||||
ComputationGraph graph = new ComputationGraph(conf);
|
||||
graph.init();
|
||||
|
||||
Random r = new Random(12345);
|
||||
INDArray input = Nd4j.rand(DataType.DOUBLE, format == RNNFormat.NCW ? new long[]{2, 3, 4} : new long[]{2,4,3});
|
||||
INDArray labels = TestUtils.randomOneHotTimeSeries(format, 2, 3, 4, new Random(12345));
|
||||
|
||||
if (PRINT_RESULTS) {
|
||||
System.out.println(msg);
|
||||
// for (int j = 0; j < graph.getNumLayers(); j++)
|
||||
// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams());
|
||||
}
|
||||
|
||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
|
||||
.labels(new INDArray[]{labels}));
|
||||
|
||||
assertTrue(msg, gradOK);
|
||||
TestUtils.testModelSerialization(graph);
|
||||
|
||||
}
|
||||
|
||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input})
|
||||
.labels(new INDArray[]{labels}));
|
||||
|
||||
String msg = "testLSTMWithMerging()";
|
||||
assertTrue(msg, gradOK);
|
||||
TestUtils.testModelSerialization(graph);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
|
@ -17,7 +17,9 @@
|
|||
package org.deeplearning4j.nn.dtypes;
|
||||
|
||||
import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
|
||||
import org.deeplearning4j.nn.conf.preprocessor.*;
|
||||
import org.deeplearning4j.nn.modelimport.keras.layers.TFOpLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor;
|
||||
import org.nd4j.shade.guava.collect.ImmutableSet;
|
||||
import org.nd4j.shade.guava.reflect.ClassPath;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
@ -51,16 +53,11 @@ import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
|
|||
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
|
||||
import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer;
|
||||
import org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer;
|
||||
import org.deeplearning4j.nn.conf.preprocessor.CnnToRnnPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.preprocessor.ComposableInputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToCnn3DPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.preprocessor.RnnToCnnPreProcessor;
|
||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||
import org.deeplearning4j.nn.layers.util.IdentityLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor;
|
||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor;
|
||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
|
||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.deeplearning4j.nn.weights.WeightInitDistribution;
|
||||
|
@ -97,7 +94,8 @@ public class DTypeTests extends BaseDL4JTest {
|
|||
Pooling2D.class, //Alias for SubsamplingLayer
|
||||
Convolution2D.class, //Alias for ConvolutionLayer
|
||||
Pooling1D.class, //Alias for Subsampling1D
|
||||
Convolution1D.class //Alias for Convolution1DLayer
|
||||
Convolution1D.class, //Alias for Convolution1DLayer
|
||||
TensorFlowCnnToFeedForwardPreProcessor.class //Deprecated
|
||||
));
|
||||
|
||||
@Override
|
||||
|
@ -1078,7 +1076,7 @@ public class DTypeTests extends BaseDL4JTest {
|
|||
.addLayer("l", new DenseLayer.Builder().nOut(16).build(), "in")
|
||||
.addVertex("preproc", new PreprocessorVertex(new FeedForwardToCnn3DPreProcessor(2, 2, 2, 2, true)), "l")
|
||||
.addVertex("preproc2", new PreprocessorVertex(new PermutePreprocessor(0, 2, 3, 4, 1)), "preproc")
|
||||
.addVertex("preproc3", new PreprocessorVertex(new ReshapePreprocessor(new long[]{2, 2, 2, 2}, new long[]{16})), "preproc2")
|
||||
.addVertex("preproc3", new PreprocessorVertex(new ReshapePreprocessor(new long[]{2, 2, 2, 2}, new long[]{16}, false)), "preproc2")
|
||||
.addLayer("out", new OutputLayer.Builder().nIn(16).nOut(10).build(), "preproc3")
|
||||
.setInputTypes(InputType.feedForward(5))
|
||||
.setOutputs("out");
|
||||
|
@ -1150,7 +1148,7 @@ public class DTypeTests extends BaseDL4JTest {
|
|||
case 7:
|
||||
b.addInputs("in")
|
||||
.addLayer("1", new ConvolutionLayer.Builder().kernelSize(2, 2).nOut(5).convolutionMode(ConvolutionMode.Same).build(), "in")
|
||||
.addVertex("2", new PreprocessorVertex(new TensorFlowCnnToFeedForwardPreProcessor(28, 28, 5)), "1")
|
||||
.addVertex("2", new PreprocessorVertex(new CnnToFeedForwardPreProcessor(28, 28, 5)), "1")
|
||||
.addLayer("out", new OutputLayer.Builder().nOut(10).build(), "2")
|
||||
.setOutputs("out")
|
||||
.setInputTypes(InputType.convolutional(28, 28, 1));
|
||||
|
|
|
@ -60,7 +60,7 @@ public class TestGraphNodes extends BaseDL4JTest {
|
|||
@Test
|
||||
public void testMergeNode() {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType());
|
||||
GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType(), 1);
|
||||
|
||||
INDArray first = Nd4j.linspace(0, 11, 12, Nd4j.dataType()).reshape(3, 4);
|
||||
INDArray second = Nd4j.linspace(0, 17, 18, Nd4j.dataType()).reshape(3, 6).addi(100);
|
||||
|
@ -82,7 +82,7 @@ public class TestGraphNodes extends BaseDL4JTest {
|
|||
public void testMergeNodeRNN() {
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType());
|
||||
GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType(), 1);
|
||||
|
||||
INDArray first = Nd4j.linspace(0, 59, 60, Nd4j.dataType()).reshape(3, 4, 5);
|
||||
INDArray second = Nd4j.linspace(0, 89, 90, Nd4j.dataType()).reshape(3, 6, 5).addi(100);
|
||||
|
@ -103,7 +103,7 @@ public class TestGraphNodes extends BaseDL4JTest {
|
|||
@Test
|
||||
public void testCnnDepthMerge() {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType());
|
||||
GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType(), 1);
|
||||
|
||||
INDArray first = Nd4j.linspace(0, 3, 4, Nd4j.dataType()).reshape(1, 1, 2, 2);
|
||||
INDArray second = Nd4j.linspace(0, 3, 4, Nd4j.dataType()).reshape(1, 1, 2, 2).addi(10);
|
||||
|
|
|
@ -15,14 +15,13 @@
|
|||
******************************************************************************/
|
||||
package org.deeplearning4j.nn.layers.convolution;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.*;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.TestUtils;
|
||||
import org.deeplearning4j.nn.api.MaskState;
|
||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.*;
|
||||
|
@ -30,8 +29,12 @@ import org.deeplearning4j.nn.conf.layers.CnnLossLayer;
|
|||
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.ZeroPaddingLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D;
|
||||
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.preprocessor.ComposableInputPreProcessor;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
|
@ -516,6 +519,40 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testGlobalPooling() {
|
||||
try {
|
||||
for (boolean helpers : new boolean[]{false, true}) {
|
||||
for (PoolingType pt : PoolingType.values()) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
Nd4j.getEnvironment().allowHelpers(helpers);
|
||||
String msg = helpers ? "With helpers (" + pt + ")" : "No helpers (" + pt + ")";
|
||||
System.out.println(" --- " + msg + " ---");
|
||||
|
||||
INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12);
|
||||
INDArray labels = TestUtils.randomOneHot(2, 10);
|
||||
|
||||
TestCase tc = TestCase.builder()
|
||||
.msg(msg)
|
||||
.net1(getGlobalPoolingNet(CNN2DFormat.NCHW, pt, true))
|
||||
.net2(getGlobalPoolingNet(CNN2DFormat.NCHW, pt, false))
|
||||
.net3(getGlobalPoolingNet(CNN2DFormat.NHWC, pt, true))
|
||||
.net4(getGlobalPoolingNet(CNN2DFormat.NHWC, pt, false))
|
||||
.inNCHW(inNCHW)
|
||||
.labelsNCHW(labels)
|
||||
.labelsNHWC(labels)
|
||||
.testLayerIdx(1)
|
||||
.build();
|
||||
|
||||
testHelper(tc);
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
Nd4j.getEnvironment().allowHelpers(true);
|
||||
}
|
||||
}
|
||||
|
||||
private MultiLayerNetwork getConv2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) {
|
||||
if (setOnLayerAlso) {
|
||||
return getNetWithLayer(new ConvolutionLayer.Builder()
|
||||
|
@ -735,11 +772,28 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
.layer(new OutputLayer.Builder().activation(Activation.SOFTMAX).nOut(10).build())
|
||||
.setInputType(inputType != null ? inputType : InputType.convolutional(12, 12, 3, format));
|
||||
|
||||
if(format == CNN2DFormat.NHWC && !(layer instanceof GlobalPoolingLayer)){
|
||||
//Add a preprocessor due to the differences in how NHWC and NCHW activations are flattened
|
||||
//DL4J's flattening behaviour matches Keras (hence TF) for import compatibility
|
||||
builder.inputPreProcessor(2, new ComposableInputPreProcessor(new NHWCToNCHWPreprocessor(), new CnnToFeedForwardPreProcessor()));
|
||||
}
|
||||
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(builder.build());
|
||||
net.init();
|
||||
return net;
|
||||
}
|
||||
|
||||
private MultiLayerNetwork getGlobalPoolingNet(CNN2DFormat format, PoolingType pt, boolean setOnLayerAlso) {
|
||||
if (setOnLayerAlso) {
|
||||
return getNetWithLayer(new GlobalPoolingLayer.Builder(pt)
|
||||
.poolingDimensions(format == CNN2DFormat.NCHW ? new int[]{2,3} : new int[]{1,2})
|
||||
.build(), format, ConvolutionMode.Same, null);
|
||||
} else {
|
||||
return getNetWithLayer(new GlobalPoolingLayer.Builder(pt)
|
||||
.build(), format, ConvolutionMode.Same, null);
|
||||
}
|
||||
}
|
||||
|
||||
private MultiLayerNetwork getCnnLossNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm){
|
||||
NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder()
|
||||
.seed(12345)
|
||||
|
@ -799,8 +853,13 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
INDArray l0_4 = tc.net4.feedForward(inNHWC).get(tc.testLayerIdx + 1);
|
||||
|
||||
assertEquals(tc.msg, l0_1, l0_2);
|
||||
assertEquals(tc.msg, l0_1, l0_3.permute(0, 3, 1, 2));
|
||||
assertEquals(tc.msg, l0_1, l0_4.permute(0, 3, 1, 2));
|
||||
if(l0_1.rank() == 4) {
|
||||
assertEquals(tc.msg, l0_1, l0_3.permute(0, 3, 1, 2));
|
||||
assertEquals(tc.msg, l0_1, l0_4.permute(0, 3, 1, 2));
|
||||
} else {
|
||||
assertEquals(tc.msg, l0_1, l0_3);
|
||||
assertEquals(tc.msg, l0_1, l0_4);
|
||||
}
|
||||
|
||||
|
||||
INDArray out1 = tc.net1.output(inNCHW);
|
||||
|
@ -880,4 +939,36 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
}
|
||||
return differs;
|
||||
}
|
||||
|
||||
|
||||
//Converts NHWC to NCHW activations
|
||||
@EqualsAndHashCode
|
||||
private static class NHWCToNCHWPreprocessor implements InputPreProcessor {
|
||||
|
||||
@Override
|
||||
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
|
||||
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.permute(0,3,1,2));
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray backprop(INDArray output, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
|
||||
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, output.permute(0,2,3,1));
|
||||
}
|
||||
|
||||
@Override
|
||||
public InputPreProcessor clone() {
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public InputType getOutputType(InputType inputType) {
|
||||
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
|
||||
return InputType.convolutional(c.getHeight(), c.getWidth(), c.getChannels(), CNN2DFormat.NCHW);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -44,6 +44,7 @@ import org.nd4j.linalg.primitives.Pair;
|
|||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertNotEquals;
|
||||
|
@ -217,7 +218,7 @@ public class TestRnnLayers extends BaseDL4JTest {
|
|||
NeuralNetConfiguration.ListBuilder lb = new NeuralNetConfiguration.Builder()
|
||||
|
||||
.list()
|
||||
.layer(new SimpleRnn.Builder().nIn(5).nOut(5).build());
|
||||
.layer(new SimpleRnn.Builder().nIn(5).nOut(5).dataFormat(rnnDataFormat).build());
|
||||
|
||||
switch (i){
|
||||
case 0:
|
||||
|
@ -235,10 +236,7 @@ public class TestRnnLayers extends BaseDL4JTest {
|
|||
net.init();
|
||||
|
||||
INDArray in = Nd4j.rand(DataType.FLOAT, 3, 5, 5);
|
||||
INDArray l = TestUtils.randomOneHotTimeSeries(3, 5, 10);
|
||||
if (rnnDataFormat == RNNFormat.NWC){
|
||||
l = l.permute(0, 2, 1);
|
||||
}
|
||||
INDArray l = TestUtils.randomOneHotTimeSeries(rnnDataFormat, 3, 5, 10, new Random(12345));
|
||||
try{
|
||||
net.fit(in,l);
|
||||
} catch (Throwable t){
|
||||
|
|
|
@ -61,15 +61,13 @@ public class TestSimpleRnn extends BaseDL4JTest {
|
|||
int tsLength = 7;
|
||||
INDArray in;
|
||||
if (rnnDataFormat == RNNFormat.NCW){
|
||||
in = Nd4j.rand(DataType.FLOAT, new int[]{m, nIn, tsLength});
|
||||
in = Nd4j.rand(DataType.FLOAT, m, nIn, tsLength);
|
||||
}
|
||||
else{
|
||||
in = Nd4j.rand(DataType.FLOAT, new int[]{m, tsLength, nIn});
|
||||
in = Nd4j.rand(DataType.FLOAT, m, tsLength, nIn);
|
||||
}
|
||||
|
||||
|
||||
// in.get(all(), all(), interval(1,tsLength)).assign(0);
|
||||
|
||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
.updater(new NoOp())
|
||||
.weightInit(WeightInit.XAVIER)
|
||||
|
|
|
@ -7,10 +7,14 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
|||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.WorkspaceMode;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.*;
|
||||
import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.LSTM;
|
||||
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
|
||||
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
|
||||
import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed;
|
||||
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
|
@ -106,4 +110,73 @@ public class TestTimeDistributed extends BaseDL4JTest {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testTimeDistributedDense(){
|
||||
|
||||
for( int rnnType=0; rnnType<3; rnnType++ ) {
|
||||
for( int ffType=0; ffType<3; ffType++ ) {
|
||||
|
||||
Layer l0, l2;
|
||||
switch (rnnType) {
|
||||
case 0:
|
||||
l0 = new LSTM.Builder().nOut(5).build();
|
||||
l2 = new LSTM.Builder().nOut(5).build();
|
||||
break;
|
||||
case 1:
|
||||
l0 = new SimpleRnn.Builder().nOut(5).build();
|
||||
l2 = new SimpleRnn.Builder().nOut(5).build();
|
||||
break;
|
||||
case 2:
|
||||
l0 = new Bidirectional(new LSTM.Builder().nOut(5).build());
|
||||
l2 = new Bidirectional(new LSTM.Builder().nOut(5).build());
|
||||
break;
|
||||
default:
|
||||
throw new RuntimeException("Not implemented: " + rnnType);
|
||||
}
|
||||
|
||||
Layer l1;
|
||||
switch (ffType){
|
||||
case 0:
|
||||
l1 = new DenseLayer.Builder().nOut(5).build();
|
||||
break;
|
||||
case 1:
|
||||
l1 = new VariationalAutoencoder.Builder().nOut(5).encoderLayerSizes(5).decoderLayerSizes(5).build();
|
||||
break;
|
||||
case 2:
|
||||
l1 = new AutoEncoder.Builder().nOut(5).build();
|
||||
break;
|
||||
default:
|
||||
throw new RuntimeException("Not implemented: " + ffType);
|
||||
}
|
||||
|
||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
.activation(Activation.TANH)
|
||||
.list()
|
||||
.layer(l0)
|
||||
.layer(l1)
|
||||
.layer(l2)
|
||||
.setInputType(InputType.recurrent(5, 9, rnnDataFormat))
|
||||
.build();
|
||||
|
||||
BaseRecurrentLayer l0a;
|
||||
BaseRecurrentLayer l2a;
|
||||
if (rnnType < 2) {
|
||||
l0a = (BaseRecurrentLayer) l0;
|
||||
l2a = (BaseRecurrentLayer) l2;
|
||||
} else {
|
||||
l0a = (BaseRecurrentLayer) ((Bidirectional) l0).getFwd();
|
||||
l2a = (BaseRecurrentLayer) ((Bidirectional) l2).getFwd();
|
||||
}
|
||||
assertEquals(rnnDataFormat, l0a.getRnnDataFormat());
|
||||
assertEquals(rnnDataFormat, l2a.getRnnDataFormat());
|
||||
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.init();
|
||||
INDArray in = Nd4j.rand(DataType.FLOAT, rnnDataFormat == RNNFormat.NCW ? new long[]{2, 5, 9} : new long[]{2, 9, 5} );
|
||||
net.output(in);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,21 +15,24 @@
|
|||
******************************************************************************/
|
||||
package org.deeplearning4j.convolution;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.*;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.CuDNNTestUtils;
|
||||
import org.deeplearning4j.TestUtils;
|
||||
import org.deeplearning4j.nn.api.MaskState;
|
||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.*;
|
||||
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D;
|
||||
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.preprocessor.ComposableInputPreProcessor;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
|
@ -816,6 +819,12 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
.layer(new OutputLayer.Builder().activation(Activation.SOFTMAX).nOut(10).build())
|
||||
.setInputType(inputType != null ? inputType : InputType.convolutional(12, 12, 3, format));
|
||||
|
||||
if(format == CNN2DFormat.NHWC && !(layer instanceof GlobalPoolingLayer)){
|
||||
//Add a preprocessor due to the differences in how NHWC and NCHW activations are flattened
|
||||
//DL4J's flattening behaviour matches Keras (hence TF) for import compatibility
|
||||
builder.inputPreProcessor(2, new ComposableInputPreProcessor(new NHWCToNCHWPreprocessor(), new CnnToFeedForwardPreProcessor()));
|
||||
}
|
||||
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(builder.build());
|
||||
net.init();
|
||||
return net;
|
||||
|
@ -964,4 +973,35 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
}
|
||||
return differs;
|
||||
}
|
||||
|
||||
//Converts NHWC to NCHW activations
|
||||
@EqualsAndHashCode
|
||||
private static class NHWCToNCHWPreprocessor implements InputPreProcessor {
|
||||
|
||||
@Override
|
||||
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
|
||||
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.permute(0,3,1,2));
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray backprop(INDArray output, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
|
||||
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, output.permute(0,2,3,1));
|
||||
}
|
||||
|
||||
@Override
|
||||
public InputPreProcessor clone() {
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public InputType getOutputType(InputType inputType) {
|
||||
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
|
||||
return InputType.convolutional(c.getHeight(), c.getWidth(), c.getChannels(), CNN2DFormat.NCHW);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -212,7 +212,7 @@ public class TestConvolution extends BaseDL4JTest {
|
|||
ComputationGraph model = KerasModelImport.importKerasModelAndWeights( fExtracted.getAbsolutePath(), new int[]{inSize, inSize, 3}, false);
|
||||
model = model.convertDataType(DataType.DOUBLE);
|
||||
|
||||
INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{1, 3, inSize, inSize});
|
||||
INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{1, inSize, inSize, 3}); //Keras import model -> NHWC
|
||||
|
||||
CuDNNTestUtils.assertHelpersPresent(model.getLayers());
|
||||
Map<String,INDArray> withCudnn = model.feedForward(in, false);
|
||||
|
|
|
@ -113,6 +113,12 @@
|
|||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.datavec</groupId>
|
||||
<artifactId>datavec-python</artifactId>
|
||||
<version>${datavec.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<profiles>
|
||||
|
|
|
@ -19,6 +19,9 @@ package org.deeplearning4j.nn.modelimport.keras.layers;
|
|||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.Convolution3D;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||
|
@ -121,27 +124,29 @@ public class KerasInput extends KerasLayer {
|
|||
InputType myInputType;
|
||||
switch (this.inputShape.length) {
|
||||
case 1:
|
||||
myInputType = new InputType.InputTypeFeedForward(this.inputShape[0]);
|
||||
myInputType = new InputType.InputTypeFeedForward(this.inputShape[0], null);
|
||||
break;
|
||||
case 2:
|
||||
if(this.dimOrder != null) {
|
||||
System.out.println("Dim order: " + this.dimOrder);
|
||||
System.out.println("Input shape: " + ArrayUtils.toString(this.inputShape));
|
||||
switch (this.dimOrder) {
|
||||
case TENSORFLOW: //NWC == channels_last
|
||||
myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0]);
|
||||
myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0], RNNFormat.NWC);
|
||||
break;
|
||||
case THEANO: //NCW == channels_first
|
||||
myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]);
|
||||
myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1], RNNFormat.NCW);
|
||||
break;
|
||||
case NONE:
|
||||
//Assume RNN in [mb, seqLen, size] format
|
||||
myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]);
|
||||
myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0], RNNFormat.NWC);
|
||||
break;
|
||||
default:
|
||||
throw new IllegalStateException("Unknown/not supported dimension ordering: " + this.dimOrder);
|
||||
}
|
||||
} else {
|
||||
//Assume RNN in [mb, seqLen, size] format
|
||||
myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]);
|
||||
myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0], RNNFormat.NWC);
|
||||
}
|
||||
|
||||
break;
|
||||
|
@ -150,17 +155,17 @@ public class KerasInput extends KerasLayer {
|
|||
case TENSORFLOW:
|
||||
/* TensorFlow convolutional input: # rows, # cols, # channels */
|
||||
myInputType = new InputType.InputTypeConvolutional(this.inputShape[0], this.inputShape[1],
|
||||
this.inputShape[2]);
|
||||
this.inputShape[2], CNN2DFormat.NHWC);
|
||||
break;
|
||||
case THEANO:
|
||||
/* Theano convolutional input: # channels, # rows, # cols */
|
||||
myInputType = new InputType.InputTypeConvolutional(this.inputShape[1], this.inputShape[2],
|
||||
this.inputShape[0]);
|
||||
this.inputShape[0], CNN2DFormat.NCHW);
|
||||
break;
|
||||
default:
|
||||
this.dimOrder = DimOrder.THEANO;
|
||||
myInputType = new InputType.InputTypeConvolutional(this.inputShape[1], this.inputShape[2],
|
||||
this.inputShape[0]);
|
||||
this.inputShape[0], CNN2DFormat.NCHW);
|
||||
log.warn("Couldn't determine dim ordering / data format from model file. Older Keras " +
|
||||
"versions may come without specified backend, in which case we assume the model was " +
|
||||
"built with theano." );
|
||||
|
|
|
@ -20,6 +20,7 @@ import org.deeplearning4j.nn.api.ParamInitializer;
|
|||
import org.deeplearning4j.nn.conf.GradientNormalization;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.Layer;
|
||||
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
|
||||
|
@ -65,6 +66,9 @@ public class TFOpLayer extends Layer {
|
|||
long[] shape = inputType.getShape(true);
|
||||
TFOpLayerImpl tempLayer = new TFOpLayerImpl(nodeDef, constants, null, null);
|
||||
long[] outputShape = tempLayer.getOutputShape(shape);
|
||||
if (outputShape.length == 3){
|
||||
return InputType.recurrent(outputShape[2], outputShape[1], RNNFormat.NWC);
|
||||
}
|
||||
return InputType.inferInputType(Nd4j.create(outputShape));
|
||||
|
||||
}
|
||||
|
|
|
@ -125,17 +125,9 @@ public class TFOpLayerImpl extends AbstractLayer<TFOpLayer> {
|
|||
}
|
||||
|
||||
private INDArray runGraph(INDArray input){
|
||||
if (input.rank() == 3){
|
||||
// TODO make this a preprocessor
|
||||
input = input.permute(0, 2, 1);
|
||||
}
|
||||
Map<String, INDArray> inputMap = new HashMap<>();
|
||||
inputMap.put(inputNames.get(0), input);
|
||||
INDArray out = graphRunnerService.run(inputMap).values().toArray(new INDArray[0])[0];
|
||||
if (out.rank() == 3){
|
||||
out = out.permute(0, 2, 1); // TODO post-processing?
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
|
|
|
@ -95,7 +95,6 @@ public class KerasConvolution1D extends KerasConvolution {
|
|||
|
||||
IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(),
|
||||
enforceTrainingConfig, conf, kerasMajorVersion);
|
||||
|
||||
Convolution1DLayer.Builder builder = new Convolution1DLayer.Builder().name(this.layerName)
|
||||
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
||||
.activation(getIActivationFromConfig(layerConfig, conf))
|
||||
|
@ -104,7 +103,7 @@ public class KerasConvolution1D extends KerasConvolution {
|
|||
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
|
||||
.kernelSize(getKernelSizeFromConfig(layerConfig, 1, conf, kerasMajorVersion)[0])
|
||||
.hasBias(hasBias)
|
||||
.stride(getStrideFromConfig(layerConfig, 1, conf)[0]);
|
||||
.stride(getStrideFromConfig(layerConfig, 1, conf)[0]).rnnDataFormat(dimOrder == DimOrder.TENSORFLOW? RNNFormat.NWC: RNNFormat.NCW);
|
||||
int[] padding = getPaddingFromBorderModeConfig(layerConfig, 1, conf, kerasMajorVersion);
|
||||
if (hasBias)
|
||||
builder.biasInit(0.0);
|
||||
|
|
|
@ -20,6 +20,7 @@ import lombok.Data;
|
|||
import lombok.EqualsAndHashCode;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
||||
|
@ -27,6 +28,7 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurat
|
|||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
|
||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
import oshi.jna.platform.windows.PowrProf;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
|
@ -93,6 +95,7 @@ public class KerasConvolution2D extends KerasConvolution {
|
|||
LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig(
|
||||
layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion);
|
||||
|
||||
System.out.println("----" + dimOrder);
|
||||
ConvolutionLayer.Builder builder = new ConvolutionLayer.Builder().name(this.layerName)
|
||||
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
||||
.activation(getIActivationFromConfig(layerConfig, conf))
|
||||
|
@ -101,7 +104,8 @@ public class KerasConvolution2D extends KerasConvolution {
|
|||
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
|
||||
.kernelSize(getKernelSizeFromConfig(layerConfig, 2, conf, kerasMajorVersion))
|
||||
.hasBias(hasBias)
|
||||
.stride(getStrideFromConfig(layerConfig, 2, conf));
|
||||
.stride(getStrideFromConfig(layerConfig, 2, conf))
|
||||
.dataFormat((dimOrder==DimOrder.TENSORFLOW)? CNN2DFormat.NHWC:CNN2DFormat.NCHW);
|
||||
int[] padding = getPaddingFromBorderModeConfig(layerConfig, 2, conf, kerasMajorVersion);
|
||||
if (hasBias)
|
||||
builder.biasInit(0.0);
|
||||
|
|
|
@ -360,8 +360,19 @@ public class KerasConvolutionUtils {
|
|||
}
|
||||
|
||||
} else if (dimension == 1) {
|
||||
int paddingInt = (int) innerConfig.get(layerField);
|
||||
padding = new int[]{paddingInt, paddingInt};
|
||||
Object paddingObj = innerConfig.get(layerField);
|
||||
if (paddingObj instanceof List){
|
||||
List<Integer> paddingList = (List)paddingObj;
|
||||
padding = new int[]{
|
||||
paddingList.get(0),
|
||||
paddingList.get(1)
|
||||
};
|
||||
}
|
||||
else{
|
||||
int paddingInt = (int) innerConfig.get(layerField);
|
||||
padding = new int[]{paddingInt, paddingInt};
|
||||
}
|
||||
|
||||
} else {
|
||||
throw new UnsupportedKerasConfigurationException(
|
||||
"Keras padding layer not supported");
|
||||
|
|
|
@ -18,6 +18,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.core;
|
|||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType.InputTypeConvolutional;
|
||||
|
@ -27,7 +28,6 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurat
|
|||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor;
|
||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
|
||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
|
@ -93,11 +93,10 @@ public class KerasFlatten extends KerasLayer {
|
|||
switch (this.getDimOrder()) {
|
||||
case NONE:
|
||||
case THEANO:
|
||||
preprocessor = new CnnToFeedForwardPreProcessor(it.getHeight(), it.getWidth(), it.getChannels());
|
||||
preprocessor = new CnnToFeedForwardPreProcessor(it.getHeight(), it.getWidth(), it.getChannels(), CNN2DFormat.NCHW);
|
||||
break;
|
||||
case TENSORFLOW:
|
||||
preprocessor = new TensorFlowCnnToFeedForwardPreProcessor(it.getHeight(), it.getWidth(),
|
||||
it.getChannels());
|
||||
preprocessor = new CnnToFeedForwardPreProcessor(it.getHeight(), it.getWidth(), it.getChannels(), CNN2DFormat.NHWC);
|
||||
break;
|
||||
default:
|
||||
throw new InvalidKerasConfigurationException("Unknown Keras backend " + this.getDimOrder());
|
||||
|
@ -111,7 +110,7 @@ public class KerasFlatten extends KerasLayer {
|
|||
// to RNN type. Otherwise we add this trivial preprocessor (since there's nothing to flatten).
|
||||
InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0];
|
||||
val inputShape = new long[]{it.getSize()};
|
||||
preprocessor = new ReshapePreprocessor(inputShape, inputShape, false);
|
||||
preprocessor = new ReshapePreprocessor(inputShape, inputShape, false, null);
|
||||
}
|
||||
return preprocessor;
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.deeplearning4j.nn.modelimport.keras.layers.core;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.misc.RepeatVector;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||
|
@ -60,6 +61,7 @@ public class KerasRepeatVector extends KerasLayer {
|
|||
super(layerConfig, enforceTrainingConfig);
|
||||
|
||||
this.layer = new RepeatVector.Builder().repetitionFactor(getRepeatMultiplier(layerConfig, conf))
|
||||
.dataFormat(RNNFormat.NWC)
|
||||
.name(this.layerName).build();
|
||||
}
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.core;
|
|||
|
||||
|
||||
import lombok.val;
|
||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||
|
@ -111,11 +112,9 @@ public class KerasReshape extends KerasLayer {
|
|||
} else {
|
||||
targetShape = new long[]{targetShape[1], targetShape[0], targetShape[2]};
|
||||
}
|
||||
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
|
||||
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false, CNN2DFormat.NCHW);
|
||||
} else { // (dimOrder == DimOrder.TENSORFLOW || dimOrder == DimOrder.NONE && kerasMajorVersion == 2)
|
||||
if (inputShape[0] != targetShape[0])
|
||||
targetShape = new long[]{targetShape[2], targetShape[0], targetShape[1]};
|
||||
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
|
||||
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false, CNN2DFormat.NHWC);
|
||||
}
|
||||
|
||||
} else if (inputType[0] instanceof InputType.InputTypeConvolutional3D) {
|
||||
|
@ -128,23 +127,23 @@ public class KerasReshape extends KerasLayer {
|
|||
} else {
|
||||
targetShape = new long[] { targetShape[2], targetShape[1], targetShape[0], targetShape[3] };
|
||||
}
|
||||
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
|
||||
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false, null);
|
||||
} else {
|
||||
if (inputShape[0] != targetShape[0])
|
||||
targetShape = new long[] { targetShape[3], targetShape[0], targetShape[1], targetShape[2] };
|
||||
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
|
||||
preprocessor = new ReshapePreprocessor(inputShape, targetShape, false, null);
|
||||
}
|
||||
} else if (inputType[0] instanceof InputType.InputTypeRecurrent) {
|
||||
InputType.InputTypeRecurrent it = (InputType.InputTypeRecurrent) inputType[0];
|
||||
val inputShape = new long[]{it.getSize(), it.getTimeSeriesLength()};
|
||||
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false);
|
||||
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false, null);
|
||||
} else if (inputType[0] instanceof InputType.InputTypeFeedForward) {
|
||||
InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0];
|
||||
val inputShape = new long[]{it.getSize()};
|
||||
if (targetShape.length == 3) {
|
||||
targetShape = targetShapeForDimOrder(inputShape, targetShape);
|
||||
}
|
||||
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false);
|
||||
preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false, null);
|
||||
}
|
||||
return preprocessor;
|
||||
}
|
||||
|
|
|
@ -21,6 +21,7 @@ import lombok.EqualsAndHashCode;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||
|
@ -121,6 +122,7 @@ public class KerasEmbedding extends KerasLayer {
|
|||
.biasInit(0.0)
|
||||
.l1(this.weightL1Regularization)
|
||||
.l2(this.weightL2Regularization)
|
||||
.outputDataFormat(RNNFormat.NWC)
|
||||
.hasBias(false);
|
||||
if (embeddingConstraint != null)
|
||||
builder.constrainWeights(embeddingConstraint);
|
||||
|
|
|
@ -186,7 +186,7 @@ public class KerasLSTM extends KerasLayer {
|
|||
.weightInitRecurrent(recurrentInit)
|
||||
.biasInit(0.0) // TODO: this is incorrect
|
||||
.l1(this.weightL1Regularization)
|
||||
.l2(this.weightL2Regularization);
|
||||
.l2(this.weightL2Regularization).dataFormat(RNNFormat.NWC);
|
||||
Integer nIn = KerasLayerUtils.getNInFromInputDim(layerConfig, conf);
|
||||
if(nIn != null)
|
||||
builder.setNIn(nIn);
|
||||
|
|
|
@ -158,7 +158,7 @@ public class KerasSimpleRnn extends KerasLayer {
|
|||
.weightInitRecurrent(recurrentInit)
|
||||
.biasInit(0.0)
|
||||
.l1(this.weightL1Regularization)
|
||||
.l2(this.weightL2Regularization);
|
||||
.l2(this.weightL2Regularization).dataFormat(RNNFormat.NWC);
|
||||
Integer nIn = KerasLayerUtils.getNInFromInputDim(layerConfig, conf);
|
||||
if(nIn != null)
|
||||
builder.setNIn(nIn);
|
||||
|
|
|
@ -147,7 +147,7 @@ public class KerasBidirectional extends KerasLayer {
|
|||
break;
|
||||
case "SimpleRNN":
|
||||
kerasRnnlayer = new KerasSimpleRnn(innerRnnConfig, enforceTrainingConfig, previousLayers);
|
||||
SimpleRnn rnnLayer = (SimpleRnn) ((KerasSimpleRnn) kerasRnnlayer).getSimpleRnnLayer();
|
||||
Layer rnnLayer = ((KerasSimpleRnn) kerasRnnlayer).getSimpleRnnLayer();
|
||||
this.layer = new Bidirectional(mode, rnnLayer);
|
||||
layer.setLayerName(layerName);
|
||||
break;
|
||||
|
|
|
@ -21,6 +21,9 @@ import lombok.Data;
|
|||
import lombok.EqualsAndHashCode;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||
import org.deeplearning4j.nn.conf.DataFormat;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
|
||||
import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor;
|
||||
|
@ -54,25 +57,30 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
|
|||
private final long[] inputShape;
|
||||
private final long[] targetShape;
|
||||
private boolean hasMiniBatchDimension;
|
||||
|
||||
/**
|
||||
* @deprecated Use constructor {@link #ReshapePreprocessor(long[], long[], boolean)}
|
||||
*/
|
||||
@Deprecated
|
||||
public ReshapePreprocessor(long[] inputShape, long[] targetShape) {
|
||||
this(inputShape, targetShape, false);
|
||||
}
|
||||
private DataFormat format;
|
||||
|
||||
/**
|
||||
* @param inputShape Input shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
|
||||
* @param targetShape Target shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
|
||||
* @param hasMiniBatchDimension If true: shapes should be of the form [minibatch, x, y, ...]; if false: shapes should be of form [x, y, ...]
|
||||
*/
|
||||
public ReshapePreprocessor(long[] inputShape, long[] targetShape, boolean hasMiniBatchDimension) {
|
||||
this(inputShape, targetShape, hasMiniBatchDimension, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param inputShape Input shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
|
||||
* @param targetShape Target shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
|
||||
* @param hasMiniBatchDimension If true: shapes should be of the form [minibatch, x, y, ...]; if false: shapes should be of form [x, y, ...]
|
||||
* @param dataFormat May be null. If non-null:
|
||||
*/
|
||||
public ReshapePreprocessor(@JsonProperty("inputShape") long[] inputShape, @JsonProperty("targetShape") long[] targetShape,
|
||||
@JsonProperty("hasMiniBatchDimension") boolean hasMiniBatchDimension) {
|
||||
@JsonProperty("hasMiniBatchDimension") boolean hasMiniBatchDimension,
|
||||
@JsonProperty("dataFormat") DataFormat dataFormat) {
|
||||
this.inputShape = inputShape;
|
||||
this.targetShape = targetShape;
|
||||
this.hasMiniBatchDimension = hasMiniBatchDimension;
|
||||
this.format = dataFormat;
|
||||
}
|
||||
|
||||
private long[] getShape(long[] originalShape, long minibatch) {
|
||||
|
@ -140,13 +148,26 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
|
|||
ret = InputType.feedForward(shape[1]);
|
||||
break;
|
||||
case 3:
|
||||
ret = InputType.recurrent(shape[2], shape[1]);
|
||||
RNNFormat format = RNNFormat.NCW;
|
||||
if(this.format != null && this.format instanceof RNNFormat)
|
||||
format = (RNNFormat)this.format;
|
||||
|
||||
ret = InputType.recurrent(shape[2], shape[1], format);
|
||||
break;
|
||||
case 4:
|
||||
if (inputShape.length == 1 || inputType.getType() == InputType.Type.RNN) {
|
||||
ret = InputType.convolutional(shape[1], shape[2], shape[3]);
|
||||
} else {
|
||||
ret = InputType.convolutional(shape[2], shape[3], shape[1]);
|
||||
|
||||
CNN2DFormat cnnFormat = CNN2DFormat.NCHW;
|
||||
if (this.format != null && this.format instanceof CNN2DFormat)
|
||||
cnnFormat = (CNN2DFormat) this.format;
|
||||
|
||||
if (cnnFormat == CNN2DFormat.NCHW) {
|
||||
ret = InputType.convolutional(shape[2], shape[3], shape[1], cnnFormat);
|
||||
} else {
|
||||
ret = InputType.convolutional(shape[1], shape[2], shape[3], cnnFormat);
|
||||
}
|
||||
}
|
||||
break;
|
||||
default:
|
||||
|
|
|
@ -27,26 +27,25 @@ import org.nd4j.shade.jackson.annotation.JsonCreator;
|
|||
import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||
|
||||
/**
|
||||
* Specialized CnnToFeedForwardInputPreProcessor for use with
|
||||
* Convolutional layers imported from Keras using the TensorFlow
|
||||
* backend.
|
||||
*
|
||||
* @author dave@skymind.io
|
||||
* @deprecated Exists only for backward compatibility of older pretrained models. Should not be used.
|
||||
* Use {@link CnnToFeedForwardPreProcessor} for all new models instead.
|
||||
*/
|
||||
@Slf4j
|
||||
@Slf4j @Deprecated
|
||||
public class TensorFlowCnnToFeedForwardPreProcessor extends CnnToFeedForwardPreProcessor {
|
||||
|
||||
@JsonCreator
|
||||
@JsonCreator @Deprecated
|
||||
public TensorFlowCnnToFeedForwardPreProcessor(@JsonProperty("inputHeight") long inputHeight,
|
||||
@JsonProperty("inputWidth") long inputWidth,
|
||||
@JsonProperty("numChannels") long numChannels) {
|
||||
super(inputHeight, inputWidth, numChannels);
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
public TensorFlowCnnToFeedForwardPreProcessor(long inputHeight, long inputWidth) {
|
||||
super(inputHeight, inputWidth);
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
public TensorFlowCnnToFeedForwardPreProcessor() {
|
||||
super();
|
||||
}
|
||||
|
|
|
@ -31,6 +31,7 @@ import org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations.*;
|
|||
import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.*;
|
||||
import org.deeplearning4j.nn.modelimport.keras.layers.core.*;
|
||||
import org.deeplearning4j.nn.modelimport.keras.layers.embeddings.KerasEmbedding;
|
||||
import org.deeplearning4j.nn.modelimport.keras.layers.local.KerasLocallyConnected1D;
|
||||
import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasAlphaDropout;
|
||||
import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasGaussianDropout;
|
||||
import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasGaussianNoise;
|
||||
|
@ -319,6 +320,8 @@ public class KerasLayerUtils {
|
|||
layer = new KerasELU(layerConfig, enforceTrainingConfig);
|
||||
} else if(layerClassName.equals(conf.getLAYER_CLASS_NAME_SOFTMAX())){
|
||||
layer = new KerasSoftmax(layerConfig, enforceTrainingConfig);
|
||||
} else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_LOCALLY_CONNECTED_1D())){
|
||||
layer = new KerasLocallyConnected1D(layerConfig, enforceTrainingConfig);
|
||||
} else if (conf instanceof Keras2LayerConfiguration){
|
||||
Keras2LayerConfiguration k2conf = (Keras2LayerConfiguration)conf;
|
||||
if (layerClassName.equals(k2conf.getTENSORFLOW_OP_LAYER())){
|
||||
|
|
|
@ -1,50 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.nn.modelimport.keras;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.resources.Resources;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.Arrays;
|
||||
|
||||
public class TFKerasTests extends BaseDL4JTest{
|
||||
|
||||
@Test
|
||||
public void testModelWithTFOp1() throws Exception{
|
||||
File f = Resources.asFile("modelimport/keras/tfkeras/reshape.h5");
|
||||
ComputationGraph graph = KerasModelImport.importKerasModelAndWeights(f.getAbsolutePath());
|
||||
INDArray out = graph.outputSingle(Nd4j.zeros(12, 2, 3));
|
||||
Assert.assertArrayEquals(new long[]{12, 3}, out.shape());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testModelWithTFOp2() throws Exception{
|
||||
File f = Resources.asFile("modelimport/keras/tfkeras/permute.h5");
|
||||
ComputationGraph graph = KerasModelImport.importKerasModelAndWeights(f.getAbsolutePath());
|
||||
INDArray out = graph.outputSingle(Nd4j.zeros(12, 2, 3));
|
||||
// dl4j's feedforward doesn't support 3D output, so batch and time axes gets squashed
|
||||
long[] expectedShape = new long[]{12 * 2, 5};
|
||||
Assert.assertArrayEquals(expectedShape, out.shape());
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,147 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.nn.modelimport.keras;
|
||||
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.datavec.python.keras.Model;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.nd4j.common.tests.ResourceUtils;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.concurrency.AffinityManager;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.resources.Resources;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.List;
|
||||
|
||||
|
||||
@RunWith(Parameterized.class)
|
||||
public class TestTFKerasModelImport extends BaseDL4JTest{
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
|
||||
private String modelFile;
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds(){
|
||||
return 300000;
|
||||
} // installing TF will take a while
|
||||
|
||||
|
||||
@Parameterized.Parameters(name = "file={0}")
|
||||
public static Object[] params() throws Exception {
|
||||
List<String> paths = ResourceUtils.listClassPathFiles("modelimport/keras/tfkeras", true, false);
|
||||
return paths.toArray(new String[0]);
|
||||
}
|
||||
|
||||
public TestTFKerasModelImport(String modelFile){
|
||||
this.modelFile = modelFile;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testModelImport() throws Exception{
|
||||
testModelImportWithData(modelFile);
|
||||
}
|
||||
|
||||
private void testModelImportWithData(String path) throws Exception{
|
||||
System.out.println(path);
|
||||
// TODO multi input/output
|
||||
INDArray inputArray;
|
||||
INDArray expectedOutputArray;
|
||||
File f = Resources.asFile(path); //May in in JAR that HDF5 can't read from
|
||||
File modelFile = new File(testDir.getRoot(), f.getName());
|
||||
FileUtils.copyFile(f, modelFile);
|
||||
|
||||
synchronized (Hdf5Archive.LOCK_OBJECT){
|
||||
Hdf5Archive hdf5Archive = new Hdf5Archive(modelFile.getAbsolutePath());
|
||||
List<String> rootGroups = hdf5Archive.getGroups();
|
||||
if (rootGroups.contains("data")){
|
||||
String inputName = hdf5Archive.readAttributeAsString("input_names", "data");
|
||||
String outputName = hdf5Archive.readAttributeAsString("output_names", "data");
|
||||
inputArray = hdf5Archive.readDataSet(inputName, "data");
|
||||
expectedOutputArray = hdf5Archive.readDataSet(outputName, "data");
|
||||
}
|
||||
else{
|
||||
hdf5Archive.close();
|
||||
return;
|
||||
}
|
||||
hdf5Archive.close();
|
||||
}
|
||||
INDArray outputArray;
|
||||
|
||||
ComputationGraph dl4jModel = KerasModelImport.importKerasModelAndWeights(path);
|
||||
outputArray = dl4jModel.outputSingle(inputArray);
|
||||
|
||||
expectedOutputArray = expectedOutputArray.castTo(DataType.FLOAT);
|
||||
outputArray = outputArray.castTo(DataType.FLOAT);
|
||||
if (path.contains("misc_")){
|
||||
//shape relaxation
|
||||
expectedOutputArray = expectedOutputArray.reshape( -1);
|
||||
outputArray = outputArray.reshape(-1);
|
||||
}
|
||||
|
||||
System.out.println(outputArray.toString());
|
||||
System.out.println(expectedOutputArray.toString());
|
||||
Assert.assertArrayEquals(expectedOutputArray.shape(), outputArray.shape());
|
||||
Assert.assertTrue(expectedOutputArray.equalsWithEps(outputArray, 1e-3));
|
||||
}
|
||||
|
||||
private void testModelImportWithKeras(String path) throws Exception{
|
||||
Model kerasModel = new Model(path);
|
||||
ComputationGraph dl4jModel = KerasModelImport.importKerasModelAndWeights(path);
|
||||
Assert.assertEquals(kerasModel.numInputs(), dl4jModel.getNumInputArrays());
|
||||
Assert.assertEquals(kerasModel.numOutputs(), dl4jModel.getNumOutputArrays());
|
||||
INDArray[] kerasInputArrays = new INDArray[kerasModel.numInputs()];
|
||||
INDArray[] dl4jInputArrays = new INDArray[kerasModel.numInputs()];
|
||||
|
||||
for (int i = 0; i < kerasInputArrays.length; i ++) {
|
||||
long[] shape = kerasModel.inputShapeAt(i);
|
||||
for (int j = 0; j < shape.length; j++) {
|
||||
if (shape[j] < 0) {
|
||||
shape[j] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
kerasInputArrays[i] = Nd4j.rand(shape);
|
||||
}
|
||||
|
||||
INDArray[] kerasOut = kerasModel.predict(kerasInputArrays);
|
||||
INDArray[] dl4jOut = dl4jModel.output(dl4jInputArrays);
|
||||
|
||||
Assert.assertEquals(kerasOut.length, dl4jOut.length);
|
||||
|
||||
for (int i = 0; i < kerasOut.length; i++){
|
||||
INDArray kerasOutArr = kerasOut[i];
|
||||
kerasOutArr = kerasOutArr.reshape(1, -1);// bit of relaxation on shape
|
||||
kerasOutArr= kerasOutArr.castTo(DataType.DOUBLE);
|
||||
Nd4j.getAffinityManager().ensureLocation(dl4jOut[i], AffinityManager.Location.HOST);
|
||||
INDArray dl4jOutArr = dl4jOut[i].reshape(1, -1);
|
||||
System.out.println(kerasOutArr.shapeInfoToString());
|
||||
System.out.println(dl4jOutArr.shapeInfoToString());
|
||||
Assert.assertEquals(kerasOutArr, dl4jOutArr);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -22,7 +22,6 @@ import org.deeplearning4j.BaseDL4JTest;
|
|||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor;
|
||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor;
|
||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
|
||||
import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor;
|
||||
import org.junit.Test;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
@ -34,8 +33,7 @@ public class JsonTest extends BaseDL4JTest {
|
|||
InputPreProcessor[] pp = new InputPreProcessor[] {
|
||||
new KerasFlattenRnnPreprocessor(10, 5),
|
||||
new PermutePreprocessor(new int[]{0,1,2}),
|
||||
new ReshapePreprocessor(new long[]{10,10}, new long[]{100,1}),
|
||||
new TensorFlowCnnToFeedForwardPreProcessor()
|
||||
new ReshapePreprocessor(new long[]{10,10}, new long[]{100,1}, true, null)
|
||||
|
||||
};
|
||||
for(InputPreProcessor p : pp ){
|
||||
|
|
|
@ -29,6 +29,7 @@ import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasSpaceTo
|
|||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.junit.Ignore;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.resources.Resources;
|
||||
|
@ -250,7 +251,7 @@ public class Keras2ModelConfigurationTest extends BaseDL4JTest {
|
|||
.enforceTrainingConfig(false).buildSequential().getMultiLayerConfiguration();
|
||||
MultiLayerNetwork model = new MultiLayerNetwork(config);
|
||||
model.init();
|
||||
INDArray input = Nd4j.create(50, 500, 1500);
|
||||
INDArray input = Nd4j.create(DataType.FLOAT, 50, 1500, 500); //NWC format - [Minibatch, seqLength, channels]
|
||||
INDArray out = model.output(input);
|
||||
assertTrue(Arrays.equals(out.shape(), new long[]{50, 64}));
|
||||
}
|
||||
|
|
|
@ -87,15 +87,6 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
|||
@Rule
|
||||
public final TemporaryFolder testDir = new TemporaryFolder();
|
||||
|
||||
public static final BiFunction<String,INDArray,INDArray> nwc2ncwExpected = new BiFunction<String, INDArray, INDArray>() {
|
||||
@Override
|
||||
public INDArray apply(String s, INDArray array) {
|
||||
if(array.rank() == 3)
|
||||
return array.permute(0, 2, 1); //NWC to NCW
|
||||
return array;
|
||||
}
|
||||
};
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 180000L; //Most benchmarks should run very quickly; large timeout is to avoid issues with unusually slow download of test resources
|
||||
|
@ -169,28 +160,28 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
|||
public void importImdbLstmTfKeras1() throws Exception {
|
||||
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_model.h5";
|
||||
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_inputs_and_outputs.h5";
|
||||
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected);
|
||||
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, null);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void importImdbLstmThKeras1() throws Exception {
|
||||
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_model.h5";
|
||||
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_inputs_and_outputs.h5";
|
||||
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected);
|
||||
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, null);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void importImdbLstmTfKeras2() throws Exception {
|
||||
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_model.h5";
|
||||
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_inputs_and_outputs.h5";
|
||||
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected);
|
||||
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, null);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void importImdbLstmThKeras2() throws Exception {
|
||||
String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_model.h5";
|
||||
String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_inputs_and_outputs.h5";
|
||||
importEndModelTest(modelPath, inputsOutputPath, false, true, false, false, true, null, nwc2ncwExpected);
|
||||
importEndModelTest(modelPath, inputsOutputPath, false, true, false, false, true, null, null);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -262,7 +253,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
|||
String modelPath = "modelimport/keras/examples/simple_flatten_rnn/simple_flatten_rnn_tf_keras_2_model.h5";
|
||||
String inputsOutputPath = "modelimport/keras/examples/simple_flatten_rnn/" +
|
||||
"simple_flatten_rnn_tf_keras_2_inputs_and_outputs.h5";
|
||||
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected);
|
||||
importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, null);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -316,7 +307,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
|||
@Test
|
||||
public void importAcganDiscriminator() throws Exception {
|
||||
ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/acgan/acgan_discriminator_1_epochs.h5");
|
||||
INDArray input = Nd4j.create(10, 1, 28, 28);
|
||||
INDArray input = Nd4j.create(10, 28, 28, 1); //NHWC
|
||||
INDArray[] output = model.output(input);
|
||||
}
|
||||
|
||||
|
@ -403,7 +394,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
|||
|
||||
// Make predictions
|
||||
int miniBatch = 32;
|
||||
INDArray input = Nd4j.ones(miniBatch, 4, 10);
|
||||
INDArray input = Nd4j.ones(miniBatch, 10, 4); //NWC format - with nIn=4, seqLength = 10
|
||||
INDArray[] out = graph.output(input);
|
||||
|
||||
// Fit model
|
||||
|
@ -450,7 +441,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
|||
@Test
|
||||
public void importMobileNet() throws Exception {
|
||||
ComputationGraph graph = importFunctionalModelH5Test("modelimport/keras/examples/mobilenet/alternative.hdf5");
|
||||
INDArray input = Nd4j.ones(10, 3, 299, 299);
|
||||
INDArray input = Nd4j.ones(10, 299, 299, 3);
|
||||
graph.output(input);
|
||||
}
|
||||
|
||||
|
@ -462,7 +453,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
|||
int[] inputShape = new int[]{299, 299, 3};
|
||||
ComputationGraph graph = importFunctionalModelH5Test(
|
||||
"modelimport/keras/examples/inception/inception_tf_keras_2.h5", inputShape, false);
|
||||
INDArray input = Nd4j.ones(10, 3, 299, 299);
|
||||
INDArray input = Nd4j.ones(10, 299, 299, 3); //TF = channels last = NHWC
|
||||
graph.output(input);
|
||||
System.out.println(graph.summary());
|
||||
}
|
||||
|
@ -476,7 +467,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
|||
public void importInception() throws Exception {
|
||||
ComputationGraph graph = importFunctionalModelH5Test(
|
||||
"modelimport/keras/examples/inception/inception_v3_complete.h5");
|
||||
INDArray input = Nd4j.ones(10, 3, 299, 299);
|
||||
INDArray input = Nd4j.ones(10, 299, 299, 3); //TF = channels last = NHWC
|
||||
graph.output(input);
|
||||
System.out.println(graph.summary());
|
||||
}
|
||||
|
@ -533,14 +524,14 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
|||
* - Separate (policy and value) residual architecture
|
||||
* - Separate (policy and value) convolutional architecture
|
||||
*/
|
||||
@Test
|
||||
@Test @Ignore //AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last
|
||||
public void importSepConvPolicy() throws Exception {
|
||||
ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/sep_conv_policy.h5");
|
||||
INDArray input = Nd4j.create(32, 19, 19, 10);
|
||||
model.output(input);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test @Ignore //AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last
|
||||
public void importSepResPolicy() throws Exception {
|
||||
ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/sep_res_policy.h5");
|
||||
INDArray input = Nd4j.create(32, 19, 19, 10);
|
||||
|
@ -548,28 +539,28 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
|||
}
|
||||
|
||||
|
||||
@Test
|
||||
@Test @Ignore //AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last
|
||||
public void importSepConvValue() throws Exception {
|
||||
ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/sep_conv_value.h5");
|
||||
INDArray input = Nd4j.create(32, 19, 19, 10);
|
||||
model.output(input);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test @Ignore //AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last
|
||||
public void importSepResValue() throws Exception {
|
||||
ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/sep_res_value.h5");
|
||||
INDArray input = Nd4j.create(32, 19, 19, 10);
|
||||
model.output(input);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test @Ignore //AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last
|
||||
public void importDualRes() throws Exception {
|
||||
ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/dual_res.h5");
|
||||
INDArray input = Nd4j.create(32, 19, 19, 10);
|
||||
model.output(input);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test @Ignore //AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last
|
||||
public void importDualConv() throws Exception {
|
||||
ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/dual_conv.h5");
|
||||
INDArray input = Nd4j.create(32, 19, 19, 10);
|
||||
|
@ -634,16 +625,9 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
|||
System.out.println("Starting test: " + name);
|
||||
String modelPath = "modelimport/keras/examples/causal_conv1d/" + name;
|
||||
String inputsOutputPath = "modelimport/keras/examples/causal_conv1d/" + (name.substring(0,name.length()-"model.h5".length()) + "inputs_and_outputs.h5");
|
||||
Function<INDArray,INDArray> f = new Function<INDArray, INDArray>() {
|
||||
@Override
|
||||
public INDArray apply(INDArray i) {
|
||||
//NWC to NCW
|
||||
return i.permute(0, 2, 1);
|
||||
}
|
||||
};
|
||||
|
||||
MultiLayerNetwork net = importEndModelTest(modelPath, inputsOutputPath, true, true,
|
||||
true, true, false, f, nwc2ncwExpected);
|
||||
true, true, false, null, null);
|
||||
Layer l = net.getLayer(0);
|
||||
Convolution1DLayer c1d = (Convolution1DLayer) l.getConfig();
|
||||
assertEquals(ConvolutionMode.Causal, c1d.getConvolutionMode());
|
||||
|
@ -707,25 +691,9 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
|||
System.out.println("Starting test: " + name);
|
||||
String modelPath = "modelimport/keras/examples/conv1d/" + name;
|
||||
String inputsOutputPath = "modelimport/keras/examples/conv1d/" + (name.substring(0,name.length()-"model.h5".length()) + "inputs_and_outputs.h5");
|
||||
Function<INDArray,INDArray> f = name.contains("_cf_") ? null : new Function<INDArray, INDArray>() {
|
||||
@Override
|
||||
public INDArray apply(INDArray i) {
|
||||
//NWC to NCW
|
||||
return i.permute(0, 2, 1);
|
||||
}
|
||||
};
|
||||
|
||||
BiFunction<String,INDArray,INDArray> f2 = name.contains("_cf_") ? null : new BiFunction<String, INDArray, INDArray>() {
|
||||
@Override
|
||||
public INDArray apply(String s, INDArray array) {
|
||||
// if("conv".equals(s)){
|
||||
return array.permute(0, 2, 1);
|
||||
// }
|
||||
}
|
||||
};
|
||||
|
||||
importEndModelTest(modelPath, inputsOutputPath, true, true,
|
||||
true, true, false, f, f2);
|
||||
true, true, false, null, null); //f, f2);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -882,8 +850,6 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
|||
INDArray[] inputs = new INDArray[inputNames.size()];
|
||||
for (int i = 0; i < inputNames.size(); i++) {
|
||||
inputs[i] = archive.readDataSet(inputNames.get(i), GROUP_ATTR_INPUTS);
|
||||
if (inputs[i].shape().length == 4 && tensorFlowImageDimOrdering)
|
||||
inputs[i] = inputs[i].permute(0, 3, 1, 2);
|
||||
}
|
||||
return inputs;
|
||||
}
|
||||
|
@ -893,8 +859,6 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
|||
Map<String, INDArray> activations = new HashMap<String, INDArray>();
|
||||
for (String layerName : archive.getDataSets(GROUP_ACTIVATIONS)) {
|
||||
INDArray activation = archive.readDataSet(layerName, GROUP_ACTIVATIONS);
|
||||
if (activation.shape().length == 4 && tensorFlowImageDimOrdering)
|
||||
activation = activation.permute(0, 3, 1, 2);
|
||||
activations.put(layerName, activation);
|
||||
}
|
||||
return activations;
|
||||
|
@ -907,8 +871,6 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
|||
INDArray[] outputs = new INDArray[outputNames.size()];
|
||||
for (int i = 0; i < outputNames.size(); i++) {
|
||||
outputs[i] = archive.readDataSet(outputNames.get(i), GROUP_ATTR_OUTPUTS);
|
||||
if (outputs[i].shape().length == 4 && tensorFlowImageDimOrdering)
|
||||
outputs[i] = outputs[i].permute(0, 3, 1, 2);
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
|
@ -920,8 +882,6 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
|||
INDArray[] predictions = new INDArray[outputNames.size()];
|
||||
for (int i = 0; i < outputNames.size(); i++) {
|
||||
predictions[i] = archive.readDataSet(outputNames.get(i), GROUP_PREDICTIONS);
|
||||
if (predictions[i].shape().length == 4 && tensorFlowImageDimOrdering)
|
||||
predictions[i] = predictions[i].permute(0, 3, 1, 2);
|
||||
}
|
||||
return predictions;
|
||||
}
|
||||
|
@ -941,6 +901,11 @@ public class KerasModelEndToEndTest extends BaseDL4JTest {
|
|||
// skip too small absolute inputs
|
||||
if (Math.abs(aAbsMax) > threshold && Math.abs(bAbsMax) > threshold) {
|
||||
boolean eq = expected.equalsWithEps(actual.castTo(expected.dataType()), eps);
|
||||
if(!eq){
|
||||
System.out.println("Expected: " + Arrays.toString(expected.shape()) + ", actual: " + Arrays.toString(actual.shape()));
|
||||
System.out.println("Expected:\n" + expected);
|
||||
System.out.println("Actual: \n" + actual);
|
||||
}
|
||||
assertTrue("Output differs: " + label, eq);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -176,10 +176,10 @@ public class KerasWeightSettingTests extends BaseDL4JTest {
|
|||
INDArray bias = model.getLayer(0).getParam("b");
|
||||
assertEquals(6, bias.length());
|
||||
|
||||
INDArray input = Nd4j.ones(1, 5, 3, 4);
|
||||
INDArray input = Nd4j.ones(1, 3, 4, 5); //NHWC
|
||||
INDArray output = model.output(input);
|
||||
|
||||
assertArrayEquals(new long[] {1, 6, 1, 2}, output.shape());
|
||||
assertArrayEquals(new long[] {1, 1, 2, 6}, output.shape()); //NHWC
|
||||
|
||||
logSuccess(modelPath);
|
||||
}
|
||||
|
@ -224,7 +224,7 @@ public class KerasWeightSettingTests extends BaseDL4JTest {
|
|||
|
||||
INDArray input = Nd4j.zeros(mb, inputLength);
|
||||
INDArray output = model.output(input);
|
||||
assertArrayEquals(new long[]{mb, nOut, inputLength - kernel + 1}, output.shape());
|
||||
assertArrayEquals(new long[]{mb, inputLength - kernel + 1, nOut}, output.shape()); //NWC
|
||||
logSuccess(modelPath);
|
||||
}
|
||||
|
||||
|
@ -238,9 +238,9 @@ public class KerasWeightSettingTests extends BaseDL4JTest {
|
|||
KerasLayer.registerCustomLayer("Lambda", KerasSpaceToDepth.class);
|
||||
MultiLayerNetwork model = loadMultiLayerNetwork(modelPath, false);
|
||||
|
||||
INDArray input = Nd4j.zeros(10, 4, 6, 6);
|
||||
INDArray input = Nd4j.zeros(10, 6, 6, 4);
|
||||
INDArray output = model.output(input);
|
||||
assertArrayEquals(new long[]{10, 16, 3, 3}, output.shape());
|
||||
assertArrayEquals(new long[]{10, 3, 3, 16}, output.shape());
|
||||
logSuccess(modelPath);
|
||||
}
|
||||
|
||||
|
@ -248,10 +248,11 @@ public class KerasWeightSettingTests extends BaseDL4JTest {
|
|||
KerasLayer.registerCustomLayer("Lambda", KerasSpaceToDepth.class);
|
||||
ComputationGraph model = loadComputationalGraph(modelPath, false);
|
||||
|
||||
INDArray input[] = new INDArray[]{Nd4j.zeros(10, 4, 6, 6), Nd4j.zeros(10, 16, 3, 3)};
|
||||
// INDArray input[] = new INDArray[]{Nd4j.zeros(10, 4, 6, 6), Nd4j.zeros(10, 16, 3, 3)};
|
||||
INDArray input[] = new INDArray[]{Nd4j.zeros(10, 6, 6, 4), Nd4j.zeros(10, 3, 3, 16)};
|
||||
INDArray[] output = model.output(input);
|
||||
log.info(Arrays.toString(output[0].shape()));
|
||||
assertArrayEquals(new long[]{10, 32, 3, 3}, output[0].shape());
|
||||
assertArrayEquals(new long[]{10, 3, 3, 32}, output[0].shape());
|
||||
logSuccess(modelPath);
|
||||
}
|
||||
|
||||
|
@ -278,7 +279,7 @@ public class KerasWeightSettingTests extends BaseDL4JTest {
|
|||
|
||||
INDArray inEmbedding = Nd4j.zeros(mb, inputLength);
|
||||
INDArray output = model.output(inEmbedding);
|
||||
assertArrayEquals(new long[]{mb, nOut, inputLength}, output.shape());
|
||||
assertArrayEquals(new long[]{mb, inputLength, nOut}, output.shape()); //NWC format
|
||||
logSuccess(modelPath);
|
||||
}
|
||||
|
||||
|
@ -304,7 +305,7 @@ public class KerasWeightSettingTests extends BaseDL4JTest {
|
|||
|
||||
INDArray inEmbedding = Nd4j.zeros(mb, inputLength);
|
||||
INDArray output = model.output(inEmbedding);
|
||||
assertArrayEquals(new long[]{mb, nOut, inputLength - kernel + 1}, output.shape());
|
||||
assertArrayEquals(new long[]{mb, inputLength - kernel + 1, nOut}, output.shape()); //NWC
|
||||
logSuccess(modelPath);
|
||||
}
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ package org.deeplearning4j.nn.conf;
|
|||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
public enum CNN2DFormat {
|
||||
public enum CNN2DFormat implements DataFormat {
|
||||
NCHW,
|
||||
NHWC;
|
||||
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.nn.conf;
|
||||
|
||||
import org.deeplearning4j.nn.conf.serde.format.DataFormatDeserializer;
|
||||
import org.deeplearning4j.nn.conf.serde.format.DataFormatSerializer;
|
||||
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
|
||||
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
|
||||
|
||||
@JsonSerialize(using = DataFormatSerializer.class)
|
||||
@JsonDeserialize(using = DataFormatDeserializer.class)
|
||||
public interface DataFormat {
|
||||
}
|
|
@ -663,7 +663,7 @@ public class MultiLayerConfiguration implements Serializable, Cloneable {
|
|||
BaseRecurrentLayer brl = (BaseRecurrentLayer) firstLayer;
|
||||
val nIn = brl.getNIn();
|
||||
if (nIn > 0) {
|
||||
inputType = InputType.recurrent(nIn);
|
||||
inputType = InputType.recurrent(nIn, brl.getRnnDataFormat());
|
||||
}
|
||||
} else if (firstLayer instanceof DenseLayer || firstLayer instanceof EmbeddingLayer
|
||||
|| firstLayer instanceof OutputLayer) {
|
||||
|
|
|
@ -23,7 +23,7 @@ package org.deeplearning4j.nn.conf;
|
|||
* "width" corresponds to sequence length and "channels" corresponds to sequence item size.
|
||||
*/
|
||||
|
||||
public enum RNNFormat {
|
||||
public enum RNNFormat implements DataFormat {
|
||||
NCW,
|
||||
NWC
|
||||
}
|
||||
|
|
|
@ -18,6 +18,8 @@ package org.deeplearning4j.nn.conf.graph;
|
|||
|
||||
|
||||
import lombok.val;
|
||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
|
||||
import org.deeplearning4j.nn.conf.layers.Convolution3D;
|
||||
|
@ -38,6 +40,8 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
|||
*/
|
||||
public class MergeVertex extends GraphVertex {
|
||||
|
||||
protected int mergeAxis = 1; //default value for backward compatibility (deserialization of old version JSON) - NCHW and NCW format
|
||||
|
||||
@Override
|
||||
public MergeVertex clone() {
|
||||
return new MergeVertex();
|
||||
|
@ -76,7 +80,7 @@ public class MergeVertex extends GraphVertex {
|
|||
@Override
|
||||
public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx,
|
||||
INDArray paramsView, boolean initializeParams, DataType networkDatatype) {
|
||||
return new org.deeplearning4j.nn.graph.vertex.impl.MergeVertex(graph, name, idx, networkDatatype);
|
||||
return new org.deeplearning4j.nn.graph.vertex.impl.MergeVertex(graph, name, idx, networkDatatype, mergeAxis);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -126,6 +130,7 @@ public class MergeVertex extends GraphVertex {
|
|||
//FF or RNN data inputs
|
||||
int size = 0;
|
||||
InputType.Type type = null;
|
||||
RNNFormat format = null;
|
||||
for (int i = 0; i < vertexInputs.length; i++) {
|
||||
if (vertexInputs[i].getType() != first.getType()) {
|
||||
throw new InvalidInputTypeException(
|
||||
|
@ -142,6 +147,8 @@ public class MergeVertex extends GraphVertex {
|
|||
break;
|
||||
case RNN:
|
||||
thisSize = ((InputType.InputTypeRecurrent) vertexInputs[i]).getSize();
|
||||
format = ((InputType.InputTypeRecurrent) vertexInputs[i]).getFormat();
|
||||
this.mergeAxis = format == RNNFormat.NCW ? 1 : 2;
|
||||
type = InputType.Type.RNN;
|
||||
break;
|
||||
default:
|
||||
|
@ -160,7 +167,7 @@ public class MergeVertex extends GraphVertex {
|
|||
return InputType.feedForward(size);
|
||||
} else {
|
||||
val tsLength = ((InputType.InputTypeRecurrent) vertexInputs[0]).getTimeSeriesLength();
|
||||
return InputType.recurrent(size, tsLength);
|
||||
return InputType.recurrent(size, tsLength, format);
|
||||
}
|
||||
} else {
|
||||
//size is unknown
|
||||
|
@ -168,13 +175,14 @@ public class MergeVertex extends GraphVertex {
|
|||
return InputType.feedForward(-1);
|
||||
} else {
|
||||
val tsLength = ((InputType.InputTypeRecurrent) vertexInputs[0]).getTimeSeriesLength();
|
||||
return InputType.recurrent(-1, tsLength);
|
||||
return InputType.recurrent(-1, tsLength, format);
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
//CNN inputs... also check that the channels, width and heights match:
|
||||
InputType.InputTypeConvolutional firstConv = (InputType.InputTypeConvolutional) first;
|
||||
CNN2DFormat format = firstConv.getFormat();
|
||||
|
||||
val fd = firstConv.getChannels();
|
||||
val fw = firstConv.getWidth();
|
||||
|
@ -206,7 +214,8 @@ public class MergeVertex extends GraphVertex {
|
|||
depthSum += od;
|
||||
}
|
||||
|
||||
return InputType.convolutional(fh, fw, depthSum);
|
||||
this.mergeAxis = format == CNN2DFormat.NCHW ? 1 : 3;
|
||||
return InputType.convolutional(fh, fw, depthSum, format);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ import lombok.Data;
|
|||
import lombok.EqualsAndHashCode;
|
||||
import lombok.Getter;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.deeplearning4j.nn.conf.DataFormat;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||
import org.deeplearning4j.nn.conf.layers.Convolution3D;
|
||||
|
@ -91,7 +92,11 @@ public abstract class InputType implements Serializable {
|
|||
* @return InputTypeFeedForward
|
||||
*/
|
||||
public static InputType feedForward(long size) {
|
||||
return new InputTypeFeedForward(size);
|
||||
return new InputTypeFeedForward(size, null);
|
||||
}
|
||||
|
||||
public static InputType feedForward(long size, DataFormat timeDistributedFormat) {
|
||||
return new InputTypeFeedForward(size,timeDistributedFormat);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -132,7 +137,6 @@ public abstract class InputType implements Serializable {
|
|||
* @return InputTypeConvolutional
|
||||
*/
|
||||
public static InputType convolutional(long height, long width, long depth) {
|
||||
// return new InputTypeConvolutional(height, width, depth);
|
||||
return convolutional(height, width, depth, CNN2DFormat.NCHW);
|
||||
}
|
||||
|
||||
|
@ -191,9 +195,11 @@ public abstract class InputType implements Serializable {
|
|||
@EqualsAndHashCode(callSuper = false)
|
||||
public static class InputTypeFeedForward extends InputType {
|
||||
private long size;
|
||||
private DataFormat timeDistributedFormat;
|
||||
|
||||
public InputTypeFeedForward(@JsonProperty("size") long size) {
|
||||
public InputTypeFeedForward(@JsonProperty("size") long size, @JsonProperty("timeDistributedFormat") DataFormat timeDistributedFormat) {
|
||||
this.size = size;
|
||||
this.timeDistributedFormat = timeDistributedFormat;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -203,7 +209,7 @@ public abstract class InputType implements Serializable {
|
|||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "InputTypeFeedForward(" + size + ")";
|
||||
return "InputTypeFeedForward(" + size + (timeDistributedFormat != null ? "," + timeDistributedFormat : "") + ")";
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -302,7 +308,8 @@ public abstract class InputType implements Serializable {
|
|||
this.height = height;
|
||||
this.width = width;
|
||||
this.channels = channels;
|
||||
this.format = format;
|
||||
if(format != null)
|
||||
this.format = format;
|
||||
}
|
||||
|
||||
public InputTypeConvolutional(long height, long width, long channels) {
|
||||
|
|
|
@ -64,11 +64,11 @@ public abstract class BaseRecurrentLayer extends FeedForwardLayer {
|
|||
+ "\"): expect RNN input type with size > 0. Got: " + inputType);
|
||||
}
|
||||
|
||||
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
|
||||
if (nIn <= 0 || override) {
|
||||
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
|
||||
this.nIn = r.getSize();
|
||||
this.rnnDataFormat = r.getFormat();
|
||||
}
|
||||
this.rnnDataFormat = r.getFormat();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -44,6 +44,7 @@ import java.util.Map;
|
|||
@ToString(callSuper = true)
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
public class Convolution1DLayer extends ConvolutionLayer {
|
||||
private RNNFormat rnnDataFormat = RNNFormat.NCW;
|
||||
/*
|
||||
//TODO: We will eventually want to NOT subclass off of ConvolutionLayer.
|
||||
//Currently, we just subclass off the ConvolutionLayer and hard code the "width" dimension to 1
|
||||
|
@ -56,6 +57,7 @@ public class Convolution1DLayer extends ConvolutionLayer {
|
|||
private Convolution1DLayer(Builder builder) {
|
||||
super(builder);
|
||||
initializeConstraints(builder);
|
||||
this.rnnDataFormat = builder.rnnDataFormat;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -92,7 +94,8 @@ public class Convolution1DLayer extends ConvolutionLayer {
|
|||
outLength = Convolution1DUtils.getOutputSize(inputTsLength, kernelSize[0], stride[0], padding[0],
|
||||
convolutionMode, dilation[0]);
|
||||
}
|
||||
return InputType.recurrent(nOut, outLength);
|
||||
|
||||
return InputType.recurrent(nOut, outLength, rnnDataFormat);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -102,10 +105,11 @@ public class Convolution1DLayer extends ConvolutionLayer {
|
|||
+ "\"): expect RNN input type with size > 0. Got: " + inputType);
|
||||
}
|
||||
|
||||
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
|
||||
if (nIn <= 0 || override) {
|
||||
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
|
||||
this.nIn = r.getSize();
|
||||
}
|
||||
this.rnnDataFormat = r.getFormat();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -115,11 +119,13 @@ public class Convolution1DLayer extends ConvolutionLayer {
|
|||
+ "\"): input is null");
|
||||
}
|
||||
|
||||
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, RNNFormat.NCW,getLayerName());
|
||||
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, rnnDataFormat,getLayerName());
|
||||
}
|
||||
|
||||
public static class Builder extends ConvolutionLayer.BaseConvBuilder<Builder> {
|
||||
|
||||
private RNNFormat rnnDataFormat = RNNFormat.NCW;
|
||||
|
||||
public Builder() {
|
||||
this(0, 1, 0);
|
||||
this.setKernelSize((int[]) null);
|
||||
|
@ -130,6 +136,11 @@ public class Convolution1DLayer extends ConvolutionLayer {
|
|||
return true;
|
||||
}
|
||||
|
||||
|
||||
public Builder rnnDataFormat(RNNFormat rnnDataFormat){
|
||||
this.rnnDataFormat = rnnDataFormat;
|
||||
return this;
|
||||
}
|
||||
/**
|
||||
* @param kernelSize Kernel size
|
||||
* @param stride Stride
|
||||
|
|
|
@ -21,6 +21,7 @@ import org.deeplearning4j.nn.api.Layer;
|
|||
import org.deeplearning4j.nn.api.ParamInitializer;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
|
||||
import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
||||
|
@ -58,12 +59,14 @@ public class EmbeddingSequenceLayer extends FeedForwardLayer {
|
|||
private int inputLength = 1; // By default only use one index to embed
|
||||
private boolean hasBias = false;
|
||||
private boolean inferInputLength = false; // use input length as provided by input data
|
||||
private RNNFormat outputFormat = RNNFormat.NCW; //Default value for older deserialized models
|
||||
|
||||
private EmbeddingSequenceLayer(Builder builder) {
|
||||
super(builder);
|
||||
this.hasBias = builder.hasBias;
|
||||
this.inputLength = builder.inputLength;
|
||||
this.inferInputLength = builder.inferInputLength;
|
||||
this.outputFormat = builder.outputFormat;
|
||||
initializeConstraints(builder);
|
||||
}
|
||||
|
||||
|
@ -87,7 +90,7 @@ public class EmbeddingSequenceLayer extends FeedForwardLayer {
|
|||
throw new IllegalStateException("Invalid input for Embedding layer (layer index = " + layerIndex
|
||||
+ ", layer name = \"" + getLayerName() + "\"): expect FF/RNN input type. Got: " + inputType);
|
||||
}
|
||||
return InputType.recurrent(nOut, inputLength);
|
||||
return InputType.recurrent(nOut, inputLength, outputFormat);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -167,6 +170,13 @@ public class EmbeddingSequenceLayer extends FeedForwardLayer {
|
|||
*/
|
||||
private boolean inferInputLength = true;
|
||||
|
||||
private RNNFormat outputFormat = RNNFormat.NCW; //Default value for older deserialized models
|
||||
|
||||
public Builder outputDataFormat(RNNFormat format){
|
||||
this.outputFormat = format;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* If true: include bias parameters in the layer. False (default): no bias.
|
||||
*
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.deeplearning4j.nn.conf.layers;
|
||||
|
||||
import lombok.*;
|
||||
import org.deeplearning4j.nn.conf.DataFormat;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.preprocessor.Cnn3DToFeedForwardPreProcessor;
|
||||
|
@ -35,6 +36,7 @@ public abstract class FeedForwardLayer extends BaseLayer {
|
|||
|
||||
protected long nIn;
|
||||
protected long nOut;
|
||||
protected DataFormat timeDistributedFormat;
|
||||
|
||||
public FeedForwardLayer(Builder builder) {
|
||||
super(builder);
|
||||
|
@ -51,7 +53,7 @@ public abstract class FeedForwardLayer extends BaseLayer {
|
|||
+ getLayerName() + "\"): expected FeedForward input type. Got: " + inputType);
|
||||
}
|
||||
|
||||
return InputType.feedForward(nOut);
|
||||
return InputType.feedForward(nOut, timeDistributedFormat);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -71,6 +73,11 @@ public abstract class FeedForwardLayer extends BaseLayer {
|
|||
this.nIn = f.getFlattenedSize();
|
||||
}
|
||||
}
|
||||
|
||||
if(inputType instanceof InputType.InputTypeFeedForward){
|
||||
InputType.InputTypeFeedForward f = (InputType.InputTypeFeedForward) inputType;
|
||||
this.timeDistributedFormat = f.getTimeDistributedFormat();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -536,11 +536,17 @@ public class InputTypeUtil {
|
|||
}
|
||||
|
||||
switch (inputType.getType()) {
|
||||
case FF:
|
||||
case CNNFlat:
|
||||
//FF -> RNN or CNNFlat -> RNN
|
||||
//In either case, input data format is a row vector per example
|
||||
return new FeedForwardToRnnPreProcessor(rnnDataFormat);
|
||||
case FF:
|
||||
//If time distributed format is defined, use that. Otherwise use the layer-defined rnnDataFormat, which may be default
|
||||
InputType.InputTypeFeedForward ff = (InputType.InputTypeFeedForward)inputType;
|
||||
if(ff.getTimeDistributedFormat() != null && ff.getTimeDistributedFormat() instanceof RNNFormat){
|
||||
return new FeedForwardToRnnPreProcessor((RNNFormat) ff.getTimeDistributedFormat());
|
||||
}
|
||||
return new FeedForwardToRnnPreProcessor(rnnDataFormat);
|
||||
case RNN:
|
||||
//RNN -> RNN: No preprocessor necessary
|
||||
return null;
|
||||
|
|
|
@ -98,9 +98,9 @@ public class RnnOutputLayer extends BaseOutputLayer {
|
|||
+ "\"): Expected RNN input, got " + inputType);
|
||||
}
|
||||
|
||||
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
|
||||
this.rnnDataFormat = r.getFormat();
|
||||
if (nIn <= 0 || override) {
|
||||
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
|
||||
this.rnnDataFormat = r.getFormat();
|
||||
this.nIn = r.getSize();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -91,7 +91,7 @@ public class Subsampling1DLayer extends SubsamplingLayer {
|
|||
outLength = Convolution1DUtils.getOutputSize(inputTsLength, kernelSize[0], stride[0], padding[0],
|
||||
convolutionMode, dilation[0]);
|
||||
}
|
||||
return InputType.recurrent(r.getSize(), outLength);
|
||||
return InputType.recurrent(r.getSize(), outLength, r.getFormat());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.deeplearning4j.nn.conf.layers.misc;
|
|||
import lombok.*;
|
||||
import org.deeplearning4j.nn.api.ParamInitializer;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
|
||||
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
|
||||
|
@ -46,10 +47,12 @@ import java.util.Map;
|
|||
public class RepeatVector extends FeedForwardLayer {
|
||||
|
||||
private int n = 1;
|
||||
private RNNFormat dataFormat = RNNFormat.NCW;
|
||||
|
||||
protected RepeatVector(Builder builder) {
|
||||
super(builder);
|
||||
this.n = builder.n;
|
||||
this.dataFormat = builder.dataFormat;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -83,7 +86,7 @@ public class RepeatVector extends FeedForwardLayer {
|
|||
+ "\"): Expected FF input, got " + inputType);
|
||||
}
|
||||
InputType.InputTypeFeedForward ffInput = (InputType.InputTypeFeedForward) inputType;
|
||||
return InputType.recurrent(ffInput.getSize(), n);
|
||||
return InputType.recurrent(ffInput.getSize(), n, this.dataFormat);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -101,13 +104,14 @@ public class RepeatVector extends FeedForwardLayer {
|
|||
}
|
||||
|
||||
|
||||
|
||||
@NoArgsConstructor
|
||||
@Getter
|
||||
@Setter
|
||||
public static class Builder<T extends Builder<T>> extends FeedForwardLayer.Builder<T> {
|
||||
|
||||
private int n = 1; // no repetition by default
|
||||
|
||||
private RNNFormat dataFormat = RNNFormat.NCW;
|
||||
/**
|
||||
* Set repetition factor for RepeatVector layer
|
||||
*/
|
||||
|
@ -115,6 +119,15 @@ public class RepeatVector extends FeedForwardLayer {
|
|||
return n;
|
||||
}
|
||||
|
||||
public RNNFormat getDataFormat(){
|
||||
return dataFormat;
|
||||
}
|
||||
|
||||
public Builder dataFormat(RNNFormat dataFormat){
|
||||
this.dataFormat = dataFormat;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set repetition factor for RepeatVector layer
|
||||
*
|
||||
|
|
|
@ -39,11 +39,13 @@ import java.util.Arrays;
|
|||
* For example, CNN -> Denselayer <br>
|
||||
* This does two things:<br>
|
||||
* (b) Reshapes 4d activations out of CNN layer, with shape
|
||||
* [numExamples, numChannels, inputHeight, inputWidth]) into 2d activations (with shape
|
||||
* [numExamples, inputHeight*inputWidth*numChannels]) for use in feed forward layer
|
||||
* [numExamples, numChannels, inputHeight, inputWidth]) (for {@link CNN2DFormat#NCHW} format activations) or shape
|
||||
* [numExamples, inputHeight, inputWidth, numChannels] (for {@link CNN2DFormat#NHWC}) format activations) into 2d activations
|
||||
* (with shape [numExamples, inputHeight*inputWidth*numChannels]) for use in feed forward layer.
|
||||
* (a) Reshapes epsilons (weights*deltas) out of FeedFoward layer (which is 2D or 3D with shape
|
||||
* [numExamples, inputHeight*inputWidth*numChannels]) into 4d epsilons (with shape
|
||||
* [numExamples, numChannels, inputHeight, inputWidth]) suitable to feed into CNN layers.<br>
|
||||
* [numExamples, numChannels, inputHeight, inputWidth] or [numExamples, inputHeight, inputWidth, numChannels]) suitable to
|
||||
* feed into CNN layers.<br>
|
||||
* Note: numChannels is equivalent to channels or featureMaps referenced in different literature
|
||||
* @author Adam Gibson
|
||||
* @see FeedForwardToCnnPreProcessor for opposite case (i.e., DenseLayer -> CNNetc)
|
||||
|
@ -68,7 +70,8 @@ public class CnnToFeedForwardPreProcessor implements InputPreProcessor {
|
|||
this.inputHeight = inputHeight;
|
||||
this.inputWidth = inputWidth;
|
||||
this.numChannels = numChannels;
|
||||
this.format = format;
|
||||
if(format != null)
|
||||
this.format = format;
|
||||
}
|
||||
|
||||
public CnnToFeedForwardPreProcessor(long inputHeight, long inputWidth) {
|
||||
|
@ -96,10 +99,17 @@ public class CnnToFeedForwardPreProcessor implements InputPreProcessor {
|
|||
wDim = 2;
|
||||
}
|
||||
|
||||
if(inputHeight == 0 && inputWidth == 0 && numChannels == 0){
|
||||
this.inputHeight = input.size(hDim);
|
||||
this.inputWidth = input.size(wDim);
|
||||
this.numChannels = input.size(chDim);
|
||||
}
|
||||
|
||||
if(input.size(chDim) != numChannels || input.size(hDim) != inputHeight || input.size(wDim) != inputWidth){
|
||||
throw new IllegalStateException("Invalid input, does not match configuration: expected [minibatch, numChannels="
|
||||
+ numChannels + ", inputHeight=" + inputHeight + ", inputWidth=" + inputWidth + "] but got input array of" +
|
||||
"shape " + Arrays.toString(input.shape()));
|
||||
throw new IllegalStateException("Invalid input, does not match configuration: expected " +
|
||||
(format == CNN2DFormat.NCHW ? "[minibatch, numChannels=" + numChannels + ", inputHeight=" + inputHeight + ", inputWidth=" + inputWidth + "] " :
|
||||
"[minibatch, inputHeight=" + inputHeight + ", inputWidth=" + inputWidth + ", numChannels=" + numChannels + "]") +
|
||||
" but got input array of shape " + Arrays.toString(input.shape()));
|
||||
}
|
||||
|
||||
//Check input: nchw format
|
||||
|
@ -110,15 +120,13 @@ public class CnnToFeedForwardPreProcessor implements InputPreProcessor {
|
|||
+ Arrays.toString(input.shape()));
|
||||
}
|
||||
|
||||
if(format == CNN2DFormat.NHWC) {
|
||||
input = input.permute(0, 3, 1, 2); //NHWC to NCHW
|
||||
}
|
||||
|
||||
//Assume input is standard rank 4 activations out of CNN layer
|
||||
//First: we require input to be in c order. But c order (as declared in array order) isn't enough; also need strides to be correct
|
||||
if (input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input))
|
||||
input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c');
|
||||
|
||||
//Note that to match Tensorflow/Keras, we do a simple "c order reshape" for both NCHW and NHWC
|
||||
|
||||
val inShape = input.shape(); //[miniBatch,depthOut,outH,outW]
|
||||
val outShape = new long[]{inShape[0], inShape[1] * inShape[2] * inShape[3]};
|
||||
|
||||
|
@ -139,11 +147,13 @@ public class CnnToFeedForwardPreProcessor implements InputPreProcessor {
|
|||
+ inputHeight + " x columns " + inputWidth + " x channels " + numChannels + " but was instead "
|
||||
+ Arrays.toString(epsilons.shape()));
|
||||
|
||||
INDArray ret = epsilons.reshape('c', epsilons.size(0), numChannels, inputHeight, inputWidth);
|
||||
|
||||
if(format == CNN2DFormat.NHWC){
|
||||
ret = ret.permute(0,2,3,1); //NCHW to NHWC
|
||||
INDArray ret;
|
||||
if(format == CNN2DFormat.NCHW){
|
||||
ret = epsilons.reshape('c', epsilons.size(0), numChannels, inputHeight, inputWidth);
|
||||
} else {
|
||||
ret = epsilons.reshape('c', epsilons.size(0), inputHeight, inputWidth, numChannels);
|
||||
}
|
||||
|
||||
return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, ret); //Move if required to specified workspace
|
||||
}
|
||||
|
||||
|
|
|
@ -52,7 +52,8 @@ public class FeedForwardToRnnPreProcessor implements InputPreProcessor {
|
|||
private RNNFormat rnnDataFormat = RNNFormat.NCW;
|
||||
|
||||
public FeedForwardToRnnPreProcessor(@JsonProperty("rnnDataFormat") RNNFormat rnnDataFormat){
|
||||
this.rnnDataFormat = rnnDataFormat;
|
||||
if(rnnDataFormat != null)
|
||||
this.rnnDataFormat = rnnDataFormat;
|
||||
}
|
||||
@Override
|
||||
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
|
||||
|
|
|
@ -57,7 +57,8 @@ public class RnnToFeedForwardPreProcessor implements InputPreProcessor {
|
|||
private RNNFormat rnnDataFormat = RNNFormat.NCW;
|
||||
|
||||
public RnnToFeedForwardPreProcessor(@JsonProperty("rnnDataFormat") RNNFormat rnnDataFormat){
|
||||
this.rnnDataFormat = rnnDataFormat;
|
||||
if(rnnDataFormat != null)
|
||||
this.rnnDataFormat = rnnDataFormat;
|
||||
}
|
||||
@Override
|
||||
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
|
||||
|
@ -116,7 +117,7 @@ public class RnnToFeedForwardPreProcessor implements InputPreProcessor {
|
|||
}
|
||||
|
||||
InputType.InputTypeRecurrent rnn = (InputType.InputTypeRecurrent) inputType;
|
||||
return InputType.feedForward(rnn.getSize());
|
||||
return InputType.feedForward(rnn.getSize(), rnn.getFormat());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.nn.conf.serde.format;
|
||||
|
||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||
import org.deeplearning4j.nn.conf.DataFormat;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.nd4j.shade.jackson.core.JsonParser;
|
||||
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
||||
import org.nd4j.shade.jackson.databind.DeserializationContext;
|
||||
import org.nd4j.shade.jackson.databind.JsonDeserializer;
|
||||
import org.nd4j.shade.jackson.databind.JsonNode;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* Simple JSON deserializer for {@link DataFormat} instances - {@link CNN2DFormat} and {@link RNNFormat}
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
public class DataFormatDeserializer extends JsonDeserializer<DataFormat> {
|
||||
@Override
|
||||
public DataFormat deserialize(JsonParser jp, DeserializationContext deserializationContext) throws IOException, JsonProcessingException {
|
||||
JsonNode node = jp.getCodec().readTree(jp);
|
||||
String text = node.textValue();
|
||||
switch (text){
|
||||
case "NCHW":
|
||||
return CNN2DFormat.NCHW;
|
||||
case "NHWC":
|
||||
return CNN2DFormat.NHWC;
|
||||
case "NCW":
|
||||
return RNNFormat.NCW;
|
||||
case "NWC":
|
||||
return RNNFormat.NWC;
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.nn.conf.serde.format;
|
||||
|
||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||
import org.deeplearning4j.nn.conf.DataFormat;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.nd4j.shade.jackson.core.JsonGenerator;
|
||||
import org.nd4j.shade.jackson.databind.JsonSerializer;
|
||||
import org.nd4j.shade.jackson.databind.SerializerProvider;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* Simple JSON deserializer for {@link DataFormat} instances - {@link CNN2DFormat} and {@link RNNFormat}
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
public class DataFormatSerializer extends JsonSerializer<DataFormat> {
|
||||
@Override
|
||||
public void serialize(DataFormat dataFormat, JsonGenerator jsonGenerator, SerializerProvider serializerProvider) throws IOException {
|
||||
jsonGenerator.writeString(dataFormat.toString());
|
||||
}
|
||||
}
|
|
@ -28,6 +28,7 @@ import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
|||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.indexing.INDArrayIndex;
|
||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||
|
@ -48,14 +49,16 @@ public class MergeVertex extends BaseGraphVertex {
|
|||
|
||||
private long[][] forwardPassShapes;
|
||||
private int fwdPassRank;
|
||||
private int mergeAxis;
|
||||
|
||||
public MergeVertex(ComputationGraph graph, String name, int vertexIndex, DataType dataType) {
|
||||
this(graph, name, vertexIndex, null, null, dataType);
|
||||
public MergeVertex(ComputationGraph graph, String name, int vertexIndex, DataType dataType, int mergeAxis) {
|
||||
this(graph, name, vertexIndex, null, null, dataType, mergeAxis);
|
||||
}
|
||||
|
||||
public MergeVertex(ComputationGraph graph, String name, int vertexIndex, VertexIndices[] inputVertices,
|
||||
VertexIndices[] outputVertices, DataType dataType) {
|
||||
VertexIndices[] outputVertices, DataType dataType, int mergeAxis) {
|
||||
super(graph, name, vertexIndex, inputVertices, outputVertices, dataType);
|
||||
this.mergeAxis = mergeAxis;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -92,7 +95,6 @@ public class MergeVertex extends BaseGraphVertex {
|
|||
|
||||
forwardPassShapes = new long[in.length][0];
|
||||
val nExamples = in[0].size(0);
|
||||
int nOut = 0;
|
||||
fwdPassRank = in[0].rank();
|
||||
for (int i = 0; i < in.length; i++) {
|
||||
val currShape = in[i].shape();
|
||||
|
@ -109,12 +111,11 @@ public class MergeVertex extends BaseGraphVertex {
|
|||
+ Arrays.toString(in[0].shape()) + ", activations[" + i
|
||||
+ "] shape: " + Arrays.toString(in[i].shape()));
|
||||
}
|
||||
|
||||
nOut += currShape[1]; //Same dimension for all of CNNs, FF, RNNs
|
||||
}
|
||||
|
||||
try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATIONS)){
|
||||
return Nd4j.concat(1, in);
|
||||
INDArray out = Nd4j.concat(mergeAxis, in);
|
||||
return out;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -145,20 +146,16 @@ public class MergeVertex extends BaseGraphVertex {
|
|||
break;
|
||||
case 3:
|
||||
for (int i = 0; i < forwardPassShapes.length; i++) {
|
||||
out[i].assign(epsilon.get(NDArrayIndex.all(), //All rows
|
||||
NDArrayIndex.interval(cumulative, cumulative + forwardPassShapes[i][1]), //subset of columns
|
||||
NDArrayIndex.all())); //All time steps
|
||||
out[i].assign(epsilon.get(indices(3, mergeAxis, cumulative, cumulative + forwardPassShapes[i][mergeAxis]))); //All time steps
|
||||
|
||||
cumulative += forwardPassShapes[i][1];
|
||||
cumulative += forwardPassShapes[i][mergeAxis];
|
||||
}
|
||||
break;
|
||||
case 4:
|
||||
for (int i = 0; i < forwardPassShapes.length; i++) {
|
||||
out[i].assign(epsilon.get(NDArrayIndex.all(),
|
||||
NDArrayIndex.interval(cumulative, cumulative + forwardPassShapes[i][1]), //Subset of depth
|
||||
NDArrayIndex.all(), //Width
|
||||
NDArrayIndex.all())); //height
|
||||
cumulative += forwardPassShapes[i][1];
|
||||
out[i].assign(epsilon.get(indices(4, mergeAxis, cumulative, cumulative + forwardPassShapes[i][mergeAxis]))); //height
|
||||
|
||||
cumulative += forwardPassShapes[i][mergeAxis];
|
||||
}
|
||||
break;
|
||||
default:
|
||||
|
@ -168,6 +165,19 @@ public class MergeVertex extends BaseGraphVertex {
|
|||
return new Pair<>(null, out);
|
||||
}
|
||||
|
||||
private INDArrayIndex[] indices(int num, int axis, long from, long to){
|
||||
INDArrayIndex[] out = new INDArrayIndex[num];
|
||||
for( int i=0; i<num; i++ ){
|
||||
if(i == axis){
|
||||
out[i] = NDArrayIndex.interval(from, to);
|
||||
} else {
|
||||
out[i] = NDArrayIndex.all();
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void setBackpropGradientsViewArray(INDArray backpropGradientsViewArray) {
|
||||
if (backpropGradientsViewArray != null)
|
||||
|
|
|
@ -79,7 +79,8 @@ public abstract class BaseOutputLayer<LayerConfT extends org.deeplearning4j.nn.c
|
|||
|
||||
ILossFunction lossFunction = layerConf().getLossFn();
|
||||
|
||||
double score = lossFunction.computeScore(getLabels2d(workspaceMgr, ArrayType.FF_WORKING_MEM), preOut,
|
||||
INDArray labels2d = getLabels2d(workspaceMgr, ArrayType.FF_WORKING_MEM);
|
||||
double score = lossFunction.computeScore(labels2d, preOut,
|
||||
layerConf().getActivationFn(), maskArray,false);
|
||||
|
||||
if(conf().isMiniBatch())
|
||||
|
|
|
@ -21,6 +21,7 @@ import org.deeplearning4j.exception.DL4JInvalidInputException;
|
|||
import org.deeplearning4j.nn.api.Layer;
|
||||
import org.deeplearning4j.nn.conf.CacheMode;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.gradient.DefaultGradient;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.nn.layers.convolution.upsampling.Upsampling2D;
|
||||
|
@ -71,7 +72,11 @@ public class RepeatVector extends AbstractLayer<org.deeplearning4j.nn.conf.layer
|
|||
|
||||
INDArray outEpsilon;
|
||||
try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATION_GRAD)){
|
||||
outEpsilon = epsilon.sum(2);
|
||||
if (layerConf().getDataFormat() == RNNFormat.NCW) {
|
||||
outEpsilon = epsilon.sum(2);
|
||||
}else{
|
||||
outEpsilon = epsilon.sum(1);
|
||||
}
|
||||
}
|
||||
|
||||
Gradient gradient = new DefaultGradient();
|
||||
|
@ -99,10 +104,22 @@ public class RepeatVector extends AbstractLayer<org.deeplearning4j.nn.conf.layer
|
|||
|
||||
long miniBatch = input.size(0);
|
||||
long size = input.size(1);
|
||||
INDArray output = input.reshape(miniBatch, size, 1).castTo(dataType);
|
||||
try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATIONS)) {
|
||||
return output.repeat(2, (long) getN());
|
||||
if (getDataFormat() == RNNFormat.NCW) {
|
||||
INDArray output = input.reshape(miniBatch, size, 1).castTo(dataType);
|
||||
try (MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATIONS)) {
|
||||
return output.repeat(2, (long) getN());
|
||||
}
|
||||
}
|
||||
else{
|
||||
INDArray output = input.reshape(miniBatch, 1, size).castTo(dataType);
|
||||
try (MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATIONS)) {
|
||||
return output.repeat(1, (long) getN());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public RNNFormat getDataFormat(){
|
||||
return layerConf().getDataFormat();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -20,6 +20,8 @@ import org.deeplearning4j.exception.DL4JInvalidInputException;
|
|||
import org.deeplearning4j.nn.api.MaskState;
|
||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.layers.Convolution1D;
|
||||
import org.deeplearning4j.nn.gradient.DefaultGradient;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.nn.params.ConvolutionParamInitializer;
|
||||
|
@ -74,6 +76,10 @@ public class Convolution1DLayer extends ConvolutionLayer {
|
|||
+ Arrays.toString(epsilon.shape())
|
||||
+ ". Expected rank 3 array with shape [minibatchSize, features, length]. " + layerId());
|
||||
|
||||
if (getRnnDataFormat() == RNNFormat.NWC){
|
||||
epsilon = epsilon.permute(0, 2, 1);
|
||||
this.input = input.permute(0, 2, 1);
|
||||
}
|
||||
if(maskArray != null){
|
||||
INDArray maskOut = feedForwardMaskArray(maskArray, MaskState.Active, (int)epsilon.size(0)).getFirst();
|
||||
Preconditions.checkState(epsilon.size(0) == maskOut.size(0) && epsilon.size(2) == maskOut.size(1),
|
||||
|
@ -125,6 +131,10 @@ public class Convolution1DLayer extends ConvolutionLayer {
|
|||
retGradient.setGradientFor(ConvolutionParamInitializer.BIAS_KEY, gradientViews.get(ConvolutionParamInitializer.BIAS_KEY));
|
||||
}
|
||||
retGradient.setGradientFor(ConvolutionParamInitializer.WEIGHT_KEY, gradientViews.get(ConvolutionParamInitializer.WEIGHT_KEY), 'c');
|
||||
if (getRnnDataFormat() == RNNFormat.NWC){
|
||||
epsOut = epsOut.permute(0, 2, 1);
|
||||
this.input = input.permute(0, 2, 1);
|
||||
}
|
||||
return new Pair<>(retGradient, epsOut);
|
||||
}
|
||||
|
||||
|
@ -140,7 +150,10 @@ public class Convolution1DLayer extends ConvolutionLayer {
|
|||
// remove singleton fourth dimension from input and current epsilon
|
||||
epsNext = epsNext.reshape(epsNext.size(0), epsNext.size(1), epsNext.size(2));
|
||||
input = origInput;
|
||||
|
||||
if (getRnnDataFormat() == RNNFormat.NWC){
|
||||
epsNext = epsNext.permute(0, 2, 1);
|
||||
this.input = input.permute(0, 2, 1);
|
||||
}
|
||||
return new Pair<>(gradientEpsNext.getFirst(), epsNext);
|
||||
}
|
||||
|
||||
|
@ -185,7 +198,8 @@ public class Convolution1DLayer extends ConvolutionLayer {
|
|||
.s(c.getStride()[0])
|
||||
.d(c.getDilation()[0])
|
||||
.p(c.getPadding()[0])
|
||||
.dataFormat(Conv1DConfig.NCW)
|
||||
.dataFormat((((org.deeplearning4j.nn.conf.layers.Convolution1DLayer)
|
||||
layerConf()).getRnnDataFormat()== RNNFormat.NCW)?Conv1DConfig.NCW: Conv1DConfig.NCW)
|
||||
.paddingMode(PaddingMode.CAUSAL)
|
||||
.build();
|
||||
INDArray w = getParam(ConvolutionParamInitializer.WEIGHT_KEY);
|
||||
|
@ -209,6 +223,9 @@ public class Convolution1DLayer extends ConvolutionLayer {
|
|||
|
||||
@Override
|
||||
public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr){
|
||||
if (getRnnDataFormat() == RNNFormat.NWC){
|
||||
this.input = input.permute(0, 2, 1);
|
||||
}
|
||||
INDArray act4d = super.activate(training, workspaceMgr);
|
||||
INDArray act3d = act4d.reshape(act4d.size(0), act4d.size(1), act4d.size(2));
|
||||
|
||||
|
@ -219,6 +236,10 @@ public class Convolution1DLayer extends ConvolutionLayer {
|
|||
act3d.shape(), maskOut.shape());
|
||||
Broadcast.mul(act3d, maskOut, act3d, 0, 2);
|
||||
}
|
||||
if (getRnnDataFormat() == RNNFormat.NWC){
|
||||
this.input = input.permute(0, 2, 1);
|
||||
act3d = act3d.permute(0, 2, 1);
|
||||
}
|
||||
|
||||
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, act3d); //Should be zero copy most of the time
|
||||
}
|
||||
|
@ -231,4 +252,8 @@ public class Convolution1DLayer extends ConvolutionLayer {
|
|||
layerConf().getConvolutionMode());
|
||||
return new Pair<>(reduced, currentMaskState);
|
||||
}
|
||||
|
||||
private RNNFormat getRnnDataFormat(){
|
||||
return ((org.deeplearning4j.nn.conf.layers.Convolution1DLayer) layerConf()).getRnnDataFormat();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -160,7 +160,8 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
|
|||
|
||||
Pair<INDArray, INDArray> p = preOutput4d(true, true, workspaceMgr);
|
||||
INDArray z = p.getFirst();
|
||||
if(layerConf().getCnn2dDataFormat() != CNN2DFormat.NCHW){
|
||||
CNN2DFormat f = layerConf().getCnn2dDataFormat();
|
||||
if(f != CNN2DFormat.NCHW){
|
||||
z = z.permute(0,3,1,2); //NHWC to NCHW
|
||||
}
|
||||
delta = afn.backprop(z, epsilon).getFirst(); //TODO handle activation function params
|
||||
|
|
|
@ -20,6 +20,7 @@ import lombok.extern.slf4j.Slf4j;
|
|||
import lombok.val;
|
||||
import org.deeplearning4j.exception.DL4JInvalidInputException;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.gradient.DefaultGradient;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.nn.layers.BaseLayer;
|
||||
|
@ -64,8 +65,14 @@ public class EmbeddingSequenceLayer extends BaseLayer<org.deeplearning4j.nn.conf
|
|||
INDArray z = preOutput(true, workspaceMgr);
|
||||
INDArray delta = layerConf().getActivationFn().backprop(z, epsilon).getFirst(); //Shape: [mb, vector, seqLength]
|
||||
|
||||
boolean ncw = layerConf().getOutputFormat() == RNNFormat.NCW;
|
||||
|
||||
if (maskArray != null) {
|
||||
delta = Broadcast.mul(delta, maskArray, delta, 0, 2);
|
||||
if(ncw){
|
||||
delta = Broadcast.mul(delta, maskArray, delta, 0, 2);
|
||||
} else {
|
||||
delta = Broadcast.mul(delta, maskArray, delta, 0, 1);
|
||||
}
|
||||
}
|
||||
|
||||
int inputLength = layerConf().getInputLength();
|
||||
|
@ -76,7 +83,10 @@ public class EmbeddingSequenceLayer extends BaseLayer<org.deeplearning4j.nn.conf
|
|||
delta = delta.dup('c');
|
||||
}
|
||||
|
||||
delta = delta.permute(0, 2, 1); //From [minibatch, nOut, length] to [minibatch, length, nOut]
|
||||
if(ncw){
|
||||
delta = delta.permute(0, 2, 1); //From [minibatch, nOut, length] to [minibatch, length, nOut]
|
||||
}
|
||||
|
||||
delta = delta.reshape('c',inputLength * numSamples, nOut);
|
||||
|
||||
INDArray weightGradients = gradientViews.get(DefaultParamInitializer.WEIGHT_KEY);
|
||||
|
@ -159,7 +169,10 @@ public class EmbeddingSequenceLayer extends BaseLayer<org.deeplearning4j.nn.conf
|
|||
}
|
||||
|
||||
val shape = new long[]{minibatch, inputLength, nOut};
|
||||
INDArray ret = rows.reshape('c', shape).permute(0, 2, 1);
|
||||
INDArray ret = rows.reshape('c', shape);
|
||||
if(layerConf().getOutputFormat() == RNNFormat.NCW){
|
||||
ret = ret.permute(0, 2, 1); //[minibatch, seqLen, nOut] -> [minibatch, nOut, seqLen] i.e., NWC -> NCW
|
||||
}
|
||||
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, ret);
|
||||
}
|
||||
|
||||
|
@ -177,8 +190,14 @@ public class EmbeddingSequenceLayer extends BaseLayer<org.deeplearning4j.nn.conf
|
|||
" 2 (when input is rank 3, shape [mb,1,tsLength]). Input shape: " + Arrays.toString(input.shape()) +
|
||||
", mask shape: " + Arrays.toString(maskArray.shape()));
|
||||
}
|
||||
//Returned array: rank 3, shape [mb, vector, seqLength]. mask shape: [mb, seqLength]
|
||||
Broadcast.mul(ret, maskArray.castTo(ret.dataType()), ret, 0, 2);
|
||||
boolean ncw = layerConf().getOutputFormat() == RNNFormat.NCW;
|
||||
if(ncw){
|
||||
//Returned array: rank 3, shape [mb, vector, seqLength]. mask shape: [mb, seqLength]
|
||||
Broadcast.mul(ret, maskArray.castTo(ret.dataType()), ret, 0, 2);
|
||||
} else {
|
||||
//Returned array: rank 3, shape [mb, seqLength, vector]. mask shape: [mb, seqLength]
|
||||
Broadcast.mul(ret, maskArray.castTo(ret.dataType()), ret, 0, 1);
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
|
|
@ -20,11 +20,13 @@ import lombok.extern.slf4j.Slf4j;
|
|||
import org.deeplearning4j.nn.api.MaskState;
|
||||
import org.deeplearning4j.nn.conf.CacheMode;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.nn.layers.LayerHelper;
|
||||
import org.deeplearning4j.nn.layers.mkldnn.BaseMKLDNNHelper;
|
||||
import org.deeplearning4j.nn.layers.mkldnn.MKLDNNLSTMHelper;
|
||||
import org.deeplearning4j.nn.params.LSTMParamInitializer;
|
||||
import org.deeplearning4j.util.TimeSeriesUtils;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -152,6 +154,14 @@ public class LSTM extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.layers.L
|
|||
assertInputSet(false);
|
||||
Preconditions.checkState(input.rank() == 3,
|
||||
"3D input expected to RNN layer expected, got " + input.rank());
|
||||
|
||||
boolean nwc = TimeSeriesUtils.getFormatFromRnnLayer(layerConf()) == RNNFormat.NWC;
|
||||
|
||||
INDArray origInput = input;
|
||||
if(nwc){
|
||||
input = permuteIfNWC(input);
|
||||
}
|
||||
|
||||
applyDropOutIfNecessary(training, workspaceMgr);
|
||||
|
||||
//TODO LSTM cache mode is disabled for now - not passing all tests
|
||||
|
@ -166,7 +176,6 @@ public class LSTM extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.layers.L
|
|||
final INDArray recurrentWeights = getParamWithNoise(LSTMParamInitializer.RECURRENT_WEIGHT_KEY, training, workspaceMgr); //Shape: [hiddenLayerSize,4*hiddenLayerSize+3]; order: [wI,wF,wO,wG,wFF,wOO,wGG]
|
||||
final INDArray inputWeights = getParamWithNoise(LSTMParamInitializer.INPUT_WEIGHT_KEY, training, workspaceMgr); //Shape: [n^(L-1),4*hiddenLayerSize]; order: [wi,wf,wo,wg]
|
||||
final INDArray biases = getParamWithNoise(LSTMParamInitializer.BIAS_KEY, training, workspaceMgr); //by row: IFOG //Shape: [4,hiddenLayerSize]; order: [bi,bf,bo,bg]^T
|
||||
INDArray input = permuteIfNWC(this.input);
|
||||
FwdPassReturn fwd = LSTMHelpers.activateHelper(this, this.conf, this.layerConf().getGateActivationFn(),
|
||||
input, recurrentWeights, inputWeights, biases, training, prevOutputActivations,
|
||||
prevMemCellState, (training && cacheMode != CacheMode.NONE) || forBackprop, true,
|
||||
|
@ -178,6 +187,11 @@ public class LSTM extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.layers.L
|
|||
if (training && cacheMode != CacheMode.NONE) {
|
||||
cachedFwdPass = fwd;
|
||||
}
|
||||
|
||||
if(nwc){
|
||||
input = origInput;
|
||||
}
|
||||
|
||||
return fwd;
|
||||
}
|
||||
|
||||
|
|
|
@ -61,11 +61,8 @@ public class LastTimeStepLayer extends BaseWrapperLayer {
|
|||
@Override
|
||||
public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
|
||||
long[] newEpsShape = origOutputShape;
|
||||
boolean nwc = (underlying instanceof BaseRecurrentLayer &&
|
||||
((BaseRecurrentLayer) underlying).getDataFormat() == RNNFormat.NWC)||
|
||||
(underlying instanceof MaskZeroLayer && ((MaskZeroLayer)underlying).getUnderlying() instanceof
|
||||
BaseRecurrentLayer && ((BaseRecurrentLayer)((MaskZeroLayer)underlying).getUnderlying()).getDataFormat()
|
||||
== RNNFormat.NWC);
|
||||
|
||||
boolean nwc = TimeSeriesUtils.getFormatFromRnnLayer(underlying.conf().getLayer()) == RNNFormat.NWC;
|
||||
INDArray newEps = Nd4j.create(epsilon.dataType(), newEpsShape, 'f');
|
||||
if(lastTimeStepIdxs == null){
|
||||
//no mask case
|
||||
|
|
|
@ -58,7 +58,8 @@ public class RnnOutputLayer extends BaseOutputLayer<org.deeplearning4j.nn.conf.l
|
|||
"Input is not rank 3. RnnOutputLayer expects rank 3 input with shape [minibatch, layerInSize, sequenceLength]." +
|
||||
" Got input with rank " + input.rank() + " and shape " + Arrays.toString(input.shape()) + " - " + layerId());
|
||||
}
|
||||
int td = (layerConf().getRnnDataFormat()==RNNFormat.NCW)? 2: 1;
|
||||
RNNFormat format = layerConf().getRnnDataFormat();
|
||||
int td = (format == RNNFormat.NCW) ? 2 : 1;
|
||||
Preconditions.checkState(labels.rank() == 3, "Expected rank 3 labels array, got label array with shape %ndShape", labels);
|
||||
Preconditions.checkState(input.size(td) == labels.size(td), "Sequence lengths do not match for RnnOutputLayer input and labels:" +
|
||||
"Arrays should be rank 3 with shape [minibatch, size, sequenceLength] - mismatch on dimension 2 (sequence length) - input=%ndShape vs. label=%ndShape", input, labels);
|
||||
|
|
|
@ -118,6 +118,7 @@ public class ConvolutionParamInitializer implements ParamInitializer {
|
|||
params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams));
|
||||
conf.addVariable(WEIGHT_KEY);
|
||||
conf.addVariable(BIAS_KEY);
|
||||
conf.addVariable(BIAS_KEY);
|
||||
} else {
|
||||
INDArray weightView = paramsView;
|
||||
params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams));
|
||||
|
|
|
@ -34,6 +34,7 @@ import org.deeplearning4j.ui.api.UIServer;
|
|||
import org.deeplearning4j.ui.stats.StatsListener;
|
||||
import org.deeplearning4j.ui.storage.InMemoryStatsStorage;
|
||||
import org.junit.Before;
|
||||
import org.junit.Ignore;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
|
@ -53,7 +54,7 @@ import static org.junit.Assert.*;
|
|||
/**
|
||||
* @author Tamas Fenyvesi
|
||||
*/
|
||||
@Slf4j
|
||||
@Slf4j @Ignore //https://github.com/eclipse/deeplearning4j/issues/8891
|
||||
public class TestVertxUIMultiSession extends BaseDL4JTest {
|
||||
|
||||
@Before
|
||||
|
|
Loading…
Reference in New Issue