diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecution.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecution.java index 5c48b4c18..ad8601000 100644 --- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecution.java +++ b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecution.java @@ -98,7 +98,7 @@ public class TestGraphLocalExecution extends BaseDL4JTest { @Override public long getTimeoutMilliseconds() { - return 90000L; + return 120_000L; } @Test @@ -156,7 +156,7 @@ public class TestGraphLocalExecution extends BaseDL4JTest { .dataSource(ds, dsP) .modelSaver(new FileModelSaver(modelSave)) .scoreFunction(new TestSetLossScoreFunction()) - .terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS), + .terminationConditions(new MaxTimeCondition(20, TimeUnit.SECONDS), new MaxCandidatesCondition(3)) .build(); diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecutionGenetic.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecutionGenetic.java index 91daa027f..caeffaaa7 100644 --- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecutionGenetic.java +++ b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecutionGenetic.java @@ -87,7 +87,7 @@ public class TestGraphLocalExecutionGenetic extends BaseDL4JTest { @Override public long getTimeoutMilliseconds() { - return 45000L; + return 120_000L; } @Test @@ -154,8 +154,8 @@ public class TestGraphLocalExecutionGenetic extends BaseDL4JTest { .dataSource(ds, dsP) .modelSaver(new FileModelSaver(modelSave)) .scoreFunction(new TestSetLossScoreFunction()) - .terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS), - new MaxCandidatesCondition(10)) + .terminationConditions(new MaxTimeCondition(20, TimeUnit.SECONDS), + new MaxCandidatesCondition(3)) .build(); IOptimizationRunner runner = new LocalOptimizationRunner(configuration, new ComputationGraphTaskCreator(new ClassificationEvaluator())); diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/records/reader/impl/LineRecordReader.java b/datavec/datavec-api/src/main/java/org/datavec/api/records/reader/impl/LineRecordReader.java index 0fd5793b8..407a71fea 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/records/reader/impl/LineRecordReader.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/records/reader/impl/LineRecordReader.java @@ -18,6 +18,7 @@ package org.datavec.api.records.reader.impl; import lombok.Getter; import lombok.Setter; +import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.IOUtils; import org.apache.commons.io.LineIterator; import org.datavec.api.conf.Configuration; @@ -43,6 +44,7 @@ import java.util.*; * * @author Adam Gibson */ +@Slf4j public class LineRecordReader extends BaseRecordReader { @@ -58,6 +60,13 @@ public class LineRecordReader extends BaseRecordReader { @Override public void initialize(InputSplit split) throws IOException, InterruptedException { super.initialize(split); + if(!(inputSplit instanceof StringSplit || inputSplit instanceof InputStreamInputSplit)){ + final ArrayList uris = new ArrayList<>(); + final Iterator uriIterator = inputSplit.locationsIterator(); + while(uriIterator.hasNext()) uris.add(uriIterator.next()); + + this.locations = uris.toArray(new URI[0]); + } this.iter = getIterator(0); this.initialized = true; } @@ -66,7 +75,6 @@ public class LineRecordReader extends BaseRecordReader { public void initialize(Configuration conf, InputSplit split) throws IOException, InterruptedException { this.conf = conf; initialize(split); - this.initialized = true; } @Override @@ -89,7 +97,7 @@ public class LineRecordReader extends BaseRecordReader { iter = getIterator(splitIndex); onLocationOpen(locations[splitIndex]); } catch (IOException e) { - e.printStackTrace(); + log.error("",e); } if (iter.hasNext()) { @@ -120,7 +128,7 @@ public class LineRecordReader extends BaseRecordReader { iter = getIterator(splitIndex); onLocationOpen(locations[splitIndex]); } catch (IOException e) { - e.printStackTrace(); + log.error("",e); } return iter.hasNext(); @@ -205,11 +213,6 @@ public class LineRecordReader extends BaseRecordReader { } } } else { - final ArrayList uris = new ArrayList<>(); - final Iterator uriIterator = inputSplit.locationsIterator(); - while(uriIterator.hasNext()) uris.add(uriIterator.next()); - - this.locations = uris.toArray(new URI[uris.size()]); if (locations.length > 0) { InputStream inputStream = streamCreatorFn.apply(locations[location]); try { diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/split/NumberedFileInputSplit.java b/datavec/datavec-api/src/main/java/org/datavec/api/split/NumberedFileInputSplit.java index 2985462c8..5227a0cc3 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/split/NumberedFileInputSplit.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/split/NumberedFileInputSplit.java @@ -16,6 +16,7 @@ package org.datavec.api.split; +import lombok.extern.slf4j.Slf4j; import org.datavec.api.util.files.UriFromPathIterator; import org.datavec.api.writable.WritableType; @@ -34,6 +35,7 @@ import java.util.regex.Pattern; * NumberedFileInputSplit utilizes String.format(), hence the requirement for "%d" to represent * the integer index. */ +@Slf4j public class NumberedFileInputSplit implements InputSplit { private final String baseString; private final int minIdx; @@ -93,7 +95,7 @@ public class NumberedFileInputSplit implements InputSplit { try { writeFile.createNewFile(); } catch (IOException e) { - e.printStackTrace(); + log.error("",e); } diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/util/files/UriFromPathIterator.java b/datavec/datavec-api/src/main/java/org/datavec/api/util/files/UriFromPathIterator.java index ffbd36fe5..4b910410d 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/util/files/UriFromPathIterator.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/util/files/UriFromPathIterator.java @@ -23,6 +23,7 @@ import java.net.URI; import java.net.URISyntaxException; import java.util.Iterator; import java.util.NoSuchElementException; +import java.util.regex.Pattern; /** * A simple utility method to convert a {@code Iterator} to an {@code Iterator}, where each @@ -32,6 +33,7 @@ import java.util.NoSuchElementException; */ @AllArgsConstructor public class UriFromPathIterator implements Iterator { + final Pattern schemaPattern = Pattern.compile("^.*?:/.*"); private final Iterator paths; @@ -42,16 +44,17 @@ public class UriFromPathIterator implements Iterator { @Override public URI next() { + if (!hasNext()) { throw new NoSuchElementException("No next element"); } try { String s = paths.next(); - if(!s.matches(".*:/.*")){ + if(schemaPattern.matcher(s).matches()){ + return new URI(s); + } else { //No scheme - assume file for backward compatibility return new File(s).toURI(); - } else { - return new URI(s); } } catch (URISyntaxException e) { diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/writable/Text.java b/datavec/datavec-api/src/main/java/org/datavec/api/writable/Text.java index d3c5b03d4..bdf401f1e 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/writable/Text.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/writable/Text.java @@ -162,7 +162,6 @@ public class Text extends BinaryComparable implements WritableComparableasList(ndArrayWritable); } catch (IOException e) { - e.printStackTrace(); + log.error("",e); } } if (iter != null) { diff --git a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/TestNativeImageLoader.java b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/TestNativeImageLoader.java index b8fa0c43d..9d139cbf5 100644 --- a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/TestNativeImageLoader.java +++ b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/TestNativeImageLoader.java @@ -16,6 +16,7 @@ package org.datavec.image.loader; +import lombok.extern.slf4j.Slf4j; import lombok.val; import org.apache.commons.io.IOUtils; import org.bytedeco.javacpp.Loader; @@ -53,6 +54,7 @@ import static org.junit.Assert.fail; * * @author saudet */ +@Slf4j public class TestNativeImageLoader { static final long seed = 10; static final Random rng = new Random(seed); @@ -123,7 +125,7 @@ public class TestNativeImageLoader { try { array6 = loader5.asMatrix(new ClassPathResource(path2MitosisFile).getFile().getAbsolutePath()); } catch (IOException e) { - e.printStackTrace(); + log.error("",e); fail(); } assertEquals(5, array6.rank()); @@ -156,7 +158,7 @@ public class TestNativeImageLoader { try { array8 = loader7.asMatrix(new ClassPathResource(path2MitosisFile).getFile().getAbsolutePath()); } catch (IOException e) { - e.printStackTrace(); + log.error("",e); } assertEquals(5, array8.rank()); assertEquals(pages2, array8.size(0)); @@ -172,7 +174,7 @@ public class TestNativeImageLoader { try { array9 = loader8.asMatrix(new ClassPathResource(braintiff).getFile().getAbsolutePath()); } catch (IOException e) { - e.printStackTrace(); + log.error("",e); fail(); } assertEquals(5, array9.rank()); diff --git a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/UimaTokenizer.java b/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/UimaTokenizer.java index 7ac2b2969..eb14fdabd 100644 --- a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/UimaTokenizer.java +++ b/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/UimaTokenizer.java @@ -66,7 +66,7 @@ public class UimaTokenizer implements Tokenizer { } catch (Exception e) { - e.printStackTrace(); + log.error("",e); throw new RuntimeException(e); } diff --git a/datavec/datavec-local/src/main/java/org/datavec/local/transforms/functions/SequenceRecordReaderFunction.java b/datavec/datavec-local/src/main/java/org/datavec/local/transforms/functions/SequenceRecordReaderFunction.java index 70391f66f..94026f2cf 100644 --- a/datavec/datavec-local/src/main/java/org/datavec/local/transforms/functions/SequenceRecordReaderFunction.java +++ b/datavec/datavec-local/src/main/java/org/datavec/local/transforms/functions/SequenceRecordReaderFunction.java @@ -17,6 +17,7 @@ package org.datavec.local.transforms.functions; +import lombok.extern.slf4j.Slf4j; import org.datavec.api.records.reader.SequenceRecordReader; import org.datavec.api.writable.Writable; import org.nd4j.linalg.function.Function; @@ -32,6 +33,7 @@ import java.util.List; * sequence data into a {@code List>} * @author Alex Black */ +@Slf4j public class SequenceRecordReaderFunction implements Function, List>> { protected SequenceRecordReader sequenceRecordReader; @@ -46,7 +48,7 @@ public class SequenceRecordReaderFunction try (DataInputStream dis = (DataInputStream) value.getRight()) { return sequenceRecordReader.sequenceRecord(uri, dis); } catch (IOException e) { - e.printStackTrace(); + log.error("",e); } throw new IllegalStateException("Something went wrong"); diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java index 3d08d3141..d0ed42e36 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java @@ -20,6 +20,7 @@ package org.datavec.python; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.IOUtils; +import org.bytedeco.cpython.global.python; import org.bytedeco.numpy.global.numpy; import org.nd4j.linalg.api.concurrency.AffinityManager; import org.nd4j.linalg.api.ndarray.INDArray; @@ -343,6 +344,19 @@ public class PythonExecutioner { if (path == null) { log.info("Setting python default path"); File[] packages = numpy.cachePackages(); + + //// TODO: fix in javacpp + File sitePackagesWindows = new File(python.cachePackage(), "site-packages"); + File[] packages2 = new File[packages.length + 1]; + for (int i = 0;i < packages.length; i++){ + //System.out.println(packages[i].getAbsolutePath()); + packages2[i] = packages[i]; + } + packages2[packages.length] = sitePackagesWindows; + //System.out.println(sitePackagesWindows.getAbsolutePath()); + packages = packages2; + ////////// + Py_SetPath(packages); } else { log.info("Setting python path " + path); diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonProcess.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonProcess.java new file mode 100644 index 000000000..a8ee56510 --- /dev/null +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonProcess.java @@ -0,0 +1,132 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +package org.datavec.python; + +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.io.IOUtils; +import org.bytedeco.javacpp.Loader; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; + +@Slf4j +public class PythonProcess { + private static String pythonExecutable = Loader.load(org.bytedeco.cpython.python.class); + public static String runAndReturn(String... arguments)throws IOException, InterruptedException{ + String[] allArgs = new String[arguments.length + 1]; + for (int i = 0; i < arguments.length; i++){ + allArgs[i + 1] = arguments[i]; + } + allArgs[0] = pythonExecutable; + log.info("Executing command: " + Arrays.toString(allArgs)); + ProcessBuilder pb = new ProcessBuilder(allArgs); + Process process = pb.start(); + String out = IOUtils.toString(process.getInputStream(), StandardCharsets.UTF_8); + process.waitFor(); + return out; + + } + + public static void run(String... arguments)throws IOException, InterruptedException{ + String[] allArgs = new String[arguments.length + 1]; + for (int i = 0; i < arguments.length; i++){ + allArgs[i + 1] = arguments[i]; + } + allArgs[0] = pythonExecutable; + log.info("Executing command: " + Arrays.toString(allArgs)); + ProcessBuilder pb = new ProcessBuilder(allArgs); + pb.inheritIO().start().waitFor(); + } + public static void pipInstall(String packageName) throws PythonException{ + try{ + run("-m", "pip", "install", packageName); + }catch(Exception e){ + throw new PythonException("Error installing package " + packageName, e); + } + + } + + public static void pipInstall(String packageName, String version) throws PythonException{ + pipInstall(packageName + "==" + version); + } + + public static void pipUninstall(String packageName) throws PythonException{ + try{ + run("-m", "pip", "uninstall", packageName); + }catch(Exception e){ + throw new PythonException("Error uninstalling package " + packageName, e); + } + + } + public static void pipInstallFromGit(String gitRepoUrl) throws PythonException{ + if (!gitRepoUrl.contains("://")){ + gitRepoUrl = "git://" + gitRepoUrl; + } + try{ + run("-m", "pip", "install", "git+", gitRepoUrl); + }catch(Exception e){ + throw new PythonException("Error installing package from " + gitRepoUrl, e); + } + + } + + public static String getPackageVersion(String packageName) throws PythonException{ + String out; + try{ + out = runAndReturn("-m", "pip", "show", packageName); + } catch (Exception e){ + throw new PythonException("Error finding version for package " + packageName, e); + } + + if (!out.contains("Version: ")){ + throw new PythonException("Can't find package " + packageName); + } + String pkgVersion = out.split("Version: ")[1].split(System.lineSeparator())[0]; + return pkgVersion; + } + + public static boolean isPackageInstalled(String packageName)throws PythonException{ + try{ + String out = runAndReturn("-m", "pip", "show", packageName); + return !out.isEmpty(); + }catch (Exception e){ + throw new PythonException("Error checking if package is installed: " +packageName, e); + } + + } + + public static void pipInstallFromRequirementsTxt(String path) throws PythonException{ + try{ + run("-m", "pip", "install","-r", path); + }catch (Exception e){ + throw new PythonException("Error installing packages from " + path, e); + } + } + + public static void pipInstallFromSetupScript(String path, boolean inplace) throws PythonException{ + + try{ + run(path, inplace?"develop":"install"); + }catch (Exception e){ + throw new PythonException("Error installing package from " + path, e); + } + + } + +} \ No newline at end of file diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/keras/Model.java b/datavec/datavec-python/src/main/java/org/datavec/python/keras/Model.java new file mode 100644 index 000000000..d8a9b0651 --- /dev/null +++ b/datavec/datavec-python/src/main/java/org/datavec/python/keras/Model.java @@ -0,0 +1,144 @@ +package org.datavec.python.keras; + +import org.datavec.python.Python; +import org.datavec.python.PythonException; +import org.datavec.python.PythonObject; +import org.datavec.python.PythonProcess; +import org.nd4j.linalg.api.ndarray.INDArray; + +public class Model { + + private PythonObject pyModel; + + + private static PythonObject installAndImportTF() throws PythonException{ + if (!PythonProcess.isPackageInstalled("tensorflow")){ + PythonProcess.pipInstall("tensorflow"); + } + return Python.importModule("tensorflow"); + } + private static PythonObject getKerasModule() throws PythonException{ + PythonObject tf = installAndImportTF(); + PythonObject keras = tf.attr("keras"); + tf.del(); + return keras; + } + + private static PythonObject loadModel(String s) throws PythonException{ + PythonObject models = getKerasModule().attr("models"); + PythonObject loadModelF = models.attr("load_model"); + PythonObject model = loadModelF.call(s); + models.del(); + loadModelF.del(); + return model; + } + + public Model(String path) throws PythonException{ + pyModel = loadModel(path); + } + + public INDArray[] predict(INDArray... inputs) throws PythonException{ + PythonObject predictF = pyModel.attr("predict"); + PythonObject inputList = new PythonObject(inputs); + PythonObject pyOut = predictF.call(inputList); + INDArray[] out; + if (Python.isinstance(pyOut, Python.listType())){ + out = new INDArray[Python.len(pyOut).toInt()]; + for(int i = 0; i < out.length; i++){ + out[i] = pyOut.get(i).toNumpy().getNd4jArray(); + } + } + else{ + out = new INDArray[]{ + pyOut.toNumpy().getNd4jArray()}; + } + + predictF.del(); + inputList.del(); + pyOut.del(); + return out; + } + + public int numInputs(){ + PythonObject inputs = pyModel.attr("inputs"); + PythonObject pyNumInputs = Python.len(inputs); + int ret = pyNumInputs.toInt(); + inputs.del(); + pyNumInputs.del(); + return ret; + } + public int numOutputs(){ + PythonObject outputs = pyModel.attr("outputs"); + PythonObject pyNumOutputs = Python.len(outputs); + int ret = pyNumOutputs.toInt(); + outputs.del(); + pyNumOutputs.del(); + return ret; + } + + public long[][] inputShapes(){ + long[][] ret = new long[numInputs()][]; + for (int i = 0; i < ret.length; i++){ + ret[i] = inputShapeAt(i); + } + return ret; + } + + public long[][] outputShapes(){ + long[][] ret = new long[numOutputs()][]; + for (int i = 0; i < ret.length; i++){ + ret[i] = outputShapeAt(i); + } + return ret; + } + + public long[] inputShapeAt(int input){ + PythonObject inputs = pyModel.attr("inputs"); + PythonObject tensor = inputs.get(input); + PythonObject tensorShape = tensor.attr("shape"); + PythonObject shapeList = Python.list(tensorShape); + PythonObject pyNdim = Python.len(shapeList); + int ndim = pyNdim.toInt(); + long[] shape = new long[ndim]; + for(int i = 0; i < shape.length; i++){ + PythonObject pyDim = shapeList.get(i); + if (pyDim == null || !Python.isinstance(pyDim, Python.intType())){ + shape[i] = -1; + } + else{ + shape[i] = pyDim.toLong(); + } + } + pyNdim.del(); + shapeList.del(); + tensorShape.del(); + tensor.del(); + inputs.del(); + return shape; + } + + public long[] outputShapeAt(int output){ + PythonObject inputs = pyModel.attr("outputs"); + PythonObject tensor = inputs.get(output); + PythonObject tensorShape = tensor.attr("shape"); + PythonObject shapeList = Python.list(tensorShape); + PythonObject pyNdim = Python.len(shapeList); + int ndim = pyNdim.toInt(); + long[] shape = new long[ndim]; + for(int i = 0; i < shape.length; i++){ + PythonObject pyDim = shapeList.get(i); + if (pyDim == null || !Python.isinstance(pyDim, Python.intType())){ + shape[i] = -1; + } + else{ + shape[i] = pyDim.toLong(); + } + } + pyNdim.del(); + shapeList.del(); + tensorShape.del(); + tensor.del(); + inputs.del(); + return shape; + } +} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/main/java/org/datavec/spark/transform/client/DataVecTransformClient.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/main/java/org/datavec/spark/transform/client/DataVecTransformClient.java index e84488f9d..a481f2cea 100644 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/main/java/org/datavec/spark/transform/client/DataVecTransformClient.java +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/main/java/org/datavec/spark/transform/client/DataVecTransformClient.java @@ -74,7 +74,6 @@ public class DataVecTransformClient implements DataVecTransformService { } catch (UnirestException e) { log.error("Error in setCSVTransformProcess()", e); - e.printStackTrace(); } } @@ -94,7 +93,6 @@ public class DataVecTransformClient implements DataVecTransformService { return TransformProcess.fromJson(s); } catch (UnirestException e) { log.error("Error in getCSVTransformProcess()",e); - e.printStackTrace(); } return null; @@ -119,7 +117,6 @@ public class DataVecTransformClient implements DataVecTransformService { return singleCsvRecord; } catch (UnirestException e) { log.error("Error in transformIncremental(SingleCSVRecord)",e); - e.printStackTrace(); } return null; } @@ -140,8 +137,7 @@ public class DataVecTransformClient implements DataVecTransformService { .getBody(); return batchCSVRecord1; } catch (UnirestException e) { - log.error("Error in transform(BatchCSVRecord)", e); - e.printStackTrace(); + log.error("",e); } return null; @@ -162,7 +158,6 @@ public class DataVecTransformClient implements DataVecTransformService { return batchCSVRecord1; } catch (UnirestException e) { log.error("Error in transform(BatchCSVRecord)", e); - e.printStackTrace(); } return null; @@ -181,7 +176,6 @@ public class DataVecTransformClient implements DataVecTransformService { return batchArray1; } catch (UnirestException e) { log.error("Error in transformArray(BatchCSVRecord)",e); - e.printStackTrace(); } return null; @@ -200,7 +194,6 @@ public class DataVecTransformClient implements DataVecTransformService { return array; } catch (UnirestException e) { log.error("Error in transformArrayIncremental(SingleCSVRecord)",e); - e.printStackTrace(); } return null; @@ -231,7 +224,6 @@ public class DataVecTransformClient implements DataVecTransformService { return array; } catch (UnirestException e) { log.error("Error in transformSequenceArrayIncremental",e); - e.printStackTrace(); } return null; @@ -252,7 +244,6 @@ public class DataVecTransformClient implements DataVecTransformService { return batchArray1; } catch (UnirestException e) { log.error("Error in transformSequenceArray",e); - e.printStackTrace(); } return null; @@ -274,7 +265,6 @@ public class DataVecTransformClient implements DataVecTransformService { return batchCSVRecord1; } catch (UnirestException e) { log.error("Error in transformSequence"); - e.printStackTrace(); } return null; @@ -295,7 +285,6 @@ public class DataVecTransformClient implements DataVecTransformService { return singleCsvRecord; } catch (UnirestException e) { log.error("Error in transformSequenceIncremental"); - e.printStackTrace(); } return null; } diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/transform/CSVSparkTransform.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/transform/CSVSparkTransform.java index f76e9885f..bc7ad126c 100644 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/transform/CSVSparkTransform.java +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/transform/CSVSparkTransform.java @@ -18,6 +18,7 @@ package org.datavec.spark.transform; import lombok.AllArgsConstructor; import lombok.Getter; +import lombok.extern.slf4j.Slf4j; import lombok.val; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; @@ -53,7 +54,7 @@ import static org.datavec.local.transforms.LocalTransformExecutor.executeToSeque * @author Adan Gibson */ @AllArgsConstructor - +@Slf4j public class CSVSparkTransform { @Getter private TransformProcess transformProcess; @@ -252,7 +253,7 @@ public class CSVSparkTransform { try { return new Base64NDArrayBody(Nd4jBase64.base64String(arr)); } catch (IOException e) { - e.printStackTrace(); + log.error("",e); } return null; diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/ImageSparkTransformServer.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/ImageSparkTransformServer.java index f8675f139..d64416d32 100644 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/ImageSparkTransformServer.java +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/transform/ImageSparkTransformServer.java @@ -88,7 +88,7 @@ public class ImageSparkTransformServer extends SparkTransformServer { log.info("Transform process initialized"); return ok(objectMapper.writeValueAsString(transform.getImageTransformProcess())).as(contentType); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); return internalServerError(); } }); @@ -100,7 +100,7 @@ public class ImageSparkTransformServer extends SparkTransformServer { log.info("Transform process initialized"); return ok(objectMapper.writeValueAsString(transformProcess)).as(contentType); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); return internalServerError(); } }); @@ -112,7 +112,7 @@ public class ImageSparkTransformServer extends SparkTransformServer { return badRequest(); return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); return internalServerError(); } }); @@ -130,7 +130,7 @@ public class ImageSparkTransformServer extends SparkTransformServer { return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); return internalServerError(); } }); @@ -142,7 +142,7 @@ public class ImageSparkTransformServer extends SparkTransformServer { return badRequest(); return ok(objectMapper.writeValueAsString(transformArray(batch))).as(contentType); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); return internalServerError(); } }); @@ -169,7 +169,7 @@ public class ImageSparkTransformServer extends SparkTransformServer { return ok(objectMapper.writeValueAsString(transformArray(batch))).as(contentType); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); return internalServerError(); } }); diff --git a/datavec/datavec-spark/src/test/java/org/datavec/spark/BaseSparkTest.java b/datavec/datavec-spark/src/test/java/org/datavec/spark/BaseSparkTest.java index 92a11f779..c474f8962 100644 --- a/datavec/datavec-spark/src/test/java/org/datavec/spark/BaseSparkTest.java +++ b/datavec/datavec-spark/src/test/java/org/datavec/spark/BaseSparkTest.java @@ -16,6 +16,7 @@ package org.datavec.spark; +import lombok.extern.slf4j.Slf4j; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; import org.junit.After; @@ -23,6 +24,7 @@ import org.junit.Before; import java.io.Serializable; +@Slf4j public abstract class BaseSparkTest implements Serializable { protected static JavaSparkContext sc; @@ -40,7 +42,7 @@ public abstract class BaseSparkTest implements Serializable { try { Thread.sleep(100L); } catch (InterruptedException e) { - e.printStackTrace(); + log.error("",e); } } else { break; diff --git a/deeplearning4j/deeplearning4j-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java b/deeplearning4j/deeplearning4j-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java index 6ce978ec5..22bdd19ae 100644 --- a/deeplearning4j/deeplearning4j-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java +++ b/deeplearning4j/deeplearning4j-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java @@ -68,7 +68,7 @@ public abstract class BaseDL4JTest { * Override this method to set the default timeout for methods in the test class */ public long getTimeoutMilliseconds(){ - return 30000; + return 60_000; } /** diff --git a/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/api/storage/impl/RemoteUIStatsStorageRouter.java b/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/api/storage/impl/RemoteUIStatsStorageRouter.java index 439e87eb7..1e4c62c3c 100644 --- a/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/api/storage/impl/RemoteUIStatsStorageRouter.java +++ b/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/api/storage/impl/RemoteUIStatsStorageRouter.java @@ -43,7 +43,8 @@ import java.util.concurrent.atomic.AtomicLong; */ @Slf4j public class RemoteUIStatsStorageRouter implements StatsStorageRouter, Serializable, Closeable { - + private static final String ROUTE_IS_DOWN = "Info posted to RemoteUIStatsStorageRouter but router is shut down."; + private static final String MAX_WARNINGS_REACHED = "RemoteUIStatsStorageRouter: Reached max shutdown warnings. No further warnings will be produced."; /** * Default path for posting data to the UI - i.e., http://localhost:9000/remoteReceive or similar */ @@ -163,10 +164,10 @@ public class RemoteUIStatsStorageRouter implements StatsStorageRouter, Serializa if (shutdown.get()) { long count = shutdownWarnCount.getAndIncrement(); if (count <= MAX_SHUTDOWN_WARN_COUNT) { - log.warn("Info posted to RemoteUIStatsStorageRouter but router is shut down."); + log.warn(ROUTE_IS_DOWN); } if (count == MAX_SHUTDOWN_WARN_COUNT) { - log.warn("RemoteUIStatsStorageRouter: Reached max shutdown warnings. No further warnings will be produced."); + log.warn(MAX_WARNINGS_REACHED); } } else { for (StorageMetaData m : storageMetaData) { @@ -186,10 +187,10 @@ public class RemoteUIStatsStorageRouter implements StatsStorageRouter, Serializa if (shutdown.get()) { long count = shutdownWarnCount.getAndIncrement(); if (count <= MAX_SHUTDOWN_WARN_COUNT) { - log.warn("Info posted to RemoteUIStatsStorageRouter but router is shut down."); + log.warn(ROUTE_IS_DOWN); } if (count == MAX_SHUTDOWN_WARN_COUNT) { - log.warn("RemoteUIStatsStorageRouter: Reached max shutdown warnings. No further warnings will be produced."); + log.warn(MAX_WARNINGS_REACHED); } } else { for (Persistable p : staticInfo) { @@ -209,10 +210,10 @@ public class RemoteUIStatsStorageRouter implements StatsStorageRouter, Serializa if (shutdown.get()) { long count = shutdownWarnCount.getAndIncrement(); if (count <= MAX_SHUTDOWN_WARN_COUNT) { - log.warn("Info posted to RemoteUIStatsStorageRouter but router is shut down."); + log.warn(ROUTE_IS_DOWN); } if (count == MAX_SHUTDOWN_WARN_COUNT) { - log.warn("RemoteUIStatsStorageRouter: Reached max shutdown warnings. No further warnings will be produced."); + log.warn(MAX_WARNINGS_REACHED); } } else { for (Persistable p : updates) { diff --git a/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/parallelism/AsyncIterator.java b/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/parallelism/AsyncIterator.java index 154934b7e..b2d3e1169 100644 --- a/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/parallelism/AsyncIterator.java +++ b/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/parallelism/AsyncIterator.java @@ -67,10 +67,7 @@ public class AsyncIterator implements Iterator { nextElement = buffer.take(); // same on this run - if (nextElement == terminator) - return false; - - return true; + return (nextElement != terminator); } catch (Exception e) { log.error("Premature end of loop!"); return false; diff --git a/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/perf/listener/SystemInfoPrintListener.java b/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/perf/listener/SystemInfoPrintListener.java index 7bdb3491c..dc9981535 100644 --- a/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/perf/listener/SystemInfoPrintListener.java +++ b/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/perf/listener/SystemInfoPrintListener.java @@ -43,7 +43,7 @@ public class SystemInfoPrintListener implements TrainingListener { private boolean printOnBackwardPass; private boolean printOnGradientCalculation; - + private static final String SYSTEM_INFO = "System info on epoch end: "; @Override public void iterationDone(Model model, int iteration, int epoch) { @@ -65,7 +65,7 @@ public class SystemInfoPrintListener implements TrainingListener { return; SystemInfo systemInfo = new SystemInfo(); - log.info("System info on epoch end: "); + log.info(SYSTEM_INFO); log.info(systemInfo.toPrettyJSON()); } @@ -75,7 +75,7 @@ public class SystemInfoPrintListener implements TrainingListener { return; SystemInfo systemInfo = new SystemInfo(); - log.info("System info on epoch end: "); + log.info(SYSTEM_INFO); log.info(systemInfo.toPrettyJSON()); } @@ -85,7 +85,7 @@ public class SystemInfoPrintListener implements TrainingListener { return; SystemInfo systemInfo = new SystemInfo(); - log.info("System info on epoch end: "); + log.info(SYSTEM_INFO); log.info(systemInfo.toPrettyJSON()); } @@ -95,7 +95,7 @@ public class SystemInfoPrintListener implements TrainingListener { return; SystemInfo systemInfo = new SystemInfo(); - log.info("System info on epoch end: "); + log.info(SYSTEM_INFO); log.info(systemInfo.toPrettyJSON()); } @@ -104,7 +104,7 @@ public class SystemInfoPrintListener implements TrainingListener { if(!printOnBackwardPass) return; SystemInfo systemInfo = new SystemInfo(); - log.info("System info on epoch end: "); + log.info(SYSTEM_INFO); log.info(systemInfo.toPrettyJSON()); } } diff --git a/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/perf/listener/SystemPolling.java b/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/perf/listener/SystemPolling.java index 4813bcf73..323f6e96e 100644 --- a/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/perf/listener/SystemPolling.java +++ b/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/perf/listener/SystemPolling.java @@ -16,6 +16,7 @@ package org.deeplearning4j.perf.listener; +import lombok.extern.slf4j.Slf4j; import org.nd4j.shade.jackson.databind.ObjectMapper; import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory; import oshi.json.SystemInfo; @@ -36,6 +37,7 @@ import java.util.concurrent.TimeUnit; * * @author Adam Gibson */ +@Slf4j public class SystemPolling { private ScheduledExecutorService scheduledExecutorService; @@ -66,7 +68,7 @@ public class SystemPolling { try { objectMapper.writeValue(hardwareFile,hardwareMetric); } catch (IOException e) { - e.printStackTrace(); + log.error("",e); } } },0,pollEveryMillis, TimeUnit.MILLISECONDS); diff --git a/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/util/ModelGuesser.java b/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/util/ModelGuesser.java index 96e29d1ac..ed2d146cc 100644 --- a/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/util/ModelGuesser.java +++ b/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/util/ModelGuesser.java @@ -27,7 +27,6 @@ import org.deeplearning4j.nn.modelimport.keras.KerasModelImport; import org.nd4j.linalg.dataset.api.preprocessor.Normalizer; import java.io.*; -import java.nio.file.Files; import java.util.UUID; /** diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/TestUtils.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/TestUtils.java index d54693f73..dac854d9d 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/TestUtils.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/TestUtils.java @@ -20,6 +20,7 @@ import org.apache.commons.compress.utils.IOUtils; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.layers.BaseLayer; import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer; import org.deeplearning4j.nn.graph.ComputationGraph; @@ -153,11 +154,22 @@ public class TestUtils { return randomOneHotTimeSeries(minibatch, outSize, tsLength, new Random(rngSeed)); } - public static INDArray randomOneHotTimeSeries(int minibatch, int outSize, int tsLength, Random rng){ - INDArray out = Nd4j.create(new int[]{minibatch, outSize, tsLength}, 'f'); + public static INDArray randomOneHotTimeSeries(int minibatch, int outSize, int tsLength, Random rng) { + return randomOneHotTimeSeries(RNNFormat.NCW, minibatch, outSize, tsLength, rng); + } + + public static INDArray randomOneHotTimeSeries(RNNFormat format, int minibatch, int outSize, int tsLength, Random rng){ + boolean ncw = format == RNNFormat.NCW; + long[] shape = ncw ? new long[]{minibatch, outSize, tsLength} : new long[]{minibatch, tsLength, outSize}; + char order = ncw ? 'f' : 'c'; + INDArray out = Nd4j.create(DataType.FLOAT, shape, order); for( int i=0; i(dArr,temp); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/EarlyTerminationDataSetIteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/EarlyTerminationDataSetIteratorTest.java index 145a12773..19b56679c 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/EarlyTerminationDataSetIteratorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/EarlyTerminationDataSetIteratorTest.java @@ -78,10 +78,10 @@ public class EarlyTerminationDataSetIteratorTest extends BaseDL4JTest { EarlyTerminationDataSetIterator earlyEndIter = new EarlyTerminationDataSetIterator(iter, terminateAfter); earlyEndIter.next(10); - assertEquals(earlyEndIter.hasNext(), false); + assertEquals(false, earlyEndIter.hasNext()); earlyEndIter.reset(); - assertEquals(earlyEndIter.hasNext(), true); + assertEquals(true, earlyEndIter.hasNext()); } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/MultipleEpochsIteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/MultipleEpochsIteratorTest.java index f37642c24..10a4285a1 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/MultipleEpochsIteratorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/MultipleEpochsIteratorTest.java @@ -98,7 +98,7 @@ public class MultipleEpochsIteratorTest extends BaseDL4JTest { while (multiIter.hasNext()) { DataSet path = multiIter.next(10); assertNotNull(path); - assertEquals(path.numExamples(), 10, 0.0); + assertEquals(10, path.numExamples(), 0.0); } assertEquals(epochs, multiIter.epochs); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/SamplingTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/SamplingTest.java index 5b159a015..fa8f20662 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/SamplingTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/SamplingTest.java @@ -33,7 +33,7 @@ public class SamplingTest extends BaseDL4JTest { DataSetIterator iter = new MnistDataSetIterator(10, 10); //batch size and total DataSetIterator sampling = new SamplingDataSetIterator(iter.next(), 10, 10); - assertEquals(sampling.next().numExamples(), 10); + assertEquals(10, sampling.next().numExamples()); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/exceptions/TestInvalidConfigurations.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/exceptions/TestInvalidConfigurations.java index f036d780a..0ea7ff282 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/exceptions/TestInvalidConfigurations.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/exceptions/TestInvalidConfigurations.java @@ -16,6 +16,7 @@ package org.deeplearning4j.exceptions; +import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.exception.DL4JException; import org.deeplearning4j.nn.conf.ConvolutionMode; @@ -34,6 +35,7 @@ import static org.junit.Assert.fail; /** * A set of tests to ensure that useful exceptions are thrown on invalid network configurations */ +@Slf4j public class TestInvalidConfigurations extends BaseDL4JTest { public static MultiLayerNetwork getDensePlusOutput(int nIn, int nOut) { @@ -78,7 +80,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest { } catch (DL4JException e) { System.out.println("testDenseNin0(): " + e.getMessage()); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); fail(); } } @@ -96,7 +98,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest { } catch (DL4JException e) { System.out.println("testDenseNout0(): " + e.getMessage()); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); fail(); } } @@ -109,7 +111,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest { } catch (DL4JException e) { System.out.println("testOutputLayerNin0(): " + e.getMessage()); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); fail(); } } @@ -122,7 +124,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest { } catch (DL4JException e) { System.out.println("testRnnOutputLayerNin0(): " + e.getMessage()); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); fail(); } } @@ -135,7 +137,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest { } catch (DL4JException e) { System.out.println("testLSTMNIn0(): " + e.getMessage()); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); fail(); } } @@ -153,7 +155,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest { } catch (DL4JException e) { System.out.println("testLSTMNOut0(): " + e.getMessage()); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); fail(); } } @@ -166,7 +168,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest { } catch (DL4JException e) { System.out.println("testConvolutionalNin0(): " + e.getMessage()); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); fail(); } } @@ -185,7 +187,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest { } catch (DL4JException e) { System.out.println("testConvolutionalNOut0(): " + e.getMessage()); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); fail(); } } @@ -216,7 +218,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest { } catch (DL4JException e) { System.out.println("testCnnInvalidConfigPaddingStridesHeight(): " + e.getMessage()); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); fail(); } } @@ -245,7 +247,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest { } catch (DL4JException e) { System.out.println("testCnnInvalidConfigOrInput_SmallerDataThanKernel(): " + e.getMessage()); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); fail(); } } @@ -277,7 +279,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest { } catch (DL4JException e) { System.out.println("testCnnInvalidConfigOrInput_BadStrides(): " + e.getMessage()); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); fail(); } } @@ -318,7 +320,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest { } catch (DL4JException e) { System.out.println("testCnnInvalidConfigPaddingStridesWidth(): " + e.getMessage()); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); fail(); } } @@ -347,7 +349,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest { } catch (DL4JException e) { System.out.println("testCnnInvalidConfigPaddingStridesWidthSubsampling(): " + e.getMessage()); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); fail(); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/exceptions/TestInvalidInput.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/exceptions/TestInvalidInput.java index 096c7ac69..49c65e2c8 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/exceptions/TestInvalidInput.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/exceptions/TestInvalidInput.java @@ -16,6 +16,7 @@ package org.deeplearning4j.exceptions; +import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.exception.DL4JException; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; @@ -36,6 +37,7 @@ import static org.junit.Assert.*; /** * A set of tests to ensure that useful exceptions are thrown on invalid input */ +@Slf4j public class TestInvalidInput extends BaseDL4JTest { @Test @@ -53,7 +55,7 @@ public class TestInvalidInput extends BaseDL4JTest { } catch (DL4JException e) { System.out.println("testInputNinMismatchDense(): " + e.getMessage()); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); fail("Expected DL4JException"); } } @@ -73,7 +75,7 @@ public class TestInvalidInput extends BaseDL4JTest { } catch (DL4JException e) { System.out.println("testInputNinMismatchOutputLayer(): " + e.getMessage()); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); fail("Expected DL4JException"); } } @@ -94,7 +96,7 @@ public class TestInvalidInput extends BaseDL4JTest { //From loss function System.out.println("testLabelsNOutMismatchOutputLayer(): " + e.getMessage()); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); fail("Expected DL4JException"); } } @@ -115,7 +117,7 @@ public class TestInvalidInput extends BaseDL4JTest { //From loss function System.out.println("testLabelsNOutMismatchRnnOutputLayer(): " + e.getMessage()); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); fail("Expected DL4JException"); } } @@ -142,7 +144,7 @@ public class TestInvalidInput extends BaseDL4JTest { } catch (DL4JException e) { System.out.println("testInputNinMismatchConvolutional(): " + e.getMessage()); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); fail("Expected DL4JException"); } } @@ -169,7 +171,7 @@ public class TestInvalidInput extends BaseDL4JTest { } catch (DL4JException e) { System.out.println("testInputNinRank2Convolutional(): " + e.getMessage()); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); fail("Expected DL4JException"); } } @@ -195,7 +197,7 @@ public class TestInvalidInput extends BaseDL4JTest { } catch (DL4JException e) { System.out.println("testInputNinRank2Subsampling(): " + e.getMessage()); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); fail("Expected DL4JException"); } } @@ -217,7 +219,7 @@ public class TestInvalidInput extends BaseDL4JTest { } catch (DL4JException e) { System.out.println("testInputNinMismatchLSTM(): " + e.getMessage()); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); fail("Expected DL4JException"); } } @@ -238,7 +240,7 @@ public class TestInvalidInput extends BaseDL4JTest { } catch (DL4JException e) { System.out.println("testInputNinMismatchBidirectionalLSTM(): " + e.getMessage()); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); fail("Expected DL4JException"); } @@ -260,7 +262,7 @@ public class TestInvalidInput extends BaseDL4JTest { } catch (DL4JException e) { System.out.println("testInputNinMismatchEmbeddingLayer(): " + e.getMessage()); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); fail("Expected DL4JException"); } } @@ -305,7 +307,7 @@ public class TestInvalidInput extends BaseDL4JTest { net.rnnTimeStep(Nd4j.create(5, 5, 10)); fail("Expected Exception - " + layerType); } catch (Exception e) { -// e.printStackTrace(); + log.error("",e); String msg = e.getMessage(); assertTrue(msg, msg != null && msg.contains("rnn") && msg.contains("batch")); } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java index eb4a51309..c303cc594 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java @@ -21,6 +21,7 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.nn.api.OptimizationAlgorithm; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; @@ -34,6 +35,8 @@ import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.junit.Ignore; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -51,6 +54,7 @@ import static org.junit.Assert.*; /** * Created by nyghtowl on 9/1/15. */ +@RunWith(Parameterized.class) public class CNNGradientCheckTest extends BaseDL4JTest { private static final boolean PRINT_RESULTS = true; private static final boolean RETURN_ON_FIRST_FAILURE = false; @@ -62,6 +66,17 @@ public class CNNGradientCheckTest extends BaseDL4JTest { Nd4j.setDataType(DataType.DOUBLE); } + private CNN2DFormat format; + + public CNNGradientCheckTest(CNN2DFormat format){ + this.format = format; + } + + @Parameterized.Parameters(name = "{0}") + public static Object[] params(){ + return CNN2DFormat.values(); + } + @Override public long getTimeoutMilliseconds() { return 90000L; @@ -69,6 +84,9 @@ public class CNNGradientCheckTest extends BaseDL4JTest { @Test public void testGradientCNNMLN() { + if(this.format != CNN2DFormat.NCHW) //Only test NCHW due to flat input format... + return; + //Parameterized test, testing combinations of: // (a) activation function // (b) Whether to test at random initialization, or after some learning (i.e., 'characteristic mode of operation') @@ -144,6 +162,9 @@ public class CNNGradientCheckTest extends BaseDL4JTest { @Test public void testGradientCNNL1L2MLN() { + if(this.format != CNN2DFormat.NCHW) //Only test NCHW due to flat input format... + return; + //Parameterized test, testing combinations of: // (a) activation function // (b) Whether to test at random initialization, or after some learning (i.e., 'characteristic mode of operation') @@ -311,10 +332,12 @@ public class CNNGradientCheckTest extends BaseDL4JTest { new SubsamplingLayer.PoolingType[]{SubsamplingLayer.PoolingType.MAX, SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM}; + boolean nchw = format == CNN2DFormat.NCHW; for (String afn : activations) { for (SubsamplingLayer.PoolingType poolingType : poolingTypes) { for (int minibatchSize : minibatchSizes) { - INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth); + long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth}; + INDArray input = Nd4j.rand(DataType.DOUBLE, inShape); INDArray labels = Nd4j.zeros(4 * minibatchSize, nOut); for (int i = 0; i < 4 * minibatchSize; i++) { labels.putScalar(new int[]{i, i % nOut}, 1.0); @@ -330,13 +353,13 @@ public class CNNGradientCheckTest extends BaseDL4JTest { .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX) .nOut(nOut).build()) - .setInputType(InputType.convolutionalFlat(height, width, inputDepth)) + .setInputType(InputType.convolutional(height, width, inputDepth, format)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn=" + String msg = format + " - poolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn=" + afn; if (PRINT_RESULTS) { @@ -377,8 +400,11 @@ public class CNNGradientCheckTest extends BaseDL4JTest { int[] padding = {0, 0}; int size = 2; + boolean nchw = format == CNN2DFormat.NCHW; + for (int minibatchSize : minibatchSizes) { - INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth); + long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth}; + INDArray input = Nd4j.rand(DataType.DOUBLE, inShape); INDArray labels = TestUtils.randomOneHot(minibatchSize, nOut); MultiLayerConfiguration conf = @@ -393,8 +419,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest { .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nIn(8 * 8 * 3) .nOut(4).build()) - .setInputType(InputType.convolutionalFlat(height, width, - inputDepth)) + .setInputType(InputType.convolutional(height, width, inputDepth, format)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -438,10 +463,13 @@ public class CNNGradientCheckTest extends BaseDL4JTest { new SubsamplingLayer.PoolingType[]{SubsamplingLayer.PoolingType.MAX, SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM}; + boolean nchw = format == CNN2DFormat.NCHW; + for (Activation afn : activations) { for (SubsamplingLayer.PoolingType poolingType : poolingTypes) { for (int minibatchSize : minibatchSizes) { - INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth); + long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth}; + INDArray input = Nd4j.rand(DataType.DOUBLE, inShape); INDArray labels = Nd4j.zeros(minibatchSize, nOut); for (int i = 0; i < minibatchSize; i++) { labels.putScalar(new int[]{i, i % nOut}, 1.0); @@ -461,14 +489,13 @@ public class CNNGradientCheckTest extends BaseDL4JTest { .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nIn(3 * 3 * 3) .nOut(4).build()) - .setInputType(InputType.convolutionalFlat(height, width, - inputDepth)) + .setInputType(InputType.convolutional(height, width, inputDepth, format)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - String msg = "PoolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn=" + String msg = format + " - poolingType=" + poolingType + ", minibatch=" + minibatchSize + ", activationFn=" + afn; if (PRINT_RESULTS) { @@ -508,10 +535,13 @@ public class CNNGradientCheckTest extends BaseDL4JTest { new SubsamplingLayer.PoolingType[]{SubsamplingLayer.PoolingType.MAX, SubsamplingLayer.PoolingType.AVG, SubsamplingLayer.PoolingType.PNORM}; + boolean nchw = format == CNN2DFormat.NCHW; + for (Activation afn : activations) { for (SubsamplingLayer.PoolingType poolingType : poolingTypes) { for (int minibatchSize : minibatchSizes) { - INDArray input = Nd4j.rand(minibatchSize, width * height * inputDepth); + long[] inShape = nchw ? new long[]{minibatchSize, inputDepth, height, width} : new long[]{minibatchSize, height, width, inputDepth}; + INDArray input = Nd4j.rand(DataType.DOUBLE, inShape); INDArray labels = Nd4j.zeros(minibatchSize, nOut); for (int i = 0; i < minibatchSize; i++) { labels.putScalar(new int[]{i, i % nOut}, 1.0); @@ -533,8 +563,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest { .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nIn(2 * 2 * 2) .nOut(4).build()) - .setInputType(InputType.convolutionalFlat(height, width, - inputDepth)) + .setInputType(InputType.convolutional(height, width, inputDepth, format)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -558,8 +587,6 @@ public class CNNGradientCheckTest extends BaseDL4JTest { @Test public void testCnnLocallyConnected2D() { int nOut = 3; - - int[] minibatchSizes = {2}; int width = 5; int height = 5; @@ -569,11 +596,15 @@ public class CNNGradientCheckTest extends BaseDL4JTest { Activation[] activations = {Activation.SIGMOID, Activation.TANH, Activation.SOFTPLUS}; int[] minibatch = {2, 1, 3}; + boolean nchw = format == CNN2DFormat.NCHW; + for( int i=0; i= 0.0); } else if (lc instanceof UnitNormConstraint) { - assertEquals(RW0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6); - assertEquals(RW0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6); + assertEquals(1.0, RW0.norm2(1).minNumber().doubleValue(), 1e-6); + assertEquals(1.0, RW0.norm2(1).maxNumber().doubleValue(), 1e-6); } TestUtils.testModelSerialization(net); @@ -149,8 +149,8 @@ public class TestConstraints extends BaseDL4JTest { } else if (lc instanceof NonNegativeConstraint) { assertTrue(b0.minNumber().doubleValue() >= 0.0); } else if (lc instanceof UnitNormConstraint) { - assertEquals(b0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6); - assertEquals(b0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6); + assertEquals(1.0, b0.norm2(1).minNumber().doubleValue(), 1e-6); + assertEquals(1.0, b0.norm2(1).maxNumber().doubleValue(), 1e-6); } TestUtils.testModelSerialization(net); @@ -201,8 +201,8 @@ public class TestConstraints extends BaseDL4JTest { } else if (lc instanceof NonNegativeConstraint) { assertTrue(w0.minNumber().doubleValue() >= 0.0); } else if (lc instanceof UnitNormConstraint) { - assertEquals(w0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6); - assertEquals(w0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6); + assertEquals(1.0, w0.norm2(1).minNumber().doubleValue(), 1e-6); + assertEquals(1.0, w0.norm2(1).maxNumber().doubleValue(), 1e-6); } TestUtils.testModelSerialization(net); @@ -259,10 +259,10 @@ public class TestConstraints extends BaseDL4JTest { assertTrue(w0.minNumber().doubleValue() >= 0.0); assertTrue(b0.minNumber().doubleValue() >= 0.0); } else if (lc instanceof UnitNormConstraint) { - assertEquals(w0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6); - assertEquals(w0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6); - assertEquals(b0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6); - assertEquals(b0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6); + assertEquals(1.0, w0.norm2(1).minNumber().doubleValue(), 1e-6); + assertEquals(1.0, w0.norm2(1).maxNumber().doubleValue(), 1e-6); + assertEquals(1.0, b0.norm2(1).minNumber().doubleValue(), 1e-6); + assertEquals(1.0, b0.norm2(1).maxNumber().doubleValue(), 1e-6); } TestUtils.testModelSerialization(net); @@ -320,10 +320,10 @@ public class TestConstraints extends BaseDL4JTest { assertTrue(w0.minNumber().doubleValue() >= 0.0); assertTrue(b0.minNumber().doubleValue() >= 0.0); } else if (lc instanceof UnitNormConstraint) { - assertEquals(w0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6); - assertEquals(w0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6); - assertEquals(b0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6); - assertEquals(b0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6); + assertEquals(1.0, w0.norm2(1).minNumber().doubleValue(), 1e-6); + assertEquals(1.0, w0.norm2(1).maxNumber().doubleValue(), 1e-6); + assertEquals(1.0, b0.norm2(1).minNumber().doubleValue(), 1e-6); + assertEquals(1.0, b0.norm2(1).maxNumber().doubleValue(), 1e-6); } TestUtils.testModelSerialization(net); @@ -378,10 +378,10 @@ public class TestConstraints extends BaseDL4JTest { } else if(lc instanceof NonNegativeConstraint ){ assertTrue(w0.minNumber().doubleValue() >= 0.0 ); } else if(lc instanceof UnitNormConstraint ){ - assertEquals(w0.norm2(1).minNumber().doubleValue(), 1.0, 1e-6 ); - assertEquals(w0.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6 ); - assertEquals(w1.norm2(1).minNumber().doubleValue(), 1.0, 1e-6 ); - assertEquals(w1.norm2(1).maxNumber().doubleValue(), 1.0, 1e-6 ); + assertEquals(1.0, w0.norm2(1).minNumber().doubleValue(), 1e-6 ); + assertEquals(1.0, w0.norm2(1).maxNumber().doubleValue(), 1e-6 ); + assertEquals(1.0, w1.norm2(1).minNumber().doubleValue(), 1e-6 ); + assertEquals(1.0, w1.norm2(1).maxNumber().doubleValue(), 1e-6 ); } TestUtils.testModelSerialization(net); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerBuilderTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerBuilderTest.java index 1288b1078..12f5a8e0d 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerBuilderTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerBuilderTest.java @@ -156,7 +156,7 @@ public class LayerBuilderTest extends BaseDL4JTest { checkSerialization(glstm); - assertEquals(glstm.getForgetGateBiasInit(), 1.5, 0.0); + assertEquals(1.5, glstm.getForgetGateBiasInit(), 0.0); assertEquals(glstm.nIn, numIn); assertEquals(glstm.nOut, numOut); assertTrue(glstm.getActivationFn() instanceof ActivationTanH); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java index 6831af10b..beec5cf20 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java @@ -17,7 +17,9 @@ package org.deeplearning4j.nn.dtypes; import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed; +import org.deeplearning4j.nn.conf.preprocessor.*; import org.deeplearning4j.nn.modelimport.keras.layers.TFOpLayer; +import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor; import org.nd4j.shade.guava.collect.ImmutableSet; import org.nd4j.shade.guava.reflect.ClassPath; import lombok.extern.slf4j.Slf4j; @@ -51,16 +53,11 @@ import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer; import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder; import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; import org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer; -import org.deeplearning4j.nn.conf.preprocessor.CnnToRnnPreProcessor; -import org.deeplearning4j.nn.conf.preprocessor.ComposableInputPreProcessor; -import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToCnn3DPreProcessor; -import org.deeplearning4j.nn.conf.preprocessor.RnnToCnnPreProcessor; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.layers.util.IdentityLayer; import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor; import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor; import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor; -import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInitDistribution; @@ -97,7 +94,8 @@ public class DTypeTests extends BaseDL4JTest { Pooling2D.class, //Alias for SubsamplingLayer Convolution2D.class, //Alias for ConvolutionLayer Pooling1D.class, //Alias for Subsampling1D - Convolution1D.class //Alias for Convolution1DLayer + Convolution1D.class, //Alias for Convolution1DLayer + TensorFlowCnnToFeedForwardPreProcessor.class //Deprecated )); @Override @@ -819,7 +817,7 @@ public class DTypeTests extends BaseDL4JTest { .layer(new DenseLayer.Builder().nOut(5).build()) .layer(new GravesBidirectionalLSTM.Builder().nIn(5).nOut(5).activation(Activation.TANH).build()) .layer(new Bidirectional(new LSTM.Builder().nIn(5).nOut(5).activation(Activation.TANH).build())) - .layer(new TimeDistributed(new DenseLayer.Builder().nIn(10).nOut(5).activation(Activation.TANH).build(), 2)) + .layer(new TimeDistributed(new DenseLayer.Builder().nIn(10).nOut(5).activation(Activation.TANH).build())) .layer(new SimpleRnn.Builder().nIn(5).nOut(5).build()) .layer(new MaskZeroLayer.Builder().underlying(new SimpleRnn.Builder().nIn(5).nOut(5).build()).maskValue(0.0).build()) .layer(secondLast) @@ -1078,7 +1076,7 @@ public class DTypeTests extends BaseDL4JTest { .addLayer("l", new DenseLayer.Builder().nOut(16).build(), "in") .addVertex("preproc", new PreprocessorVertex(new FeedForwardToCnn3DPreProcessor(2, 2, 2, 2, true)), "l") .addVertex("preproc2", new PreprocessorVertex(new PermutePreprocessor(0, 2, 3, 4, 1)), "preproc") - .addVertex("preproc3", new PreprocessorVertex(new ReshapePreprocessor(new long[]{2, 2, 2, 2}, new long[]{16})), "preproc2") + .addVertex("preproc3", new PreprocessorVertex(new ReshapePreprocessor(new long[]{2, 2, 2, 2}, new long[]{16}, false)), "preproc2") .addLayer("out", new OutputLayer.Builder().nIn(16).nOut(10).build(), "preproc3") .setInputTypes(InputType.feedForward(5)) .setOutputs("out"); @@ -1150,7 +1148,7 @@ public class DTypeTests extends BaseDL4JTest { case 7: b.addInputs("in") .addLayer("1", new ConvolutionLayer.Builder().kernelSize(2, 2).nOut(5).convolutionMode(ConvolutionMode.Same).build(), "in") - .addVertex("2", new PreprocessorVertex(new TensorFlowCnnToFeedForwardPreProcessor(28, 28, 5)), "1") + .addVertex("2", new PreprocessorVertex(new CnnToFeedForwardPreProcessor(28, 28, 5)), "1") .addLayer("out", new OutputLayer.Builder().nOut(10).build(), "2") .setOutputs("out") .setInputTypes(InputType.convolutional(28, 28, 1)); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/ComputationGraphTestRNN.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/ComputationGraphTestRNN.java index ff6ff2da2..a6778626f 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/ComputationGraphTestRNN.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/ComputationGraphTestRNN.java @@ -16,6 +16,7 @@ package org.deeplearning4j.nn.graph; +import lombok.extern.slf4j.Slf4j; import lombok.val; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.IteratorDataSetIterator; @@ -54,6 +55,7 @@ import java.util.Map; import static org.junit.Assert.*; +@Slf4j public class ComputationGraphTestRNN extends BaseDL4JTest { @Test @@ -618,7 +620,7 @@ public class ComputationGraphTestRNN extends BaseDL4JTest { .build(); fail("Exception expected"); } catch (IllegalStateException e){ -// e.printStackTrace(); + log.error("",e); assertTrue(e.getMessage().contains("TBPTT") && e.getMessage().contains("validateTbpttConfig")); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java index 28bc42983..7ba76894d 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java @@ -1394,7 +1394,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { } catch (Exception e) { - //e.printStackTrace(); + log.error("",e); if(allowDisconnected){ fail("No exception expected"); } else { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestVariableLengthTSCG.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestVariableLengthTSCG.java index f249acafc..05025fde5 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestVariableLengthTSCG.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestVariableLengthTSCG.java @@ -416,8 +416,8 @@ public class TestVariableLengthTSCG extends BaseDL4JTest { INDArray outRow2 = out2.get(NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.point(j)); for (int k = 0; k < nOut; k++) { - assertEquals(outRow.getDouble(k), 0.0, 0.0); - assertEquals(outRow2.getDouble(k), 0.0, 0.0); + assertEquals(0.0, outRow.getDouble(k), 0.0); + assertEquals(0.0, outRow2.getDouble(k), 0.0); } } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/graphnodes/TestGraphNodes.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/graphnodes/TestGraphNodes.java index a9816fd7c..b9e9cf698 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/graphnodes/TestGraphNodes.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/graphnodes/TestGraphNodes.java @@ -60,7 +60,7 @@ public class TestGraphNodes extends BaseDL4JTest { @Test public void testMergeNode() { Nd4j.getRandom().setSeed(12345); - GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType()); + GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType(), 1); INDArray first = Nd4j.linspace(0, 11, 12, Nd4j.dataType()).reshape(3, 4); INDArray second = Nd4j.linspace(0, 17, 18, Nd4j.dataType()).reshape(3, 6).addi(100); @@ -82,7 +82,7 @@ public class TestGraphNodes extends BaseDL4JTest { public void testMergeNodeRNN() { Nd4j.getRandom().setSeed(12345); - GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType()); + GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType(), 1); INDArray first = Nd4j.linspace(0, 59, 60, Nd4j.dataType()).reshape(3, 4, 5); INDArray second = Nd4j.linspace(0, 89, 90, Nd4j.dataType()).reshape(3, 6, 5).addi(100); @@ -103,7 +103,7 @@ public class TestGraphNodes extends BaseDL4JTest { @Test public void testCnnDepthMerge() { Nd4j.getRandom().setSeed(12345); - GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType()); + GraphVertex mergeNode = new MergeVertex(null, "", -1, Nd4j.dataType(), 1); INDArray first = Nd4j.linspace(0, 3, 4, Nd4j.dataType()).reshape(1, 1, 2, 2); INDArray second = Nd4j.linspace(0, 3, 4, Nd4j.dataType()).reshape(1, 1, 2, 2).addi(10); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvDataFormatTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvDataFormatTests.java new file mode 100644 index 000000000..c91bb8e56 --- /dev/null +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvDataFormatTests.java @@ -0,0 +1,974 @@ +/* ****************************************************************************** + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.nn.layers.convolution; + +import lombok.*; +import org.deeplearning4j.BaseDL4JTest; +import org.deeplearning4j.TestUtils; +import org.deeplearning4j.nn.api.MaskState; +import org.deeplearning4j.nn.conf.CNN2DFormat; +import org.deeplearning4j.nn.conf.ConvolutionMode; +import org.deeplearning4j.nn.conf.InputPreProcessor; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.*; +import org.deeplearning4j.nn.conf.layers.CnnLossLayer; +import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; +import org.deeplearning4j.nn.conf.layers.ZeroPaddingLayer; +import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D; +import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor; +import org.deeplearning4j.nn.conf.preprocessor.ComposableInputPreProcessor; +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.nn.workspace.ArrayType; +import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.primitives.Pair; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; + +@RunWith(Parameterized.class) +public class ConvDataFormatTests extends BaseDL4JTest { + + private final DataType dataType; + + public ConvDataFormatTests(DataType dataType){ + this.dataType = dataType; + } + + @Parameterized.Parameters(name = "{0}") + public static Object[] params(){ + return new DataType[]{DataType.FLOAT, DataType.DOUBLE}; + } + + @Test + public void testConv2d() { + try { + for (boolean helpers : new boolean[]{false, true}) { + for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")"; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray labels = TestUtils.randomOneHot(2, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getConv2dNet(CNN2DFormat.NCHW, true, cm)) + .net2(getConv2dNet(CNN2DFormat.NCHW, false, cm)) + .net3(getConv2dNet(CNN2DFormat.NHWC, true, cm)) + .net4(getConv2dNet(CNN2DFormat.NHWC, false, cm)) + .inNCHW(inNCHW) + .labelsNCHW(labels) + .labelsNHWC(labels) + .testLayerIdx(1) + .build(); + + testHelper(tc); + } + } + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + @Test + public void testSubsampling2d() { + try { + for (boolean helpers : new boolean[]{false, true}) { + for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")"; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray labels = TestUtils.randomOneHot(2, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getSubsampling2dNet(CNN2DFormat.NCHW, true, cm)) + .net2(getSubsampling2dNet(CNN2DFormat.NCHW, false, cm)) + .net3(getSubsampling2dNet(CNN2DFormat.NHWC, true, cm)) + .net4(getSubsampling2dNet(CNN2DFormat.NHWC, false, cm)) + .inNCHW(inNCHW) + .labelsNCHW(labels) + .labelsNHWC(labels) + .testLayerIdx(1) + .build(); + + testHelper(tc); + } + } + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + @Test + public void testDepthwiseConv2d() { + try { + for (boolean helpers : new boolean[]{false, true}) { + for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")"; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray labels = TestUtils.randomOneHot(2, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getDepthwiseConv2dNet(CNN2DFormat.NCHW, true, cm)) + .net2(getDepthwiseConv2dNet(CNN2DFormat.NCHW, false, cm)) + .net3(getDepthwiseConv2dNet(CNN2DFormat.NHWC, true, cm)) + .net4(getDepthwiseConv2dNet(CNN2DFormat.NHWC, false, cm)) + .inNCHW(inNCHW) + .labelsNCHW(labels) + .labelsNHWC(labels) + .testLayerIdx(1) + .build(); + + testHelper(tc); + } + } + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + @Test + public void testSeparableConv2d() { + try { + for (boolean helpers : new boolean[]{false, true}) { + for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")"; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray labels = TestUtils.randomOneHot(2, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getSeparableConv2dNet(CNN2DFormat.NCHW, true, cm)) + .net2(getSeparableConv2dNet(CNN2DFormat.NCHW, false, cm)) + .net3(getSeparableConv2dNet(CNN2DFormat.NHWC, true, cm)) + .net4(getSeparableConv2dNet(CNN2DFormat.NHWC, false, cm)) + .inNCHW(inNCHW) + .labelsNCHW(labels) + .labelsNHWC(labels) + .testLayerIdx(1) + .build(); + + testHelper(tc); + } + } + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + @Test + public void testDeconv2d() { + try { + for (boolean helpers : new boolean[]{false, true}) { + for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")"; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray labels = TestUtils.randomOneHot(2, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getDeconv2DNet2dNet(CNN2DFormat.NCHW, true, cm)) + .net2(getDeconv2DNet2dNet(CNN2DFormat.NCHW, false, cm)) + .net3(getDeconv2DNet2dNet(CNN2DFormat.NHWC, true, cm)) + .net4(getDeconv2DNet2dNet(CNN2DFormat.NHWC, false, cm)) + .inNCHW(inNCHW) + .labelsNCHW(labels) + .labelsNHWC(labels) + .testLayerIdx(1) + .build(); + + testHelper(tc); + } + } + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + @Test + public void testLRN() { + try { + for (boolean helpers : new boolean[]{false, true}) { + for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")"; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray labels = TestUtils.randomOneHot(2, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getLrnLayer(CNN2DFormat.NCHW, true, cm)) + .net2(getLrnLayer(CNN2DFormat.NCHW, false, cm)) + .net3(getLrnLayer(CNN2DFormat.NHWC, true, cm)) + .net4(getLrnLayer(CNN2DFormat.NHWC, false, cm)) + .inNCHW(inNCHW) + .labelsNCHW(labels) + .labelsNHWC(labels) + .testLayerIdx(1) + .build(); + + testHelper(tc); + } + } + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + @Test + public void testZeroPaddingLayer(){ + try { + for (boolean helpers : new boolean[]{false, true}) { + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = helpers ? "With helpers" : "No helpers"; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray labels = TestUtils.randomOneHot(2, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getZeroPaddingNet(CNN2DFormat.NCHW, true)) + .net2(getZeroPaddingNet(CNN2DFormat.NCHW, false)) + .net3(getZeroPaddingNet(CNN2DFormat.NHWC, true)) + .net4(getZeroPaddingNet(CNN2DFormat.NHWC, false)) + .inNCHW(inNCHW) + .labelsNCHW(labels) + .labelsNHWC(labels) + .testLayerIdx(1) + .build(); + + testHelper(tc); + } + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + @Test + public void testCropping2DLayer(){ + try { + for (boolean helpers : new boolean[]{false, true}) { + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = helpers ? "With helpers" : "No helpers"; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray labels = TestUtils.randomOneHot(2, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getCropping2dNet(CNN2DFormat.NCHW, true)) + .net2(getCropping2dNet(CNN2DFormat.NCHW, false)) + .net3(getCropping2dNet(CNN2DFormat.NHWC, true)) + .net4(getCropping2dNet(CNN2DFormat.NHWC, false)) + .inNCHW(inNCHW) + .labelsNCHW(labels) + .labelsNHWC(labels) + .testLayerIdx(1) + .build(); + + testHelper(tc); + } + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + @Test + public void testUpsampling2d(){ + try { + for (boolean helpers : new boolean[]{false, true}) { + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = helpers ? "With helpers" : "No helpers"; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray labels = TestUtils.randomOneHot(2, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getUpsamplingNet(CNN2DFormat.NCHW, true)) + .net2(getUpsamplingNet(CNN2DFormat.NCHW, false)) + .net3(getUpsamplingNet(CNN2DFormat.NHWC, true)) + .net4(getUpsamplingNet(CNN2DFormat.NHWC, false)) + .inNCHW(inNCHW) + .labelsNCHW(labels) + .labelsNHWC(labels) + .testLayerIdx(1) + .build(); + + testHelper(tc); + } + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + @Test + public void testBatchNormNet(){ + try { + for(boolean useLogStd : new boolean[]{true, false}) { + for (boolean helpers : new boolean[]{false, true}) { + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = (helpers ? "With helpers" : "No helpers") + " - " + (useLogStd ? "logstd" : "std"); + System.out.println(" --- " + msg + " ---"); + + INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray labels = TestUtils.randomOneHot(2, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getBatchNormNet(useLogStd, CNN2DFormat.NCHW, true)) + .net2(getBatchNormNet(useLogStd, CNN2DFormat.NCHW, false)) + .net3(getBatchNormNet(useLogStd, CNN2DFormat.NHWC, true)) + .net4(getBatchNormNet(useLogStd, CNN2DFormat.NHWC, false)) + .inNCHW(inNCHW) + .labelsNCHW(labels) + .labelsNHWC(labels) + .testLayerIdx(1) + .build(); + + testHelper(tc); + } + } + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + @Test + public void testCnnLossLayer() { + try { + for (boolean helpers : new boolean[]{false, true}) { + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = helpers ? "With helpers" : "No helpers"; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray labelsNHWC = TestUtils.randomOneHot(this.dataType,2*6*6, 3); + labelsNHWC = labelsNHWC.reshape(2,6,6,3); + INDArray labelsNCHW = labelsNHWC.permute(0,3,1,2).dup(); + + + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getCnnLossNet(CNN2DFormat.NCHW, true, ConvolutionMode.Same)) + .net2(getCnnLossNet(CNN2DFormat.NCHW, false, ConvolutionMode.Same)) + .net3(getCnnLossNet(CNN2DFormat.NHWC, true, ConvolutionMode.Same)) + .net4(getCnnLossNet(CNN2DFormat.NHWC, false, ConvolutionMode.Same)) + .inNCHW(inNCHW) + .labelsNCHW(labelsNCHW) + .labelsNHWC(labelsNHWC) + .testLayerIdx(1) + .nhwcOutput(true) + .build(); + + testHelper(tc); + } + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + @Test + public void testSpaceToDepthNet(){ + try { + for (boolean helpers : new boolean[]{false, true}) { + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = helpers ? "With helpers" : "No helpers"; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray labels = TestUtils.randomOneHot(2, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getSpaceToDepthNet(CNN2DFormat.NCHW, true)) + .net2(getSpaceToDepthNet(CNN2DFormat.NCHW, false)) + .net3(getSpaceToDepthNet(CNN2DFormat.NHWC, true)) + .net4(getSpaceToDepthNet(CNN2DFormat.NHWC, false)) + .inNCHW(inNCHW) + .labelsNCHW(labels) + .labelsNHWC(labels) + .testLayerIdx(1) + .build(); + + testHelper(tc); + } + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + @Test + public void testSpaceToBatchNet(){ + try { + for (boolean helpers : new boolean[]{false, true}) { + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = helpers ? "With helpers" : "No helpers"; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 16, 16); + INDArray labels = TestUtils.randomOneHot(8, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getSpaceToBatchNet(CNN2DFormat.NCHW, true)) + .net2(getSpaceToBatchNet(CNN2DFormat.NCHW, false)) + .net3(getSpaceToBatchNet(CNN2DFormat.NHWC, true)) + .net4(getSpaceToBatchNet(CNN2DFormat.NHWC, false)) + .inNCHW(inNCHW) + .labelsNCHW(labels) + .labelsNHWC(labels) + .testLayerIdx(1) + .build(); + + testHelper(tc); + } + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + @Test + public void testLocallyConnected() { + try { + for (boolean helpers : new boolean[]{false, true}) { + for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")"; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray labels = TestUtils.randomOneHot(2, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getLocallyConnectedNet(CNN2DFormat.NCHW, true, cm)) + .net2(getLocallyConnectedNet(CNN2DFormat.NCHW, false, cm)) + .net3(getLocallyConnectedNet(CNN2DFormat.NHWC, true, cm)) + .net4(getLocallyConnectedNet(CNN2DFormat.NHWC, false, cm)) + .inNCHW(inNCHW) + .labelsNCHW(labels) + .labelsNHWC(labels) + .testLayerIdx(1) + .build(); + + testHelper(tc); + } + } + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + + @Test + public void testGlobalPooling() { + try { + for (boolean helpers : new boolean[]{false, true}) { + for (PoolingType pt : PoolingType.values()) { + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = helpers ? "With helpers (" + pt + ")" : "No helpers (" + pt + ")"; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray labels = TestUtils.randomOneHot(2, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getGlobalPoolingNet(CNN2DFormat.NCHW, pt, true)) + .net2(getGlobalPoolingNet(CNN2DFormat.NCHW, pt, false)) + .net3(getGlobalPoolingNet(CNN2DFormat.NHWC, pt, true)) + .net4(getGlobalPoolingNet(CNN2DFormat.NHWC, pt, false)) + .inNCHW(inNCHW) + .labelsNCHW(labels) + .labelsNHWC(labels) + .testLayerIdx(1) + .build(); + + testHelper(tc); + } + } + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + private MultiLayerNetwork getConv2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { + if (setOnLayerAlso) { + return getNetWithLayer(new ConvolutionLayer.Builder() + .kernelSize(3, 3) + .stride(2, 2) + .activation(Activation.TANH) + .dataFormat(format) + .nOut(3) + .helperAllowFallback(false) + .build(), format, cm, null); + } else { + return getNetWithLayer(new ConvolutionLayer.Builder() + .kernelSize(3, 3) + .stride(2, 2) + .activation(Activation.TANH) + .nOut(3) + .helperAllowFallback(false) + .build(), format, cm, null); + } + } + + private MultiLayerNetwork getSubsampling2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { + if (setOnLayerAlso) { + return getNetWithLayer(new SubsamplingLayer.Builder() + .kernelSize(2, 2) + .stride(1, 1) + .dataFormat(format) + .helperAllowFallback(false) + .build(), format, cm, null); + } else { + return getNetWithLayer(new SubsamplingLayer.Builder() + .kernelSize(2, 2) + .stride(1, 1) + .helperAllowFallback(false) + .build(), format, cm, null); + } + } + + private MultiLayerNetwork getSeparableConv2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { + if (setOnLayerAlso) { + return getNetWithLayer(new SeparableConvolution2D.Builder() + .kernelSize(3, 3) + .stride(2, 2) + .activation(Activation.TANH) + .dataFormat(format) + .nOut(3) + .helperAllowFallback(false) + .build(), format, cm, null); + } else { + return getNetWithLayer(new SeparableConvolution2D.Builder() + .kernelSize(3, 3) + .stride(2, 2) + .activation(Activation.TANH) + .nOut(3) + .helperAllowFallback(false) + .build(), format, cm, null); + } + } + + private MultiLayerNetwork getDepthwiseConv2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { + if (setOnLayerAlso) { + return getNetWithLayer(new DepthwiseConvolution2D.Builder() + .depthMultiplier(2) + .kernelSize(3, 3) + .stride(2, 2) + .activation(Activation.TANH) + .dataFormat(format) + .nOut(3) + .helperAllowFallback(false) + .build(), format, cm, null); + } else { + return getNetWithLayer(new DepthwiseConvolution2D.Builder() + .depthMultiplier(2) + .kernelSize(3, 3) + .stride(2, 2) + .activation(Activation.TANH) + .nOut(3) + .helperAllowFallback(false) + .build(), format, cm, null); + } + } + + private MultiLayerNetwork getLrnLayer(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { + if (setOnLayerAlso) { + return getNetWithLayer(new LocalResponseNormalization.Builder() + .dataFormat(format) + .helperAllowFallback(false) + .build(), format, cm, null); + } else { + return getNetWithLayer(new LocalResponseNormalization.Builder() + .helperAllowFallback(false) + .build(), format, cm, null); + } + } + + private MultiLayerNetwork getZeroPaddingNet(CNN2DFormat format, boolean setOnLayerAlso) { + if (setOnLayerAlso) { + return getNetWithLayer(new ZeroPaddingLayer.Builder(2,2) + .dataFormat(format).build(), format, ConvolutionMode.Same, null); + } else { + return getNetWithLayer(new ZeroPaddingLayer.Builder(2,2).build(), + format, ConvolutionMode.Same, null); + } + } + + private MultiLayerNetwork getCropping2dNet(CNN2DFormat format, boolean setOnLayerAlso) { + if (setOnLayerAlso) { + return getNetWithLayer(new Cropping2D.Builder(2,2) + .dataFormat(format).build(), format, ConvolutionMode.Same, null); + } else { + return getNetWithLayer(new Cropping2D.Builder(2,2) + .build(), format, ConvolutionMode.Same, null); + } + } + + private MultiLayerNetwork getUpsamplingNet(CNN2DFormat format, boolean setOnLayerAlso) { + if (setOnLayerAlso) { + return getNetWithLayer(new Upsampling2D.Builder(2) + .dataFormat(format).build(), format, ConvolutionMode.Same, null); + } else { + return getNetWithLayer(new Upsampling2D.Builder(2) + .build(), format, ConvolutionMode.Same, null); + } + } + + private MultiLayerNetwork getDeconv2DNet2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { + if (setOnLayerAlso) { + return getNetWithLayer(new Deconvolution2D.Builder().nOut(2) + .activation(Activation.TANH) + .kernelSize(2,2) + .stride(2,2) + .build(), format, cm, null); + } else { + return getNetWithLayer(new Deconvolution2D.Builder().nOut(2) + .activation(Activation.TANH) + .kernelSize(2,2) + .stride(2,2) + .build(), format, cm, null); + } + } + + private MultiLayerNetwork getBatchNormNet(boolean logStdev, CNN2DFormat format, boolean setOnLayerAlso) { + if (setOnLayerAlso) { + return getNetWithLayer(new BatchNormalization.Builder() + .useLogStd(logStdev) + .dataFormat(format) + .helperAllowFallback(false) + .nOut(3).build(), format, ConvolutionMode.Same, null); + } else { + return getNetWithLayer(new BatchNormalization.Builder() + .useLogStd(logStdev) + .helperAllowFallback(false) + .nOut(3).build(), format, ConvolutionMode.Same, null); + } + } + + private MultiLayerNetwork getSpaceToDepthNet(CNN2DFormat format, boolean setOnLayerAlso) { + if (setOnLayerAlso) { + return getNetWithLayer(new SpaceToDepthLayer.Builder() + .blocks(2) + .dataFormat(format) + .build(), format, ConvolutionMode.Same, null); + } else { + return getNetWithLayer(new SpaceToDepthLayer.Builder() + .blocks(2) + .build(), format, ConvolutionMode.Same, null); + } + } + + private MultiLayerNetwork getSpaceToBatchNet(CNN2DFormat format, boolean setOnLayerAlso) { + if (setOnLayerAlso) { + return getNetWithLayer(new SpaceToBatchLayer.Builder() + .blocks(2, 2) + .dataFormat(format) + .build(), format, ConvolutionMode.Same, InputType.convolutional(16, 16, 3, format)); + } else { + return getNetWithLayer(new SpaceToBatchLayer.Builder() + .blocks(2, 2) + .build(), format, ConvolutionMode.Same, InputType.convolutional(16, 16, 3, format)); + } + } + + private MultiLayerNetwork getLocallyConnectedNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { + if (setOnLayerAlso) { + return getNetWithLayer(new LocallyConnected2D.Builder() + .kernelSize(3, 3) + .stride(2, 2) + .activation(Activation.TANH) + .dataFormat(format) + .nOut(3) + .build(), format, cm, null); + } else { + return getNetWithLayer(new LocallyConnected2D.Builder() + .kernelSize(3, 3) + .stride(2, 2) + .activation(Activation.TANH) + .nOut(3) + .build(), format, cm, null); + } + } + + private MultiLayerNetwork getNetWithLayer(Layer layer, CNN2DFormat format, ConvolutionMode cm, InputType inputType) { + NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder() + .dataType(this.dataType) + .seed(12345) + .convolutionMode(cm) + .list() + .layer(new ConvolutionLayer.Builder() + .kernelSize(3, 3) + .stride(2, 2) + .activation(Activation.TANH) + .dataFormat(format) + .nOut(3) + .helperAllowFallback(false) + .build()) + .layer(layer) + .layer(new OutputLayer.Builder().activation(Activation.SOFTMAX).nOut(10).build()) + .setInputType(inputType != null ? inputType : InputType.convolutional(12, 12, 3, format)); + + if(format == CNN2DFormat.NHWC && !(layer instanceof GlobalPoolingLayer)){ + //Add a preprocessor due to the differences in how NHWC and NCHW activations are flattened + //DL4J's flattening behaviour matches Keras (hence TF) for import compatibility + builder.inputPreProcessor(2, new ComposableInputPreProcessor(new NHWCToNCHWPreprocessor(), new CnnToFeedForwardPreProcessor())); + } + + MultiLayerNetwork net = new MultiLayerNetwork(builder.build()); + net.init(); + return net; + } + + private MultiLayerNetwork getGlobalPoolingNet(CNN2DFormat format, PoolingType pt, boolean setOnLayerAlso) { + if (setOnLayerAlso) { + return getNetWithLayer(new GlobalPoolingLayer.Builder(pt) + .poolingDimensions(format == CNN2DFormat.NCHW ? new int[]{2,3} : new int[]{1,2}) + .build(), format, ConvolutionMode.Same, null); + } else { + return getNetWithLayer(new GlobalPoolingLayer.Builder(pt) + .build(), format, ConvolutionMode.Same, null); + } + } + + private MultiLayerNetwork getCnnLossNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm){ + NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder() + .seed(12345) + .convolutionMode(cm) + .list() + .layer(new ConvolutionLayer.Builder() + .kernelSize(3, 3) + .stride(2, 2) + .activation(Activation.TANH) + .dataFormat(format) + .nOut(3) + .helperAllowFallback(false) + .build()); + if(setOnLayerAlso){ + builder.layer(new CnnLossLayer.Builder().format(format).activation(Activation.SOFTMAX).build()); + } else { + builder.layer(new CnnLossLayer.Builder().activation(Activation.SOFTMAX).build()); + } + + builder.setInputType(InputType.convolutional(12, 12, 3, format)); + + MultiLayerNetwork net = new MultiLayerNetwork(builder.build()); + net.init(); + return net; + } + + @AllArgsConstructor + @Data + @NoArgsConstructor + @Builder + private static class TestCase { + private String msg; + private MultiLayerNetwork net1; + private MultiLayerNetwork net2; + private MultiLayerNetwork net3; + private MultiLayerNetwork net4; + private INDArray inNCHW; + private INDArray labelsNCHW; + private INDArray labelsNHWC; + private int testLayerIdx; + private boolean nhwcOutput; + } + + public static void testHelper(TestCase tc) { + + tc.net2.params().assign(tc.net1.params()); + tc.net3.params().assign(tc.net1.params()); + tc.net4.params().assign(tc.net1.params()); + + //Test forward pass: + INDArray inNCHW = tc.inNCHW; + INDArray inNHWC = tc.inNCHW.permute(0, 2, 3, 1).dup(); + + INDArray l0_1 = tc.net1.feedForward(inNCHW).get(tc.testLayerIdx + 1); + INDArray l0_2 = tc.net2.feedForward(inNCHW).get(tc.testLayerIdx + 1); + INDArray l0_3 = tc.net3.feedForward(inNHWC).get(tc.testLayerIdx + 1); + INDArray l0_4 = tc.net4.feedForward(inNHWC).get(tc.testLayerIdx + 1); + + assertEquals(tc.msg, l0_1, l0_2); + if(l0_1.rank() == 4) { + assertEquals(tc.msg, l0_1, l0_3.permute(0, 3, 1, 2)); + assertEquals(tc.msg, l0_1, l0_4.permute(0, 3, 1, 2)); + } else { + assertEquals(tc.msg, l0_1, l0_3); + assertEquals(tc.msg, l0_1, l0_4); + } + + + INDArray out1 = tc.net1.output(inNCHW); + INDArray out2 = tc.net2.output(inNCHW); + INDArray out3 = tc.net3.output(inNHWC); + INDArray out4 = tc.net4.output(inNHWC); + + assertEquals(tc.msg, out1, out2); + if(!tc.nhwcOutput) { + assertEquals(tc.msg, out1, out3); + assertEquals(tc.msg, out1, out4); + } else { + assertEquals(tc.msg, out1, out3.permute(0,3,1,2)); //NHWC to NCHW + assertEquals(tc.msg, out1, out4.permute(0,3,1,2)); + } + + //Test backprop + Pair p1 = tc.net1.calculateGradients(inNCHW, tc.labelsNCHW, null, null); + Pair p2 = tc.net2.calculateGradients(inNCHW, tc.labelsNCHW, null, null); + Pair p3 = tc.net3.calculateGradients(inNHWC, tc.labelsNHWC, null, null); + Pair p4 = tc.net4.calculateGradients(inNHWC, tc.labelsNHWC, null, null); + + //Inpput gradients + assertEquals(tc.msg, p1.getSecond(), p2.getSecond()); + assertEquals(tc.msg, p1.getSecond(), p3.getSecond().permute(0,3,1,2)); //Input gradients for NHWC input are also in NHWC format + assertEquals(tc.msg, p1.getSecond(), p4.getSecond().permute(0,3,1,2)); + + List diff12 = differentGrads(p1.getFirst(), p2.getFirst()); + List diff13 = differentGrads(p1.getFirst(), p3.getFirst()); + List diff14 = differentGrads(p1.getFirst(), p4.getFirst()); + assertEquals(tc.msg + " " + diff12, 0, diff12.size()); + assertEquals(tc.msg + " " + diff13, 0, diff13.size()); + assertEquals(tc.msg + " " + diff14, 0, diff14.size()); + + assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p2.getFirst().gradientForVariable()); + assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p3.getFirst().gradientForVariable()); + assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p4.getFirst().gradientForVariable()); + + tc.net1.fit(inNCHW, tc.labelsNCHW); + tc.net2.fit(inNCHW, tc.labelsNCHW); + tc.net3.fit(inNHWC, tc.labelsNHWC); + tc.net4.fit(inNHWC, tc.labelsNHWC); + + assertEquals(tc.msg, tc.net1.params(), tc.net2.params()); + assertEquals(tc.msg, tc.net1.params(), tc.net3.params()); + assertEquals(tc.msg, tc.net1.params(), tc.net4.params()); + + //Test serialization + MultiLayerNetwork net1a = TestUtils.testModelSerialization(tc.net1); + MultiLayerNetwork net2a = TestUtils.testModelSerialization(tc.net2); + MultiLayerNetwork net3a = TestUtils.testModelSerialization(tc.net3); + MultiLayerNetwork net4a = TestUtils.testModelSerialization(tc.net4); + + out1 = tc.net1.output(inNCHW); + assertEquals(tc.msg, out1, net1a.output(inNCHW)); + assertEquals(tc.msg, out1, net2a.output(inNCHW)); + if(!tc.nhwcOutput) { + assertEquals(tc.msg, out1, net3a.output(inNHWC)); + assertEquals(tc.msg, out1, net4a.output(inNHWC)); + } else { + assertEquals(tc.msg, out1, net3a.output(inNHWC).permute(0,3,1,2)); //NHWC to NCHW + assertEquals(tc.msg, out1, net4a.output(inNHWC).permute(0,3,1,2)); + } + + } + + private static List differentGrads(Gradient g1, Gradient g2){ + List differs = new ArrayList<>(); + Map m1 = g1.gradientForVariable(); + Map m2 = g2.gradientForVariable(); + for(String s : m1.keySet()){ + INDArray a1 = m1.get(s); + INDArray a2 = m2.get(s); + if(!a1.equals(a2)){ + differs.add(s); + } + } + return differs; + } + + + //Converts NHWC to NCHW activations + @EqualsAndHashCode + private static class NHWCToNCHWPreprocessor implements InputPreProcessor { + + @Override + public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) { + return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.permute(0,3,1,2)); + } + + @Override + public INDArray backprop(INDArray output, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) { + return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, output.permute(0,2,3,1)); + } + + @Override + public InputPreProcessor clone() { + return this; + } + + @Override + public InputType getOutputType(InputType inputType) { + InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType; + return InputType.convolutional(c.getHeight(), c.getWidth(), c.getChannels(), CNN2DFormat.NCHW); + } + + @Override + public Pair feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) { + return null; + } + } +} diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/TestConvolutionModes.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/TestConvolutionModes.java index 037640ebb..0f5b5d6d1 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/TestConvolutionModes.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/TestConvolutionModes.java @@ -16,6 +16,7 @@ package org.deeplearning4j.nn.layers.convolution; +import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.exception.DL4JException; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; @@ -45,6 +46,7 @@ import static org.junit.Assert.*; /** * Created by Alex on 15/11/2016. */ +@Slf4j public class TestConvolutionModes extends BaseDL4JTest { @Test @@ -106,12 +108,12 @@ public class TestConvolutionModes extends BaseDL4JTest { } } catch (DL4JException e) { if (inSize == 9 || cm != ConvolutionMode.Strict) { - e.printStackTrace(); + log.error("",e); fail("Unexpected exception"); } continue; //Expected exception } catch (Exception e) { - e.printStackTrace(); + log.error("",e); fail("Unexpected exception"); } @@ -184,12 +186,12 @@ public class TestConvolutionModes extends BaseDL4JTest { } } catch (DL4JException e) { if (inSize == 9 || cm != ConvolutionMode.Strict) { - e.printStackTrace(); + log.error("",e); fail("Unexpected exception"); } continue; //Expected exception } catch (Exception e) { - e.printStackTrace(); + log.error("",e); fail("Unexpected exception"); } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalTest.java index f7a4c087f..489687679 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalTest.java @@ -24,10 +24,7 @@ import org.deeplearning4j.earlystopping.saver.InMemoryModelSaver; import org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculator; import org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition; import org.deeplearning4j.earlystopping.trainer.EarlyStoppingGraphTrainer; -import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.WorkspaceMode; +import org.deeplearning4j.nn.conf.*; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM; import org.deeplearning4j.nn.conf.layers.GravesLSTM; @@ -45,6 +42,8 @@ import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.util.ModelSerializer; import org.deeplearning4j.util.TimeSeriesUtils; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -61,12 +60,22 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; +import static org.deeplearning4j.nn.conf.RNNFormat.NCW; import static org.junit.Assert.assertEquals; @Slf4j +@RunWith(Parameterized.class) public class BidirectionalTest extends BaseDL4JTest { + private RNNFormat rnnDataFormat; + public BidirectionalTest(RNNFormat rnnDataFormat){ + this.rnnDataFormat = rnnDataFormat; + } + @Parameterized.Parameters + public static Object[] params(){ + return RNNFormat.values(); + } @Test public void compareImplementations(){ for(WorkspaceMode wsm : WorkspaceMode.values()) { @@ -82,9 +91,9 @@ public class BidirectionalTest extends BaseDL4JTest { .inferenceWorkspaceMode(wsm) .updater(new Adam()) .list() - .layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build())) - .layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build())) - .layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE) + .layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build())) + .layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build())) + .layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).dataFormat(rnnDataFormat) .nIn(10).nOut(10).build()) .build(); @@ -95,9 +104,9 @@ public class BidirectionalTest extends BaseDL4JTest { .inferenceWorkspaceMode(wsm) .updater(new Adam()) .list() - .layer(new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10).build()) - .layer(new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10).build()) - .layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE) + .layer(new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()) + .layer(new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()) + .layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).dataFormat(rnnDataFormat) .nIn(10).nOut(10).build()) .build(); @@ -116,15 +125,24 @@ public class BidirectionalTest extends BaseDL4JTest { net2.setParams(net1.params()); //Assuming exact same layout here... - INDArray in = Nd4j.rand(new int[]{3, 10, 5}); + INDArray in; + if (rnnDataFormat == NCW){ + in = Nd4j.rand(new int[]{3, 10, 5}); + }else{ + in = Nd4j.rand(new int[]{3, 5, 10}); + } INDArray out1 = net1.output(in); INDArray out2 = net2.output(in); assertEquals(out1, out2); - INDArray labels = Nd4j.rand(new int[]{3, 10, 5}); - + INDArray labels; + if (rnnDataFormat == NCW){ + labels = Nd4j.rand(new int[]{3, 10, 5}); + }else{ + labels = Nd4j.rand(new int[]{3, 5, 10}); + } net1.setInput(in); net1.setLabels(labels); @@ -276,17 +294,22 @@ public class BidirectionalTest extends BaseDL4JTest { .inferenceWorkspaceMode(wsm) .updater(new Adam()) .list() - .layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build())) - .layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build())) + .layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build())) + .layer(new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build())) .layer(new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE) - .nIn(10).nOut(10).build()) + .nIn(10).nOut(10).dataFormat(rnnDataFormat).build()) .build(); MultiLayerNetwork net1 = new MultiLayerNetwork(conf1); net1.init(); - INDArray in = Nd4j.rand(new int[]{3, 10, 5}); - INDArray labels = Nd4j.rand(new int[]{3, 10, 5}); + INDArray in; + INDArray labels; + + long[] inshape = rnnDataFormat == NCW ? new long[]{3, 10, 5} : new long[]{3, 5, 10}; + + in = Nd4j.rand(inshape); + labels = Nd4j.rand(inshape); net1.fit(in, labels); @@ -300,8 +323,8 @@ public class BidirectionalTest extends BaseDL4JTest { MultiLayerNetwork net2 = ModelSerializer.restoreMultiLayerNetwork(new ByteArrayInputStream(bytes), true); - in = Nd4j.rand(new int[]{3, 10, 5}); - labels = Nd4j.rand(new int[]{3, 10, 5}); + in = Nd4j.rand(inshape); + labels = Nd4j.rand(inshape); INDArray out1 = net1.output(in); INDArray out2 = net2.output(in); @@ -338,18 +361,18 @@ public class BidirectionalTest extends BaseDL4JTest { .updater(new Adam()) .graphBuilder() .addInputs("in") - .layer("0", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()), "in") - .layer("1", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).build()), "0") - .layer("2", new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE) + .layer("0", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()), "in") + .layer("1", new Bidirectional(Bidirectional.Mode.ADD, new GravesLSTM.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()), "0") + .layer("2", new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).dataFormat(rnnDataFormat) .nIn(10).nOut(10).build(), "1") .setOutputs("2") .build(); ComputationGraph net1 = new ComputationGraph(conf1); net1.init(); - - INDArray in = Nd4j.rand(new int[]{3, 10, 5}); - INDArray labels = Nd4j.rand(new int[]{3, 10, 5}); + long[] inshape = (rnnDataFormat == NCW)? new long[]{3, 10, 5}: new long[]{3, 5, 10}; + INDArray in = Nd4j.rand(inshape); + INDArray labels = Nd4j.rand(inshape); net1.fit(new DataSet(in, labels)); @@ -363,8 +386,8 @@ public class BidirectionalTest extends BaseDL4JTest { ComputationGraph net2 = ModelSerializer.restoreComputationGraph(new ByteArrayInputStream(bytes), true); - in = Nd4j.rand(new int[]{3, 10, 5}); - labels = Nd4j.rand(new int[]{3, 10, 5}); + in = Nd4j.rand(inshape); + labels = Nd4j.rand(inshape); INDArray out1 = net1.outputSingle(in); INDArray out2 = net2.outputSingle(in); @@ -394,8 +417,8 @@ public class BidirectionalTest extends BaseDL4JTest { Bidirectional.Mode[] modes = new Bidirectional.Mode[]{Bidirectional.Mode.CONCAT, Bidirectional.Mode.ADD, Bidirectional.Mode.AVERAGE, Bidirectional.Mode.MUL}; - - INDArray in = Nd4j.rand(new int[]{3, 10, 6}); + long[] inshape = rnnDataFormat == NCW ? new long[]{3, 10, 6} : new long[]{3, 6, 10}; + INDArray in = Nd4j.rand(inshape); for (Bidirectional.Mode m : modes) { MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder() @@ -406,7 +429,7 @@ public class BidirectionalTest extends BaseDL4JTest { .inferenceWorkspaceMode(wsm) .updater(new Adam()) .list() - .layer(new Bidirectional(m, new SimpleRnn.Builder().nIn(10).nOut(10).build())) + .layer(new Bidirectional(m, new SimpleRnn.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build())) .build(); MultiLayerNetwork net1 = new MultiLayerNetwork(conf1); @@ -418,7 +441,7 @@ public class BidirectionalTest extends BaseDL4JTest { .weightInit(WeightInit.XAVIER) .updater(new Adam()) .list() - .layer(new SimpleRnn.Builder().nIn(10).nOut(10).build()) + .layer(new SimpleRnn.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()) .build(); MultiLayerNetwork net2 = new MultiLayerNetwork(conf2.clone()); @@ -434,11 +457,10 @@ public class BidirectionalTest extends BaseDL4JTest { net3.setParam("0_RW", net1.getParam("0_bRW")); net3.setParam("0_b", net1.getParam("0_bb")); - INDArray inReverse = TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT); - + INDArray inReverse = TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT, rnnDataFormat); INDArray out1 = net1.output(in); INDArray out2 = net2.output(in); - INDArray out3 = TimeSeriesUtils.reverseTimeSeries(net3.output(inReverse), LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT); + INDArray out3 = TimeSeriesUtils.reverseTimeSeries(net3.output(inReverse), LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT, rnnDataFormat); INDArray outExp; switch (m) { @@ -452,7 +474,7 @@ public class BidirectionalTest extends BaseDL4JTest { outExp = out2.add(out3).muli(0.5); break; case CONCAT: - outExp = Nd4j.concat(1, out2, out3); + outExp = Nd4j.concat((rnnDataFormat == NCW)?1:2, out2, out3); break; default: throw new RuntimeException(); @@ -464,25 +486,25 @@ public class BidirectionalTest extends BaseDL4JTest { //Check gradients: if (m == Bidirectional.Mode.ADD || m == Bidirectional.Mode.CONCAT) { - INDArray eps = Nd4j.rand(new int[]{3, 10, 6}); + INDArray eps = Nd4j.rand(inshape); INDArray eps1; if (m == Bidirectional.Mode.CONCAT) { - eps1 = Nd4j.concat(1, eps, eps); + eps1 = Nd4j.concat((rnnDataFormat == NCW)?1:2, eps, eps); } else { eps1 = eps; } net1.setInput(in); net2.setInput(in); - net3.setInput(TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT)); + net3.setInput(TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT, rnnDataFormat)); net1.feedForward(true, false); net2.feedForward(true, false); net3.feedForward(true, false); Pair p1 = net1.backpropGradient(eps1, LayerWorkspaceMgr.noWorkspaces()); Pair p2 = net2.backpropGradient(eps, LayerWorkspaceMgr.noWorkspaces()); - Pair p3 = net3.backpropGradient(TimeSeriesUtils.reverseTimeSeries(eps, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT), LayerWorkspaceMgr.noWorkspaces()); + Pair p3 = net3.backpropGradient(TimeSeriesUtils.reverseTimeSeries(eps, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT, rnnDataFormat), LayerWorkspaceMgr.noWorkspaces()); Gradient g1 = p1.getFirst(); Gradient g2 = p2.getFirst(); Gradient g3 = p3.getFirst(); @@ -520,7 +542,9 @@ public class BidirectionalTest extends BaseDL4JTest { Bidirectional.Mode.AVERAGE, Bidirectional.Mode.MUL}; - INDArray in = Nd4j.rand(new int[]{3, 10, 6}); + long[] inshape = rnnDataFormat == NCW ? new long[]{3, 10, 6} : new long[]{3, 6, 10}; + INDArray in = Nd4j.rand(inshape); + for (Bidirectional.Mode m : modes) { ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder() @@ -532,7 +556,7 @@ public class BidirectionalTest extends BaseDL4JTest { .updater(new Adam()) .graphBuilder() .addInputs("in") - .layer("0", new Bidirectional(m, new SimpleRnn.Builder().nIn(10).nOut(10).build()), "in") + .layer("0", new Bidirectional(m, new SimpleRnn.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build()), "in") .setOutputs("0") .build(); @@ -546,7 +570,7 @@ public class BidirectionalTest extends BaseDL4JTest { .updater(new Adam()) .graphBuilder() .addInputs("in") - .layer("0", new SimpleRnn.Builder().nIn(10).nOut(10).build(), "in") + .layer("0", new SimpleRnn.Builder().nIn(10).nOut(10).dataFormat(rnnDataFormat).build(), "in") .setOutputs("0") .build(); @@ -566,9 +590,20 @@ public class BidirectionalTest extends BaseDL4JTest { INDArray out1 = net1.outputSingle(in); INDArray out2 = net2.outputSingle(in); - INDArray out3 = TimeSeriesUtils.reverseTimeSeries(net3.outputSingle( - TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT)), - LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT); + INDArray out3; + INDArray inReverse; + if (rnnDataFormat == RNNFormat.NWC){ + inReverse = TimeSeriesUtils.reverseTimeSeries(in.permute(0, 2, 1), LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT).permute(0, 2, 1); + out3 = net3.outputSingle(inReverse); + out3 = TimeSeriesUtils.reverseTimeSeries(out3.permute(0, 2, 1), LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT).permute(0, 2, 1); + + } + else{ + inReverse = TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT); + out3 = net3.outputSingle(inReverse); + out3 = TimeSeriesUtils.reverseTimeSeries(out3, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT); + + } INDArray outExp; switch (m) { @@ -582,7 +617,9 @@ public class BidirectionalTest extends BaseDL4JTest { outExp = out2.add(out3).muli(0.5); break; case CONCAT: - outExp = Nd4j.concat(1, out2, out3); + System.out.println(out2.shapeInfoToString()); + System.out.println(out3.shapeInfoToString()); + outExp = Nd4j.concat((rnnDataFormat == NCW)?1:2, out2, out3); break; default: throw new RuntimeException(); @@ -594,22 +631,26 @@ public class BidirectionalTest extends BaseDL4JTest { //Check gradients: if (m == Bidirectional.Mode.ADD || m == Bidirectional.Mode.CONCAT) { - INDArray eps = Nd4j.rand(new int[]{3, 10, 6}); + INDArray eps = Nd4j.rand(inshape); INDArray eps1; if (m == Bidirectional.Mode.CONCAT) { - eps1 = Nd4j.concat(1, eps, eps); + eps1 = Nd4j.concat((rnnDataFormat == NCW)?1:2, eps, eps); } else { eps1 = eps; } + INDArray epsReversed = (rnnDataFormat == NCW)? + TimeSeriesUtils.reverseTimeSeries(eps, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT): + TimeSeriesUtils.reverseTimeSeries(eps.permute(0, 2, 1), LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT) + .permute(0, 2, 1); net1.outputSingle(true, false, in); net2.outputSingle(true, false, in); - net3.outputSingle(true, false, TimeSeriesUtils.reverseTimeSeries(in, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT)); + net3.outputSingle(true, false, inReverse); Gradient g1 = net1.backpropGradient(eps1); Gradient g2 = net2.backpropGradient(eps); - Gradient g3 = net3.backpropGradient(TimeSeriesUtils.reverseTimeSeries(eps, LayerWorkspaceMgr.noWorkspaces(), ArrayType.INPUT)); + Gradient g3 = net3.backpropGradient(epsReversed); for (boolean updates : new boolean[]{false, true}) { if (updates) { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTMTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTMTest.java index 751b6f6bf..441267e86 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTMTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTMTest.java @@ -23,6 +23,7 @@ import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.CacheMode; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.distribution.UniformDistribution; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; @@ -31,6 +32,8 @@ import org.deeplearning4j.nn.params.GravesLSTMParamInitializer; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.impl.ActivationSigmoid; import org.nd4j.linalg.api.ndarray.INDArray; @@ -42,10 +45,18 @@ import org.nd4j.linalg.primitives.Pair; import static org.junit.Assert.*; - +@RunWith(Parameterized.class) public class GravesBidirectionalLSTMTest extends BaseDL4JTest { private double score = 0.0; + private RNNFormat rnnDataFormat; + public GravesBidirectionalLSTMTest(RNNFormat rnnDataFormat){ + this.rnnDataFormat = rnnDataFormat; + } + @Parameterized.Parameters + public static Object[] params(){ + return RNNFormat.values(); + } @Test public void testBidirectionalLSTMGravesForwardBasic() { //Very basic test of forward prop. of LSTM layer with a time series. @@ -55,7 +66,7 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { final NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() .layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn) - .nOut(nHiddenUnits).activation(Activation.TANH).build()) + .nOut(nHiddenUnits).dataFormat(rnnDataFormat).activation(Activation.TANH).build()) .build(); val numParams = conf.getLayer().initializer().numParams(conf); @@ -65,22 +76,41 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { //Data: has shape [miniBatchSize,nIn,timeSeriesLength]; //Output/activations has shape [miniBatchsize,nHiddenUnits,timeSeriesLength]; + if (rnnDataFormat == RNNFormat.NCW){ + final INDArray dataSingleExampleTimeLength1 = Nd4j.ones(1, nIn, 1); + final INDArray activations1 = layer.activate(dataSingleExampleTimeLength1, false, LayerWorkspaceMgr.noWorkspaces()); + assertArrayEquals(activations1.shape(), new long[] {1, nHiddenUnits, 1}); - final INDArray dataSingleExampleTimeLength1 = Nd4j.ones(1, nIn, 1); - final INDArray activations1 = layer.activate(dataSingleExampleTimeLength1, false, LayerWorkspaceMgr.noWorkspaces()); - assertArrayEquals(activations1.shape(), new long[] {1, nHiddenUnits, 1}); + final INDArray dataMultiExampleLength1 = Nd4j.ones(10, nIn, 1); + final INDArray activations2 = layer.activate(dataMultiExampleLength1, false, LayerWorkspaceMgr.noWorkspaces()); + assertArrayEquals(activations2.shape(), new long[] {10, nHiddenUnits, 1}); - final INDArray dataMultiExampleLength1 = Nd4j.ones(10, nIn, 1); - final INDArray activations2 = layer.activate(dataMultiExampleLength1, false, LayerWorkspaceMgr.noWorkspaces()); - assertArrayEquals(activations2.shape(), new long[] {10, nHiddenUnits, 1}); + final INDArray dataSingleExampleLength12 = Nd4j.ones(1, nIn, 12); + final INDArray activations3 = layer.activate(dataSingleExampleLength12, false, LayerWorkspaceMgr.noWorkspaces()); + assertArrayEquals(activations3.shape(), new long[] {1, nHiddenUnits, 12}); - final INDArray dataSingleExampleLength12 = Nd4j.ones(1, nIn, 12); - final INDArray activations3 = layer.activate(dataSingleExampleLength12, false, LayerWorkspaceMgr.noWorkspaces()); - assertArrayEquals(activations3.shape(), new long[] {1, nHiddenUnits, 12}); + final INDArray dataMultiExampleLength15 = Nd4j.ones(10, nIn, 15); + final INDArray activations4 = layer.activate(dataMultiExampleLength15, false, LayerWorkspaceMgr.noWorkspaces()); + assertArrayEquals(activations4.shape(), new long[] {10, nHiddenUnits, 15}); + } + else{ + final INDArray dataSingleExampleTimeLength1 = Nd4j.ones(1, 1, nIn); + final INDArray activations1 = layer.activate(dataSingleExampleTimeLength1, false, LayerWorkspaceMgr.noWorkspaces()); + assertArrayEquals(activations1.shape(), new long[] {1, 1, nHiddenUnits}); + + final INDArray dataMultiExampleLength1 = Nd4j.ones(10, 1, nIn); + final INDArray activations2 = layer.activate(dataMultiExampleLength1, false, LayerWorkspaceMgr.noWorkspaces()); + assertArrayEquals(activations2.shape(), new long[] {10, 1, nHiddenUnits}); + + final INDArray dataSingleExampleLength12 = Nd4j.ones(1, 12, nIn); + final INDArray activations3 = layer.activate(dataSingleExampleLength12, false, LayerWorkspaceMgr.noWorkspaces()); + assertArrayEquals(activations3.shape(), new long[] {1, 12, nHiddenUnits}); + + final INDArray dataMultiExampleLength15 = Nd4j.ones(10, 15, nIn); + final INDArray activations4 = layer.activate(dataMultiExampleLength15, false, LayerWorkspaceMgr.noWorkspaces()); + assertArrayEquals(activations4.shape(), new long[] {10, 15, nHiddenUnits}); + } - final INDArray dataMultiExampleLength15 = Nd4j.ones(10, nIn, 15); - final INDArray activations4 = layer.activate(dataMultiExampleLength15, false, LayerWorkspaceMgr.noWorkspaces()); - assertArrayEquals(activations4.shape(), new long[] {10, nHiddenUnits, 15}); } @Test @@ -94,14 +124,15 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { testGravesBackwardBasicHelper(13, 3, 17, 1, 1); //Edge case: both miniBatchSize = 1 and timeSeriesLength = 1 } - private static void testGravesBackwardBasicHelper(int nIn, int nOut, int lstmNHiddenUnits, int miniBatchSize, + private void testGravesBackwardBasicHelper(int nIn, int nOut, int lstmNHiddenUnits, int miniBatchSize, int timeSeriesLength) { - INDArray inputData = Nd4j.ones(miniBatchSize, nIn, timeSeriesLength); + INDArray inputData = (rnnDataFormat == RNNFormat.NCW)?Nd4j.ones(miniBatchSize, nIn, timeSeriesLength): + Nd4j.ones(miniBatchSize, timeSeriesLength, nIn); NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() .layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn) - .nOut(lstmNHiddenUnits) + .nOut(lstmNHiddenUnits).dataFormat(rnnDataFormat) .dist(new UniformDistribution(0, 1)).activation(Activation.TANH).build()) .build(); @@ -114,7 +145,8 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { lstm.activate(inputData, false, LayerWorkspaceMgr.noWorkspaces()); assertNotNull(lstm.input()); - INDArray epsilon = Nd4j.ones(miniBatchSize, lstmNHiddenUnits, timeSeriesLength); + INDArray epsilon =(rnnDataFormat == RNNFormat.NCW)? Nd4j.ones(miniBatchSize, lstmNHiddenUnits, timeSeriesLength): + Nd4j.ones(miniBatchSize, timeSeriesLength, lstmNHiddenUnits); Pair out = lstm.backpropGradient(epsilon, LayerWorkspaceMgr.noWorkspaces()); Gradient outGradient = out.getFirst(); @@ -147,7 +179,11 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { assertArrayEquals(recurrentWeightGradientB.shape(), new long[] {lstmNHiddenUnits, 4 * lstmNHiddenUnits + 3}); assertNotNull(nextEpsilon); - assertArrayEquals(nextEpsilon.shape(), new long[] {miniBatchSize, nIn, timeSeriesLength}); + if (rnnDataFormat == RNNFormat.NCW) { + assertArrayEquals(nextEpsilon.shape(), new long[]{miniBatchSize, nIn, timeSeriesLength}); + }else{ + assertArrayEquals(nextEpsilon.shape(), new long[]{miniBatchSize, timeSeriesLength, nIn }); + } //Check update: for (String s : outGradient.gradientForVariable().keySet()) { @@ -226,7 +262,7 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { final NeuralNetConfiguration confBidirectional = new NeuralNetConfiguration.Builder() .layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn) - .nOut(layerSize) + .nOut(layerSize).dataFormat(rnnDataFormat) .dist(new UniformDistribution(-0.1, 0.1)).activation(Activation.TANH).build()) .build(); @@ -237,7 +273,8 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { .instantiate(confBidirectional, null, 0, params, true, params.dataType()); - final INDArray sig = Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength}); + final INDArray sig = (rnnDataFormat == RNNFormat.NCW)?Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength}): + Nd4j.rand(new int[] {miniBatchSize, timeSeriesLength, nIn}); final INDArray act1 = bidirectionalLSTM.activate(sig, false, LayerWorkspaceMgr.noWorkspaces()); @@ -265,13 +302,13 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { final NeuralNetConfiguration confBidirectional = new NeuralNetConfiguration.Builder() .layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder() - .nIn(nIn).nOut(layerSize) + .nIn(nIn).nOut(layerSize).dataFormat(rnnDataFormat) .dist(new UniformDistribution(-0.1, 0.1)) .activation(Activation.TANH).updater(new NoOp()).build()) .build(); final NeuralNetConfiguration confForwards = new NeuralNetConfiguration.Builder() - .layer(new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(layerSize) + .layer(new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(layerSize).dataFormat(rnnDataFormat) .weightInit(WeightInit.ZERO).activation(Activation.TANH).build()) .build(); @@ -290,9 +327,16 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { Nd4j.create(1, confForwards.getLayer().initializer().numParams(confForwards))); - final INDArray sig = Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength}); + final INDArray sig = (rnnDataFormat == RNNFormat.NCW)?Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength}): + Nd4j.rand(new int[] {miniBatchSize, timeSeriesLength, nIn}); final INDArray sigb = sig.dup(); - reverseColumnsInPlace(sigb.slice(0)); + + if (rnnDataFormat == RNNFormat.NCW) { + reverseColumnsInPlace(sigb.slice(0)); + } + else{ + reverseColumnsInPlace(sigb.slice(0).permute(1, 0)); + } final INDArray recurrentWeightsF = bidirectionalLSTM .getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS); @@ -345,10 +389,14 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { assertArrayEquals(activation1.data().asFloat(), activation2.data().asFloat(), 1e-5f); - final INDArray randSig = Nd4j.rand(new int[] {1, layerSize, timeSeriesLength}); - final INDArray randSigBackwards = randSig.dup(); - reverseColumnsInPlace(randSigBackwards.slice(0)); - + final INDArray randSig = (rnnDataFormat == RNNFormat.NCW)?Nd4j.rand(new int[] {1, layerSize, timeSeriesLength}): + Nd4j.rand(new int[] {1, timeSeriesLength, layerSize}); + INDArray randSigBackwards = randSig.dup(); + if (rnnDataFormat == RNNFormat.NCW){ + reverseColumnsInPlace(randSigBackwards.slice(0)); + }else{ + reverseColumnsInPlace(randSigBackwards.slice(0).permute(1, 0)); + } final Pair backprop1 = forwardsLSTM.backpropGradient(randSig, LayerWorkspaceMgr.noWorkspaces()); final Pair backprop2 = bidirectionalLSTM.backpropGradient(randSig, LayerWorkspaceMgr.noWorkspaces()); @@ -399,10 +447,16 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { final INDArray activation3 = bidirectionalLSTM.activate(sigb, false, LayerWorkspaceMgr.noWorkspaces()).slice(0); final INDArray activation3Reverse = activation3.dup(); - reverseColumnsInPlace(activation3Reverse); + if (rnnDataFormat == RNNFormat.NCW){ + reverseColumnsInPlace(activation3Reverse); + } + else{ + reverseColumnsInPlace(activation3Reverse.permute(1, 0)); + } - assertEquals(activation3Reverse, activation1); assertArrayEquals(activation3Reverse.shape(), activation1.shape()); + assertEquals(activation3Reverse, activation1); + //test backprop now final INDArray refBackGradientReccurrent = @@ -434,7 +488,12 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { final INDArray refEpsilon = backprop1.getSecond().dup(); final INDArray backEpsilon = backprop3.getSecond().dup(); - reverseColumnsInPlace(refEpsilon.slice(0)); + if (rnnDataFormat == RNNFormat.NCW) { + reverseColumnsInPlace(refEpsilon.slice(0)); + } + else{ + reverseColumnsInPlace(refEpsilon.slice(0).permute(1, 0)); + } assertArrayEquals(backEpsilon.dup().data().asDouble(), refEpsilon.dup().data().asDouble(), 1e-6); } @@ -477,10 +536,10 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .seed(12345).list() .layer(0, new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder() - .gateActivationFunction(gateAfn).activation(Activation.TANH).nIn(2).nOut(2) + .gateActivationFunction(gateAfn).activation(Activation.TANH).nIn(2).nOut(2).dataFormat(rnnDataFormat) .build()) .layer(1, new org.deeplearning4j.nn.conf.layers.RnnOutputLayer.Builder() - .lossFunction(LossFunctions.LossFunction.MSE).nIn(2).nOut(2) + .lossFunction(LossFunctions.LossFunction.MSE).nIn(2).nOut(2).dataFormat(rnnDataFormat) .activation(Activation.TANH).build()) .build(); @@ -492,7 +551,10 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { INDArray in = Nd4j.rand(new int[] {3, 2, 5}); INDArray labels = Nd4j.rand(new int[] {3, 2, 5}); - + if (rnnDataFormat == RNNFormat.NWC){ + in = in.permute(0, 2, 1); + labels = labels.permute(0, 2, 1); + } net.fit(in, labels); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayerTest.java index d13545694..7ddc31220 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayerTest.java @@ -21,11 +21,14 @@ import org.deeplearning4j.TestUtils; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.layers.LSTM; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.optimize.api.TrainingListener; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -36,9 +39,17 @@ import java.util.Collections; import static org.junit.Assert.assertEquals; - +@RunWith(Parameterized.class) public class MaskZeroLayerTest extends BaseDL4JTest { + private RNNFormat rnnDataFormat; + public MaskZeroLayerTest(RNNFormat rnnDataFormat){ + this.rnnDataFormat = rnnDataFormat; + } + @Parameterized.Parameters + public static Object[] params(){ + return RNNFormat.values(); + } @Test public void activate() { @@ -57,7 +68,7 @@ public class MaskZeroLayerTest extends BaseDL4JTest { .activation(Activation.IDENTITY) .gateActivationFunction(Activation.IDENTITY) .nIn(2) - .nOut(1) + .nOut(1).dataFormat(rnnDataFormat) .build(); NeuralNetConfiguration conf = new NeuralNetConfiguration(); conf.setLayer(underlying); @@ -72,20 +83,25 @@ public class MaskZeroLayerTest extends BaseDL4JTest { MaskZeroLayer l = new MaskZeroLayer(lstm, maskingValue); INDArray input = Nd4j.create(Arrays.asList(ex1, ex2), new int[]{2, 2, 3}); + if (rnnDataFormat == RNNFormat.NWC){ + input = input.permute(0, 2, 1); + } //WHEN INDArray out = l.activate(input, true, LayerWorkspaceMgr.noWorkspaces()); - + if (rnnDataFormat == RNNFormat.NWC){ + out = out.permute(0, 2,1); + } //THEN output should only be incremented for the non-zero timesteps INDArray firstExampleOutput = out.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.all()); INDArray secondExampleOutput = out.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.all()); - assertEquals(firstExampleOutput.getDouble(0), 0.0, 1e-6); - assertEquals(firstExampleOutput.getDouble(1), 1.0, 1e-6); - assertEquals(firstExampleOutput.getDouble(2), 2.0, 1e-6); + assertEquals(0.0, firstExampleOutput.getDouble(0), 1e-6); + assertEquals(1.0, firstExampleOutput.getDouble(1), 1e-6); + assertEquals(2.0, firstExampleOutput.getDouble(2), 1e-6); - assertEquals(secondExampleOutput.getDouble(0), 0.0, 1e-6); - assertEquals(secondExampleOutput.getDouble(1), 0.0, 1e-6); - assertEquals(secondExampleOutput.getDouble(2), 1.0, 1e-6); + assertEquals(0.0, secondExampleOutput.getDouble(0), 1e-6); + assertEquals(0.0, secondExampleOutput.getDouble(1), 1e-6); + assertEquals(1.0, secondExampleOutput.getDouble(2), 1e-6); } @@ -94,7 +110,7 @@ public class MaskZeroLayerTest extends BaseDL4JTest { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .list() .layer(new org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer.Builder() - .setMaskValue(0.0).setUnderlying(new LSTM.Builder().nIn(4).nOut(5).build()).build()) + .setMaskValue(0.0).setUnderlying(new LSTM.Builder().nIn(4).nOut(5).dataFormat(rnnDataFormat).build()).build()) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/RnnDataFormatTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/RnnDataFormatTests.java new file mode 100644 index 000000000..43dd93f56 --- /dev/null +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/RnnDataFormatTests.java @@ -0,0 +1,394 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.nn.layers.recurrent; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; +import org.deeplearning4j.BaseDL4JTest; +import org.deeplearning4j.TestUtils; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.*; +import org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM; +import org.deeplearning4j.nn.conf.layers.GravesLSTM; +import org.deeplearning4j.nn.conf.layers.LSTM; +import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; +import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep; +import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; +import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer; +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.primitives.Pair; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; + +@RunWith(Parameterized.class) +@AllArgsConstructor +public class RnnDataFormatTests extends BaseDL4JTest { + + private boolean helpers; + private boolean lastTimeStep; + private boolean maskZeros; + + @Parameterized.Parameters(name = "helpers={0},lastTimeStep={1},maskZero={2}") + public static List params(){ + List ret = new ArrayList<>(); + for (boolean helpers: new boolean[]{true, false}) + for (boolean lastTimeStep: new boolean[]{true, false}) + for (boolean maskZero: new boolean[]{true, false}) + ret.add(new Object[]{helpers, lastTimeStep, maskZero}); + return ret; + } + + + @Test + public void testSimpleRnn() { + try { + + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = "Helpers: " + helpers + ", lastTimeStep: " + lastTimeStep + ", maskZeros: " + maskZeros; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCW = Nd4j.rand(DataType.FLOAT, 2, 3, 12); + + INDArray labelsNWC = (lastTimeStep) ?TestUtils.randomOneHot(2, 10): TestUtils.randomOneHot(2 * 12, 10).reshape(2, 12, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getSimpleRnnNet(RNNFormat.NCW, true, lastTimeStep, maskZeros)) + .net2(getSimpleRnnNet(RNNFormat.NCW, false, lastTimeStep, maskZeros)) + .net3(getSimpleRnnNet(RNNFormat.NWC, true, lastTimeStep, maskZeros)) + .net4(getSimpleRnnNet(RNNFormat.NWC, false, lastTimeStep, maskZeros)) + .inNCW(inNCW) + .labelsNCW((lastTimeStep)? labelsNWC: labelsNWC.permute(0, 2, 1)) + .labelsNWC(labelsNWC) + .testLayerIdx(1) + .build(); + + TestCase.testHelper(tc); + + + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + @Test + public void testLSTM() { + try { + + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = "Helpers: " + helpers + ", lastTimeStep: " + lastTimeStep + ", maskZeros: " + maskZeros; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCW = Nd4j.rand(DataType.FLOAT, 2, 3, 12); + + INDArray labelsNWC = (lastTimeStep) ?TestUtils.randomOneHot(2, 10): TestUtils.randomOneHot(2 * 12, 10).reshape(2, 12, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getLstmNet(RNNFormat.NCW, true, lastTimeStep, maskZeros)) + .net2(getLstmNet(RNNFormat.NCW, false, lastTimeStep, maskZeros)) + .net3(getLstmNet(RNNFormat.NWC, true, lastTimeStep, maskZeros)) + .net4(getLstmNet(RNNFormat.NWC, false, lastTimeStep, maskZeros)) + .inNCW(inNCW) + .labelsNCW((lastTimeStep)? labelsNWC: labelsNWC.permute(0, 2, 1)) + .labelsNWC(labelsNWC) + .testLayerIdx(1) + .build(); + + TestCase.testHelper(tc); + + + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + + @Test + public void testGraveLSTM() { + try { + + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = "Helpers: " + helpers + ", lastTimeStep: " + lastTimeStep + ", maskZeros: " + maskZeros; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCW = Nd4j.rand(DataType.FLOAT, 2, 3, 12); + + INDArray labelsNWC = (lastTimeStep) ?TestUtils.randomOneHot(2, 10): TestUtils.randomOneHot(2 * 12, 10).reshape(2, 12, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getGravesLstmNet(RNNFormat.NCW, true, lastTimeStep, maskZeros)) + .net2(getGravesLstmNet(RNNFormat.NCW, false, lastTimeStep, maskZeros)) + .net3(getGravesLstmNet(RNNFormat.NWC, true, lastTimeStep, maskZeros)) + .net4(getGravesLstmNet(RNNFormat.NWC, false, lastTimeStep, maskZeros)) + .inNCW(inNCW) + .labelsNCW((lastTimeStep)? labelsNWC: labelsNWC.permute(0, 2, 1)) + .labelsNWC(labelsNWC) + .testLayerIdx(1) + .build(); + + TestCase.testHelper(tc); + + + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + + @Test + public void testGraveBiLSTM() { + try { + + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = "Helpers: " + helpers + ", lastTimeStep: " + lastTimeStep + ", maskZeros: " + maskZeros; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCW = Nd4j.rand(DataType.FLOAT, 2, 3, 12); + + INDArray labelsNWC = (lastTimeStep) ?TestUtils.randomOneHot(2, 10): TestUtils.randomOneHot(2 * 12, 10).reshape(2, 12, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getGravesBidirectionalLstmNet(RNNFormat.NCW, true, lastTimeStep, maskZeros)) + .net2(getGravesBidirectionalLstmNet(RNNFormat.NCW, false, lastTimeStep, maskZeros)) + .net3(getGravesBidirectionalLstmNet(RNNFormat.NWC, true, lastTimeStep, maskZeros)) + .net4(getGravesBidirectionalLstmNet(RNNFormat.NWC, false, lastTimeStep, maskZeros)) + .inNCW(inNCW) + .labelsNCW((lastTimeStep)? labelsNWC: labelsNWC.permute(0, 2, 1)) + .labelsNWC(labelsNWC) + .testLayerIdx(1) + .build(); + + TestCase.testHelper(tc); + + + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + + private MultiLayerNetwork getGravesBidirectionalLstmNet(RNNFormat format, boolean setOnLayerAlso, boolean lastTimeStep, boolean maskZeros) { + if (setOnLayerAlso) { + return getNetWithLayer(new GravesBidirectionalLSTM.Builder().nOut(3) + .dataFormat(format).build(), format, lastTimeStep, maskZeros); + } else { + return getNetWithLayer(new GravesBidirectionalLSTM.Builder().nOut(3).build(), format, lastTimeStep, maskZeros); + } + } + private MultiLayerNetwork getGravesLstmNet(RNNFormat format, boolean setOnLayerAlso, boolean lastTimeStep, boolean maskZeros) { + if (setOnLayerAlso) { + return getNetWithLayer(new GravesLSTM.Builder().nOut(3) + .dataFormat(format).build(), format, lastTimeStep, maskZeros); + } else { + return getNetWithLayer(new GravesLSTM.Builder().nOut(3).build(), format, lastTimeStep, maskZeros); + } + } + + private MultiLayerNetwork getLstmNet(RNNFormat format, boolean setOnLayerAlso, boolean lastTimeStep, boolean maskZeros) { + if (setOnLayerAlso) { + return getNetWithLayer(new LSTM.Builder().nOut(3) + .dataFormat(format).build(), format, lastTimeStep, maskZeros); + } else { + return getNetWithLayer(new LSTM.Builder().nOut(3).build(), format, lastTimeStep, maskZeros); + } + } + + private MultiLayerNetwork getSimpleRnnNet(RNNFormat format, boolean setOnLayerAlso, boolean lastTimeStep, boolean maskZeros) { + if (setOnLayerAlso) { + return getNetWithLayer(new SimpleRnn.Builder().nOut(3) + .dataFormat(format).build(), format, lastTimeStep, maskZeros); + } else { + return getNetWithLayer(new SimpleRnn.Builder().nOut(3).build(), format, lastTimeStep, maskZeros); + } + } + private MultiLayerNetwork getNetWithLayer(Layer layer, RNNFormat format, boolean lastTimeStep, boolean maskZeros) { + if (maskZeros){ + layer = new MaskZeroLayer.Builder().setMaskValue(0.).setUnderlying(layer).build(); + } + if(lastTimeStep){ + layer = new LastTimeStep(layer); + } + NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder() + .seed(12345) + .list() + .layer(new LSTM.Builder() + .nIn(3) + .activation(Activation.TANH) + .dataFormat(format) + .nOut(3) + .helperAllowFallback(false) + .build()) + .layer(layer) + .layer( + (lastTimeStep)?new OutputLayer.Builder().activation(Activation.SOFTMAX).nOut(10).build(): + new RnnOutputLayer.Builder().activation(Activation.SOFTMAX).nOut(10).dataFormat(format).build() + ) + .setInputType(InputType.recurrent(3, 12, format)); + + MultiLayerNetwork net = new MultiLayerNetwork(builder.build()); + net.init(); + return net; + } + + @AllArgsConstructor + @Data + @NoArgsConstructor + @Builder + private static class TestCase { + private String msg; + private MultiLayerNetwork net1; + private MultiLayerNetwork net2; + private MultiLayerNetwork net3; + private MultiLayerNetwork net4; + private INDArray inNCW; + private INDArray labelsNCW; + private INDArray labelsNWC; + private int testLayerIdx; + private boolean nwcOutput; + + public static void testHelper(TestCase tc) { + + tc.net2.params().assign(tc.net1.params()); + tc.net3.params().assign(tc.net1.params()); + tc.net4.params().assign(tc.net1.params()); + + INDArray inNCW = tc.inNCW; + INDArray inNWC = tc.inNCW.permute(0, 2, 1).dup(); + + INDArray l0_1 = tc.net1.feedForward(inNCW).get(tc.testLayerIdx + 1); + INDArray l0_2 = tc.net2.feedForward(inNCW).get(tc.testLayerIdx + 1); + INDArray l0_3 = tc.net3.feedForward(inNWC).get(tc.testLayerIdx + 1); + INDArray l0_4 = tc.net4.feedForward(inNWC).get(tc.testLayerIdx + 1); + + boolean rank3Out = tc.labelsNCW.rank() == 3; + assertEquals(tc.msg, l0_1, l0_2); + if (rank3Out){ + assertEquals(tc.msg, l0_1, l0_3.permute(0, 2, 1)); + assertEquals(tc.msg, l0_1, l0_4.permute(0, 2, 1)); + } + else{ + assertEquals(tc.msg, l0_1, l0_3); + assertEquals(tc.msg, l0_1, l0_4); + } + INDArray out1 = tc.net1.output(inNCW); + INDArray out2 = tc.net2.output(inNCW); + INDArray out3 = tc.net3.output(inNWC); + INDArray out4 = tc.net4.output(inNWC); + + assertEquals(tc.msg, out1, out2); + if (rank3Out){ + assertEquals(tc.msg, out1, out3.permute(0, 2, 1)); //NWC to NCW + assertEquals(tc.msg, out1, out4.permute(0, 2, 1)); + } + else{ + assertEquals(tc.msg, out1, out3); //NWC to NCW + assertEquals(tc.msg, out1, out4); + } + + + //Test backprop + Pair p1 = tc.net1.calculateGradients(inNCW, tc.labelsNCW, null, null); + Pair p2 = tc.net2.calculateGradients(inNCW, tc.labelsNCW, null, null); + Pair p3 = tc.net3.calculateGradients(inNWC, tc.labelsNWC, null, null); + Pair p4 = tc.net4.calculateGradients(inNWC, tc.labelsNWC, null, null); + + //Inpput gradients + assertEquals(tc.msg, p1.getSecond(), p2.getSecond()); + + assertEquals(tc.msg, p1.getSecond(), p3.getSecond().permute(0, 2, 1)); //Input gradients for NWC input are also in NWC format + assertEquals(tc.msg, p1.getSecond(), p4.getSecond().permute(0, 2, 1)); + + + List diff12 = differentGrads(p1.getFirst(), p2.getFirst()); + List diff13 = differentGrads(p1.getFirst(), p3.getFirst()); + List diff14 = differentGrads(p1.getFirst(), p4.getFirst()); + assertEquals(tc.msg + " " + diff12, 0, diff12.size()); + assertEquals(tc.msg + " " + diff13, 0, diff13.size()); + assertEquals(tc.msg + " " + diff14, 0, diff14.size()); + + assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p2.getFirst().gradientForVariable()); + assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p3.getFirst().gradientForVariable()); + assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p4.getFirst().gradientForVariable()); + + tc.net1.fit(inNCW, tc.labelsNCW); + tc.net2.fit(inNCW, tc.labelsNCW); + tc.net3.fit(inNWC, tc.labelsNWC); + tc.net4.fit(inNWC, tc.labelsNWC); + + assertEquals(tc.msg, tc.net1.params(), tc.net2.params()); + assertEquals(tc.msg, tc.net1.params(), tc.net3.params()); + assertEquals(tc.msg, tc.net1.params(), tc.net4.params()); + + //Test serialization + MultiLayerNetwork net1a = TestUtils.testModelSerialization(tc.net1); + MultiLayerNetwork net2a = TestUtils.testModelSerialization(tc.net2); + MultiLayerNetwork net3a = TestUtils.testModelSerialization(tc.net3); + MultiLayerNetwork net4a = TestUtils.testModelSerialization(tc.net4); + + out1 = tc.net1.output(inNCW); + assertEquals(tc.msg, out1, net1a.output(inNCW)); + assertEquals(tc.msg, out1, net2a.output(inNCW)); + + if (rank3Out) { + assertEquals(tc.msg, out1, net3a.output(inNWC).permute(0, 2, 1)); //NWC to NCW + assertEquals(tc.msg, out1, net4a.output(inNWC).permute(0, 2, 1)); + } + else{ + assertEquals(tc.msg, out1, net3a.output(inNWC)); //NWC to NCW + assertEquals(tc.msg, out1, net4a.output(inNWC)); + } + } + + } + private static List differentGrads(Gradient g1, Gradient g2){ + List differs = new ArrayList<>(); + Map m1 = g1.gradientForVariable(); + Map m2 = g2.gradientForVariable(); + for(String s : m1.keySet()){ + INDArray a1 = m1.get(s); + INDArray a2 = m2.get(s); + if(!a1.equals(a2)){ + differs.add(s); + } + } + return differs; + } +} diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestLastTimeStepLayer.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestLastTimeStepLayer.java index 7dd965ffb..9f60d674d 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestLastTimeStepLayer.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestLastTimeStepLayer.java @@ -21,6 +21,7 @@ import org.deeplearning4j.TestUtils; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.LSTM; @@ -29,6 +30,8 @@ import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep; import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; import org.deeplearning4j.nn.graph.ComputationGraph; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; @@ -42,14 +45,25 @@ import static org.nd4j.linalg.activations.Activation.IDENTITY; import static org.nd4j.linalg.activations.Activation.TANH; import static org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction.MSE; + +@RunWith(Parameterized.class) public class TestLastTimeStepLayer extends BaseDL4JTest { + private RNNFormat rnnDataFormat; + + public TestLastTimeStepLayer(RNNFormat rnnDataFormat){ + this.rnnDataFormat = rnnDataFormat; + } + @Parameterized.Parameters(name="{0}") + public static Object[] params(){ + return RNNFormat.values(); + } @Test public void testLastTimeStepVertex() { ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in") .addLayer("lastTS", new LastTimeStep(new SimpleRnn.Builder() - .nIn(5).nOut(6).build()), "in") + .nIn(5).nOut(6).dataFormat(rnnDataFormat).build()), "in") .setOutputs("lastTS") .build(); @@ -59,9 +73,22 @@ public class TestLastTimeStepLayer extends BaseDL4JTest { //First: test without input mask array Nd4j.getRandom().setSeed(12345); Layer l = graph.getLayer("lastTS"); - INDArray in = Nd4j.rand(new int[]{3, 5, 6}); + INDArray in; + if (rnnDataFormat == RNNFormat.NCW){ + in = Nd4j.rand(3, 5, 6); + } + else{ + in = Nd4j.rand(3, 6, 5); + } INDArray outUnderlying = ((LastTimeStepLayer)l).getUnderlying().activate(in, false, LayerWorkspaceMgr.noWorkspaces()); - INDArray expOut = outUnderlying.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(5)); + INDArray expOut; + if (rnnDataFormat == RNNFormat.NCW){ + expOut = outUnderlying.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(5)); + } + else{ + expOut = outUnderlying.get(NDArrayIndex.all(), NDArrayIndex.point(5), NDArrayIndex.all()); + } + //Forward pass: @@ -76,9 +103,17 @@ public class TestLastTimeStepLayer extends BaseDL4JTest { graph.setLayerMaskArrays(new INDArray[]{inMask}, null); expOut = Nd4j.zeros(3, 6); - expOut.putRow(0, outUnderlying.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(2))); - expOut.putRow(1, outUnderlying.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.point(3))); - expOut.putRow(2, outUnderlying.get(NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.point(4))); + if (rnnDataFormat == RNNFormat.NCW){ + expOut.putRow(0, outUnderlying.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(2))); + expOut.putRow(1, outUnderlying.get(NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.point(3))); + expOut.putRow(2, outUnderlying.get(NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.point(4))); + } + else{ + expOut.putRow(0, outUnderlying.get(NDArrayIndex.point(0), NDArrayIndex.point(2), NDArrayIndex.all())); + expOut.putRow(1, outUnderlying.get(NDArrayIndex.point(1), NDArrayIndex.point(3), NDArrayIndex.all())); + expOut.putRow(2, outUnderlying.get(NDArrayIndex.point(2), NDArrayIndex.point(4), NDArrayIndex.all())); + } + outFwd = l.activate(in, false, LayerWorkspaceMgr.noWorkspaces()); assertEquals(expOut, outFwd); @@ -97,9 +132,9 @@ public class TestLastTimeStepLayer extends BaseDL4JTest { .seed(1234) .graphBuilder() .addInputs("in") - .setInputTypes(InputType.recurrent(1)) + .setInputTypes(InputType.recurrent(1, rnnDataFormat)) .addLayer("RNN", new LastTimeStep(new LSTM.Builder() - .nOut(10) + .nOut(10).dataFormat(rnnDataFormat) .build()), "in") .addLayer("dense", new DenseLayer.Builder() .nOut(10) @@ -120,7 +155,9 @@ public class TestLastTimeStepLayer extends BaseDL4JTest { INDArray fm2 = Nd4j.zeros(1,24); INDArray fm3 = Nd4j.zeros(1,24); fm3.get(NDArrayIndex.point(0), NDArrayIndex.interval(0,5)).assign(1); - + if (rnnDataFormat == RNNFormat.NWC){ + f = f.permute(0, 2, 1); + } INDArray[] out1 = cg.output(false, new INDArray[]{f}, new INDArray[]{fm1}); try { cg.output(false, new INDArray[]{f}, new INDArray[]{fm2}); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java index 11e45c51d..83b312308 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java @@ -20,6 +20,7 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.dropout.TestDropout; import org.deeplearning4j.nn.conf.layers.GravesLSTM; import org.deeplearning4j.nn.conf.layers.LSTM; @@ -31,6 +32,8 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -41,13 +44,24 @@ import org.nd4j.linalg.primitives.Pair; import java.util.Arrays; import java.util.List; +import java.util.Random; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertTrue; +@RunWith(Parameterized.class) public class TestRnnLayers extends BaseDL4JTest { + private RNNFormat rnnDataFormat; + + public TestRnnLayers(RNNFormat rnnDataFormat){ + this.rnnDataFormat = rnnDataFormat; + } + @Parameterized.Parameters + public static Object[] params(){ + return RNNFormat.values(); + } @Test public void testTimeStepIs3Dimensional() { @@ -58,8 +72,8 @@ public class TestRnnLayers extends BaseDL4JTest { .updater(new NoOp()) .weightInit(WeightInit.XAVIER) .list() - .layer(new SimpleRnn.Builder().nIn(nIn).nOut(3).build()) - .layer(new LSTM.Builder().nIn(3).nOut(5).build()) + .layer(new SimpleRnn.Builder().nIn(nIn).nOut(3).dataFormat(rnnDataFormat).build()) + .layer(new LSTM.Builder().nIn(3).nOut(5).dataFormat(rnnDataFormat).build()) .layer(new RnnOutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX).build()) .build(); @@ -70,9 +84,9 @@ public class TestRnnLayers extends BaseDL4JTest { org.deeplearning4j.nn.layers.recurrent.SimpleRnn simpleRnn = (org.deeplearning4j.nn.layers.recurrent.SimpleRnn) net.getLayer(0); - INDArray rnnInput3d = Nd4j.create(10, 12, 1); + INDArray rnnInput3d = (rnnDataFormat==RNNFormat.NCW)?Nd4j.create(10,12, 1):Nd4j.create(10, 1, 12); INDArray simpleOut = simpleRnn.rnnTimeStep(rnnInput3d, LayerWorkspaceMgr.noWorkspaces()); - assertTrue(Arrays.equals(simpleOut.shape(), new long[] {10, 3, 1})); + assertTrue(Arrays.equals(simpleOut.shape(), (rnnDataFormat==RNNFormat.NCW)?new long[] {10, 3, 1}:new long[]{10, 1, 3})); INDArray rnnInput2d = Nd4j.create(10, 12); try { @@ -84,9 +98,9 @@ public class TestRnnLayers extends BaseDL4JTest { org.deeplearning4j.nn.layers.recurrent.LSTM lstm = (org.deeplearning4j.nn.layers.recurrent.LSTM) net.getLayer(1); - INDArray lstmInput3d = Nd4j.create(10, 3, 1); + INDArray lstmInput3d = (rnnDataFormat==RNNFormat.NCW)?Nd4j.create(10, 3, 1):Nd4j.create(10, 1, 3); INDArray lstmOut = lstm.rnnTimeStep(lstmInput3d, LayerWorkspaceMgr.noWorkspaces()); - assertTrue(Arrays.equals(lstmOut.shape(), new long[] {10, 5, 1})); + assertTrue(Arrays.equals(lstmOut.shape(), (rnnDataFormat==RNNFormat.NCW)?new long[] {10, 5, 1}:new long[]{10, 1, 5})); INDArray lstmInput2d = Nd4j.create(10, 3); try { @@ -112,19 +126,19 @@ public class TestRnnLayers extends BaseDL4JTest { TestDropout.CustomDropout cd = new TestDropout.CustomDropout(); switch (s){ case "graves": - layer = new GravesLSTM.Builder().activation(Activation.TANH).nIn(10).nOut(10).build(); - layerD = new GravesLSTM.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).build(); - layerD2 = new GravesLSTM.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).build(); + layer = new GravesLSTM.Builder().activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build(); + layerD = new GravesLSTM.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build(); + layerD2 = new GravesLSTM.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build(); break; case "lstm": - layer = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().activation(Activation.TANH).nIn(10).nOut(10).build(); - layerD = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).build(); - layerD2 = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).build(); + layer = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build(); + layerD = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build(); + layerD2 = new org.deeplearning4j.nn.conf.layers.LSTM.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build(); break; case "simple": - layer = new SimpleRnn.Builder().activation(Activation.TANH).nIn(10).nOut(10).build(); - layerD = new SimpleRnn.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).build(); - layerD2 = new SimpleRnn.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).build(); + layer = new SimpleRnn.Builder().activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build(); + layerD = new SimpleRnn.Builder().dropOut(0.5).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build(); + layerD2 = new SimpleRnn.Builder().dropOut(cd).activation(Activation.TANH).nIn(10).nOut(10).dataFormat(rnnDataFormat).build(); break; default: throw new RuntimeException(s); @@ -134,21 +148,21 @@ public class TestRnnLayers extends BaseDL4JTest { .seed(12345) .list() .layer(layer) - .layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).build()) + .layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).dataFormat(rnnDataFormat).build()) .build(); MultiLayerConfiguration confD = new NeuralNetConfiguration.Builder() .seed(12345) .list() .layer(layerD) - .layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).build()) + .layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).dataFormat(rnnDataFormat).build()) .build(); MultiLayerConfiguration confD2 = new NeuralNetConfiguration.Builder() .seed(12345) .list() .layer(layerD2) - .layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).build()) + .layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).dataFormat(rnnDataFormat).build()) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -178,7 +192,6 @@ public class TestRnnLayers extends BaseDL4JTest { assertNotEquals(s, out2, out2D); INDArray l = TestUtils.randomOneHotTimeSeries(3, 10, 10, 12345); - net.fit(f.dup(), l); netD.fit(f.dup(), l); assertNotEquals(s, net.params(), netD.params()); @@ -205,14 +218,14 @@ public class TestRnnLayers extends BaseDL4JTest { NeuralNetConfiguration.ListBuilder lb = new NeuralNetConfiguration.Builder() .list() - .layer(new SimpleRnn.Builder().nIn(5).nOut(5).build()); + .layer(new SimpleRnn.Builder().nIn(5).nOut(5).dataFormat(rnnDataFormat).build()); switch (i){ case 0: - lb.layer(new RnnOutputLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).nIn(5).nOut(5).build()); + lb.layer(new RnnOutputLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).nIn(5).nOut(5).dataFormat(rnnDataFormat).build()); break; case 1: - lb.layer(new RnnLossLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()); + lb.layer(new RnnLossLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).dataFormat(rnnDataFormat).build()); break; default: throw new RuntimeException(); @@ -223,14 +236,14 @@ public class TestRnnLayers extends BaseDL4JTest { net.init(); INDArray in = Nd4j.rand(DataType.FLOAT, 3, 5, 5); - INDArray l = TestUtils.randomOneHotTimeSeries(3, 5, 10); - + INDArray l = TestUtils.randomOneHotTimeSeries(rnnDataFormat, 3, 5, 10, new Random(12345)); try{ net.fit(in,l); } catch (Throwable t){ String msg = t.getMessage(); if(msg == null) t.printStackTrace(); + System.out.println(i); assertTrue(msg, msg != null && msg.contains("sequence length") && msg.contains("input") && msg.contains("label")); } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestSimpleRnn.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestSimpleRnn.java index bf8b964b1..639d3fafd 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestSimpleRnn.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestSimpleRnn.java @@ -20,10 +20,13 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -36,8 +39,18 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.all; import static org.nd4j.linalg.indexing.NDArrayIndex.interval; import static org.nd4j.linalg.indexing.NDArrayIndex.point; +@RunWith(Parameterized.class) public class TestSimpleRnn extends BaseDL4JTest { + private RNNFormat rnnDataFormat; + + public TestSimpleRnn(RNNFormat rnnDataFormat){ + this.rnnDataFormat = rnnDataFormat; + } + @Parameterized.Parameters + public static Object[] params(){ + return RNNFormat.values(); + } @Test public void testSimpleRnn(){ Nd4j.getRandom().setSeed(12345); @@ -46,15 +59,21 @@ public class TestSimpleRnn extends BaseDL4JTest { int nIn = 5; int layerSize = 6; int tsLength = 7; - INDArray in = Nd4j.rand(DataType.FLOAT, new int[]{m, nIn, tsLength}); -// in.get(all(), all(), interval(1,tsLength)).assign(0); + INDArray in; + if (rnnDataFormat == RNNFormat.NCW){ + in = Nd4j.rand(DataType.FLOAT, m, nIn, tsLength); + } + else{ + in = Nd4j.rand(DataType.FLOAT, m, tsLength, nIn); + } + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .updater(new NoOp()) .weightInit(WeightInit.XAVIER) .activation(Activation.TANH) .list() - .layer(new SimpleRnn.Builder().nIn(nIn).nOut(layerSize).build()) + .layer(new SimpleRnn.Builder().nIn(nIn).nOut(layerSize).dataFormat(rnnDataFormat).build()) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -68,7 +87,13 @@ public class TestSimpleRnn extends BaseDL4JTest { INDArray outLast = null; for( int i=0; i backpropGradient(INDArray input, INDArray weights, INDArray bias, INDArray delta, int[] kernel, int[] strides, int[] pad, INDArray biasGradView, INDArray weightGradView, IActivation afn, AlgoMode mode, BwdFilterAlgo bwdFilterAlgo, BwdDataAlgo bwdDataAlgo, - ConvolutionMode convolutionMode, int[] dilation, LayerWorkspaceMgr workspaceMgr) { + ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) { + + //AB 2020/04/21 - cuDNN does have NHWC support (with limitations) however I have been unable to get it working + // correctly on NHWC data, even after updating all descriptors, tensor format, etc. + //Therefore: all computation here is done in NCHW format only + //As of a future (next?) release we'll likely switch to C++ for cuDNN support + boolean origNHWC = false; + if(format == CNN2DFormat.NHWC){ + input = input.permute(0,3,1,2); //NHWC to NCHW + delta = delta.permute(0,3,1,2); + origNHWC = true; + } + + int TENSOR_FORMAT = CUDNN_TENSOR_NCHW; + int code; val miniBatch = input.size(0); @@ -147,7 +163,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti val kH = weights.size(2); val kW = weights.size(3); - CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, null); + CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, null, CNN2DFormat.NCHW); //Note hardcoded NCHW due to above input = args.getInput(); val inH = input.size(2); val inW = input.size(3); @@ -176,7 +192,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti (int) deltaStride[0], (int) deltaStride[1], (int) deltaStride[2], (int) deltaStride[3]); checkCudnn(false, "cudnnSetTensor4dDescriptorEx", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); code = cudnnSetConvolution2dDescriptor(cudnnContext.convDesc, pad[0], pad[1], strides[0], strides[1], dilation[0], - dilation[1], CUDNN_CROSS_CORRELATION, dataType); + dilation[1], CUDNN_CROSS_CORRELATION, dataType); checkCudnn(false, "cudnnSetConvolution2dDescriptor", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); code = cudnnSetFilter4dDescriptor(cudnnContext.filterDesc, dataType, TENSOR_FORMAT, (int) outDepth, (int) inDepth, (int) kH, (int) kW); checkCudnn(false, "cudnnSetFilter4dDescriptor", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); @@ -238,16 +254,16 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti } } else { code = cudnnGetConvolutionBackwardFilterAlgorithm(cudnnContext, cudnnContext.srcTensorDesc, - cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.filterDesc, - mode == AlgoMode.NO_WORKSPACE ? CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE - : CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST, - 0, algo1); + cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.filterDesc, + mode == AlgoMode.NO_WORKSPACE ? CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE + : CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST, + 0, algo1); checkCudnn(false, "cudnnGetConvolutionBackwardFilterAlgorithm", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); code = cudnnGetConvolutionBackwardDataAlgorithm(cudnnContext, cudnnContext.filterDesc, - cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.srcTensorDesc, - mode == AlgoMode.NO_WORKSPACE ? CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE - : CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST, - 0, algo2); + cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.srcTensorDesc, + mode == AlgoMode.NO_WORKSPACE ? CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE + : CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST, + 0, algo2); checkCudnn(false, "cudnnGetConvolutionBackwardDataAlgorithm", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); } @@ -263,7 +279,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti Allocator allocator = AtomicAllocator.getInstance(); CudaContext context = allocator.getFlowController().prepareActionAllWrite(input, weights, weightGradView, - biasGradView, delta, epsNext); + biasGradView, delta, epsNext); Pointer srcData = allocator.getPointer(input, context); Pointer filterData = allocator.getPointer(weights, context); Pointer filterGradData = allocator.getPointer(weightGradView, context); @@ -279,14 +295,14 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti checkCudnn(false, "cudnnSetTensor4dDescriptorEx", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); code = cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnnContext, cudnnContext.srcTensorDesc, - cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.filterDesc, algo1[0], - sizeInBytes); + cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.filterDesc, algo1[0], + sizeInBytes); checkCudnn(false, "cudnnGetConvolutionBackwardFilterWorkspaceSize", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); long sizeInBytes1 = sizeInBytes.get(0); code = cudnnGetConvolutionBackwardDataWorkspaceSize(cudnnContext, cudnnContext.filterDesc, - cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.dstTensorDesc, algo2[0], - sizeInBytes); + cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.dstTensorDesc, algo2[0], + sizeInBytes); checkCudnn(false, "cudnnGetConvolutionBackwardDataWorkspaceSize", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); DataCache workSpace = workspaceMgr.getHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY); @@ -313,21 +329,21 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti checkCudnn(false, "cudnnSetTensor4dDescriptor", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); code = cudnnConvolutionBackwardBias(cudnnContext, alpha, cudnnContext.deltaTensorDesc, deltaData, beta, - cudnnContext.biasTensorDesc, biasGradData); + cudnnContext.biasTensorDesc, biasGradData); checkCudnn(false, "cudnnConvolutionBackwardBias", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); code = cudnnConvolutionBackwardFilter(cudnnContext, alpha, cudnnContext.srcTensorDesc, srcData, - cudnnContext.deltaTensorDesc, deltaData, cudnnContext.convDesc, algo1[0], workSpace, - workSpace.capacity(), beta, cudnnContext.filterDesc, filterGradData); + cudnnContext.deltaTensorDesc, deltaData, cudnnContext.convDesc, algo1[0], workSpace, + workSpace.capacity(), beta, cudnnContext.filterDesc, filterGradData); checkCudnn(false, "cudnnConvolutionBackwardFilter", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); code = cudnnConvolutionBackwardData(cudnnContext, alpha, cudnnContext.filterDesc, filterData, - cudnnContext.deltaTensorDesc, deltaData, cudnnContext.convDesc, algo2[0], workSpace, - workSpace.capacity(), beta, cudnnContext.dstTensorDesc, dstData); + cudnnContext.deltaTensorDesc, deltaData, cudnnContext.convDesc, algo2[0], workSpace, + workSpace.capacity(), beta, cudnnContext.dstTensorDesc, dstData); checkCudnn(false, "cudnnConvolutionBackwardData", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); allocator.getFlowController().registerActionAllWrite(context, input, weights, weightGradView, biasGradView, - delta, epsNext); + delta, epsNext); Gradient retGradient = new DefaultGradient(); retGradient.setGradientFor(ConvolutionParamInitializer.BIAS_KEY, biasGradView); @@ -344,12 +360,30 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti interval(0, epsNext.size(3) - (args.isManualPadRight() ? 1 : 0))); } + if(origNHWC){ + epsNext = epsNext.permute(0,2,3,1); //NCHW to NHWC + } + return new Pair<>(retGradient, epsNext); } @Override public INDArray preOutput(INDArray input, INDArray weights, INDArray bias, int[] kernel, int[] strides, int[] pad, - AlgoMode mode, FwdAlgo fwdAlgo, ConvolutionMode convolutionMode, int[] dilation, LayerWorkspaceMgr workspaceMgr) { + AlgoMode mode, FwdAlgo fwdAlgo, ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format, + LayerWorkspaceMgr workspaceMgr) { + + //AB 2020/04/21 - cuDNN does have NHWC support (with limitations) however I have been unable to get it working + // correctly on NHWC data, even after updating all descriptors, tensor format, etc. + //Therefore: all computation here is done in NCHW format only + //As of a future (next?) release we'll likely switch to C++ for cuDNN support + boolean origNHWC = false; + if(format == CNN2DFormat.NHWC){ + input = input.permute(0,3,1,2); //NHWC to NCHW + origNHWC = true; + } + + int TENSOR_FORMAT = CUDNN_TENSOR_NCHW; + int code; val miniBatch = input.size(0); @@ -358,7 +392,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti val kH = weights.size(2); val kW = weights.size(3); - CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, null); + CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, null, CNN2DFormat.NCHW); //Note hardcoded NCHW due to above input = args.getInput(); val inH = input.size(2); val inW = input.size(3); @@ -378,7 +412,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti checkCudnn(true, "cudnnSetFilter4dDescriptor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation); code = cudnnSetConvolution2dDescriptor(cudnnContext.convDesc, pad[0], pad[1], strides[0], strides[1], dilation[0], - dilation[1], CUDNN_CROSS_CORRELATION, dataType); + dilation[1], CUDNN_CROSS_CORRELATION, dataType); checkCudnn(true, "cudnnSetConvolution2dDescriptor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation); @@ -460,8 +494,8 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti checkCudnn(true, "cudnnSetStream", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation); code = cudnnGetConvolutionForwardWorkspaceSize(cudnnContext, cudnnContext.srcTensorDesc, - cudnnContext.filterDesc, cudnnContext.convDesc, cudnnContext.dstTensorDesc, algo[0], - sizeInBytes); + cudnnContext.filterDesc, cudnnContext.convDesc, cudnnContext.dstTensorDesc, algo[0], + sizeInBytes); checkCudnn(true, "cudnnGetConvolutionForwardWorkspaceSize", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation); DataCache workSpace = workspaceMgr.getHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY); @@ -482,8 +516,8 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti workspaceMgr.setHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY, workSpace); } code = cudnnConvolutionForward(cudnnContext, alpha, cudnnContext.srcTensorDesc, srcData, - cudnnContext.filterDesc, filterData, cudnnContext.convDesc, algo[0], workSpace, - workSpace.capacity(), beta, cudnnContext.dstTensorDesc, dstData); + cudnnContext.filterDesc, filterData, cudnnContext.convDesc, algo[0], workSpace, + workSpace.capacity(), beta, cudnnContext.dstTensorDesc, dstData); checkCudnn(true, "cudnnConvolutionForward", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation); @@ -491,7 +525,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti checkCudnn(true, "cudnnSetTensor4dDescriptor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation); code = cudnnAddTensor(cudnnContext, alpha, cudnnContext.biasTensorDesc, biasData, alpha, - cudnnContext.dstTensorDesc, dstData); + cudnnContext.dstTensorDesc, dstData); checkCudnn(true, "cudnnAddTensor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation); allocator.registerAction(context, z, input, weights, bias); @@ -499,6 +533,10 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti if (CudaEnvironment.getInstance().getConfiguration().isDebug()) context.syncOldStream(); + if(origNHWC){ + z = z.permute(0,2,3,1); //NCHW to NHWC + } + return z; } @@ -552,29 +590,29 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti break; case "sigmoid": checkCudnn(cudnnSetActivationDescriptor(cudnnContext.activationDesc, CUDNN_ACTIVATION_SIGMOID, - CUDNN_PROPAGATE_NAN, 0)); + CUDNN_PROPAGATE_NAN, 0)); checkCudnn(cudnnActivationForward(cudnnContext, cudnnContext.activationDesc, alpha, - cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData)); + cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData)); break; case "relu": checkCudnn(cudnnSetActivationDescriptor(cudnnContext.activationDesc, CUDNN_ACTIVATION_RELU, - CUDNN_PROPAGATE_NAN, 0)); + CUDNN_PROPAGATE_NAN, 0)); checkCudnn(cudnnActivationForward(cudnnContext, cudnnContext.activationDesc, alpha, - cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData)); + cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData)); break; case "tanh": checkCudnn(cudnnSetActivationDescriptor(cudnnContext.activationDesc, CUDNN_ACTIVATION_TANH, - CUDNN_PROPAGATE_NAN, 0)); + CUDNN_PROPAGATE_NAN, 0)); checkCudnn(cudnnActivationForward(cudnnContext, cudnnContext.activationDesc, alpha, - cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData)); + cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData)); break; case "softmax": checkCudnn(cudnnSoftmaxForward(cudnnContext, CUDNN_SOFTMAX_ACCURATE, CUDNN_SOFTMAX_MODE_CHANNEL, alpha, - cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData)); + cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData)); break; case "logsoftmax": checkCudnn(cudnnSoftmaxForward(cudnnContext, CUDNN_SOFTMAX_LOG, CUDNN_SOFTMAX_MODE_CHANNEL, alpha, - cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData)); + cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData)); break; default: activation = null; @@ -593,7 +631,7 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti * @return */ public static CudnnForwardArgs getCudnnForwardArgs(INDArray input, int[] kernel, int[] strides, int[] padding, int[] dilation, - ConvolutionMode convolutionMode, PoolingType poolingType){ + ConvolutionMode convolutionMode, PoolingType poolingType, CNN2DFormat format){ INDArray origInput = input; //Check if we need to dup the input: views, non-contiguous, etc. CuDNN also seems to have has issues if strides @@ -602,16 +640,19 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti input = input.dup('c'); } + boolean nchw = format == CNN2DFormat.NCHW; + int hIdx = nchw ? 2 : 1; + int wIdx = nchw ? 3 : 2; - val inH = input.size(2); - val inW = input.size(3); + val inH = input.size(hIdx); + val inW = input.size(wIdx); boolean manualPadBottom = false; boolean manualPadRight = false; int[] outSize; if (convolutionMode == ConvolutionMode.Same) { - outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation); //Also performs validation + outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation, format); //Also performs validation padding = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {(int) inH, (int) inW}, kernel, strides, dilation); int[] padBottomRight = ConvolutionUtils.getSameModeBottomRightPadding(outSize, new int[] {(int) inH, (int) inW}, kernel, strides, dilation); if(!Arrays.equals(padding, padBottomRight)){ @@ -626,9 +667,17 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti manualPadRight = (padding[1] != padBottomRight[1]); //NCHW format - val newShape = new long[]{input.size(0), input.size(1), - input.size(2) + (manualPadBottom ? 1 : 0), - input.size(3) + (manualPadRight ? 1 : 0)}; + long[] newShape; + if(nchw){ + newShape = new long[]{input.size(0), input.size(1), + input.size(2) + (manualPadBottom ? 1 : 0), + input.size(3) + (manualPadRight ? 1 : 0)}; + } else { + newShape = new long[]{input.size(0), + input.size(1) + (manualPadBottom ? 1 : 0), + input.size(2) + (manualPadRight ? 1 : 0), + input.size(3)}; + } INDArray newInput; if(poolingType == null || poolingType != PoolingType.MAX){ newInput = Nd4j.create(input.dataType(), newShape); @@ -638,15 +687,22 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti // if the 'real' (non-padding) values are all < 0, we take the real value, not the padding value newInput = Nd4j.valueArrayOf(newShape, Double.NEGATIVE_INFINITY, input.dataType()); } - newInput.put(new INDArrayIndex[]{all(), all(), interval(0,input.size(2)), - interval(0, input.size(3))}, input); + + if(nchw){ + newInput.put(new INDArrayIndex[]{all(), all(), interval(0,input.size(2)), + interval(0, input.size(3))}, input); + } else { + newInput.put(new INDArrayIndex[]{all(), interval(0,input.size(1)), + interval(0, input.size(2)), all()}, input); + } + input = newInput; //Now: we've manually applied the "extra" bottom/right padding only - if required. Consequently, we // now have the same amount of padding required for top/bottom, and left/right - which we'll let // CuDNN handle } } else { - outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, padding, convolutionMode, dilation); //Also performs validation + outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, padding, convolutionMode, dilation, format); //Also performs validation } return new CudnnForwardArgs(manualPadBottom, manualPadRight, input, origInput, padding, outSize); @@ -670,4 +726,4 @@ public class CudnnConvolutionHelper extends BaseCudnnHelper implements Convoluti return Collections.emptyMap(); } -} +} \ No newline at end of file diff --git a/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/convolution/subsampling/CudnnSubsamplingHelper.java b/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/convolution/subsampling/CudnnSubsamplingHelper.java index 7fb9bf51e..84ed6ef63 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/convolution/subsampling/CudnnSubsamplingHelper.java +++ b/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/convolution/subsampling/CudnnSubsamplingHelper.java @@ -19,6 +19,7 @@ package org.deeplearning4j.nn.layers.convolution.subsampling; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.bytedeco.javacpp.Pointer; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.layers.PoolingType; import org.deeplearning4j.nn.gradient.DefaultGradient; @@ -114,23 +115,29 @@ public class CudnnSubsamplingHelper extends BaseCudnnHelper implements Subsampli @Override public Pair backpropGradient(INDArray input, INDArray epsilon, int[] kernel, int[] strides, - int[] pad, PoolingType poolingType, ConvolutionMode convolutionMode, int[] dilation, LayerWorkspaceMgr workspaceMgr) { + int[] pad, PoolingType poolingType, ConvolutionMode convolutionMode, + int[] dilation, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) { if(dilation[0] != 1 || dilation[1] != 1){ //CuDNN doesn't support dilated subsampling return null; } + boolean nchw = format == CNN2DFormat.NCHW; + int chIdx = nchw ? 1 : 3; + int hIdx = nchw ? 2 : 1; + int wIdx = nchw ? 3 : 2; + //We require the output as one of the arguments for backprop here //TODO we could add cache mode support here somehow... - INDArray reduced = activate(input, true, kernel, strides, pad, poolingType, convolutionMode, dilation, workspaceMgr); + INDArray reduced = activate(input, true, kernel, strides, pad, poolingType, convolutionMode, dilation, format, workspaceMgr); val miniBatch = input.size(0); - val depth = input.size(1); + val depth = input.size(chIdx); - CudnnConvolutionHelper.CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, poolingType); + CudnnConvolutionHelper.CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, poolingType, format); input = args.getInput(); - val inH = input.size(2); - val inW = input.size(3); + val inH = input.size(hIdx); + val inW = input.size(wIdx); val srcStride = input.stride(); int[] outSize = args.getOutSize(); int outH = outSize[0]; @@ -160,23 +167,26 @@ public class CudnnSubsamplingHelper extends BaseCudnnHelper implements Subsampli epsilon = epsilon.dup('c'); } + input = input.dup(); + val deltaStride = epsilon.stride(); if (Nd4j.getExecutioner() instanceof GridExecutioner) ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW, - (int) srcStride[0], (int) srcStride[1], (int) srcStride[2], (int) srcStride[3])); + (int) srcStride[0], (int) srcStride[chIdx], (int) srcStride[hIdx], (int) srcStride[wIdx])); checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.deltaTensorDesc, dataType, (int) miniBatch, (int) depth, (int) outH, (int) outW, - (int) deltaStride[0], (int) deltaStride[1], (int) deltaStride[2], (int) deltaStride[3])); + (int) deltaStride[0], (int) deltaStride[chIdx], (int) deltaStride[hIdx], (int) deltaStride[wIdx])); checkCudnn(cudnnSetPooling2dDescriptor(cudnnContext.poolingDesc, poolingMode, CUDNN_PROPAGATE_NAN, kernel[0], kernel[1], pad[0], pad[1], strides[0], strides[1])); - INDArray outEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), new long[] {(int) miniBatch, (int) depth, (int) inH, (int) inW}, 'c'); + long[] outEpsShape = nchw ? new long[] {miniBatch, depth, inH, inW} : new long[] {miniBatch, inH, inW, depth}; + INDArray outEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), outEpsShape, 'c'); val dstStride = outEpsilon.stride(); checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW, - (int) dstStride[0], (int) dstStride[1], (int) dstStride[2], (int) dstStride[3])); + (int) dstStride[0], (int) dstStride[chIdx], (int) dstStride[hIdx], (int) dstStride[wIdx])); Allocator allocator = AtomicAllocator.getInstance(); CudaContext context = allocator.getFlowController().prepareAction(input, epsilon, reduced, outEpsilon); @@ -198,9 +208,16 @@ public class CudnnSubsamplingHelper extends BaseCudnnHelper implements Subsampli //Note that: if we had to manually pad for SAME mode, we have to 'undo' this manual padding for the epsilon // we return. The returned epsilon (i.e., dL/dIn array) has to be the same shape as the *original* input. if(args.isManualPadBottom() || args.isManualPadRight()) { - outEpsilon = outEpsilon.get(all(), all(), - interval(0, outEpsilon.size(2) - (args.isManualPadBottom() ? 1 : 0)), - interval(0, outEpsilon.size(3) - (args.isManualPadRight() ? 1 : 0))); + if(nchw){ + outEpsilon = outEpsilon.get(all(), all(), + interval(0, outEpsilon.size(2) - (args.isManualPadBottom() ? 1 : 0)), + interval(0, outEpsilon.size(3) - (args.isManualPadRight() ? 1 : 0))); + } else { + outEpsilon = outEpsilon.get(all(), + interval(0, outEpsilon.size(1) - (args.isManualPadBottom() ? 1 : 0)), + interval(0, outEpsilon.size(2) - (args.isManualPadRight() ? 1 : 0)), + all()); + } } return new Pair<>(retGradient, outEpsilon); @@ -209,19 +226,24 @@ public class CudnnSubsamplingHelper extends BaseCudnnHelper implements Subsampli @Override public INDArray activate(INDArray input, boolean training, int[] kernel, int[] strides, int[] pad, - PoolingType poolingType, ConvolutionMode convolutionMode, int[] dilation, LayerWorkspaceMgr workspaceMgr) { + PoolingType poolingType, ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) { if(dilation[0] != 1 || dilation[1] != 1){ //CuDNN doesn't support dilated subsampling return null; } - val miniBatch = input.size(0); - val inDepth = input.size(1); + boolean nchw = format == CNN2DFormat.NCHW; + int chIdx = nchw ? 1 : 3; + int hIdx = nchw ? 2 : 1; + int wIdx = nchw ? 3 : 2; - CudnnConvolutionHelper.CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, poolingType); + val miniBatch = input.size(0); + val inDepth = input.size(nchw ? 1 : 3); + + CudnnConvolutionHelper.CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, poolingType, format); input = args.getInput(); - val inH = input.size(2); - val inW = input.size(3); + val inH = input.size(nchw ? 2 : 1); + val inW = input.size(nchw ? 3 : 2); val srcStride = input.stride(); val outSize = args.getOutSize(); int outH = outSize[0]; @@ -246,13 +268,14 @@ public class CudnnSubsamplingHelper extends BaseCudnnHelper implements Subsampli checkCudnn(cudnnSetPooling2dDescriptor(cudnnContext.poolingDesc, poolingMode, CUDNN_PROPAGATE_NAN, kernel[0], kernel[1], pad[0], pad[1], strides[0], strides[1])); checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) inDepth, (int) inH, (int) inW, - (int) srcStride[0], (int) srcStride[1], (int) srcStride[2], (int) srcStride[3])); + (int) srcStride[0], (int) srcStride[chIdx], (int) srcStride[hIdx], (int) srcStride[wIdx])); - INDArray reduced = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), new long[] {(int) miniBatch, (int) inDepth, outH, outW}, 'c'); + long[] outShape = nchw ? new long[] {miniBatch, inDepth, outH, outW} : new long[] {miniBatch, outH, outW, inDepth}; + INDArray reduced = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), outShape, 'c'); val dstStride = reduced.stride(); checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, (int) miniBatch, (int) inDepth, (int) outH, (int) outW, - (int) dstStride[0], (int) dstStride[1], (int) dstStride[2], (int) dstStride[3])); + (int) dstStride[0], (int) dstStride[chIdx], (int) dstStride[hIdx], (int) dstStride[wIdx])); Allocator allocator = AtomicAllocator.getInstance(); CudaContext context = allocator.getFlowController().prepareAction(input, reduced); diff --git a/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/normalization/CudnnBatchNormalizationHelper.java b/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/normalization/CudnnBatchNormalizationHelper.java index 6d826c5eb..fd8cd2657 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/normalization/CudnnBatchNormalizationHelper.java +++ b/deeplearning4j/deeplearning4j-cuda/src/main/java/org/deeplearning4j/nn/layers/normalization/CudnnBatchNormalizationHelper.java @@ -19,6 +19,7 @@ package org.deeplearning4j.nn.layers.normalization; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.bytedeco.javacpp.Pointer; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.BaseCudnnHelper; @@ -124,12 +125,21 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba @Override public Pair backpropGradient(INDArray input, INDArray epsilon, long[] shape, INDArray gamma, INDArray beta, - INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr layerWorkspaceMgr) { + INDArray dGammaView, INDArray dBetaView, double eps, CNN2DFormat format, LayerWorkspaceMgr layerWorkspaceMgr) { + + boolean nchw = format == CNN2DFormat.NCHW; + this.eps = eps; + + int cudnnTensorFormat = nchw ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + int chIdx = nchw ? 1 : 3; + int hIdx = nchw ? 2 : 1; + int wIdx = nchw ? 3 : 2; + val miniBatch = (int) input.size(0); - val depth = (int) input.size(1); - val inH = (int) input.size(2); - val inW = (int) input.size(3); + val depth = (int) input.size(chIdx); + val inH = (int) input.size(hIdx); + val inW = (int) input.size(wIdx); final boolean isHalf = (input.dataType() == DataType.HALF); INDArray gammaOrig = null; @@ -164,16 +174,17 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW, - (int) srcStride[0], (int) srcStride[1], (int) srcStride[2], (int) srcStride[3])); + (int) srcStride[0], (int) srcStride[chIdx], (int) srcStride[hIdx], (int) srcStride[wIdx])); checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.deltaTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW, - (int) deltaStride[0], (int) deltaStride[1], (int) deltaStride[2], (int) deltaStride[3])); + (int) deltaStride[0], (int) deltaStride[chIdx], (int) deltaStride[hIdx], (int) deltaStride[wIdx])); - INDArray nextEpsilon = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), new long[] {miniBatch, depth, inH, inW}, 'c'); + long[] nextEpsShape = nchw ? new long[] {miniBatch, depth, inH, inW} : new long[] {miniBatch, inH, inW, depth}; + INDArray nextEpsilon = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), nextEpsShape, 'c'); val dstStride = ArrayUtil.toInts(nextEpsilon.stride()); checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, depth, inH, inW, - dstStride[0], dstStride[1], dstStride[2], dstStride[3])); - checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, TENSOR_FORMAT, toCudnnDataType(gamma.data().dataType()), (int)shape[0], + dstStride[0], dstStride[chIdx], dstStride[hIdx], dstStride[wIdx])); + checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, cudnnTensorFormat, toCudnnDataType(gamma.data().dataType()), (int)shape[0], (int)shape[1], shape.length > 2 ? (int)shape[2] : 1, shape.length > 3 ? (int)shape[3] : 1)); Allocator allocator = AtomicAllocator.getInstance(); @@ -215,9 +226,15 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba @Override public INDArray preOutput(INDArray x, boolean training, long[] shape, INDArray gamma, INDArray beta, INDArray mean, - INDArray var, double decay, double eps, LayerWorkspaceMgr workspaceMgr) { + INDArray var, double decay, double eps, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) { + boolean nchw = format == CNN2DFormat.NCHW; + int cudnnTensorFormat = nchw ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + int chIdx = nchw ? 1 : 3; + int hIdx = nchw ? 2 : 1; + int wIdx = nchw ? 3 : 2; + this.eps = eps; - final boolean isHalf = (x.dataType() == DataType.HALF); + final boolean isHalf = (x.dataType() == DataType.FLOAT16); INDArray origGamma = gamma; INDArray origBeta = beta; INDArray origMean = mean; @@ -238,21 +255,22 @@ public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements Ba decay = 0.0; //From cudnn docs: runningMean = newMean*factor + runningMean*(1-factor). -> 0 = "in-place modification of running mean disabled" val miniBatch = (int) x.size(0); - val inDepth = (int) x.size(1); - val inH = (int) x.size(2); - val inW = (int) x.size(3); + val inDepth = (int) x.size(chIdx); + val inH = (int) x.size(hIdx); + val inW = (int) x.size(wIdx); val srcStride = ArrayUtil.toInts(x.stride()); checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, miniBatch, inDepth, inH, inW, - srcStride[0], srcStride[1], srcStride[2], srcStride[3])); + srcStride[0], srcStride[chIdx], srcStride[hIdx], srcStride[wIdx])); - INDArray activations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, x.dataType(), new long[] {miniBatch, inDepth, inH, inW}, 'c'); + long[] actShape = nchw ? new long[] {miniBatch, inDepth, inH, inW} : new long[] {miniBatch, inH, inW, inDepth}; + INDArray activations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, x.dataType(), actShape, 'c'); val dstStride = ArrayUtil.toInts(activations.stride()); checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, inDepth, inH, inW, - dstStride[0], dstStride[1], dstStride[2], dstStride[3])); + dstStride[0], dstStride[chIdx], dstStride[hIdx], dstStride[wIdx])); - checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, TENSOR_FORMAT, toCudnnDataType(mean.data().dataType()), (int)shape[0], + checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, cudnnTensorFormat, toCudnnDataType(mean.data().dataType()), (int)shape[0], (int)shape[1], shape.length > 2 ? (int)shape[2] : 1, shape.length > 3 ? (int)shape[3] : 1)); Allocator allocator = AtomicAllocator.getInstance(); diff --git a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/TestUtils.java b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/TestUtils.java index 8d6933846..d54693f73 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/TestUtils.java +++ b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/TestUtils.java @@ -16,74 +16,131 @@ package org.deeplearning4j; +import org.apache.commons.compress.utils.IOUtils; +import org.deeplearning4j.nn.api.Layer; +import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; +import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.layers.BaseLayer; +import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer; import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.nn.layers.convolution.ConvolutionLayer; +import org.deeplearning4j.nn.layers.convolution.subsampling.SubsamplingLayer; +import org.deeplearning4j.nn.layers.normalization.BatchNormalization; +import org.deeplearning4j.nn.layers.normalization.LocalResponseNormalization; +import org.deeplearning4j.nn.layers.recurrent.LSTM; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.util.ModelSerializer; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.learning.regularization.L1Regularization; +import org.nd4j.linalg.learning.regularization.L2Regularization; +import org.nd4j.linalg.learning.regularization.Regularization; +import org.nd4j.linalg.learning.regularization.WeightDecay; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; +import java.io.*; +import java.lang.reflect.Field; +import java.util.List; import java.util.Random; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; public class TestUtils { public static MultiLayerNetwork testModelSerialization(MultiLayerNetwork net){ + MultiLayerNetwork restored; try { ByteArrayOutputStream baos = new ByteArrayOutputStream(); ModelSerializer.writeModel(net, baos, true); byte[] bytes = baos.toByteArray(); ByteArrayInputStream bais = new ByteArrayInputStream(bytes); - MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(bais, true); + restored = ModelSerializer.restoreMultiLayerNetwork(bais, true); assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations()); assertEquals(net.params(), restored.params()); - - return restored; } catch (IOException e){ //Should never happen throw new RuntimeException(e); } + + //Also check the MultiLayerConfiguration is serializable (required by Spark etc) + MultiLayerConfiguration conf = net.getLayerWiseConfigurations(); + serializeDeserializeJava(conf); + + return restored; } public static ComputationGraph testModelSerialization(ComputationGraph net){ - + ComputationGraph restored; try { ByteArrayOutputStream baos = new ByteArrayOutputStream(); ModelSerializer.writeModel(net, baos, true); byte[] bytes = baos.toByteArray(); ByteArrayInputStream bais = new ByteArrayInputStream(bytes); - ComputationGraph restored = ModelSerializer.restoreComputationGraph(bais, true); + restored = ModelSerializer.restoreComputationGraph(bais, true); assertEquals(net.getConfiguration(), restored.getConfiguration()); assertEquals(net.params(), restored.params()); - - return restored; } catch (IOException e){ //Should never happen throw new RuntimeException(e); } + + //Also check the ComputationGraphConfiguration is serializable (required by Spark etc) + ComputationGraphConfiguration conf = net.getConfiguration(); + serializeDeserializeJava(conf); + + return restored; } - public static INDArray randomOneHot(int examples, int nOut){ + private static T serializeDeserializeJava(T object){ + byte[] bytes; + try(ByteArrayOutputStream baos = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(baos)){ + oos.writeObject(object); + oos.close(); + bytes = baos.toByteArray(); + } catch (IOException e){ + //Should never happen + throw new RuntimeException(e); + } + + T out; + try(ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(bytes))){ + out = (T)ois.readObject(); + } catch (IOException | ClassNotFoundException e){ + throw new RuntimeException(e); + } + + assertEquals(object, out); + return out; + } + + public static INDArray randomOneHot(long examples, long nOut){ return randomOneHot(examples, nOut, new Random(12345)); } - public static INDArray randomOneHot(int examples, int nOut, long rngSeed){ + public static INDArray randomOneHot(DataType dataType, long examples, long nOut){ + return randomOneHot(dataType, examples, nOut, new Random(12345)); + } + + public static INDArray randomOneHot(long examples, long nOut, long rngSeed){ return randomOneHot(examples, nOut, new Random(rngSeed)); } - public static INDArray randomOneHot(int examples, int nOut, Random rng){ - INDArray arr = Nd4j.create(examples, nOut); + public static INDArray randomOneHot(long examples, long nOut, Random rng) { + return randomOneHot(Nd4j.defaultFloatingPointType(), examples,nOut, rng); + } + + public static INDArray randomOneHot(DataType dataType, long examples, long nOut, Random rng){ + INDArray arr = Nd4j.create(dataType, examples, nOut); for( int i=0; i l){ + for(Regularization r : l){ + if(r instanceof L1Regularization){ + return (L1Regularization) r; + } + } + return null; + } + + public static L2Regularization getL2Reg(BaseLayer baseLayer){ + return getL2Reg(baseLayer.getRegularization()); + } + + public static L2Regularization getL2Reg(List l){ + for(Regularization r : l){ + if(r instanceof L2Regularization){ + return (L2Regularization) r; + } + } + return null; + } + + public static WeightDecay getWeightDecayReg(BaseLayer bl){ + return getWeightDecayReg(bl.getRegularization()); + } + + public static WeightDecay getWeightDecayReg(List l){ + for(Regularization r : l){ + if(r instanceof WeightDecay){ + return (WeightDecay) r; + } + } + return null; + } + + public static double getL1(BaseLayer layer) { + List l = layer.getRegularization(); + return getL1(l); + } + + public static double getL1(List l){ + L1Regularization l1Reg = null; + for(Regularization reg : l){ + if(reg instanceof L1Regularization) + l1Reg = (L1Regularization) reg; + } + assertNotNull(l1Reg); + return l1Reg.getL1().valueAt(0,0); + } + + public static double getL2(BaseLayer layer) { + List l = layer.getRegularization(); + return getL2(l); + } + + public static double getL2(List l){ + L2Regularization l2Reg = null; + for(Regularization reg : l){ + if(reg instanceof L2Regularization) + l2Reg = (L2Regularization) reg; + } + assertNotNull(l2Reg); + return l2Reg.getL2().valueAt(0,0); + } + + public static double getL1(AbstractSameDiffLayer layer){ + return getL1(layer.getRegularization()); + } + + public static double getL2(AbstractSameDiffLayer layer){ + return getL2(layer.getRegularization()); + } + + public static double getWeightDecay(BaseLayer layer) { + return getWeightDecayReg(layer.getRegularization()).getCoeff().valueAt(0,0); + } + + public static void removeHelper(Layer layer) throws Exception { + removeHelpers(new Layer[]{layer}); + } + + public static void removeHelpers(Layer[] layers) throws Exception { + for(Layer l : layers){ + + if(l instanceof ConvolutionLayer){ + Field f1 = ConvolutionLayer.class.getDeclaredField("helper"); + f1.setAccessible(true); + f1.set(l, null); + } else if(l instanceof SubsamplingLayer){ + Field f2 = SubsamplingLayer.class.getDeclaredField("helper"); + f2.setAccessible(true); + f2.set(l, null); + } else if(l instanceof BatchNormalization) { + Field f3 = BatchNormalization.class.getDeclaredField("helper"); + f3.setAccessible(true); + f3.set(l, null); + } else if(l instanceof LSTM){ + Field f4 = LSTM.class.getDeclaredField("helper"); + f4.setAccessible(true); + f4.set(l, null); + } else if(l instanceof LocalResponseNormalization){ + Field f5 = LocalResponseNormalization.class.getDeclaredField("helper"); + f5.setAccessible(true); + f5.set(l, null); + } + + + if(l.getHelper() != null){ + throw new IllegalStateException("Did not remove helper for layer: " + l.getClass().getSimpleName()); + } + } + } + + public static void assertHelperPresent(Layer layer){ + + } + + public static void assertHelpersPresent(Layer[] layers) throws Exception { + for(Layer l : layers){ + //Don't use instanceof here - there are sub conv subclasses + if(l.getClass() == ConvolutionLayer.class || l instanceof SubsamplingLayer || l instanceof BatchNormalization || l instanceof LSTM){ + Preconditions.checkNotNull(l.getHelper(), l.conf().getLayer().getLayerName()); + } + } + } + + public static void assertHelpersAbsent(Layer[] layers) throws Exception { + for(Layer l : layers){ + Preconditions.checkState(l.getHelper() == null, l.conf().getLayer().getLayerName()); + } + } } diff --git a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/convolution/ConvDataFormatTests.java b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/convolution/ConvDataFormatTests.java new file mode 100644 index 000000000..8210903ef --- /dev/null +++ b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/convolution/ConvDataFormatTests.java @@ -0,0 +1,1007 @@ +/* ****************************************************************************** + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.convolution; + +import lombok.*; +import org.deeplearning4j.BaseDL4JTest; +import org.deeplearning4j.CuDNNTestUtils; +import org.deeplearning4j.TestUtils; +import org.deeplearning4j.nn.api.MaskState; +import org.deeplearning4j.nn.conf.CNN2DFormat; +import org.deeplearning4j.nn.conf.ConvolutionMode; +import org.deeplearning4j.nn.conf.InputPreProcessor; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.*; +import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D; +import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor; +import org.deeplearning4j.nn.conf.preprocessor.ComposableInputPreProcessor; +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.nn.workspace.ArrayType; +import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.primitives.Pair; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; + +@RunWith(Parameterized.class) +public class ConvDataFormatTests extends BaseDL4JTest { + + private final DataType dataType; + + public ConvDataFormatTests(DataType dataType){ + this.dataType = dataType; + } + + @Parameterized.Parameters(name = "{0}") + public static Object[] params(){ + return new DataType[]{DataType.FLOAT, DataType.DOUBLE}; + } + + @Test + public void testConv2d() { + try { + for (boolean helpers : new boolean[]{false, true}) { + for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")"; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray labels = TestUtils.randomOneHot(2, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getConv2dNet(CNN2DFormat.NCHW, true, cm)) + .net2(getConv2dNet(CNN2DFormat.NCHW, false, cm)) + .net3(getConv2dNet(CNN2DFormat.NHWC, true, cm)) + .net4(getConv2dNet(CNN2DFormat.NHWC, false, cm)) + .inNCHW(inNCHW) + .labelsNCHW(labels) + .labelsNHWC(labels) + .testLayerIdx(1) + .helpers(helpers) + .build(); + + testHelper(tc); + } + } + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + @Test + public void testSubsampling2d() { + try { + for (boolean helpers : new boolean[]{false, true}) { + for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")"; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray labels = TestUtils.randomOneHot(2, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getSubsampling2dNet(CNN2DFormat.NCHW, true, cm)) + .net2(getSubsampling2dNet(CNN2DFormat.NCHW, false, cm)) + .net3(getSubsampling2dNet(CNN2DFormat.NHWC, true, cm)) + .net4(getSubsampling2dNet(CNN2DFormat.NHWC, false, cm)) + .inNCHW(inNCHW) + .labelsNCHW(labels) + .labelsNHWC(labels) + .testLayerIdx(1) + .helpers(helpers) + .build(); + + testHelper(tc); + } + } + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + @Test + public void testDepthwiseConv2d() { + try { + for (boolean helpers : new boolean[]{false, true}) { + for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")"; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray labels = TestUtils.randomOneHot(2, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getDepthwiseConv2dNet(CNN2DFormat.NCHW, true, cm)) + .net2(getDepthwiseConv2dNet(CNN2DFormat.NCHW, false, cm)) + .net3(getDepthwiseConv2dNet(CNN2DFormat.NHWC, true, cm)) + .net4(getDepthwiseConv2dNet(CNN2DFormat.NHWC, false, cm)) + .inNCHW(inNCHW) + .labelsNCHW(labels) + .labelsNHWC(labels) + .testLayerIdx(1) + .helpers(helpers) + .build(); + + testHelper(tc); + } + } + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + @Test + public void testSeparableConv2d() { + try { + for (boolean helpers : new boolean[]{false, true}) { + for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")"; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray labels = TestUtils.randomOneHot(2, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getSeparableConv2dNet(CNN2DFormat.NCHW, true, cm)) + .net2(getSeparableConv2dNet(CNN2DFormat.NCHW, false, cm)) + .net3(getSeparableConv2dNet(CNN2DFormat.NHWC, true, cm)) + .net4(getSeparableConv2dNet(CNN2DFormat.NHWC, false, cm)) + .inNCHW(inNCHW) + .labelsNCHW(labels) + .labelsNHWC(labels) + .testLayerIdx(1) + .build(); + + testHelper(tc); + } + } + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + @Test + public void testDeconv2d() { + try { + for (boolean helpers : new boolean[]{false, true}) { + for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")"; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray labels = TestUtils.randomOneHot(2, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getDeconv2DNet2dNet(CNN2DFormat.NCHW, true, cm)) + .net2(getDeconv2DNet2dNet(CNN2DFormat.NCHW, false, cm)) + .net3(getDeconv2DNet2dNet(CNN2DFormat.NHWC, true, cm)) + .net4(getDeconv2DNet2dNet(CNN2DFormat.NHWC, false, cm)) + .inNCHW(inNCHW) + .labelsNCHW(labels) + .labelsNHWC(labels) + .testLayerIdx(1) + .helpers(helpers) + .build(); + + testHelper(tc); + } + } + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + @Test + public void testLRN() { + try { + for (boolean helpers : new boolean[]{false, true}) { + for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")"; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray labels = TestUtils.randomOneHot(2, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getLrnLayer(CNN2DFormat.NCHW, true, cm)) + .net2(getLrnLayer(CNN2DFormat.NCHW, false, cm)) + .net3(getLrnLayer(CNN2DFormat.NHWC, true, cm)) + .net4(getLrnLayer(CNN2DFormat.NHWC, false, cm)) + .inNCHW(inNCHW) + .labelsNCHW(labels) + .labelsNHWC(labels) + .testLayerIdx(1) + .helpers(helpers) + .build(); + + testHelper(tc); + } + } + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + @Test + public void testZeroPaddingLayer(){ + try { + for (boolean helpers : new boolean[]{false, true}) { + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = helpers ? "With helpers" : "No helpers"; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray labels = TestUtils.randomOneHot(2, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getZeroPaddingNet(CNN2DFormat.NCHW, true)) + .net2(getZeroPaddingNet(CNN2DFormat.NCHW, false)) + .net3(getZeroPaddingNet(CNN2DFormat.NHWC, true)) + .net4(getZeroPaddingNet(CNN2DFormat.NHWC, false)) + .inNCHW(inNCHW) + .labelsNCHW(labels) + .labelsNHWC(labels) + .testLayerIdx(1) + .helpers(helpers) + .build(); + + testHelper(tc); + } + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + @Test + public void testCropping2DLayer(){ + try { + for (boolean helpers : new boolean[]{false, true}) { + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = helpers ? "With helpers" : "No helpers"; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray labels = TestUtils.randomOneHot(2, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getCropping2dNet(CNN2DFormat.NCHW, true)) + .net2(getCropping2dNet(CNN2DFormat.NCHW, false)) + .net3(getCropping2dNet(CNN2DFormat.NHWC, true)) + .net4(getCropping2dNet(CNN2DFormat.NHWC, false)) + .inNCHW(inNCHW) + .labelsNCHW(labels) + .labelsNHWC(labels) + .testLayerIdx(1) + .helpers(helpers) + .build(); + + testHelper(tc); + } + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + @Test + public void testUpsampling2d(){ + try { + for (boolean helpers : new boolean[]{false, true}) { + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = helpers ? "With helpers" : "No helpers"; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray labels = TestUtils.randomOneHot(2, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getUpsamplingNet(CNN2DFormat.NCHW, true)) + .net2(getUpsamplingNet(CNN2DFormat.NCHW, false)) + .net3(getUpsamplingNet(CNN2DFormat.NHWC, true)) + .net4(getUpsamplingNet(CNN2DFormat.NHWC, false)) + .inNCHW(inNCHW) + .labelsNCHW(labels) + .labelsNHWC(labels) + .testLayerIdx(1) + .helpers(helpers) + .build(); + + testHelper(tc); + } + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + @Test + public void testBatchNormNet(){ + try { + for(boolean useLogStd : new boolean[]{true, false}) { + for (boolean helpers : new boolean[]{false, true}) { + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = (helpers ? "With helpers" : "No helpers") + " - " + (useLogStd ? "logstd" : "std"); + System.out.println(" --- " + msg + " ---"); + + INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray labels = TestUtils.randomOneHot(2, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getBatchNormNet(useLogStd, CNN2DFormat.NCHW, true)) + .net2(getBatchNormNet(useLogStd, CNN2DFormat.NCHW, false)) + .net3(getBatchNormNet(useLogStd, CNN2DFormat.NHWC, true)) + .net4(getBatchNormNet(useLogStd, CNN2DFormat.NHWC, false)) + .inNCHW(inNCHW) + .labelsNCHW(labels) + .labelsNHWC(labels) + .testLayerIdx(1) + .helpers(helpers) + .build(); + + testHelper(tc); + } + } + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + @Test + public void testCnnLossLayer() { + try { + for (boolean helpers : new boolean[]{false, true}) { + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = helpers ? "With helpers" : "No helpers"; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray labelsNHWC = TestUtils.randomOneHot(this.dataType,2*6*6, 3); + labelsNHWC = labelsNHWC.reshape(2,6,6,3); + INDArray labelsNCHW = labelsNHWC.permute(0,3,1,2).dup(); + + + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getCnnLossNet(CNN2DFormat.NCHW, true, ConvolutionMode.Same)) + .net2(getCnnLossNet(CNN2DFormat.NCHW, false, ConvolutionMode.Same)) + .net3(getCnnLossNet(CNN2DFormat.NHWC, true, ConvolutionMode.Same)) + .net4(getCnnLossNet(CNN2DFormat.NHWC, false, ConvolutionMode.Same)) + .inNCHW(inNCHW) + .labelsNCHW(labelsNCHW) + .labelsNHWC(labelsNHWC) + .testLayerIdx(1) + .nhwcOutput(true) + .helpers(helpers) + .build(); + + testHelper(tc); + } + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + @Test + public void testSpaceToDepthNet(){ + try { + for (boolean helpers : new boolean[]{false, true}) { + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = helpers ? "With helpers" : "No helpers"; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray labels = TestUtils.randomOneHot(2, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getSpaceToDepthNet(CNN2DFormat.NCHW, true)) + .net2(getSpaceToDepthNet(CNN2DFormat.NCHW, false)) + .net3(getSpaceToDepthNet(CNN2DFormat.NHWC, true)) + .net4(getSpaceToDepthNet(CNN2DFormat.NHWC, false)) + .inNCHW(inNCHW) + .labelsNCHW(labels) + .labelsNHWC(labels) + .testLayerIdx(1) + .helpers(helpers) + .build(); + + testHelper(tc); + } + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + @Test + public void testSpaceToBatchNet(){ + try { + for (boolean helpers : new boolean[]{false, true}) { + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = helpers ? "With helpers" : "No helpers"; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 16, 16); + INDArray labels = TestUtils.randomOneHot(8, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getSpaceToBatchNet(CNN2DFormat.NCHW, true)) + .net2(getSpaceToBatchNet(CNN2DFormat.NCHW, false)) + .net3(getSpaceToBatchNet(CNN2DFormat.NHWC, true)) + .net4(getSpaceToBatchNet(CNN2DFormat.NHWC, false)) + .inNCHW(inNCHW) + .labelsNCHW(labels) + .labelsNHWC(labels) + .testLayerIdx(1) + .helpers(helpers) + .build(); + + testHelper(tc); + } + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + @Test + public void testLocallyConnected() { + try { + for (boolean helpers : new boolean[]{false, true}) { + for (ConvolutionMode cm : new ConvolutionMode[]{ConvolutionMode.Truncate, ConvolutionMode.Same}) { + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = helpers ? "With helpers (" + cm + ")" : "No helpers (" + cm + ")"; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray labels = TestUtils.randomOneHot(2, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getLocallyConnectedNet(CNN2DFormat.NCHW, true, cm)) + .net2(getLocallyConnectedNet(CNN2DFormat.NCHW, false, cm)) + .net3(getLocallyConnectedNet(CNN2DFormat.NHWC, true, cm)) + .net4(getLocallyConnectedNet(CNN2DFormat.NHWC, false, cm)) + .inNCHW(inNCHW) + .labelsNCHW(labels) + .labelsNHWC(labels) + .testLayerIdx(1) + .helpers(helpers) + .build(); + + testHelper(tc); + } + } + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + @Test + public void testGlobalPooling() { + try { + for (boolean helpers : new boolean[]{false, true}) { + for (PoolingType pt : PoolingType.values()) { + Nd4j.getRandom().setSeed(12345); + Nd4j.getEnvironment().allowHelpers(helpers); + String msg = helpers ? "With helpers (" + pt + ")" : "No helpers (" + pt + ")"; + System.out.println(" --- " + msg + " ---"); + + INDArray inNCHW = Nd4j.rand(this.dataType, 2, 3, 12, 12); + INDArray labels = TestUtils.randomOneHot(2, 10); + + TestCase tc = TestCase.builder() + .msg(msg) + .net1(getGlobalPoolingNet(CNN2DFormat.NCHW, pt, true)) + .net2(getGlobalPoolingNet(CNN2DFormat.NCHW, pt, false)) + .net3(getGlobalPoolingNet(CNN2DFormat.NHWC, pt, true)) + .net4(getGlobalPoolingNet(CNN2DFormat.NHWC, pt, false)) + .inNCHW(inNCHW) + .labelsNCHW(labels) + .labelsNHWC(labels) + .testLayerIdx(1) + .build(); + + testHelper(tc); + } + } + } finally { + Nd4j.getEnvironment().allowHelpers(true); + } + } + + private MultiLayerNetwork getConv2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { + if (setOnLayerAlso) { + return getNetWithLayer(new ConvolutionLayer.Builder() + .kernelSize(3, 3) + .stride(2, 2) + .activation(Activation.TANH) + .dataFormat(format) + .nOut(3) + .helperAllowFallback(false) + .build(), format, cm, null); + } else { + return getNetWithLayer(new ConvolutionLayer.Builder() + .kernelSize(3, 3) + .stride(2, 2) + .activation(Activation.TANH) + .nOut(3) + .helperAllowFallback(false) + .build(), format, cm, null); + } + } + + private MultiLayerNetwork getSubsampling2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { + if (setOnLayerAlso) { + return getNetWithLayer(new SubsamplingLayer.Builder() + .kernelSize(2, 2) + .stride(1, 1) + .dataFormat(format) + .helperAllowFallback(false) + .build(), format, cm, null); + } else { + return getNetWithLayer(new SubsamplingLayer.Builder() + .kernelSize(2, 2) + .stride(1, 1) + .helperAllowFallback(false) + .build(), format, cm, null); + } + } + + private MultiLayerNetwork getSeparableConv2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { + if (setOnLayerAlso) { + return getNetWithLayer(new SeparableConvolution2D.Builder() + .kernelSize(3, 3) + .stride(2, 2) + .activation(Activation.TANH) + .dataFormat(format) + .nOut(3) + .helperAllowFallback(false) + .build(), format, cm, null); + } else { + return getNetWithLayer(new SeparableConvolution2D.Builder() + .kernelSize(3, 3) + .stride(2, 2) + .activation(Activation.TANH) + .nOut(3) + .helperAllowFallback(false) + .build(), format, cm, null); + } + } + + private MultiLayerNetwork getDepthwiseConv2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { + if (setOnLayerAlso) { + return getNetWithLayer(new DepthwiseConvolution2D.Builder() + .depthMultiplier(2) + .kernelSize(3, 3) + .stride(2, 2) + .activation(Activation.TANH) + .dataFormat(format) + .nOut(3) + .helperAllowFallback(false) + .build(), format, cm, null); + } else { + return getNetWithLayer(new DepthwiseConvolution2D.Builder() + .depthMultiplier(2) + .kernelSize(3, 3) + .stride(2, 2) + .activation(Activation.TANH) + .nOut(3) + .helperAllowFallback(false) + .build(), format, cm, null); + } + } + + private MultiLayerNetwork getLrnLayer(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { + if (setOnLayerAlso) { + return getNetWithLayer(new LocalResponseNormalization.Builder() + .dataFormat(format) + .helperAllowFallback(false) + .build(), format, cm, null); + } else { + return getNetWithLayer(new LocalResponseNormalization.Builder() + .helperAllowFallback(false) + .build(), format, cm, null); + } + } + + private MultiLayerNetwork getZeroPaddingNet(CNN2DFormat format, boolean setOnLayerAlso) { + if (setOnLayerAlso) { + return getNetWithLayer(new ZeroPaddingLayer.Builder(2,2) + .dataFormat(format).build(), format, ConvolutionMode.Same, null); + } else { + return getNetWithLayer(new ZeroPaddingLayer.Builder(2,2).build(), + format, ConvolutionMode.Same, null); + } + } + + private MultiLayerNetwork getCropping2dNet(CNN2DFormat format, boolean setOnLayerAlso) { + if (setOnLayerAlso) { + return getNetWithLayer(new Cropping2D.Builder(2,2) + .dataFormat(format).build(), format, ConvolutionMode.Same, null); + } else { + return getNetWithLayer(new Cropping2D.Builder(2,2) + .build(), format, ConvolutionMode.Same, null); + } + } + + private MultiLayerNetwork getUpsamplingNet(CNN2DFormat format, boolean setOnLayerAlso) { + if (setOnLayerAlso) { + return getNetWithLayer(new Upsampling2D.Builder(2) + .dataFormat(format).build(), format, ConvolutionMode.Same, null); + } else { + return getNetWithLayer(new Upsampling2D.Builder(2) + .build(), format, ConvolutionMode.Same, null); + } + } + + private MultiLayerNetwork getDeconv2DNet2dNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { + if (setOnLayerAlso) { + return getNetWithLayer(new Deconvolution2D.Builder().nOut(2) + .activation(Activation.TANH) + .kernelSize(2,2) + .stride(2,2) + .build(), format, cm, null); + } else { + return getNetWithLayer(new Deconvolution2D.Builder().nOut(2) + .activation(Activation.TANH) + .kernelSize(2,2) + .stride(2,2) + .build(), format, cm, null); + } + } + + private MultiLayerNetwork getBatchNormNet(boolean logStdev, CNN2DFormat format, boolean setOnLayerAlso) { + if (setOnLayerAlso) { + return getNetWithLayer(new BatchNormalization.Builder() + .useLogStd(logStdev) + .dataFormat(format) + .helperAllowFallback(false) + .nOut(3).build(), format, ConvolutionMode.Same, null); + } else { + return getNetWithLayer(new BatchNormalization.Builder() + .useLogStd(logStdev) + .helperAllowFallback(false) + .nOut(3).build(), format, ConvolutionMode.Same, null); + } + } + + private MultiLayerNetwork getSpaceToDepthNet(CNN2DFormat format, boolean setOnLayerAlso) { + if (setOnLayerAlso) { + return getNetWithLayer(new SpaceToDepthLayer.Builder() + .blocks(2) + .dataFormat(format) + .build(), format, ConvolutionMode.Same, null); + } else { + return getNetWithLayer(new SpaceToDepthLayer.Builder() + .blocks(2) + .build(), format, ConvolutionMode.Same, null); + } + } + + private MultiLayerNetwork getSpaceToBatchNet(CNN2DFormat format, boolean setOnLayerAlso) { + if (setOnLayerAlso) { + return getNetWithLayer(new SpaceToBatchLayer.Builder() + .blocks(2, 2) + .dataFormat(format) + .build(), format, ConvolutionMode.Same, InputType.convolutional(16, 16, 3, format)); + } else { + return getNetWithLayer(new SpaceToBatchLayer.Builder() + .blocks(2, 2) + .build(), format, ConvolutionMode.Same, InputType.convolutional(16, 16, 3, format)); + } + } + + private MultiLayerNetwork getLocallyConnectedNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm) { + if (setOnLayerAlso) { + return getNetWithLayer(new LocallyConnected2D.Builder() + .kernelSize(3, 3) + .stride(2, 2) + .activation(Activation.TANH) + .dataFormat(format) + .nOut(3) + .build(), format, cm, null); + } else { + return getNetWithLayer(new LocallyConnected2D.Builder() + .kernelSize(3, 3) + .stride(2, 2) + .activation(Activation.TANH) + .nOut(3) + .build(), format, cm, null); + } + } + + private MultiLayerNetwork getGlobalPoolingNet(CNN2DFormat format, PoolingType pt, boolean setOnLayerAlso) { + if (setOnLayerAlso) { + return getNetWithLayer(new GlobalPoolingLayer.Builder(pt) + .poolingDimensions(format == CNN2DFormat.NCHW ? new int[]{2,3} : new int[]{1,2}) + .build(), format, ConvolutionMode.Same, null); + } else { + return getNetWithLayer(new GlobalPoolingLayer.Builder(pt) + .build(), format, ConvolutionMode.Same, null); + } + } + + private MultiLayerNetwork getCnnLossNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm){ + NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder() + .seed(12345) + .convolutionMode(cm) + .list() + .layer(new ConvolutionLayer.Builder() + .kernelSize(3, 3) + .stride(2, 2) + .activation(Activation.TANH) + .dataFormat(format) + .nOut(3) + .helperAllowFallback(false) + .build()); + if(setOnLayerAlso){ + builder.layer(new CnnLossLayer.Builder().format(format).activation(Activation.SOFTMAX).build()); + } else { + builder.layer(new CnnLossLayer.Builder().activation(Activation.SOFTMAX).build()); + } + + builder.setInputType(InputType.convolutional(12, 12, 3, format)); + + MultiLayerNetwork net = new MultiLayerNetwork(builder.build()); + net.init(); + return net; + } + + private MultiLayerNetwork getNetWithLayer(Layer layer, CNN2DFormat format, ConvolutionMode cm, InputType inputType) { + NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder() + .dataType(this.dataType) + .seed(12345) + .convolutionMode(cm) + .list() + .layer(new ConvolutionLayer.Builder() + .kernelSize(3, 3) + .stride(2, 2) + .activation(Activation.TANH) + .dataFormat(format) + .nOut(3) + .helperAllowFallback(false) + .build()) + .layer(layer) + .layer(new OutputLayer.Builder().activation(Activation.SOFTMAX).nOut(10).build()) + .setInputType(inputType != null ? inputType : InputType.convolutional(12, 12, 3, format)); + + if(format == CNN2DFormat.NHWC && !(layer instanceof GlobalPoolingLayer)){ + //Add a preprocessor due to the differences in how NHWC and NCHW activations are flattened + //DL4J's flattening behaviour matches Keras (hence TF) for import compatibility + builder.inputPreProcessor(2, new ComposableInputPreProcessor(new NHWCToNCHWPreprocessor(), new CnnToFeedForwardPreProcessor())); + } + + MultiLayerNetwork net = new MultiLayerNetwork(builder.build()); + net.init(); + return net; + } + + @AllArgsConstructor + @Data + @NoArgsConstructor + @Builder + private static class TestCase { + private String msg; + private MultiLayerNetwork net1; + private MultiLayerNetwork net2; + private MultiLayerNetwork net3; + private MultiLayerNetwork net4; + private INDArray inNCHW; + private INDArray labelsNCHW; + private INDArray labelsNHWC; + private int testLayerIdx; + private boolean nhwcOutput; + private boolean helpers; + } + + public static void testHelper(TestCase tc) { + + if(!tc.helpers){ + try { + CuDNNTestUtils.removeHelpers(tc.net1.getLayers()); + CuDNNTestUtils.removeHelpers(tc.net2.getLayers()); + CuDNNTestUtils.removeHelpers(tc.net3.getLayers()); + CuDNNTestUtils.removeHelpers(tc.net4.getLayers()); + } catch (Throwable t){ + throw new RuntimeException(t); + } + } + + + tc.net2.params().assign(tc.net1.params()); + tc.net3.params().assign(tc.net1.params()); + tc.net4.params().assign(tc.net1.params()); + + //Test forward pass: + INDArray inNCHW = tc.inNCHW; + INDArray inNHWC = tc.inNCHW.permute(0, 2, 3, 1).dup(); + + INDArray l0_1 = tc.net1.feedForward(inNCHW).get(tc.testLayerIdx + 1); + INDArray l0_2 = tc.net2.feedForward(inNCHW).get(tc.testLayerIdx + 1); + INDArray l0_3 = tc.net3.feedForward(inNHWC).get(tc.testLayerIdx + 1); + INDArray l0_4 = tc.net4.feedForward(inNHWC).get(tc.testLayerIdx + 1); + + assertEquals(tc.msg, l0_1, l0_2); + if(l0_1.rank() == 4) { + assertEquals(tc.msg, l0_1, l0_3.permute(0, 3, 1, 2)); + assertEquals(tc.msg, l0_1, l0_4.permute(0, 3, 1, 2)); + } else { + assertEquals(tc.msg, l0_1, l0_3); + assertEquals(tc.msg, l0_1, l0_4); + } + + + INDArray out1 = tc.net1.output(inNCHW); + INDArray out2 = tc.net2.output(inNCHW); + INDArray out3 = tc.net3.output(inNHWC); + INDArray out4 = tc.net4.output(inNHWC); + + assertEquals(tc.msg, out1, out2); + if(!tc.nhwcOutput) { + assertEquals(tc.msg, out1, out3); + assertEquals(tc.msg, out1, out4); + } else { + assertEquals(tc.msg, out1, out3.permute(0,3,1,2)); //NHWC to NCHW + assertEquals(tc.msg, out1, out4.permute(0,3,1,2)); + } + + //Test backprop + Pair p1 = tc.net1.calculateGradients(inNCHW, tc.labelsNCHW, null, null); + Pair p2 = tc.net2.calculateGradients(inNCHW, tc.labelsNCHW, null, null); + Pair p3 = tc.net3.calculateGradients(inNHWC, tc.labelsNHWC, null, null); + Pair p4 = tc.net4.calculateGradients(inNHWC, tc.labelsNHWC, null, null); + + //Inpput gradients + assertEquals(tc.msg, p1.getSecond(), p2.getSecond()); + assertEquals(tc.msg, p1.getSecond(), p3.getSecond().permute(0,3,1,2)); //Input gradients for NHWC input are also in NHWC format + assertEquals(tc.msg, p1.getSecond(), p4.getSecond().permute(0,3,1,2)); + + List diff12 = differentGrads(p1.getFirst(), p2.getFirst()); + List diff13 = differentGrads(p1.getFirst(), p3.getFirst()); + List diff14 = differentGrads(p1.getFirst(), p4.getFirst()); + assertEquals(tc.msg + " " + diff12, 0, diff12.size()); + assertEquals(tc.msg + " " + diff13, 0, diff13.size()); + assertEquals(tc.msg + " " + diff14, 0, diff14.size()); + + assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p2.getFirst().gradientForVariable()); + assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p3.getFirst().gradientForVariable()); + assertEquals(tc.msg, p1.getFirst().gradientForVariable(), p4.getFirst().gradientForVariable()); + + tc.net1.fit(inNCHW, tc.labelsNCHW); + tc.net2.fit(inNCHW, tc.labelsNCHW); + tc.net3.fit(inNHWC, tc.labelsNHWC); + tc.net4.fit(inNHWC, tc.labelsNHWC); + + assertEquals(tc.msg, tc.net1.params(), tc.net2.params()); + assertEquals(tc.msg, tc.net1.params(), tc.net3.params()); + assertEquals(tc.msg, tc.net1.params(), tc.net4.params()); + + //Test serialization + MultiLayerNetwork net1a = TestUtils.testModelSerialization(tc.net1); + MultiLayerNetwork net2a = TestUtils.testModelSerialization(tc.net2); + MultiLayerNetwork net3a = TestUtils.testModelSerialization(tc.net3); + MultiLayerNetwork net4a = TestUtils.testModelSerialization(tc.net4); + + if(!tc.helpers){ + try { + CuDNNTestUtils.removeHelpers(net1a.getLayers()); + CuDNNTestUtils.removeHelpers(net2a.getLayers()); + CuDNNTestUtils.removeHelpers(net3a.getLayers()); + CuDNNTestUtils.removeHelpers(net4a.getLayers()); + } catch (Throwable t){ + throw new RuntimeException(t); + } + } + + out1 = tc.net1.output(inNCHW); + assertEquals(tc.msg, out1, net1a.output(inNCHW)); + assertEquals(tc.msg, out1, net2a.output(inNCHW)); + if(!tc.nhwcOutput) { + assertEquals(tc.msg, out1, net3a.output(inNHWC)); + assertEquals(tc.msg, out1, net4a.output(inNHWC)); + } else { + assertEquals(tc.msg, out1, net3a.output(inNHWC).permute(0,3,1,2)); //NHWC to NCHW + assertEquals(tc.msg, out1, net4a.output(inNHWC).permute(0,3,1,2)); + } + + } + + private static List differentGrads(Gradient g1, Gradient g2){ + List differs = new ArrayList<>(); + Map m1 = g1.gradientForVariable(); + Map m2 = g2.gradientForVariable(); + for(String s : m1.keySet()){ + INDArray a1 = m1.get(s); + INDArray a2 = m2.get(s); + if(!a1.equals(a2)){ + differs.add(s); + } + } + return differs; + } + + //Converts NHWC to NCHW activations + @EqualsAndHashCode + private static class NHWCToNCHWPreprocessor implements InputPreProcessor { + + @Override + public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) { + return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.permute(0,3,1,2)); + } + + @Override + public INDArray backprop(INDArray output, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) { + return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, output.permute(0,2,3,1)); + } + + @Override + public InputPreProcessor clone() { + return this; + } + + @Override + public InputType getOutputType(InputType inputType) { + InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType; + return InputType.convolutional(c.getHeight(), c.getWidth(), c.getChannels(), CNN2DFormat.NCHW); + } + + @Override + public Pair feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) { + return null; + } + } +} diff --git a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/convolution/TestConvolution.java b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/convolution/TestConvolution.java index 67a2958b7..72ab2c043 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/convolution/TestConvolution.java +++ b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/convolution/TestConvolution.java @@ -212,7 +212,7 @@ public class TestConvolution extends BaseDL4JTest { ComputationGraph model = KerasModelImport.importKerasModelAndWeights( fExtracted.getAbsolutePath(), new int[]{inSize, inSize, 3}, false); model = model.convertDataType(DataType.DOUBLE); - INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{1, 3, inSize, inSize}); + INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{1, inSize, inSize, 3}); //Keras import model -> NHWC CuDNNTestUtils.assertHelpersPresent(model.getLayers()); Map withCudnn = model.feedForward(in, false); diff --git a/deeplearning4j/deeplearning4j-data/deeplearning4j-datasets/src/main/java/org/deeplearning4j/datasets/fetchers/EmnistDataFetcher.java b/deeplearning4j/deeplearning4j-data/deeplearning4j-datasets/src/main/java/org/deeplearning4j/datasets/fetchers/EmnistDataFetcher.java index 674043a8a..071c8c009 100644 --- a/deeplearning4j/deeplearning4j-data/deeplearning4j-datasets/src/main/java/org/deeplearning4j/datasets/fetchers/EmnistDataFetcher.java +++ b/deeplearning4j/deeplearning4j-data/deeplearning4j-datasets/src/main/java/org/deeplearning4j/datasets/fetchers/EmnistDataFetcher.java @@ -16,6 +16,7 @@ package org.deeplearning4j.datasets.fetchers; +import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.FileUtils; import org.apache.commons.io.FilenameUtils; import org.deeplearning4j.base.EmnistFetcher; @@ -36,6 +37,7 @@ import java.util.Random; * @author Alex Black * */ +@Slf4j public class EmnistDataFetcher extends MnistDataFetcher implements DataSetFetcher { protected EmnistFetcher fetcher; @@ -64,7 +66,7 @@ public class EmnistDataFetcher extends MnistDataFetcher implements DataSetFetche try { man = new MnistManager(images, labels, totalExamples); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); FileUtils.deleteDirectory(new File(EMNIST_ROOT)); new EmnistFetcher(dataSet).downloadAndUntar(); man = new MnistManager(images, labels, totalExamples); diff --git a/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/BarnesHutTsne.java b/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/BarnesHutTsne.java index 8b8b6a524..7595a0777 100644 --- a/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/BarnesHutTsne.java +++ b/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/BarnesHutTsne.java @@ -687,9 +687,7 @@ public class BarnesHutTsne implements Model { * @throws IOException */ public void saveAsFile(List labels, String path) throws IOException { - BufferedWriter write = null; - try { - write = new BufferedWriter(new FileWriter(new File(path))); + try (BufferedWriter write = new BufferedWriter(new FileWriter(new File(path)))) { for (int i = 0; i < Y.rows(); i++) { if (i >= labels.size()) break; @@ -711,17 +709,11 @@ public class BarnesHutTsne implements Model { } write.flush(); - write.close(); - } finally { - if (write != null) - write.close(); } } public void saveAsFile(String path) throws IOException { - BufferedWriter write = null; - try { - write = new BufferedWriter(new FileWriter(new File(path))); + try (BufferedWriter write = new BufferedWriter(new FileWriter(new File(path)))) { for (int i = 0; i < Y.rows(); i++) { StringBuilder sb = new StringBuilder(); INDArray wordVector = Y.getRow(i); @@ -734,10 +726,6 @@ public class BarnesHutTsne implements Model { write.write(sb.toString()); } write.flush(); - write.close(); - } finally { - if (write != null) - write.close(); } } /** diff --git a/deeplearning4j/deeplearning4j-modelimport/pom.xml b/deeplearning4j/deeplearning4j-modelimport/pom.xml index 6d71c394e..3dcdcc720 100644 --- a/deeplearning4j/deeplearning4j-modelimport/pom.xml +++ b/deeplearning4j/deeplearning4j-modelimport/pom.xml @@ -113,6 +113,12 @@ test + + org.datavec + datavec-python + ${datavec.version} + test + diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/Hdf5Archive.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/Hdf5Archive.java index a5ea8efca..73a9a26ea 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/Hdf5Archive.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/Hdf5Archive.java @@ -60,7 +60,7 @@ public class Hdf5Archive implements Closeable { /* This is necessary for the call to the BytePointer constructor below. */ Loader.load(org.bytedeco.hdf5.global.hdf5.class); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); } } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/KerasInput.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/KerasInput.java index 785e480d1..d83d27011 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/KerasInput.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/KerasInput.java @@ -19,6 +19,9 @@ package org.deeplearning4j.nn.modelimport.keras.layers; import lombok.Data; import lombok.EqualsAndHashCode; import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.ArrayUtils; +import org.deeplearning4j.nn.conf.CNN2DFormat; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.Convolution3D; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; @@ -121,27 +124,29 @@ public class KerasInput extends KerasLayer { InputType myInputType; switch (this.inputShape.length) { case 1: - myInputType = new InputType.InputTypeFeedForward(this.inputShape[0]); + myInputType = new InputType.InputTypeFeedForward(this.inputShape[0], null); break; case 2: if(this.dimOrder != null) { + System.out.println("Dim order: " + this.dimOrder); + System.out.println("Input shape: " + ArrayUtils.toString(this.inputShape)); switch (this.dimOrder) { case TENSORFLOW: //NWC == channels_last - myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0]); + myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0], RNNFormat.NWC); break; case THEANO: //NCW == channels_first - myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]); + myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1], RNNFormat.NCW); break; case NONE: //Assume RNN in [mb, seqLen, size] format - myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]); + myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0], RNNFormat.NWC); break; default: throw new IllegalStateException("Unknown/not supported dimension ordering: " + this.dimOrder); } } else { //Assume RNN in [mb, seqLen, size] format - myInputType = new InputType.InputTypeRecurrent(this.inputShape[0], this.inputShape[1]); + myInputType = new InputType.InputTypeRecurrent(this.inputShape[1], this.inputShape[0], RNNFormat.NWC); } break; @@ -150,17 +155,17 @@ public class KerasInput extends KerasLayer { case TENSORFLOW: /* TensorFlow convolutional input: # rows, # cols, # channels */ myInputType = new InputType.InputTypeConvolutional(this.inputShape[0], this.inputShape[1], - this.inputShape[2]); + this.inputShape[2], CNN2DFormat.NHWC); break; case THEANO: /* Theano convolutional input: # channels, # rows, # cols */ myInputType = new InputType.InputTypeConvolutional(this.inputShape[1], this.inputShape[2], - this.inputShape[0]); + this.inputShape[0], CNN2DFormat.NCHW); break; default: this.dimOrder = DimOrder.THEANO; myInputType = new InputType.InputTypeConvolutional(this.inputShape[1], this.inputShape[2], - this.inputShape[0]); + this.inputShape[0], CNN2DFormat.NCHW); log.warn("Couldn't determine dim ordering / data format from model file. Older Keras " + "versions may come without specified backend, in which case we assume the model was " + "built with theano." ); diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayer.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayer.java index ecf64e8c0..5a1f0e8dd 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayer.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayer.java @@ -20,6 +20,7 @@ import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.Layer; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; @@ -65,6 +66,9 @@ public class TFOpLayer extends Layer { long[] shape = inputType.getShape(true); TFOpLayerImpl tempLayer = new TFOpLayerImpl(nodeDef, constants, null, null); long[] outputShape = tempLayer.getOutputShape(shape); + if (outputShape.length == 3){ + return InputType.recurrent(outputShape[2], outputShape[1], RNNFormat.NWC); + } return InputType.inferInputType(Nd4j.create(outputShape)); } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayerImpl.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayerImpl.java index d7b0b3b56..ebc1ca5f0 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayerImpl.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayerImpl.java @@ -125,17 +125,9 @@ public class TFOpLayerImpl extends AbstractLayer { } private INDArray runGraph(INDArray input){ - if (input.rank() == 3){ - // TODO make this a preprocessor - input = input.permute(0, 2, 1); - } Map inputMap = new HashMap<>(); inputMap.put(inputNames.get(0), input); INDArray out = graphRunnerService.run(inputMap).values().toArray(new INDArray[0])[0]; - if (out.rank() == 3){ - out = out.permute(0, 2, 1); // TODO post-processing? - } - return out; } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution1D.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution1D.java index 120870de9..49e862de7 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution1D.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution1D.java @@ -22,6 +22,7 @@ import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.ArrayUtils; import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.conf.InputPreProcessor; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.Convolution1DLayer; import org.deeplearning4j.nn.conf.layers.InputTypeUtil; @@ -94,7 +95,6 @@ public class KerasConvolution1D extends KerasConvolution { IWeightInit init = getWeightInitFromConfig(layerConfig, conf.getLAYER_FIELD_INIT(), enforceTrainingConfig, conf, kerasMajorVersion); - Convolution1DLayer.Builder builder = new Convolution1DLayer.Builder().name(this.layerName) .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout) .activation(getIActivationFromConfig(layerConfig, conf)) @@ -103,7 +103,7 @@ public class KerasConvolution1D extends KerasConvolution { .convolutionMode(getConvolutionModeFromConfig(layerConfig, conf)) .kernelSize(getKernelSizeFromConfig(layerConfig, 1, conf, kerasMajorVersion)[0]) .hasBias(hasBias) - .stride(getStrideFromConfig(layerConfig, 1, conf)[0]); + .stride(getStrideFromConfig(layerConfig, 1, conf)[0]).rnnDataFormat(dimOrder == DimOrder.TENSORFLOW? RNNFormat.NWC: RNNFormat.NCW); int[] padding = getPaddingFromBorderModeConfig(layerConfig, 1, conf, kerasMajorVersion); if (hasBias) builder.biasInit(0.0); @@ -160,8 +160,8 @@ public class KerasConvolution1D extends KerasConvolution { public InputPreProcessor getInputPreprocessor(InputType... inputType) throws InvalidKerasConfigurationException { if (inputType.length > 1) throw new InvalidKerasConfigurationException( - "Keras LSTM layer accepts only one input (received " + inputType.length + ")"); - return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], layerName); + "Keras Conv1D layer accepts only one input (received " + inputType.length + ")"); + return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], RNNFormat.NCW,layerName); } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution2D.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution2D.java index e9c74e78c..51d481e5d 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution2D.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolution2D.java @@ -20,6 +20,7 @@ import lombok.Data; import lombok.EqualsAndHashCode; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.api.layers.LayerConstraint; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; @@ -27,6 +28,7 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurat import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; import org.deeplearning4j.nn.weights.IWeightInit; +import oshi.jna.platform.windows.PowrProf; import java.util.Map; @@ -93,6 +95,7 @@ public class KerasConvolution2D extends KerasConvolution { LayerConstraint weightConstraint = KerasConstraintUtils.getConstraintsFromConfig( layerConfig, conf.getLAYER_FIELD_W_CONSTRAINT(), conf, kerasMajorVersion); + System.out.println("----" + dimOrder); ConvolutionLayer.Builder builder = new ConvolutionLayer.Builder().name(this.layerName) .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout) .activation(getIActivationFromConfig(layerConfig, conf)) @@ -101,7 +104,8 @@ public class KerasConvolution2D extends KerasConvolution { .convolutionMode(getConvolutionModeFromConfig(layerConfig, conf)) .kernelSize(getKernelSizeFromConfig(layerConfig, 2, conf, kerasMajorVersion)) .hasBias(hasBias) - .stride(getStrideFromConfig(layerConfig, 2, conf)); + .stride(getStrideFromConfig(layerConfig, 2, conf)) + .dataFormat((dimOrder==DimOrder.TENSORFLOW)? CNN2DFormat.NHWC:CNN2DFormat.NCHW); int[] padding = getPaddingFromBorderModeConfig(layerConfig, 2, conf, kerasMajorVersion); if (hasBias) builder.biasInit(0.0); diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolutionUtils.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolutionUtils.java index 0968260b7..d60e5f6e8 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolutionUtils.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolutionUtils.java @@ -360,8 +360,19 @@ public class KerasConvolutionUtils { } } else if (dimension == 1) { - int paddingInt = (int) innerConfig.get(layerField); - padding = new int[]{paddingInt, paddingInt}; + Object paddingObj = innerConfig.get(layerField); + if (paddingObj instanceof List){ + List paddingList = (List)paddingObj; + padding = new int[]{ + paddingList.get(0), + paddingList.get(1) + }; + } + else{ + int paddingInt = (int) innerConfig.get(layerField); + padding = new int[]{paddingInt, paddingInt}; + } + } else { throw new UnsupportedKerasConfigurationException( "Keras padding layer not supported"); diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasFlatten.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasFlatten.java index 196f9d3d9..a807d3357 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasFlatten.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasFlatten.java @@ -18,6 +18,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.core; import lombok.extern.slf4j.Slf4j; import lombok.val; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType.InputTypeConvolutional; @@ -27,7 +28,6 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurat import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor; import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor; -import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor; import java.util.Map; @@ -93,11 +93,10 @@ public class KerasFlatten extends KerasLayer { switch (this.getDimOrder()) { case NONE: case THEANO: - preprocessor = new CnnToFeedForwardPreProcessor(it.getHeight(), it.getWidth(), it.getChannels()); + preprocessor = new CnnToFeedForwardPreProcessor(it.getHeight(), it.getWidth(), it.getChannels(), CNN2DFormat.NCHW); break; case TENSORFLOW: - preprocessor = new TensorFlowCnnToFeedForwardPreProcessor(it.getHeight(), it.getWidth(), - it.getChannels()); + preprocessor = new CnnToFeedForwardPreProcessor(it.getHeight(), it.getWidth(), it.getChannels(), CNN2DFormat.NHWC); break; default: throw new InvalidKerasConfigurationException("Unknown Keras backend " + this.getDimOrder()); @@ -111,7 +110,7 @@ public class KerasFlatten extends KerasLayer { // to RNN type. Otherwise we add this trivial preprocessor (since there's nothing to flatten). InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0]; val inputShape = new long[]{it.getSize()}; - preprocessor = new ReshapePreprocessor(inputShape, inputShape, false); + preprocessor = new ReshapePreprocessor(inputShape, inputShape, false, null); } return preprocessor; } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasRepeatVector.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasRepeatVector.java index 41254e221..15f3011c4 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasRepeatVector.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasRepeatVector.java @@ -17,6 +17,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.core; import lombok.extern.slf4j.Slf4j; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.misc.RepeatVector; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; @@ -60,6 +61,7 @@ public class KerasRepeatVector extends KerasLayer { super(layerConfig, enforceTrainingConfig); this.layer = new RepeatVector.Builder().repetitionFactor(getRepeatMultiplier(layerConfig, conf)) + .dataFormat(RNNFormat.NWC) .name(this.layerName).build(); } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasReshape.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasReshape.java index e5f1375d1..bb4dc5ecf 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasReshape.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasReshape.java @@ -18,6 +18,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.core; import lombok.val; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; @@ -111,11 +112,9 @@ public class KerasReshape extends KerasLayer { } else { targetShape = new long[]{targetShape[1], targetShape[0], targetShape[2]}; } - preprocessor = new ReshapePreprocessor(inputShape, targetShape, false); + preprocessor = new ReshapePreprocessor(inputShape, targetShape, false, CNN2DFormat.NCHW); } else { // (dimOrder == DimOrder.TENSORFLOW || dimOrder == DimOrder.NONE && kerasMajorVersion == 2) - if (inputShape[0] != targetShape[0]) - targetShape = new long[]{targetShape[2], targetShape[0], targetShape[1]}; - preprocessor = new ReshapePreprocessor(inputShape, targetShape, false); + preprocessor = new ReshapePreprocessor(inputShape, targetShape, false, CNN2DFormat.NHWC); } } else if (inputType[0] instanceof InputType.InputTypeConvolutional3D) { @@ -128,23 +127,23 @@ public class KerasReshape extends KerasLayer { } else { targetShape = new long[] { targetShape[2], targetShape[1], targetShape[0], targetShape[3] }; } - preprocessor = new ReshapePreprocessor(inputShape, targetShape, false); + preprocessor = new ReshapePreprocessor(inputShape, targetShape, false, null); } else { if (inputShape[0] != targetShape[0]) targetShape = new long[] { targetShape[3], targetShape[0], targetShape[1], targetShape[2] }; - preprocessor = new ReshapePreprocessor(inputShape, targetShape, false); + preprocessor = new ReshapePreprocessor(inputShape, targetShape, false, null); } } else if (inputType[0] instanceof InputType.InputTypeRecurrent) { InputType.InputTypeRecurrent it = (InputType.InputTypeRecurrent) inputType[0]; val inputShape = new long[]{it.getSize(), it.getTimeSeriesLength()}; - preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false); + preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false, null); } else if (inputType[0] instanceof InputType.InputTypeFeedForward) { InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0]; val inputShape = new long[]{it.getSize()}; if (targetShape.length == 3) { targetShape = targetShapeForDimOrder(inputShape, targetShape); } - preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false); + preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false, null); } return preprocessor; } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/embeddings/KerasEmbedding.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/embeddings/KerasEmbedding.java index 1ee13c0b0..f5a29d4f5 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/embeddings/KerasEmbedding.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/embeddings/KerasEmbedding.java @@ -21,6 +21,7 @@ import lombok.EqualsAndHashCode; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.conf.InputPreProcessor; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; @@ -121,6 +122,7 @@ public class KerasEmbedding extends KerasLayer { .biasInit(0.0) .l1(this.weightL1Regularization) .l2(this.weightL2Regularization) + .outputDataFormat(RNNFormat.NWC) .hasBias(false); if (embeddingConstraint != null) builder.constrainWeights(embeddingConstraint); diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTM.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTM.java index 1c205bbca..3ab596cd5 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTM.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTM.java @@ -22,11 +22,9 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.conf.InputPreProcessor; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; -import org.deeplearning4j.nn.conf.layers.InputTypeUtil; -import org.deeplearning4j.nn.conf.layers.LSTM; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep; import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer; import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; @@ -37,6 +35,7 @@ import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils; import org.deeplearning4j.nn.params.LSTMParamInitializer; import org.deeplearning4j.nn.weights.IWeightInit; +import org.deeplearning4j.util.TimeSeriesUtils; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -187,7 +186,7 @@ public class KerasLSTM extends KerasLayer { .weightInitRecurrent(recurrentInit) .biasInit(0.0) // TODO: this is incorrect .l1(this.weightL1Regularization) - .l2(this.weightL2Regularization); + .l2(this.weightL2Regularization).dataFormat(RNNFormat.NWC); Integer nIn = KerasLayerUtils.getNInFromInputDim(layerConfig, conf); if(nIn != null) builder.setNIn(nIn); @@ -266,7 +265,8 @@ public class KerasLSTM extends KerasLayer { throw new InvalidKerasConfigurationException("Keras LSTM layer accepts only one single input" + "or three (input to LSTM and two states tensors, but " + "received " + inputType.length + "."); - return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], layerName); + RNNFormat f = TimeSeriesUtils.getFormatFromRnnLayer(layer); + return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], f,layerName); } /** diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnn.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnn.java index f6ecbb6a5..3e44bd1f9 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnn.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnn.java @@ -21,7 +21,9 @@ import lombok.EqualsAndHashCode; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.conf.InputPreProcessor; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer; import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; import org.deeplearning4j.nn.conf.layers.InputTypeUtil; import org.deeplearning4j.nn.conf.layers.Layer; @@ -36,6 +38,7 @@ import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils; import org.deeplearning4j.nn.params.SimpleRnnParamInitializer; import org.deeplearning4j.nn.weights.IWeightInit; +import org.deeplearning4j.util.TimeSeriesUtils; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.primitives.Pair; @@ -155,7 +158,7 @@ public class KerasSimpleRnn extends KerasLayer { .weightInitRecurrent(recurrentInit) .biasInit(0.0) .l1(this.weightL1Regularization) - .l2(this.weightL2Regularization); + .l2(this.weightL2Regularization).dataFormat(RNNFormat.NWC); Integer nIn = KerasLayerUtils.getNInFromInputDim(layerConfig, conf); if(nIn != null) builder.setNIn(nIn); @@ -227,7 +230,8 @@ public class KerasSimpleRnn extends KerasLayer { throw new InvalidKerasConfigurationException( "Keras SimpleRnn layer accepts only one input (received " + inputType.length + ")"); - return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], layerName); + RNNFormat f = TimeSeriesUtils.getFormatFromRnnLayer(layer); + return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], f, layerName); } /** diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/wrappers/KerasBidirectional.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/wrappers/KerasBidirectional.java index 3b7cb1721..faa271987 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/wrappers/KerasBidirectional.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/wrappers/KerasBidirectional.java @@ -147,7 +147,7 @@ public class KerasBidirectional extends KerasLayer { break; case "SimpleRNN": kerasRnnlayer = new KerasSimpleRnn(innerRnnConfig, enforceTrainingConfig, previousLayers); - SimpleRnn rnnLayer = (SimpleRnn) ((KerasSimpleRnn) kerasRnnlayer).getSimpleRnnLayer(); + Layer rnnLayer = ((KerasSimpleRnn) kerasRnnlayer).getSimpleRnnLayer(); this.layer = new Bidirectional(mode, rnnLayer); layer.setLayerName(layerName); break; @@ -218,7 +218,7 @@ public class KerasBidirectional extends KerasLayer { if (inputType.length > 1) throw new InvalidKerasConfigurationException( "Keras Bidirectional layer accepts only one input (received " + inputType.length + ")"); - return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], layerName); + return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType[0], ((Bidirectional)layer).getRNNDataFormat(), layerName); } /** diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/ReshapePreprocessor.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/ReshapePreprocessor.java index afc9392a5..e6819a3bf 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/ReshapePreprocessor.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/ReshapePreprocessor.java @@ -21,6 +21,9 @@ import lombok.Data; import lombok.EqualsAndHashCode; import lombok.extern.slf4j.Slf4j; import lombok.val; +import org.deeplearning4j.nn.conf.CNN2DFormat; +import org.deeplearning4j.nn.conf.DataFormat; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException; import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor; @@ -54,25 +57,30 @@ public class ReshapePreprocessor extends BaseInputPreProcessor { private final long[] inputShape; private final long[] targetShape; private boolean hasMiniBatchDimension; - - /** - * @deprecated Use constructor {@link #ReshapePreprocessor(long[], long[], boolean)} - */ - @Deprecated - public ReshapePreprocessor(long[] inputShape, long[] targetShape) { - this(inputShape, targetShape, false); - } + private DataFormat format; /** * @param inputShape Input shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension * @param targetShape Target shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension * @param hasMiniBatchDimension If true: shapes should be of the form [minibatch, x, y, ...]; if false: shapes should be of form [x, y, ...] */ + public ReshapePreprocessor(long[] inputShape, long[] targetShape, boolean hasMiniBatchDimension) { + this(inputShape, targetShape, hasMiniBatchDimension, null); + } + + /** + * @param inputShape Input shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension + * @param targetShape Target shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension + * @param hasMiniBatchDimension If true: shapes should be of the form [minibatch, x, y, ...]; if false: shapes should be of form [x, y, ...] + * @param dataFormat May be null. If non-null: + */ public ReshapePreprocessor(@JsonProperty("inputShape") long[] inputShape, @JsonProperty("targetShape") long[] targetShape, - @JsonProperty("hasMiniBatchDimension") boolean hasMiniBatchDimension) { + @JsonProperty("hasMiniBatchDimension") boolean hasMiniBatchDimension, + @JsonProperty("dataFormat") DataFormat dataFormat) { this.inputShape = inputShape; this.targetShape = targetShape; this.hasMiniBatchDimension = hasMiniBatchDimension; + this.format = dataFormat; } private long[] getShape(long[] originalShape, long minibatch) { @@ -140,13 +148,26 @@ public class ReshapePreprocessor extends BaseInputPreProcessor { ret = InputType.feedForward(shape[1]); break; case 3: - ret = InputType.recurrent(shape[2], shape[1]); + RNNFormat format = RNNFormat.NCW; + if(this.format != null && this.format instanceof RNNFormat) + format = (RNNFormat)this.format; + + ret = InputType.recurrent(shape[2], shape[1], format); break; case 4: if (inputShape.length == 1 || inputType.getType() == InputType.Type.RNN) { ret = InputType.convolutional(shape[1], shape[2], shape[3]); } else { - ret = InputType.convolutional(shape[2], shape[3], shape[1]); + + CNN2DFormat cnnFormat = CNN2DFormat.NCHW; + if (this.format != null && this.format instanceof CNN2DFormat) + cnnFormat = (CNN2DFormat) this.format; + + if (cnnFormat == CNN2DFormat.NCHW) { + ret = InputType.convolutional(shape[2], shape[3], shape[1], cnnFormat); + } else { + ret = InputType.convolutional(shape[1], shape[2], shape[3], cnnFormat); + } } break; default: diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/TensorFlowCnnToFeedForwardPreProcessor.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/TensorFlowCnnToFeedForwardPreProcessor.java index db7d2e990..2e1ac2e51 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/TensorFlowCnnToFeedForwardPreProcessor.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/TensorFlowCnnToFeedForwardPreProcessor.java @@ -27,26 +27,25 @@ import org.nd4j.shade.jackson.annotation.JsonCreator; import org.nd4j.shade.jackson.annotation.JsonProperty; /** - * Specialized CnnToFeedForwardInputPreProcessor for use with - * Convolutional layers imported from Keras using the TensorFlow - * backend. - * - * @author dave@skymind.io + * @deprecated Exists only for backward compatibility of older pretrained models. Should not be used. + * Use {@link CnnToFeedForwardPreProcessor} for all new models instead. */ -@Slf4j +@Slf4j @Deprecated public class TensorFlowCnnToFeedForwardPreProcessor extends CnnToFeedForwardPreProcessor { - @JsonCreator + @JsonCreator @Deprecated public TensorFlowCnnToFeedForwardPreProcessor(@JsonProperty("inputHeight") long inputHeight, @JsonProperty("inputWidth") long inputWidth, @JsonProperty("numChannels") long numChannels) { super(inputHeight, inputWidth, numChannels); } + @Deprecated public TensorFlowCnnToFeedForwardPreProcessor(long inputHeight, long inputWidth) { super(inputHeight, inputWidth); } + @Deprecated public TensorFlowCnnToFeedForwardPreProcessor() { super(); } @@ -81,4 +80,4 @@ public class TensorFlowCnnToFeedForwardPreProcessor extends CnnToFeedForwardPreP public TensorFlowCnnToFeedForwardPreProcessor clone() { return (TensorFlowCnnToFeedForwardPreProcessor) super.clone(); } -} +} \ No newline at end of file diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLayerUtils.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLayerUtils.java index 3f69cb7d4..a0578a4df 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLayerUtils.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLayerUtils.java @@ -31,6 +31,7 @@ import org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations.*; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.*; import org.deeplearning4j.nn.modelimport.keras.layers.core.*; import org.deeplearning4j.nn.modelimport.keras.layers.embeddings.KerasEmbedding; +import org.deeplearning4j.nn.modelimport.keras.layers.local.KerasLocallyConnected1D; import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasAlphaDropout; import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasGaussianDropout; import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasGaussianNoise; @@ -319,6 +320,8 @@ public class KerasLayerUtils { layer = new KerasELU(layerConfig, enforceTrainingConfig); } else if(layerClassName.equals(conf.getLAYER_CLASS_NAME_SOFTMAX())){ layer = new KerasSoftmax(layerConfig, enforceTrainingConfig); + } else if (layerClassName.equals(conf.getLAYER_CLASS_NAME_LOCALLY_CONNECTED_1D())){ + layer = new KerasLocallyConnected1D(layerConfig, enforceTrainingConfig); } else if (conf instanceof Keras2LayerConfiguration){ Keras2LayerConfiguration k2conf = (Keras2LayerConfiguration)conf; if (layerClassName.equals(k2conf.getTENSORFLOW_OP_LAYER())){ diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/TFKerasTests.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/TFKerasTests.java deleted file mode 100644 index cb74b1ed1..000000000 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/TFKerasTests.java +++ /dev/null @@ -1,50 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2020 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.nn.modelimport.keras; - -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.nn.graph.ComputationGraph; -import org.junit.Assert; -import org.junit.Test; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.resources.Resources; - -import java.io.File; -import java.util.Arrays; - -public class TFKerasTests extends BaseDL4JTest{ - - @Test - public void testModelWithTFOp1() throws Exception{ - File f = Resources.asFile("modelimport/keras/tfkeras/reshape.h5"); - ComputationGraph graph = KerasModelImport.importKerasModelAndWeights(f.getAbsolutePath()); - INDArray out = graph.outputSingle(Nd4j.zeros(12, 2, 3)); - Assert.assertArrayEquals(new long[]{12, 3}, out.shape()); - } - - @Test - public void testModelWithTFOp2() throws Exception{ - File f = Resources.asFile("modelimport/keras/tfkeras/permute.h5"); - ComputationGraph graph = KerasModelImport.importKerasModelAndWeights(f.getAbsolutePath()); - INDArray out = graph.outputSingle(Nd4j.zeros(12, 2, 3)); - // dl4j's feedforward doesn't support 3D output, so batch and time axes gets squashed - long[] expectedShape = new long[]{12 * 2, 5}; - Assert.assertArrayEquals(expectedShape, out.shape()); - } - -} diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/TestTFKerasModelImport.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/TestTFKerasModelImport.java new file mode 100644 index 000000000..455d7ca56 --- /dev/null +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/TestTFKerasModelImport.java @@ -0,0 +1,147 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.nn.modelimport.keras; + +import org.apache.commons.io.FileUtils; +import org.datavec.python.keras.Model; +import org.deeplearning4j.BaseDL4JTest; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.nd4j.common.tests.ResourceUtils; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.concurrency.AffinityManager; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.resources.Resources; + +import java.io.File; +import java.util.List; + + +@RunWith(Parameterized.class) +public class TestTFKerasModelImport extends BaseDL4JTest{ + + @Rule + public TemporaryFolder testDir = new TemporaryFolder(); + + private String modelFile; + + @Override + public long getTimeoutMilliseconds(){ + return 300000; + } // installing TF will take a while + + + @Parameterized.Parameters(name = "file={0}") + public static Object[] params() throws Exception { + List paths = ResourceUtils.listClassPathFiles("modelimport/keras/tfkeras", true, false); + return paths.toArray(new String[0]); + } + + public TestTFKerasModelImport(String modelFile){ + this.modelFile = modelFile; + } + + @Test + public void testModelImport() throws Exception{ + testModelImportWithData(modelFile); + } + + private void testModelImportWithData(String path) throws Exception{ + System.out.println(path); + // TODO multi input/output + INDArray inputArray; + INDArray expectedOutputArray; + File f = Resources.asFile(path); //May in in JAR that HDF5 can't read from + File modelFile = new File(testDir.getRoot(), f.getName()); + FileUtils.copyFile(f, modelFile); + + synchronized (Hdf5Archive.LOCK_OBJECT){ + Hdf5Archive hdf5Archive = new Hdf5Archive(modelFile.getAbsolutePath()); + List rootGroups = hdf5Archive.getGroups(); + if (rootGroups.contains("data")){ + String inputName = hdf5Archive.readAttributeAsString("input_names", "data"); + String outputName = hdf5Archive.readAttributeAsString("output_names", "data"); + inputArray = hdf5Archive.readDataSet(inputName, "data"); + expectedOutputArray = hdf5Archive.readDataSet(outputName, "data"); + } + else{ + hdf5Archive.close(); + return; + } + hdf5Archive.close(); + } + INDArray outputArray; + + ComputationGraph dl4jModel = KerasModelImport.importKerasModelAndWeights(path); + outputArray = dl4jModel.outputSingle(inputArray); + + expectedOutputArray = expectedOutputArray.castTo(DataType.FLOAT); + outputArray = outputArray.castTo(DataType.FLOAT); + if (path.contains("misc_")){ + //shape relaxation + expectedOutputArray = expectedOutputArray.reshape( -1); + outputArray = outputArray.reshape(-1); + } + + System.out.println(outputArray.toString()); + System.out.println(expectedOutputArray.toString()); + Assert.assertArrayEquals(expectedOutputArray.shape(), outputArray.shape()); + Assert.assertTrue(expectedOutputArray.equalsWithEps(outputArray, 1e-3)); + } + + private void testModelImportWithKeras(String path) throws Exception{ + Model kerasModel = new Model(path); + ComputationGraph dl4jModel = KerasModelImport.importKerasModelAndWeights(path); + Assert.assertEquals(kerasModel.numInputs(), dl4jModel.getNumInputArrays()); + Assert.assertEquals(kerasModel.numOutputs(), dl4jModel.getNumOutputArrays()); + INDArray[] kerasInputArrays = new INDArray[kerasModel.numInputs()]; + INDArray[] dl4jInputArrays = new INDArray[kerasModel.numInputs()]; + + for (int i = 0; i < kerasInputArrays.length; i ++) { + long[] shape = kerasModel.inputShapeAt(i); + for (int j = 0; j < shape.length; j++) { + if (shape[j] < 0) { + shape[j] = 1; + } + } + + kerasInputArrays[i] = Nd4j.rand(shape); + } + + INDArray[] kerasOut = kerasModel.predict(kerasInputArrays); + INDArray[] dl4jOut = dl4jModel.output(dl4jInputArrays); + + Assert.assertEquals(kerasOut.length, dl4jOut.length); + + for (int i = 0; i < kerasOut.length; i++){ + INDArray kerasOutArr = kerasOut[i]; + kerasOutArr = kerasOutArr.reshape(1, -1);// bit of relaxation on shape + kerasOutArr= kerasOutArr.castTo(DataType.DOUBLE); + Nd4j.getAffinityManager().ensureLocation(dl4jOut[i], AffinityManager.Location.HOST); + INDArray dl4jOutArr = dl4jOut[i].reshape(1, -1); + System.out.println(kerasOutArr.shapeInfoToString()); + System.out.println(dl4jOutArr.shapeInfoToString()); + Assert.assertEquals(kerasOutArr, dl4jOutArr); + } + } +} diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/JsonTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/JsonTest.java index 4aae27af3..c95a6a8bb 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/JsonTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/JsonTest.java @@ -22,7 +22,6 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.preprocessors.KerasFlattenRnnPreprocessor; import org.deeplearning4j.nn.modelimport.keras.preprocessors.PermutePreprocessor; import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor; -import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor; import org.junit.Test; import static org.junit.Assert.assertEquals; @@ -34,8 +33,7 @@ public class JsonTest extends BaseDL4JTest { InputPreProcessor[] pp = new InputPreProcessor[] { new KerasFlattenRnnPreprocessor(10, 5), new PermutePreprocessor(new int[]{0,1,2}), - new ReshapePreprocessor(new long[]{10,10}, new long[]{100,1}), - new TensorFlowCnnToFeedForwardPreProcessor() + new ReshapePreprocessor(new long[]{10,10}, new long[]{100,1}, true, null) }; for(InputPreProcessor p : pp ){ diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java index 4d4bf067e..7f566a6cb 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java @@ -29,6 +29,7 @@ import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasSpaceTo import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.junit.Ignore; import org.junit.Test; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.resources.Resources; @@ -250,7 +251,7 @@ public class Keras2ModelConfigurationTest extends BaseDL4JTest { .enforceTrainingConfig(false).buildSequential().getMultiLayerConfiguration(); MultiLayerNetwork model = new MultiLayerNetwork(config); model.init(); - INDArray input = Nd4j.create(50, 500, 1500); + INDArray input = Nd4j.create(DataType.FLOAT, 50, 1500, 500); //NWC format - [Minibatch, seqLength, channels] INDArray out = model.output(input); assertTrue(Arrays.equals(out.shape(), new long[]{50, 64})); } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasModelImportTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasModelImportTest.java index cf51831a2..5a10760d5 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasModelImportTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasModelImportTest.java @@ -69,7 +69,7 @@ public class KerasModelImportTest extends BaseDL4JTest { network = KerasModelImport.importKerasSequentialModelAndWeights(Resources.asFile(modelJsonFilename).getAbsolutePath(), Resources.asFile(modelWeightFilename).getAbsolutePath(), false); } catch (IOException | InvalidKerasConfigurationException | UnsupportedKerasConfigurationException e) { - e.printStackTrace(); + log.error("",e); } return network; @@ -80,7 +80,7 @@ public class KerasModelImportTest extends BaseDL4JTest { try { model = KerasModelImport.importKerasSequentialModelAndWeights(Resources.asFile(modelFilename).getAbsolutePath()); } catch (IOException | InvalidKerasConfigurationException | UnsupportedKerasConfigurationException e) { - e.printStackTrace(); + log.error("",e); } return model; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java index 3e1efa365..f705f9f86 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java @@ -87,15 +87,6 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { @Rule public final TemporaryFolder testDir = new TemporaryFolder(); - public static final BiFunction nwc2ncwExpected = new BiFunction() { - @Override - public INDArray apply(String s, INDArray array) { - if(array.rank() == 3) - return array.permute(0, 2, 1); //NWC to NCW - return array; - } - }; - @Override public long getTimeoutMilliseconds() { return 180000L; //Most benchmarks should run very quickly; large timeout is to avoid issues with unusually slow download of test resources @@ -169,28 +160,28 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { public void importImdbLstmTfKeras1() throws Exception { String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_model.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_1_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected); + importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, null); } @Test public void importImdbLstmThKeras1() throws Exception { String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_model.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_1_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected); + importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, null); } @Test public void importImdbLstmTfKeras2() throws Exception { String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_model.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_tf_keras_2_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected); + importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, null); } @Test public void importImdbLstmThKeras2() throws Exception { String modelPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_model.h5"; String inputsOutputPath = "modelimport/keras/examples/imdb_lstm/imdb_lstm_th_keras_2_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, false, true, false, false, true, null, nwc2ncwExpected); + importEndModelTest(modelPath, inputsOutputPath, false, true, false, false, true, null, null); } /** @@ -262,7 +253,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { String modelPath = "modelimport/keras/examples/simple_flatten_rnn/simple_flatten_rnn_tf_keras_2_model.h5"; String inputsOutputPath = "modelimport/keras/examples/simple_flatten_rnn/" + "simple_flatten_rnn_tf_keras_2_inputs_and_outputs.h5"; - importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, nwc2ncwExpected); + importEndModelTest(modelPath, inputsOutputPath, true, true, false, false, true, null, null); } /** @@ -316,11 +307,11 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { @Test public void importAcganDiscriminator() throws Exception { ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/acgan/acgan_discriminator_1_epochs.h5"); - INDArray input = Nd4j.create(10, 1, 28, 28); + INDArray input = Nd4j.create(10, 28, 28, 1); //NHWC INDArray[] output = model.output(input); } - @Test + @Test @Ignore //AB 2020/04/22 Ignored until Keras model import updated to use NHWC support public void importAcganGenerator() throws Exception { ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/acgan/acgan_generator_1_epochs.h5"); //System.out.println(model.summary()) ; @@ -403,7 +394,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { // Make predictions int miniBatch = 32; - INDArray input = Nd4j.ones(miniBatch, 4, 10); + INDArray input = Nd4j.ones(miniBatch, 10, 4); //NWC format - with nIn=4, seqLength = 10 INDArray[] out = graph.output(input); // Fit model @@ -450,7 +441,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { @Test public void importMobileNet() throws Exception { ComputationGraph graph = importFunctionalModelH5Test("modelimport/keras/examples/mobilenet/alternative.hdf5"); - INDArray input = Nd4j.ones(10, 3, 299, 299); + INDArray input = Nd4j.ones(10, 299, 299, 3); graph.output(input); } @@ -462,7 +453,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { int[] inputShape = new int[]{299, 299, 3}; ComputationGraph graph = importFunctionalModelH5Test( "modelimport/keras/examples/inception/inception_tf_keras_2.h5", inputShape, false); - INDArray input = Nd4j.ones(10, 3, 299, 299); + INDArray input = Nd4j.ones(10, 299, 299, 3); //TF = channels last = NHWC graph.output(input); System.out.println(graph.summary()); } @@ -476,7 +467,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { public void importInception() throws Exception { ComputationGraph graph = importFunctionalModelH5Test( "modelimport/keras/examples/inception/inception_v3_complete.h5"); - INDArray input = Nd4j.ones(10, 3, 299, 299); + INDArray input = Nd4j.ones(10, 299, 299, 3); //TF = channels last = NHWC graph.output(input); System.out.println(graph.summary()); } @@ -533,14 +524,14 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { * - Separate (policy and value) residual architecture * - Separate (policy and value) convolutional architecture */ - @Test + @Test @Ignore //AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last public void importSepConvPolicy() throws Exception { ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/sep_conv_policy.h5"); INDArray input = Nd4j.create(32, 19, 19, 10); model.output(input); } - @Test + @Test @Ignore //AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last public void importSepResPolicy() throws Exception { ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/sep_res_policy.h5"); INDArray input = Nd4j.create(32, 19, 19, 10); @@ -548,28 +539,28 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { } - @Test + @Test @Ignore //AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last public void importSepConvValue() throws Exception { ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/sep_conv_value.h5"); INDArray input = Nd4j.create(32, 19, 19, 10); model.output(input); } - @Test + @Test @Ignore //AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last public void importSepResValue() throws Exception { ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/sep_res_value.h5"); INDArray input = Nd4j.create(32, 19, 19, 10); model.output(input); } - @Test + @Test @Ignore //AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last public void importDualRes() throws Exception { ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/dual_res.h5"); INDArray input = Nd4j.create(32, 19, 19, 10); model.output(input); } - @Test + @Test @Ignore //AB 20200427 Bad keras model - Keras JSON has input shape [null, 10, 19, 19] (i.e., NCHW) but all layers are set to channels_last public void importDualConv() throws Exception { ComputationGraph model = importFunctionalModelH5Test("modelimport/keras/examples/agz/dual_conv.h5"); INDArray input = Nd4j.create(32, 19, 19, 10); @@ -634,16 +625,9 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { System.out.println("Starting test: " + name); String modelPath = "modelimport/keras/examples/causal_conv1d/" + name; String inputsOutputPath = "modelimport/keras/examples/causal_conv1d/" + (name.substring(0,name.length()-"model.h5".length()) + "inputs_and_outputs.h5"); - Function f = new Function() { - @Override - public INDArray apply(INDArray i) { - //NWC to NCW - return i.permute(0, 2, 1); - } - }; MultiLayerNetwork net = importEndModelTest(modelPath, inputsOutputPath, true, true, - true, true, false, f, nwc2ncwExpected); + true, true, false, null, null); Layer l = net.getLayer(0); Convolution1DLayer c1d = (Convolution1DLayer) l.getConfig(); assertEquals(ConvolutionMode.Causal, c1d.getConvolutionMode()); @@ -707,25 +691,9 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { System.out.println("Starting test: " + name); String modelPath = "modelimport/keras/examples/conv1d/" + name; String inputsOutputPath = "modelimport/keras/examples/conv1d/" + (name.substring(0,name.length()-"model.h5".length()) + "inputs_and_outputs.h5"); - Function f = name.contains("_cf_") ? null : new Function() { - @Override - public INDArray apply(INDArray i) { - //NWC to NCW - return i.permute(0, 2, 1); - } - }; - - BiFunction f2 = name.contains("_cf_") ? null : new BiFunction() { - @Override - public INDArray apply(String s, INDArray array) { -// if("conv".equals(s)){ - return array.permute(0, 2, 1); -// } - } - }; importEndModelTest(modelPath, inputsOutputPath, true, true, - true, true, false, f, f2); + true, true, false, null, null); //f, f2); } } @@ -882,8 +850,6 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { INDArray[] inputs = new INDArray[inputNames.size()]; for (int i = 0; i < inputNames.size(); i++) { inputs[i] = archive.readDataSet(inputNames.get(i), GROUP_ATTR_INPUTS); - if (inputs[i].shape().length == 4 && tensorFlowImageDimOrdering) - inputs[i] = inputs[i].permute(0, 3, 1, 2); } return inputs; } @@ -893,8 +859,6 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { Map activations = new HashMap(); for (String layerName : archive.getDataSets(GROUP_ACTIVATIONS)) { INDArray activation = archive.readDataSet(layerName, GROUP_ACTIVATIONS); - if (activation.shape().length == 4 && tensorFlowImageDimOrdering) - activation = activation.permute(0, 3, 1, 2); activations.put(layerName, activation); } return activations; @@ -907,8 +871,6 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { INDArray[] outputs = new INDArray[outputNames.size()]; for (int i = 0; i < outputNames.size(); i++) { outputs[i] = archive.readDataSet(outputNames.get(i), GROUP_ATTR_OUTPUTS); - if (outputs[i].shape().length == 4 && tensorFlowImageDimOrdering) - outputs[i] = outputs[i].permute(0, 3, 1, 2); } return outputs; } @@ -920,8 +882,6 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { INDArray[] predictions = new INDArray[outputNames.size()]; for (int i = 0; i < outputNames.size(); i++) { predictions[i] = archive.readDataSet(outputNames.get(i), GROUP_PREDICTIONS); - if (predictions[i].shape().length == 4 && tensorFlowImageDimOrdering) - predictions[i] = predictions[i].permute(0, 3, 1, 2); } return predictions; } @@ -941,6 +901,11 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { // skip too small absolute inputs if (Math.abs(aAbsMax) > threshold && Math.abs(bAbsMax) > threshold) { boolean eq = expected.equalsWithEps(actual.castTo(expected.dataType()), eps); + if(!eq){ + System.out.println("Expected: " + Arrays.toString(expected.shape()) + ", actual: " + Arrays.toString(actual.shape())); + System.out.println("Expected:\n" + expected); + System.out.println("Actual: \n" + actual); + } assertTrue("Output differs: " + label, eq); } } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/weights/KerasWeightSettingTests.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/weights/KerasWeightSettingTests.java index 75334bcd0..7204ebb82 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/weights/KerasWeightSettingTests.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/weights/KerasWeightSettingTests.java @@ -176,10 +176,10 @@ public class KerasWeightSettingTests extends BaseDL4JTest { INDArray bias = model.getLayer(0).getParam("b"); assertEquals(6, bias.length()); - INDArray input = Nd4j.ones(1, 5, 3, 4); + INDArray input = Nd4j.ones(1, 3, 4, 5); //NHWC INDArray output = model.output(input); - assertArrayEquals(new long[] {1, 6, 1, 2}, output.shape()); + assertArrayEquals(new long[] {1, 1, 2, 6}, output.shape()); //NHWC logSuccess(modelPath); } @@ -224,7 +224,7 @@ public class KerasWeightSettingTests extends BaseDL4JTest { INDArray input = Nd4j.zeros(mb, inputLength); INDArray output = model.output(input); - assertArrayEquals(new long[]{mb, nOut, inputLength - kernel + 1}, output.shape()); + assertArrayEquals(new long[]{mb, inputLength - kernel + 1, nOut}, output.shape()); //NWC logSuccess(modelPath); } @@ -238,9 +238,9 @@ public class KerasWeightSettingTests extends BaseDL4JTest { KerasLayer.registerCustomLayer("Lambda", KerasSpaceToDepth.class); MultiLayerNetwork model = loadMultiLayerNetwork(modelPath, false); - INDArray input = Nd4j.zeros(10, 4, 6, 6); + INDArray input = Nd4j.zeros(10, 6, 6, 4); INDArray output = model.output(input); - assertArrayEquals(new long[]{10, 16, 3, 3}, output.shape()); + assertArrayEquals(new long[]{10, 3, 3, 16}, output.shape()); logSuccess(modelPath); } @@ -248,10 +248,11 @@ public class KerasWeightSettingTests extends BaseDL4JTest { KerasLayer.registerCustomLayer("Lambda", KerasSpaceToDepth.class); ComputationGraph model = loadComputationalGraph(modelPath, false); - INDArray input[] = new INDArray[]{Nd4j.zeros(10, 4, 6, 6), Nd4j.zeros(10, 16, 3, 3)}; +// INDArray input[] = new INDArray[]{Nd4j.zeros(10, 4, 6, 6), Nd4j.zeros(10, 16, 3, 3)}; + INDArray input[] = new INDArray[]{Nd4j.zeros(10, 6, 6, 4), Nd4j.zeros(10, 3, 3, 16)}; INDArray[] output = model.output(input); log.info(Arrays.toString(output[0].shape())); - assertArrayEquals(new long[]{10, 32, 3, 3}, output[0].shape()); + assertArrayEquals(new long[]{10, 3, 3, 32}, output[0].shape()); logSuccess(modelPath); } @@ -278,7 +279,7 @@ public class KerasWeightSettingTests extends BaseDL4JTest { INDArray inEmbedding = Nd4j.zeros(mb, inputLength); INDArray output = model.output(inEmbedding); - assertArrayEquals(new long[]{mb, nOut, inputLength}, output.shape()); + assertArrayEquals(new long[]{mb, inputLength, nOut}, output.shape()); //NWC format logSuccess(modelPath); } @@ -304,7 +305,7 @@ public class KerasWeightSettingTests extends BaseDL4JTest { INDArray inEmbedding = Nd4j.zeros(mb, inputLength); INDArray output = model.output(inEmbedding); - assertArrayEquals(new long[]{mb, nOut, inputLength - kernel + 1}, output.shape()); + assertArrayEquals(new long[]{mb, inputLength - kernel + 1, nOut}, output.shape()); //NWC logSuccess(modelPath); } diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborsServer.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborsServer.java index 6610e75f9..33c651dee 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborsServer.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborsServer.java @@ -210,7 +210,6 @@ public class NearestNeighborsServer extends AbstractVerticle { return; } catch (Throwable e) { log.error("Error in POST /knn",e); - e.printStackTrace(); rc.response().setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code()) .end("Error parsing request - " + e.getMessage()); return; @@ -265,7 +264,6 @@ public class NearestNeighborsServer extends AbstractVerticle { .end(j); } catch (Throwable e) { log.error("Error in POST /knnnew",e); - e.printStackTrace(); rc.response().setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code()) .end("Error parsing request - " + e.getMessage()); return; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/src/main/java/org/ansj/recognition/impl/StopRecognition.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/src/main/java/org/ansj/recognition/impl/StopRecognition.java index bbdd2597e..4449155af 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/src/main/java/org/ansj/recognition/impl/StopRecognition.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/src/main/java/org/ansj/recognition/impl/StopRecognition.java @@ -75,7 +75,6 @@ public class StopRecognition implements Recognition { try { regexList.add(Pattern.compile(regex)); } catch (Exception e) { - e.printStackTrace(); LOG.error("regex err : " + regex, e); } } diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/main/java/com/atilika/kuromoji/compile/DictionaryCompilerBase.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/main/java/com/atilika/kuromoji/compile/DictionaryCompilerBase.java index b090c7d07..e5f1c3379 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/main/java/com/atilika/kuromoji/compile/DictionaryCompilerBase.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/main/java/com/atilika/kuromoji/compile/DictionaryCompilerBase.java @@ -20,10 +20,12 @@ import com.atilika.kuromoji.dict.CharacterDefinitions; import com.atilika.kuromoji.dict.ConnectionCosts; import com.atilika.kuromoji.dict.UnknownDictionary; import com.atilika.kuromoji.trie.DoubleArrayTrie; +import lombok.extern.slf4j.Slf4j; import java.io.*; import java.util.List; +@Slf4j public abstract class DictionaryCompilerBase { public void build(String inputDirname, String outputDirname, String encoding, boolean compactTries) @@ -66,7 +68,7 @@ public abstract class DictionaryCompilerBase { } } } catch (Exception e) { - e.printStackTrace(); + log.error("",e); } ProgressLog.end(); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/main/java/com/atilika/kuromoji/util/FileResourceResolver.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/main/java/com/atilika/kuromoji/util/FileResourceResolver.java index 8338f42a0..2e6377ec5 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/main/java/com/atilika/kuromoji/util/FileResourceResolver.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/main/java/com/atilika/kuromoji/util/FileResourceResolver.java @@ -33,7 +33,6 @@ public class FileResourceResolver implements ResourceResolver { try { KuromojiBinFilesFetcher.downloadAndUntar(); } catch (IOException e) { - e.printStackTrace(); log.error("IOException : ", e); } } diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/pom.xml b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/pom.xml index 7aa6090e1..e5ee63ea0 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/pom.xml +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/pom.xml @@ -95,6 +95,13 @@ deeplearning4j-common-tests ${project.version} test + + + + org.springframework + spring-core + + diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/main/java/org/deeplearning4j/text/corpora/sentiwordnet/SWN3.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/main/java/org/deeplearning4j/text/corpora/sentiwordnet/SWN3.java index 03d2462a5..f82838ea5 100755 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/main/java/org/deeplearning4j/text/corpora/sentiwordnet/SWN3.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/main/java/org/deeplearning4j/text/corpora/sentiwordnet/SWN3.java @@ -16,6 +16,7 @@ package org.deeplearning4j.text.corpora.sentiwordnet; +import lombok.extern.slf4j.Slf4j; import org.nd4j.shade.guava.collect.Sets; import org.apache.uima.analysis_engine.AnalysisEngine; import org.apache.uima.cas.CAS; @@ -37,6 +38,7 @@ import java.util.*; * @author Adam Gibson * */ +@Slf4j public class SWN3 implements Serializable { /** * @@ -120,7 +122,7 @@ public class SWN3 implements Serializable { try { csv.close(); } catch (IOException e) { - e.printStackTrace(); + log.error("",e); } } } diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/main/java/org/deeplearning4j/text/tokenization/tokenizer/UimaTokenizer.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/main/java/org/deeplearning4j/text/tokenization/tokenizer/UimaTokenizer.java index 830165ec0..b1068e0ab 100755 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/main/java/org/deeplearning4j/text/tokenization/tokenizer/UimaTokenizer.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/main/java/org/deeplearning4j/text/tokenization/tokenizer/UimaTokenizer.java @@ -16,6 +16,7 @@ package org.deeplearning4j.text.tokenization.tokenizer; +import lombok.extern.slf4j.Slf4j; import org.apache.uima.cas.CAS; import org.apache.uima.fit.util.JCasUtil; import org.cleartk.token.type.Token; @@ -32,6 +33,7 @@ import java.util.List; * @author Adam Gibson * */ +@Slf4j public class UimaTokenizer implements Tokenizer { private List tokens; @@ -66,7 +68,7 @@ public class UimaTokenizer implements Tokenizer { } catch (Exception e) { - e.printStackTrace(); + log.error("",e); throw new RuntimeException(e); } diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTests.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTests.java index 7dcfb160a..f205a0994 100755 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTests.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTests.java @@ -16,6 +16,7 @@ package org.deeplearning4j.models.word2vec; +import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.IOUtils; import org.apache.commons.io.LineIterator; import org.deeplearning4j.text.sentenceiterator.CollectionSentenceIterator; @@ -64,6 +65,7 @@ import static org.junit.Assert.*; /** * @author jeffreytang */ +@Slf4j public class Word2VecTests extends BaseDL4JTest { private static final Logger log = LoggerFactory.getLogger(Word2VecTests.class); @@ -621,7 +623,7 @@ public class Word2VecTests extends BaseDL4JTest { unserialized = Word2Vec.fromJson(json); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); fail(); } diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/graph/walkers/impl/NearestVertexWalker.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/graph/walkers/impl/NearestVertexWalker.java index a84cd8d4c..fe7f117af 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/graph/walkers/impl/NearestVertexWalker.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/graph/walkers/impl/NearestVertexWalker.java @@ -70,7 +70,7 @@ public class NearestVertexWalker implements GraphWalk public void reset(boolean shuffle) { position.set(0); if (shuffle) { - log.debug("Calling shuffle() on entries..."); + log.trace("Calling shuffle() on entries..."); // https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm for (int i = order.length - 1; i > 0; i--) { int j = rng.nextInt(i + 1); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/graph/walkers/impl/RandomWalker.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/graph/walkers/impl/RandomWalker.java index b422a52d1..a033a4a81 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/graph/walkers/impl/RandomWalker.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/graph/walkers/impl/RandomWalker.java @@ -249,7 +249,7 @@ public class RandomWalker implements GraphWalker { public void reset(boolean shuffle) { this.position.set(0); if (shuffle) { - logger.debug("Calling shuffle() on entries..."); + logger.trace("Calling shuffle() on entries..."); // https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm for (int i = order.length - 1; i > 0; i--) { int j = rng.nextInt(i + 1); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/listeners/SerializingListener.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/listeners/SerializingListener.java index 8d49720e3..2afbde2b7 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/listeners/SerializingListener.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/listeners/SerializingListener.java @@ -17,6 +17,7 @@ package org.deeplearning4j.models.sequencevectors.listeners; import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.models.sequencevectors.SequenceVectors; import org.deeplearning4j.models.sequencevectors.enums.ListenerEvent; import org.deeplearning4j.models.sequencevectors.interfaces.VectorsListener; @@ -34,6 +35,7 @@ import java.util.concurrent.Semaphore; * * @author raver119@gmail.com */ +@Slf4j public class SerializingListener implements VectorsListener { private File targetFolder = new File("./"); private String modelPrefix = "Model_"; @@ -96,7 +98,7 @@ public class SerializingListener implements VectorsLi } } catch (Exception e) { - e.printStackTrace(); + log.error("",e); } finally { locker.release(); } diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/transformers/impl/iterables/ParallelTransformerIterator.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/transformers/impl/iterables/ParallelTransformerIterator.java index ac974d811..01c0bb9e7 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/transformers/impl/iterables/ParallelTransformerIterator.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/transformers/impl/iterables/ParallelTransformerIterator.java @@ -147,7 +147,7 @@ public class ParallelTransformerIterator extends BasicTransformerIterator { try { buffer.put(futureSequence); } catch (InterruptedException e) { - e.printStackTrace(); + log.error("",e); } } /* else diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/documentiterator/AsyncLabelAwareIterator.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/documentiterator/AsyncLabelAwareIterator.java index cf9337b5b..dd87f9b1f 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/documentiterator/AsyncLabelAwareIterator.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/documentiterator/AsyncLabelAwareIterator.java @@ -21,6 +21,7 @@ import lombok.NonNull; import org.deeplearning4j.parallelism.AsyncIterator; import java.util.Iterator; +import java.util.NoSuchElementException; /** * @author raver119@gmail.com @@ -77,7 +78,7 @@ public class AsyncLabelAwareIterator implements LabelAwareIterator, Iterator { private BufferedReader reader; @@ -113,7 +115,7 @@ public class BasicLineIterator implements SentenceIterator, Iterable { reader.close(); } catch (Exception e) { // do nothing here - e.printStackTrace(); + log.error("",e); } super.finalize(); } diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/StreamLineIterator.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/StreamLineIterator.java index 737e43b26..156e88958 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/StreamLineIterator.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/StreamLineIterator.java @@ -17,6 +17,7 @@ package org.deeplearning4j.text.sentenceiterator; import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.text.documentiterator.DocumentIterator; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -35,6 +36,7 @@ import java.util.concurrent.atomic.AtomicBoolean; * * @author raver119@gmail.com */ +@Slf4j public class StreamLineIterator implements SentenceIterator { private DocumentIterator iterator; private int linesToFetch; @@ -64,7 +66,7 @@ public class StreamLineIterator implements SentenceIterator { currentReader = null; } } catch (IOException e) { - e.printStackTrace(); + log.error("",e); throw new RuntimeException(e); } } @@ -145,7 +147,7 @@ public class StreamLineIterator implements SentenceIterator { try { this.onlyStream.reset(); } catch (IOException e) { - e.printStackTrace(); + log.error("",e); throw new RuntimeException(e); } } diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/fasttext/FastTextTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/fasttext/FastTextTest.java index eba2bce85..2e3ec06c2 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/fasttext/FastTextTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/fasttext/FastTextTest.java @@ -183,7 +183,7 @@ public class FastTextTest extends BaseDL4JTest { fastText.loadIterator(); } catch (IOException e) { - log.error(e.toString()); + log.error("",e); } } diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectorsTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectorsTest.java index 14495ffaf..a7423984c 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectorsTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectorsTest.java @@ -1164,7 +1164,7 @@ public class ParagraphVectorsTest extends BaseDL4JTest { unserialized = ParagraphVectors.fromJson(json); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); fail(); } diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/graph/walkers/impl/RandomWalkerTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/graph/walkers/impl/RandomWalkerTest.java index 1bfe47c41..1d03ac033 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/graph/walkers/impl/RandomWalkerTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/graph/walkers/impl/RandomWalkerTest.java @@ -101,7 +101,7 @@ public class RandomWalkerTest extends BaseDL4JTest { @Test public void testGraphTraverseRandom1() throws Exception { RandomWalker walker = (RandomWalker) new RandomWalker.Builder<>(graph) - .setNoEdgeHandling(NoEdgeHandling.SELF_LOOP_ON_DISCONNECTED).setWalkLength(3).build(); + .setNoEdgeHandling(NoEdgeHandling.SELF_LOOP_ON_DISCONNECTED).setWalkLength(3).build(); int cnt = 0; while (walker.hasNext()) { @@ -123,9 +123,10 @@ public class RandomWalkerTest extends BaseDL4JTest { @Test public void testGraphTraverseRandom2() throws Exception { RandomWalker walker = (RandomWalker) new RandomWalker.Builder<>(graph) - .setNoEdgeHandling(NoEdgeHandling.EXCEPTION_ON_DISCONNECTED).setWalkLength(20) - .setWalkDirection(WalkDirection.FORWARD_UNIQUE) - .setNoEdgeHandling(NoEdgeHandling.CUTOFF_ON_DISCONNECTED).build(); + .setSeed(12345) + .setNoEdgeHandling(NoEdgeHandling.EXCEPTION_ON_DISCONNECTED).setWalkLength(20) + .setWalkDirection(WalkDirection.FORWARD_UNIQUE) + .setNoEdgeHandling(NoEdgeHandling.CUTOFF_ON_DISCONNECTED).build(); int cnt = 0; while (walker.hasNext()) { @@ -147,9 +148,9 @@ public class RandomWalkerTest extends BaseDL4JTest { @Test public void testGraphTraverseRandom3() throws Exception { RandomWalker walker = (RandomWalker) new RandomWalker.Builder<>(graph) - .setNoEdgeHandling(NoEdgeHandling.EXCEPTION_ON_DISCONNECTED).setWalkLength(20) - .setWalkDirection(WalkDirection.FORWARD_UNIQUE) - .setNoEdgeHandling(NoEdgeHandling.EXCEPTION_ON_DISCONNECTED).build(); + .setNoEdgeHandling(NoEdgeHandling.EXCEPTION_ON_DISCONNECTED).setWalkLength(20) + .setWalkDirection(WalkDirection.FORWARD_UNIQUE) + .setNoEdgeHandling(NoEdgeHandling.EXCEPTION_ON_DISCONNECTED).build(); try { while (walker.hasNext()) { @@ -169,9 +170,10 @@ public class RandomWalkerTest extends BaseDL4JTest { @Test public void testGraphTraverseRandom4() throws Exception { RandomWalker walker = (RandomWalker) new RandomWalker.Builder<>(graphBig) - .setNoEdgeHandling(NoEdgeHandling.EXCEPTION_ON_DISCONNECTED).setWalkLength(20) - .setWalkDirection(WalkDirection.FORWARD_UNIQUE) - .setNoEdgeHandling(NoEdgeHandling.CUTOFF_ON_DISCONNECTED).build(); + .setSeed(12345) + .setNoEdgeHandling(NoEdgeHandling.EXCEPTION_ON_DISCONNECTED).setWalkLength(20) + .setWalkDirection(WalkDirection.FORWARD_UNIQUE) + .setNoEdgeHandling(NoEdgeHandling.CUTOFF_ON_DISCONNECTED).build(); Sequence sequence1 = walker.next(); @@ -185,8 +187,8 @@ public class RandomWalkerTest extends BaseDL4JTest { @Test public void testGraphTraverseRandom5() throws Exception { RandomWalker walker = (RandomWalker) new RandomWalker.Builder<>(graphBig) - .setWalkLength(20).setWalkDirection(WalkDirection.FORWARD_UNIQUE) - .setNoEdgeHandling(NoEdgeHandling.CUTOFF_ON_DISCONNECTED).build(); + .setWalkLength(20).setWalkDirection(WalkDirection.FORWARD_UNIQUE) + .setNoEdgeHandling(NoEdgeHandling.CUTOFF_ON_DISCONNECTED).build(); Sequence sequence1 = walker.next(); @@ -200,8 +202,8 @@ public class RandomWalkerTest extends BaseDL4JTest { @Test public void testGraphTraverseRandom6() throws Exception { GraphWalker walker = new RandomWalker.Builder<>(graphDirected).setWalkLength(20) - .setWalkDirection(WalkDirection.FORWARD_UNIQUE) - .setNoEdgeHandling(NoEdgeHandling.CUTOFF_ON_DISCONNECTED).build(); + .setWalkDirection(WalkDirection.FORWARD_UNIQUE) + .setNoEdgeHandling(NoEdgeHandling.CUTOFF_ON_DISCONNECTED).build(); Sequence sequence = walker.next(); assertEquals("0", sequence.getElements().get(0).getLabel()); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/serialization/WordVectorSerializerTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/serialization/WordVectorSerializerTest.java index 4e70cb212..b7aff923e 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/serialization/WordVectorSerializerTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/serialization/WordVectorSerializerTest.java @@ -16,6 +16,7 @@ package org.deeplearning4j.models.sequencevectors.serialization; +import lombok.extern.slf4j.Slf4j; import lombok.val; import org.apache.commons.lang.StringUtils; import org.deeplearning4j.BaseDL4JTest; @@ -48,6 +49,7 @@ import java.util.Collections; import static org.junit.Assert.*; +@Slf4j public class WordVectorSerializerTest extends BaseDL4JTest { private AbstractCache cache; @@ -97,7 +99,7 @@ public class WordVectorSerializerTest extends BaseDL4JTest { byte[] bytesResult = baos.toByteArray(); deser = WordVectorSerializer.readSequenceVectors(new ByteArrayInputStream(bytesResult), true); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); fail(); } @@ -175,7 +177,7 @@ public class WordVectorSerializerTest extends BaseDL4JTest { byte[] bytesResult = baos.toByteArray(); deser = WordVectorSerializer.readWord2Vec(new ByteArrayInputStream(bytesResult), true); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); fail(); } @@ -223,7 +225,7 @@ public class WordVectorSerializerTest extends BaseDL4JTest { byte[] bytesResult = baos.toByteArray(); deser = WordVectorSerializer.readWord2Vec(new ByteArrayInputStream(bytesResult), true); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); fail(); } @@ -268,7 +270,7 @@ public class WordVectorSerializerTest extends BaseDL4JTest { ByteArrayOutputStream baos = new ByteArrayOutputStream(); deser = WordVectorSerializer.readLookupTable(file); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); fail(); } assertEquals(lookupTable.getVocab().totalWordOccurrences(), ((InMemoryLookupTable)deser).getVocab().totalWordOccurrences()); @@ -306,7 +308,7 @@ public class WordVectorSerializerTest extends BaseDL4JTest { ByteArrayOutputStream baos = new ByteArrayOutputStream(); deser = WordVectorSerializer.readWordVectors(new File(dir, "some.data")); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); fail(); } diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/wordstore/inmemory/AbstractCacheTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/wordstore/inmemory/AbstractCacheTest.java index 8e434644f..c2770486d 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/wordstore/inmemory/AbstractCacheTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/wordstore/inmemory/AbstractCacheTest.java @@ -140,7 +140,7 @@ public class AbstractCacheTest extends BaseDL4JTest { unserialized = AbstractCache.fromJson(json); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); fail(); } assertEquals(cache.totalWordOccurrences(),unserialized.totalWordOccurrences()); @@ -175,7 +175,7 @@ public class AbstractCacheTest extends BaseDL4JTest { unserialized = AbstractCache.fromJson(json); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); fail(); } assertEquals(cache.totalWordOccurrences(),unserialized.totalWordOccurrences()); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/resources/logback-test.xml b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/resources/logback-test.xml new file mode 100644 index 000000000..69246755b --- /dev/null +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/resources/logback-test.xml @@ -0,0 +1,50 @@ + + + + + + logs/application.log + + %logger{15} - %message%n%xException{5} + + + + + + + %logger{15} - %message%n%xException{5} + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/api/layers/RecurrentLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/api/layers/RecurrentLayer.java index 98cce2d27..0fe2a8689 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/api/layers/RecurrentLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/api/layers/RecurrentLayer.java @@ -17,6 +17,7 @@ package org.deeplearning4j.nn.api.layers; import org.deeplearning4j.nn.api.Layer; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.gradient.Gradient; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.primitives.Pair; @@ -98,4 +99,5 @@ public interface RecurrentLayer extends Layer { */ Pair tbpttBackpropGradient(INDArray epsilon, int tbpttBackLength, LayerWorkspaceMgr workspaceMgr); + } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/CNN2DFormat.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/CNN2DFormat.java new file mode 100644 index 000000000..b40d76bff --- /dev/null +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/CNN2DFormat.java @@ -0,0 +1,31 @@ +package org.deeplearning4j.nn.conf; + +/** + * CNN2DFormat defines the format of the activations (including input images) in to and out of all 2D convolution layers in + * Deeplearning4j. Default value is NCHW.
+ *
+ * NCHW = "channels first" - arrays of shape [minibatch, channels, height, width]
+ * NHWC = "channels last" - arrays of shape [minibatch, height, width, channels]
+ * + * @author Alex Black + */ +public enum CNN2DFormat implements DataFormat { + NCHW, + NHWC; + + /** + * Returns a string that explains the dimensions:
+ * NCHW -> returns "[minibatch, channels, height, width]"
+ * NHWC -> returns "[minibatch, height, width, channels]" + */ + public String dimensionNames(){ + switch (this){ + case NCHW: + return "[minibatch, channels, height, width]"; + case NHWC: + return "[minibatch, height, width, channels]"; + default: + throw new IllegalStateException("Unknown enum: " + this); //Should never happen + } + } +} diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/DataFormat.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/DataFormat.java new file mode 100644 index 000000000..bde392a58 --- /dev/null +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/DataFormat.java @@ -0,0 +1,26 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.nn.conf; + +import org.deeplearning4j.nn.conf.serde.format.DataFormatDeserializer; +import org.deeplearning4j.nn.conf.serde.format.DataFormatSerializer; +import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize; +import org.nd4j.shade.jackson.databind.annotation.JsonSerialize; + +@JsonSerialize(using = DataFormatSerializer.class) +@JsonDeserialize(using = DataFormatDeserializer.class) +public interface DataFormat { +} diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/MultiLayerConfiguration.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/MultiLayerConfiguration.java index b1ac6a968..1a61acc4d 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/MultiLayerConfiguration.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/MultiLayerConfiguration.java @@ -663,7 +663,7 @@ public class MultiLayerConfiguration implements Serializable, Cloneable { BaseRecurrentLayer brl = (BaseRecurrentLayer) firstLayer; val nIn = brl.getNIn(); if (nIn > 0) { - inputType = InputType.recurrent(nIn); + inputType = InputType.recurrent(nIn, brl.getRnnDataFormat()); } } else if (firstLayer instanceof DenseLayer || firstLayer instanceof EmbeddingLayer || firstLayer instanceof OutputLayer) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/RNNFormat.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/RNNFormat.java new file mode 100644 index 000000000..c12857178 --- /dev/null +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/RNNFormat.java @@ -0,0 +1,29 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +package org.deeplearning4j.nn.conf; + +/** + * NCW = "channels first" - arrays of shape [minibatch, channels, width]
+ * NWC = "channels last" - arrays of shape [minibatch, width, channels]
+ * "width" corresponds to sequence length and "channels" corresponds to sequence item size. + */ + +public enum RNNFormat implements DataFormat { + NCW, + NWC +} diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/MergeVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/MergeVertex.java index 77dd41c3a..726a68403 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/MergeVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/MergeVertex.java @@ -18,6 +18,8 @@ package org.deeplearning4j.nn.conf.graph; import lombok.val; +import org.deeplearning4j.nn.conf.CNN2DFormat; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException; import org.deeplearning4j.nn.conf.layers.Convolution3D; @@ -38,6 +40,8 @@ import org.nd4j.linalg.api.ndarray.INDArray; */ public class MergeVertex extends GraphVertex { + protected int mergeAxis = 1; //default value for backward compatibility (deserialization of old version JSON) - NCHW and NCW format + @Override public MergeVertex clone() { return new MergeVertex(); @@ -76,7 +80,7 @@ public class MergeVertex extends GraphVertex { @Override public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx, INDArray paramsView, boolean initializeParams, DataType networkDatatype) { - return new org.deeplearning4j.nn.graph.vertex.impl.MergeVertex(graph, name, idx, networkDatatype); + return new org.deeplearning4j.nn.graph.vertex.impl.MergeVertex(graph, name, idx, networkDatatype, mergeAxis); } @Override @@ -126,6 +130,7 @@ public class MergeVertex extends GraphVertex { //FF or RNN data inputs int size = 0; InputType.Type type = null; + RNNFormat format = null; for (int i = 0; i < vertexInputs.length; i++) { if (vertexInputs[i].getType() != first.getType()) { throw new InvalidInputTypeException( @@ -142,6 +147,8 @@ public class MergeVertex extends GraphVertex { break; case RNN: thisSize = ((InputType.InputTypeRecurrent) vertexInputs[i]).getSize(); + format = ((InputType.InputTypeRecurrent) vertexInputs[i]).getFormat(); + this.mergeAxis = format == RNNFormat.NCW ? 1 : 2; type = InputType.Type.RNN; break; default: @@ -160,7 +167,7 @@ public class MergeVertex extends GraphVertex { return InputType.feedForward(size); } else { val tsLength = ((InputType.InputTypeRecurrent) vertexInputs[0]).getTimeSeriesLength(); - return InputType.recurrent(size, tsLength); + return InputType.recurrent(size, tsLength, format); } } else { //size is unknown @@ -168,13 +175,14 @@ public class MergeVertex extends GraphVertex { return InputType.feedForward(-1); } else { val tsLength = ((InputType.InputTypeRecurrent) vertexInputs[0]).getTimeSeriesLength(); - return InputType.recurrent(-1, tsLength); + return InputType.recurrent(-1, tsLength, format); } } } else { //CNN inputs... also check that the channels, width and heights match: InputType.InputTypeConvolutional firstConv = (InputType.InputTypeConvolutional) first; + CNN2DFormat format = firstConv.getFormat(); val fd = firstConv.getChannels(); val fw = firstConv.getWidth(); @@ -206,7 +214,8 @@ public class MergeVertex extends GraphVertex { depthSum += od; } - return InputType.convolutional(fh, fw, depthSum); + this.mergeAxis = format == CNN2DFormat.NCHW ? 1 : 3; + return InputType.convolutional(fh, fw, depthSum, format); } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/inputs/InputType.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/inputs/InputType.java index 047618661..ce5e8c78f 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/inputs/InputType.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/inputs/InputType.java @@ -20,6 +20,9 @@ import lombok.Data; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.NoArgsConstructor; +import org.deeplearning4j.nn.conf.DataFormat; +import org.deeplearning4j.nn.conf.RNNFormat; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.layers.Convolution3D; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.shade.jackson.annotation.JsonIgnore; @@ -89,7 +92,11 @@ public abstract class InputType implements Serializable { * @return InputTypeFeedForward */ public static InputType feedForward(long size) { - return new InputTypeFeedForward(size); + return new InputTypeFeedForward(size, null); + } + + public static InputType feedForward(long size, DataFormat timeDistributedFormat) { + return new InputTypeFeedForward(size,timeDistributedFormat); } /** @@ -110,9 +117,16 @@ public abstract class InputType implements Serializable { * @return InputTypeRecurrent */ public static InputType recurrent(long size, long timeSeriesLength) { - return new InputTypeRecurrent(size, timeSeriesLength); + return new InputTypeRecurrent(size, timeSeriesLength, RNNFormat.NCW); } + public static InputType recurrent(long size, RNNFormat format){ + return new InputTypeRecurrent(size, format); + } + + public static InputType recurrent(long size, long timeSeriesLength, RNNFormat format){ + return new InputTypeRecurrent(size, timeSeriesLength, format); + } /** * Input type for convolutional (CNN) data, that is 4d with shape [miniBatchSize, channels, height, width]. * For CNN data that has been flattened, use {@link #convolutionalFlat(long, long, long)} @@ -123,7 +137,11 @@ public abstract class InputType implements Serializable { * @return InputTypeConvolutional */ public static InputType convolutional(long height, long width, long depth) { - return new InputTypeConvolutional(height, width, depth); + return convolutional(height, width, depth, CNN2DFormat.NCHW); + } + + public static InputType convolutional(long height, long width, long depth, CNN2DFormat format){ + return new InputTypeConvolutional(height, width, depth, format); } /** @@ -177,9 +195,11 @@ public abstract class InputType implements Serializable { @EqualsAndHashCode(callSuper = false) public static class InputTypeFeedForward extends InputType { private long size; + private DataFormat timeDistributedFormat; - public InputTypeFeedForward(@JsonProperty("size") long size) { + public InputTypeFeedForward(@JsonProperty("size") long size, @JsonProperty("timeDistributedFormat") DataFormat timeDistributedFormat) { this.size = size; + this.timeDistributedFormat = timeDistributedFormat; } @Override @@ -189,7 +209,7 @@ public abstract class InputType implements Serializable { @Override public String toString() { - return "InputTypeFeedForward(" + size + ")"; + return "InputTypeFeedForward(" + size + (timeDistributedFormat != null ? "," + timeDistributedFormat : "") + ")"; } @Override @@ -210,14 +230,23 @@ public abstract class InputType implements Serializable { public static class InputTypeRecurrent extends InputType { private long size; private long timeSeriesLength; - + private RNNFormat format = RNNFormat.NCW; public InputTypeRecurrent(long size) { this(size, -1); } + public InputTypeRecurrent(long size, long timeSeriesLength){ + this(size, timeSeriesLength, RNNFormat.NCW); + } - public InputTypeRecurrent(@JsonProperty("size") long size, @JsonProperty("timeSeriesLength") long timeSeriesLength) { + public InputTypeRecurrent(long size, RNNFormat format){ + this(size, -1, format); + } + public InputTypeRecurrent(@JsonProperty("size") long size, + @JsonProperty("timeSeriesLength") long timeSeriesLength, + @JsonProperty("format") RNNFormat format) { this.size = size; this.timeSeriesLength = timeSeriesLength; + this.format = format; } @Override @@ -228,9 +257,9 @@ public abstract class InputType implements Serializable { @Override public String toString() { if (timeSeriesLength > 0) { - return "InputTypeRecurrent(" + size + ",timeSeriesLength=" + timeSeriesLength + ")"; + return "InputTypeRecurrent(" + size + ",timeSeriesLength=" + timeSeriesLength + ",format=" + format + ")"; } else { - return "InputTypeRecurrent(" + size + ")"; + return "InputTypeRecurrent(" + size + ",format=" + format + ")"; } } @@ -245,8 +274,23 @@ public abstract class InputType implements Serializable { @Override public long[] getShape(boolean includeBatchDim) { - if(includeBatchDim) return new long[]{-1, size, timeSeriesLength}; - else return new long[]{size, timeSeriesLength}; + if (includeBatchDim){ + if (format == RNNFormat.NCW){ + return new long[]{-1, size, timeSeriesLength}; + } + else{ + return new long[]{-1, timeSeriesLength, size}; + } + + } + else{ + if (format == RNNFormat.NCW){ + return new long[]{size, timeSeriesLength}; + } + else{ + return new long[]{timeSeriesLength, size}; + } + } } } @@ -257,11 +301,19 @@ public abstract class InputType implements Serializable { private long height; private long width; private long channels; + private CNN2DFormat format = CNN2DFormat.NCHW; //Default for JSON deserialization of older configurations - public InputTypeConvolutional(@JsonProperty("height") long height, @JsonProperty("width") long width, @JsonProperty("channels") long channels) { + public InputTypeConvolutional(@JsonProperty("height") long height, @JsonProperty("width") long width, + @JsonProperty("channels") long channels, @JsonProperty("format") CNN2DFormat format) { this.height = height; this.width = width; this.channels = channels; + if(format != null) + this.format = format; + } + + public InputTypeConvolutional(long height, long width, long channels) { + this(height, width, channels, CNN2DFormat.NCHW); } /** @@ -292,7 +344,7 @@ public abstract class InputType implements Serializable { @Override public String toString() { - return "InputTypeConvolutional(h=" + height + ",w=" + width + ",c=" + channels + ")"; + return "InputTypeConvolutional(h=" + height + ",w=" + width + ",c=" + channels + "," + format + ")"; } @Override @@ -302,8 +354,13 @@ public abstract class InputType implements Serializable { @Override public long[] getShape(boolean includeBatchDim) { - if(includeBatchDim) return new long[]{-1, channels, height, width}; - else return new long[]{channels, height, width}; + if(format == CNN2DFormat.NCHW){ + if(includeBatchDim) return new long[]{-1, channels, height, width}; + else return new long[]{channels, height, width}; + } else { + if(includeBatchDim) return new long[]{-1, height, width, channels}; + else return new long[]{height, width, channels}; + } } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseRecurrentLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseRecurrentLayer.java index 07bb3d674..0b98dfad9 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseRecurrentLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseRecurrentLayer.java @@ -19,6 +19,7 @@ package org.deeplearning4j.nn.conf.layers; import lombok.*; import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.conf.InputPreProcessor; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.weights.IWeightInit; @@ -35,10 +36,12 @@ import java.util.List; public abstract class BaseRecurrentLayer extends FeedForwardLayer { protected IWeightInit weightInitFnRecurrent; + protected RNNFormat rnnDataFormat = RNNFormat.NCW; protected BaseRecurrentLayer(Builder builder) { super(builder); this.weightInitFnRecurrent = builder.weightInitFnRecurrent; + this.rnnDataFormat = builder.rnnDataFormat; } @Override @@ -51,7 +54,7 @@ public abstract class BaseRecurrentLayer extends FeedForwardLayer { InputType.InputTypeRecurrent itr = (InputType.InputTypeRecurrent) inputType; - return InputType.recurrent(nOut, itr.getTimeSeriesLength()); + return InputType.recurrent(nOut, itr.getTimeSeriesLength(), itr.getFormat()); } @Override @@ -61,15 +64,16 @@ public abstract class BaseRecurrentLayer extends FeedForwardLayer { + "\"): expect RNN input type with size > 0. Got: " + inputType); } + InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType; if (nIn <= 0 || override) { - InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType; this.nIn = r.getSize(); } + this.rnnDataFormat = r.getFormat(); } @Override public InputPreProcessor getPreProcessorForInputType(InputType inputType) { - return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, getLayerName()); + return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, rnnDataFormat,getLayerName()); } @NoArgsConstructor @@ -77,6 +81,12 @@ public abstract class BaseRecurrentLayer extends FeedForwardLayer { @Setter public static abstract class Builder> extends FeedForwardLayer.Builder { + /** + * Set the format of data expected by the RNN. NCW = [miniBatchSize, size, timeSeriesLength], + * NWC = [miniBatchSize, timeSeriesLength, size]. Defaults to NCW. + */ + protected RNNFormat rnnDataFormat = RNNFormat.NCW; + /** * Set constraints to be applied to the RNN recurrent weight parameters of this layer. Default: no * constraints.
Constraints can be used to enforce certain conditions (non-negativity of parameters, @@ -163,5 +173,10 @@ public abstract class BaseRecurrentLayer extends FeedForwardLayer { this.setWeightInitFnRecurrent(new WeightInitDistribution(dist)); return (T) this; } + + public T dataFormat(RNNFormat rnnDataFormat){ + this.rnnDataFormat = rnnDataFormat; + return (T)this; + } } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BatchNormalization.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BatchNormalization.java index f95421585..dcced3aeb 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BatchNormalization.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BatchNormalization.java @@ -20,6 +20,7 @@ import lombok.*; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.api.layers.LayerConstraint; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -60,6 +61,7 @@ public class BatchNormalization extends FeedForwardLayer { protected boolean lockGammaBeta = false; protected boolean cudnnAllowFallback = true; protected boolean useLogStd = false; //Default for deserialized models (1.0.0-beta3) and earlier: store variance as variance. Post 1.0.0-beta3: use log stdev instead + protected CNN2DFormat cnn2DFormat = CNN2DFormat.NCHW; //Default for deserialized models, 1.0.0-beta6 and earlier private BatchNormalization(Builder builder) { super(builder); @@ -71,6 +73,7 @@ public class BatchNormalization extends FeedForwardLayer { this.lockGammaBeta = builder.lockGammaBeta; this.cudnnAllowFallback = builder.cudnnAllowFallback; this.useLogStd = builder.useLogStd; + this.cnn2DFormat = builder.cnn2DFormat; initializeConstraints(builder); } @@ -138,6 +141,7 @@ public class BatchNormalization extends FeedForwardLayer { break; case CNN: nIn = ((InputType.InputTypeConvolutional) inputType).getChannels(); + cnn2DFormat = ((InputType.InputTypeConvolutional) inputType).getFormat(); break; case CNN3D: nIn = ((InputType.InputTypeConvolutional3D) inputType).getChannels(); @@ -307,6 +311,8 @@ public class BatchNormalization extends FeedForwardLayer { */ protected boolean useLogStd = true; + protected CNN2DFormat cnn2DFormat = CNN2DFormat.NCHW; //Default for deserialized models, 1.0.0-beta6 and earlier + public Builder(double decay, boolean isMinibatch) { this.setDecay(decay); this.setMinibatch(isMinibatch); @@ -329,6 +335,16 @@ public class BatchNormalization extends FeedForwardLayer { public Builder() {} + /** + * Set the input and output array data format. Defaults to NCHW format - i.e., channels first. + * See {@link CNN2DFormat} for more details + * @param format Format to use + */ + public Builder dataFormat(CNN2DFormat format){ + this.cnn2DFormat = format; + return this; + } + /** * If doing minibatch training or not. Default: true. Under most circumstances, this should be set to true. If * doing full batch training (i.e., all examples in a single DataSet object - very small data sets) then this diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CnnLossLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CnnLossLayer.java index 3bcae0357..647b187e3 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CnnLossLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CnnLossLayer.java @@ -22,6 +22,7 @@ import lombok.NoArgsConstructor; import lombok.ToString; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.ParamInitializer; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -62,10 +63,12 @@ import java.util.Map; public class CnnLossLayer extends FeedForwardLayer { protected ILossFunction lossFn; + protected CNN2DFormat format = CNN2DFormat.NCHW; private CnnLossLayer(Builder builder) { super(builder); this.lossFn = builder.lossFn; + this.format = builder.format; } @Override @@ -114,12 +117,16 @@ public class CnnLossLayer extends FeedForwardLayer { @Override public void setNIn(InputType inputType, boolean override) { - //No op + if(inputType instanceof InputType.InputTypeConvolutional){ + this.format = ((InputType.InputTypeConvolutional) inputType).getFormat(); + } } public static class Builder extends BaseOutputLayer.Builder { + protected CNN2DFormat format = CNN2DFormat.NCHW; + public Builder() { this.activationFn = Activation.IDENTITY.getActivationFunction(); } @@ -132,6 +139,11 @@ public class CnnLossLayer extends FeedForwardLayer { this.lossFn = lossFunction; } + public Builder format(CNN2DFormat format){ + this.format = format; + return this; + } + @Override @SuppressWarnings("unchecked") public Builder nIn(int nIn) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution1DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution1DLayer.java index b220ba5a6..f4d247670 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution1DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution1DLayer.java @@ -22,6 +22,7 @@ import lombok.NoArgsConstructor; import lombok.ToString; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.Convolution1DUtils; @@ -43,6 +44,7 @@ import java.util.Map; @ToString(callSuper = true) @EqualsAndHashCode(callSuper = true) public class Convolution1DLayer extends ConvolutionLayer { + private RNNFormat rnnDataFormat = RNNFormat.NCW; /* //TODO: We will eventually want to NOT subclass off of ConvolutionLayer. //Currently, we just subclass off the ConvolutionLayer and hard code the "width" dimension to 1 @@ -55,6 +57,7 @@ public class Convolution1DLayer extends ConvolutionLayer { private Convolution1DLayer(Builder builder) { super(builder); initializeConstraints(builder); + this.rnnDataFormat = builder.rnnDataFormat; } @Override @@ -91,7 +94,8 @@ public class Convolution1DLayer extends ConvolutionLayer { outLength = Convolution1DUtils.getOutputSize(inputTsLength, kernelSize[0], stride[0], padding[0], convolutionMode, dilation[0]); } - return InputType.recurrent(nOut, outLength); + + return InputType.recurrent(nOut, outLength, rnnDataFormat); } @Override @@ -101,10 +105,11 @@ public class Convolution1DLayer extends ConvolutionLayer { + "\"): expect RNN input type with size > 0. Got: " + inputType); } + InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType; if (nIn <= 0 || override) { - InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType; this.nIn = r.getSize(); } + this.rnnDataFormat = r.getFormat(); } @Override @@ -114,11 +119,13 @@ public class Convolution1DLayer extends ConvolutionLayer { + "\"): input is null"); } - return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, getLayerName()); + return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, rnnDataFormat,getLayerName()); } public static class Builder extends ConvolutionLayer.BaseConvBuilder { + private RNNFormat rnnDataFormat = RNNFormat.NCW; + public Builder() { this(0, 1, 0); this.setKernelSize((int[]) null); @@ -129,6 +136,11 @@ public class Convolution1DLayer extends ConvolutionLayer { return true; } + + public Builder rnnDataFormat(RNNFormat rnnDataFormat){ + this.rnnDataFormat = rnnDataFormat; + return this; + } /** * @param kernelSize Kernel size * @param stride Stride diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java index 9e52981e2..ebe1b8568 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java @@ -19,10 +19,7 @@ package org.deeplearning4j.nn.conf.layers; import lombok.*; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.ParamInitializer; -import org.deeplearning4j.nn.conf.CacheMode; -import org.deeplearning4j.nn.conf.ConvolutionMode; -import org.deeplearning4j.nn.conf.InputPreProcessor; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.*; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; @@ -58,6 +55,7 @@ public class ConvolutionLayer extends FeedForwardLayer { protected int[] stride; // Default is 2. Down-sample by a factor of 2 protected int[] padding; protected boolean cudnnAllowFallback = true; + protected CNN2DFormat cnn2dDataFormat = CNN2DFormat.NCHW; /** * The "PREFER_FASTEST" mode will pick the fastest algorithm for the specified parameters from the {@link FwdAlgo}, @@ -139,6 +137,9 @@ public class ConvolutionLayer extends FeedForwardLayer { this.cudnnBwdFilterAlgo = builder.cudnnBwdFilterAlgo; this.cudnnBwdDataAlgo = builder.cudnnBwdDataAlgo; this.cudnnAllowFallback = builder.cudnnAllowFallback; + if(builder instanceof Builder) { + this.cnn2dDataFormat = ((Builder)builder).dataFormat; + } initializeConstraints(builder); } @@ -191,7 +192,7 @@ public class ConvolutionLayer extends FeedForwardLayer { } return InputTypeUtil.getOutputTypeCnnLayers(inputType, kernelSize, stride, padding, dilation, convolutionMode, - nOut, layerIndex, getLayerName(), ConvolutionLayer.class); + nOut, layerIndex, getLayerName(), cnn2dDataFormat, ConvolutionLayer.class); } @Override @@ -205,6 +206,7 @@ public class ConvolutionLayer extends FeedForwardLayer { InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType; this.nIn = c.getChannels(); } + this.cnn2dDataFormat = ((InputType.InputTypeConvolutional) inputType).getFormat(); } @Override @@ -285,6 +287,8 @@ public class ConvolutionLayer extends FeedForwardLayer { super(); } + protected CNN2DFormat dataFormat = CNN2DFormat.NCHW; + @Override protected boolean allowCausal() { //Causal convolution - allowed for 1D only @@ -311,6 +315,17 @@ public class ConvolutionLayer extends FeedForwardLayer { return this; } + /** + * Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last). + * See {@link CNN2DFormat} for more details.
+ * Default: NCHW + * @param format Format for activations (in and out) + */ + public Builder dataFormat(CNN2DFormat format){ + this.dataFormat = format; + return this; + } + @Override @SuppressWarnings("unchecked") public ConvolutionLayer build() { @@ -359,6 +374,10 @@ public class ConvolutionLayer extends FeedForwardLayer { public void setDilation(int... dilation) { this.dilation = ValidationUtils.validate2NonNegative(dilation, false, "dilation"); } + + public void setDataFormat(CNN2DFormat dataFormat){ + this.dataFormat = dataFormat; + } } @Getter diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution2D.java index 11c9fdb7b..8daa947df 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution2D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution2D.java @@ -22,6 +22,7 @@ import lombok.NoArgsConstructor; import lombok.ToString; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.ParamInitializer; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -133,6 +134,13 @@ public class Deconvolution2D extends ConvolutionLayer { super(); } + private CNN2DFormat format = CNN2DFormat.NCHW; + + public Builder format(CNN2DFormat format){ + this.format = format; + return this; + } + @Override protected boolean allowCausal() { //Causal convolution - allowed for 1D only diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DepthwiseConvolution2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DepthwiseConvolution2D.java index e103cb0a0..e5e7b5436 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DepthwiseConvolution2D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DepthwiseConvolution2D.java @@ -19,6 +19,7 @@ package org.deeplearning4j.nn.conf.layers; import lombok.*; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.ParamInitializer; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.layers.convolution.DepthwiseConvolution2DLayer; @@ -47,13 +48,14 @@ import java.util.*; @EqualsAndHashCode(callSuper = true) public class DepthwiseConvolution2D extends ConvolutionLayer { - int depthMultiplier; + protected int depthMultiplier; protected DepthwiseConvolution2D(Builder builder) { super(builder); Preconditions.checkState(builder.depthMultiplier > 0, "Depth multiplier must be > 0, got %s", builder.depthMultiplier); this.depthMultiplier = builder.depthMultiplier; this.nOut = this.nIn * this.depthMultiplier; + this.cnn2dDataFormat = builder.cnn2DFormat; initializeConstraints(builder); } @@ -95,7 +97,7 @@ public class DepthwiseConvolution2D extends ConvolutionLayer { } return InputTypeUtil.getOutputTypeCnnLayers(inputType, kernelSize, stride, padding, dilation, convolutionMode, - nOut, layerIndex, getLayerName(), DepthwiseConvolution2DLayer.class); + nOut, layerIndex, getLayerName(), cnn2dDataFormat, DepthwiseConvolution2DLayer.class); } @Override @@ -105,6 +107,7 @@ public class DepthwiseConvolution2D extends ConvolutionLayer { if(nOut == 0 || override){ nOut = this.nIn * this.depthMultiplier; } + this.cnn2dDataFormat = ((InputType.InputTypeConvolutional)inputType).getFormat(); } @Getter @@ -115,7 +118,9 @@ public class DepthwiseConvolution2D extends ConvolutionLayer { * Set channels multiplier for depth-wise convolution * */ - public int depthMultiplier = 1; + protected int depthMultiplier = 1; + protected CNN2DFormat cnn2DFormat = CNN2DFormat.NCHW; + public Builder(int[] kernelSize, int[] stride, int[] padding) { super(kernelSize, stride, padding); @@ -139,6 +144,17 @@ public class DepthwiseConvolution2D extends ConvolutionLayer { return false; } + /** + * Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last). + * See {@link CNN2DFormat} for more details.
+ * Default: NCHW + * @param format Format for activations (in and out) + */ + public Builder dataFormat(CNN2DFormat format){ + this.cnn2DFormat = format; + return this; + } + /** * Set channels multiplier for depth-wise convolution * diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingSequenceLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingSequenceLayer.java index 9b7725801..1b76b6c7b 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingSequenceLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingSequenceLayer.java @@ -21,6 +21,7 @@ import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; @@ -58,12 +59,14 @@ public class EmbeddingSequenceLayer extends FeedForwardLayer { private int inputLength = 1; // By default only use one index to embed private boolean hasBias = false; private boolean inferInputLength = false; // use input length as provided by input data + private RNNFormat outputFormat = RNNFormat.NCW; //Default value for older deserialized models private EmbeddingSequenceLayer(Builder builder) { super(builder); this.hasBias = builder.hasBias; this.inputLength = builder.inputLength; this.inferInputLength = builder.inferInputLength; + this.outputFormat = builder.outputFormat; initializeConstraints(builder); } @@ -87,7 +90,7 @@ public class EmbeddingSequenceLayer extends FeedForwardLayer { throw new IllegalStateException("Invalid input for Embedding layer (layer index = " + layerIndex + ", layer name = \"" + getLayerName() + "\"): expect FF/RNN input type. Got: " + inputType); } - return InputType.recurrent(nOut, inputLength); + return InputType.recurrent(nOut, inputLength, outputFormat); } @Override @@ -167,6 +170,13 @@ public class EmbeddingSequenceLayer extends FeedForwardLayer { */ private boolean inferInputLength = true; + private RNNFormat outputFormat = RNNFormat.NCW; //Default value for older deserialized models + + public Builder outputDataFormat(RNNFormat format){ + this.outputFormat = format; + return this; + } + /** * If true: include bias parameters in the layer. False (default): no bias. * diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/FeedForwardLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/FeedForwardLayer.java index 206071e38..20d3c926a 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/FeedForwardLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/FeedForwardLayer.java @@ -17,6 +17,7 @@ package org.deeplearning4j.nn.conf.layers; import lombok.*; +import org.deeplearning4j.nn.conf.DataFormat; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.preprocessor.Cnn3DToFeedForwardPreProcessor; @@ -35,6 +36,7 @@ public abstract class FeedForwardLayer extends BaseLayer { protected long nIn; protected long nOut; + protected DataFormat timeDistributedFormat; public FeedForwardLayer(Builder builder) { super(builder); @@ -51,7 +53,7 @@ public abstract class FeedForwardLayer extends BaseLayer { + getLayerName() + "\"): expected FeedForward input type. Got: " + inputType); } - return InputType.feedForward(nOut); + return InputType.feedForward(nOut, timeDistributedFormat); } @Override @@ -71,6 +73,11 @@ public abstract class FeedForwardLayer extends BaseLayer { this.nIn = f.getFlattenedSize(); } } + + if(inputType instanceof InputType.InputTypeFeedForward){ + InputType.InputTypeFeedForward f = (InputType.InputTypeFeedForward) inputType; + this.timeDistributedFormat = f.getTimeDistributedFormat(); + } } @Override @@ -87,11 +94,11 @@ public abstract class FeedForwardLayer extends BaseLayer { return null; case RNN: //RNN -> FF - return new RnnToFeedForwardPreProcessor(); + return new RnnToFeedForwardPreProcessor(((InputType.InputTypeRecurrent)inputType).getFormat()); case CNN: //CNN -> FF InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType; - return new CnnToFeedForwardPreProcessor(c.getHeight(), c.getWidth(), c.getChannels()); + return new CnnToFeedForwardPreProcessor(c.getHeight(), c.getWidth(), c.getChannels(), c.getFormat()); case CNN3D: //CNN3D -> FF InputType.InputTypeConvolutional3D c3d = (InputType.InputTypeConvolutional3D) inputType; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GlobalPoolingLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GlobalPoolingLayer.java index 4de2d481b..d9e10e6f5 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GlobalPoolingLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GlobalPoolingLayer.java @@ -18,6 +18,7 @@ package org.deeplearning4j.nn.conf.layers; import lombok.*; import org.deeplearning4j.nn.api.ParamInitializer; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -127,7 +128,7 @@ public class GlobalPoolingLayer extends NoParamLayer { if (collapseDimensions) { return InputType.feedForward(conv.getChannels()); } else { - return InputType.convolutional(1, 1, conv.getChannels()); + return InputType.convolutional(1, 1, conv.getChannels(), conv.getFormat()); } case CNN3D: InputType.InputTypeConvolutional3D conv3d = (InputType.InputTypeConvolutional3D) inputType; @@ -150,7 +151,14 @@ public class GlobalPoolingLayer extends NoParamLayer { @Override public void setNIn(InputType inputType, boolean override) { - //Not applicable + if(inputType.getType() == InputType.Type.CNN){ + InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType; + if(c.getFormat() == CNN2DFormat.NCHW){ + poolingDimensions = new int[]{2,3}; + } else { + poolingDimensions = new int[]{1,2}; + } + } } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/InputTypeUtil.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/InputTypeUtil.java index eb78323b6..a60c3a6bc 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/InputTypeUtil.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/InputTypeUtil.java @@ -19,8 +19,10 @@ package org.deeplearning4j.nn.conf.layers; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.deeplearning4j.exception.DL4JInvalidConfigException; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.InputPreProcessor; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.preprocessor.CnnToRnnPreProcessor; import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToCnnPreProcessor; @@ -70,13 +72,13 @@ public class InputTypeUtil { if (convolutionMode == ConvolutionMode.Same) { long hOut = stride[0] * hIn; long wOut = stride[1] * wIn; - return InputType.convolutional(hOut, wOut, outputDepth); + return InputType.convolutional(hOut, wOut, outputDepth, i.getFormat()); } long hOut = sH * (hIn - 1) + kH - 2 * padH; long wOut = sW * (wIn - 1) + kW - 2 * padW; - return InputType.convolutional(hOut, wOut, outputDepth); + return InputType.convolutional(hOut, wOut, outputDepth, i.getFormat()); } public static InputType getOutputTypeDeconv3dLayer(InputType inputType, int[] kernelSize, int[] stride, int[] padding, @@ -332,10 +334,20 @@ public class InputTypeUtil { return InputType.recurrent(outputDepth, outH); } + /** + * @deprecated Use {@link #getOutputTypeCnnLayers(InputType, int[], int[], int[], int[], ConvolutionMode, long, long, String, CNN2DFormat, Class)} + */ + @Deprecated + public static InputType getOutputTypeCnnLayers(InputType inputType, int[] kernelSize, int[] stride, int[] padding, + int[] dilation, ConvolutionMode convolutionMode, long outputDepth, long layerIdx, String layerName, + Class layerClass) { + return getOutputTypeCnnLayers(inputType, kernelSize, stride, padding, dilation, convolutionMode, outputDepth, + layerIdx, layerName, CNN2DFormat.NCHW, layerClass); + } public static InputType getOutputTypeCnnLayers(InputType inputType, int[] kernelSize, int[] stride, int[] padding, int[] dilation, ConvolutionMode convolutionMode, long outputDepth, long layerIdx, String layerName, - Class layerClass) { + CNN2DFormat format, Class layerClass) { if (convolutionMode == null) { String name = layerName == null ? "(not named)" : layerName; @@ -424,12 +436,12 @@ public class InputTypeUtil { int outH = (int) Math.ceil(inHeight / ((double) stride[0])); int outW = (int) Math.ceil(inWidth / ((double) stride[1])); - return InputType.convolutional(outH, outW, outputDepth); + return InputType.convolutional(outH, outW, outputDepth, format); } long hOut = (inHeight - kH + 2 * padH) / sH + 1; long wOut = (inWidth - kW + 2 * padW) / sW + 1; - return InputType.convolutional(hOut, wOut, outputDepth); + return InputType.convolutional(hOut, wOut, outputDepth, format); } private static String getConfigErrorCommonLine(long layerIdx, String layerName, Class layerClass, @@ -517,25 +529,31 @@ public class InputTypeUtil { } } - public static InputPreProcessor getPreprocessorForInputTypeRnnLayers(InputType inputType, String layerName) { + public static InputPreProcessor getPreprocessorForInputTypeRnnLayers(InputType inputType, RNNFormat rnnDataFormat, String layerName) { if (inputType == null) { throw new IllegalStateException( "Invalid input for RNN layer (layer name = \"" + layerName + "\"): input type is null"); } switch (inputType.getType()) { - case FF: case CNNFlat: //FF -> RNN or CNNFlat -> RNN //In either case, input data format is a row vector per example - return new FeedForwardToRnnPreProcessor(); + return new FeedForwardToRnnPreProcessor(rnnDataFormat); + case FF: + //If time distributed format is defined, use that. Otherwise use the layer-defined rnnDataFormat, which may be default + InputType.InputTypeFeedForward ff = (InputType.InputTypeFeedForward)inputType; + if(ff.getTimeDistributedFormat() != null && ff.getTimeDistributedFormat() instanceof RNNFormat){ + return new FeedForwardToRnnPreProcessor((RNNFormat) ff.getTimeDistributedFormat()); + } + return new FeedForwardToRnnPreProcessor(rnnDataFormat); case RNN: //RNN -> RNN: No preprocessor necessary return null; case CNN: //CNN -> RNN InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType; - return new CnnToRnnPreProcessor(c.getHeight(), c.getWidth(), c.getChannels()); + return new CnnToRnnPreProcessor(c.getHeight(), c.getWidth(), c.getChannels(), rnnDataFormat); default: throw new RuntimeException("Unknown input type: " + inputType); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LearnedSelfAttentionLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LearnedSelfAttentionLayer.java index 0c3c6e383..a0fa4d680 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LearnedSelfAttentionLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LearnedSelfAttentionLayer.java @@ -19,6 +19,7 @@ package org.deeplearning4j.nn.conf.layers; import lombok.*; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.conf.InputPreProcessor; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.samediff.SDLayerParams; import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer; @@ -86,7 +87,7 @@ public class LearnedSelfAttentionLayer extends SameDiffLayer { @Override public InputPreProcessor getPreProcessorForInputType(InputType inputType) { - return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, getLayerName()); + return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, RNNFormat.NCW,getLayerName()); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocalResponseNormalization.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocalResponseNormalization.java index b16703569..f4f49b79a 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocalResponseNormalization.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocalResponseNormalization.java @@ -18,6 +18,7 @@ package org.deeplearning4j.nn.conf.layers; import lombok.*; import org.deeplearning4j.nn.api.ParamInitializer; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; @@ -26,6 +27,7 @@ import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.params.EmptyParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.learning.regularization.Regularization; @@ -50,6 +52,7 @@ public class LocalResponseNormalization extends Layer { protected double beta = 0.75; // decay rate protected double alpha = 1e-4; // decay rate protected boolean cudnnAllowFallback = true; + protected CNN2DFormat dataFormat = CNN2DFormat.NCHW; private LocalResponseNormalization(Builder builder) { super(builder); @@ -58,6 +61,7 @@ public class LocalResponseNormalization extends Layer { this.alpha = builder.alpha; this.beta = builder.beta; this.cudnnAllowFallback = builder.cudnnAllowFallback; + this.dataFormat = builder.dataFormat; } @Override @@ -99,7 +103,8 @@ public class LocalResponseNormalization extends Layer { @Override public void setNIn(InputType inputType, boolean override) { - //No op + Preconditions.checkState(inputType.getType() == InputType.Type.CNN, "Only CNN input types can be used with LocalResponseNormalisation, got %s", inputType); + this.dataFormat = ((InputType.InputTypeConvolutional)inputType).getFormat(); } @Override @@ -184,8 +189,10 @@ public class LocalResponseNormalization extends Layer { */ protected boolean cudnnAllowFallback = true; + protected CNN2DFormat dataFormat = CNN2DFormat.NCHW; + public Builder(double k, double n, double alpha, double beta) { - this(k, n, alpha, beta, true); + this(k, n, alpha, beta, true, CNN2DFormat.NCHW); } public Builder(double k, double alpha, double beta) { @@ -263,6 +270,17 @@ public class LocalResponseNormalization extends Layer { return this; } + /** + * Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last). + * See {@link CNN2DFormat} for more details.
+ * Default: NCHW + * @param format Format for activations (in and out) + */ + public Builder dataFormat(CNN2DFormat dataFormat){ + this.dataFormat = dataFormat; + return this; + } + @Override public LocalResponseNormalization build() { return new LocalResponseNormalization(this); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java index 8fedee7b0..d43423bf4 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java @@ -20,6 +20,7 @@ import lombok.*; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.samediff.SDLayerParams; import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer; @@ -136,7 +137,7 @@ public class LocallyConnected1D extends SameDiffLayer { @Override public InputPreProcessor getPreProcessorForInputType(InputType inputType) { - return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, getLayerName()); + return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, RNNFormat.NCW, getLayerName()); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java index 6fad9ec69..9b8fb10aa 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java @@ -17,6 +17,7 @@ package org.deeplearning4j.nn.conf.layers; import lombok.*; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; @@ -70,6 +71,7 @@ public class LocallyConnected2D extends SameDiffLayer { private int[] inputSize; private int[] outputSize; private int featureDim; + protected CNN2DFormat format = CNN2DFormat.NCHW; protected LocallyConnected2D(Builder builder) { super(builder); @@ -84,6 +86,7 @@ public class LocallyConnected2D extends SameDiffLayer { this.hasBias = builder.hasBias; this.inputSize = builder.inputSize; this.featureDim = kernel[0] * kernel[1] * (int) nIn; + this.format = builder.format; } private LocallyConnected2D() { @@ -97,17 +100,19 @@ public class LocallyConnected2D extends SameDiffLayer { throw new IllegalArgumentException("Input size has to be specified for locally connected layers."); } - int[] inputShape = new int[] {1, nIn, inputSize[0], inputSize[1]}; + boolean nchw = format == CNN2DFormat.NCHW; + + int[] inputShape = nchw ? new int[] {1, nIn, inputSize[0], inputSize[1]} : new int[] {1, inputSize[0], inputSize[1], nIn}; INDArray dummyInputForShapeInference = Nd4j.ones(inputShape); if (cm == ConvolutionMode.Same) { this.outputSize = ConvolutionUtils.getOutputSize(dummyInputForShapeInference, kernel, stride, null, cm, - dilation); + dilation, format); this.padding = ConvolutionUtils.getSameModeTopLeftPadding(outputSize, inputSize, kernel, stride, dilation); this.paddingBr = ConvolutionUtils.getSameModeBottomRightPadding(outputSize, inputSize, kernel, stride, dilation); } else { this.outputSize = ConvolutionUtils.getOutputSize(dummyInputForShapeInference, kernel, stride, padding, cm, - dilation); + dilation, format); } } @@ -123,7 +128,7 @@ public class LocallyConnected2D extends SameDiffLayer { computeOutputSize(); return InputTypeUtil.getOutputTypeCnnLayers(inputType, kernel, stride, padding, new int[] {1, 1}, cm, nOut, - layerIndex, getLayerName(), LocallyConnected2D.class); + layerIndex, getLayerName(), format, LocallyConnected2D.class); } @Override @@ -133,6 +138,7 @@ public class LocallyConnected2D extends SameDiffLayer { this.nIn = c.getChannels(); this.featureDim = kernel[0] * kernel[1] * (int) nIn; } + this.format = ((InputType.InputTypeConvolutional)inputType).getFormat(); } @Override @@ -181,6 +187,10 @@ public class LocallyConnected2D extends SameDiffLayer { int kH = kernel[0]; int kW = kernel[1]; + boolean nchw = format == CNN2DFormat.NCHW; + if(!nchw) + layerInput = layerInput.permute(0,3,1,2); //NHWC to NCHW + if(padding[0] > 0 || padding[1] > 0 || (cm == ConvolutionMode.Same && (paddingBr[0] > 0 || paddingBr[1] > 0))){ //Note: for same mode, bottom/right padding can be 1 more than top/left padding //NCHW format @@ -210,16 +220,15 @@ public class LocallyConnected2D extends SameDiffLayer { SDVariable reshapeResult = sameDiff.reshape(mmulResult, outH, outW, miniBatch, nOut); - SDVariable permutedResult = sameDiff.permute(reshapeResult, 2, 3, 0, 1); // (mb, nOut, outH, outW) + SDVariable permutedResult = nchw ? reshapeResult.permute(2, 3, 0, 1) : reshapeResult.permute(2, 0, 1, 3); // (mb, nOut, outH, outW) or (mb, outH, outW, nOut) if (hasBias) { SDVariable b = paramTable.get(ConvolutionParamInitializer.BIAS_KEY); - SDVariable biasAddedResult = sameDiff.nn().biasAdd(permutedResult, b, true); + SDVariable biasAddedResult = sameDiff.nn().biasAdd(permutedResult, b, nchw); return activation.asSameDiff("out", sameDiff, biasAddedResult); } else { return activation.asSameDiff("out", sameDiff, permutedResult); } - } @Override @@ -292,6 +301,7 @@ public class LocallyConnected2D extends SameDiffLayer { */ private boolean hasBias = true; + protected CNN2DFormat format = CNN2DFormat.NCHW; /** @@ -386,6 +396,17 @@ public class LocallyConnected2D extends SameDiffLayer { return this; } + /** + * Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last). + * See {@link CNN2DFormat} for more details.
+ * Default: NCHW + * @param format Format for activations (in and out) + */ + public Builder dataFormat(CNN2DFormat format){ + this.format = format; + return this; + } + /** * @param hasBias If true (default is false) the layer will have a bias */ diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RecurrentAttentionLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RecurrentAttentionLayer.java index d12e0ec74..bc746f8ab 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RecurrentAttentionLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RecurrentAttentionLayer.java @@ -19,6 +19,7 @@ package org.deeplearning4j.nn.conf.layers; import lombok.*; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.samediff.SDLayerParams; import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer; @@ -92,7 +93,7 @@ public class RecurrentAttentionLayer extends SameDiffLayer { @Override public InputPreProcessor getPreProcessorForInputType(InputType inputType) { - return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, getLayerName()); + return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, RNNFormat.NCW, getLayerName()); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnLossLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnLossLayer.java index df0b16e6c..f1dcd73a6 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnLossLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnLossLayer.java @@ -24,6 +24,7 @@ import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; @@ -53,12 +54,13 @@ import java.util.Map; @ToString(callSuper = true) @EqualsAndHashCode(callSuper = true) public class RnnLossLayer extends FeedForwardLayer { - + private RNNFormat rnnDataFormat = RNNFormat.NCW; protected ILossFunction lossFn; private RnnLossLayer(Builder builder) { super(builder); this.setLossFn(builder.lossFn); + this.rnnDataFormat = builder.rnnDataFormat; } @Override @@ -91,7 +93,7 @@ public class RnnLossLayer extends FeedForwardLayer { @Override public InputPreProcessor getPreProcessorForInputType(InputType inputType) { - return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, getLayerName()); + return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, RNNFormat.NCW, getLayerName()); } @Override @@ -111,8 +113,9 @@ public class RnnLossLayer extends FeedForwardLayer { public static class Builder extends BaseOutputLayer.Builder { - public Builder() { + private RNNFormat rnnDataFormat = RNNFormat.NCW; + public Builder() { } /** @@ -153,6 +156,14 @@ public class RnnLossLayer extends FeedForwardLayer { "This layer has no parameters, thus nIn will always equal nOut."); } + /** + * @param rnnDataFormat Data format expected by the layer. NCW = [miniBatchSize, size, timeSeriesLength], + * NWC = [miniBatchSize, timeSeriesLength, size]. Defaults to NCW. + */ + public Builder dataFormat(RNNFormat rnnDataFormat){ + this.rnnDataFormat = rnnDataFormat; + return this; + } @Override @SuppressWarnings("unchecked") public RnnLossLayer build() { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnOutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnOutputLayer.java index 078673f5d..cfd337514 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnOutputLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnOutputLayer.java @@ -24,6 +24,7 @@ import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; @@ -51,9 +52,11 @@ import java.util.Map; @EqualsAndHashCode(callSuper = true) public class RnnOutputLayer extends BaseOutputLayer { + private RNNFormat rnnDataFormat = RNNFormat.NCW; private RnnOutputLayer(Builder builder) { super(builder); initializeConstraints(builder); + this.rnnDataFormat = builder.rnnDataFormat; } @Override @@ -85,7 +88,7 @@ public class RnnOutputLayer extends BaseOutputLayer { } InputType.InputTypeRecurrent itr = (InputType.InputTypeRecurrent) inputType; - return InputType.recurrent(nOut, itr.getTimeSeriesLength()); + return InputType.recurrent(nOut, itr.getTimeSeriesLength(), itr.getFormat()); } @Override @@ -95,20 +98,22 @@ public class RnnOutputLayer extends BaseOutputLayer { + "\"): Expected RNN input, got " + inputType); } + InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType; + this.rnnDataFormat = r.getFormat(); if (nIn <= 0 || override) { - InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType; this.nIn = r.getSize(); } } @Override public InputPreProcessor getPreProcessorForInputType(InputType inputType) { - return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, getLayerName()); + return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, rnnDataFormat, getLayerName()); } public static class Builder extends BaseOutputLayer.Builder { + private RNNFormat rnnDataFormat = RNNFormat.NCW; public Builder() { //Set default activation function to softmax (to match default loss function MCXENT) this.setActivationFn(new ActivationSoftmax()); @@ -137,5 +142,14 @@ public class RnnOutputLayer extends BaseOutputLayer { public RnnOutputLayer build() { return new RnnOutputLayer(this); } + + /** + * @param rnnDataFormat Data format expected by the layer. NCW = [miniBatchSize, size, timeSeriesLength], + * NWC = [miniBatchSize, timeSeriesLength, size]. Defaults to NCW. + */ + public Builder dataFormat(RNNFormat rnnDataFormat){ + this.rnnDataFormat = rnnDataFormat; + return this; + } } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SelfAttentionLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SelfAttentionLayer.java index db898dbdd..8daa4a2c2 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SelfAttentionLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SelfAttentionLayer.java @@ -18,6 +18,7 @@ package org.deeplearning4j.nn.conf.layers; import lombok.*; import org.deeplearning4j.nn.conf.InputPreProcessor; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.samediff.SDLayerParams; import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer; @@ -75,7 +76,7 @@ public class SelfAttentionLayer extends SameDiffLayer { @Override public InputPreProcessor getPreProcessorForInputType(InputType inputType) { - return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, getLayerName()); + return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, RNNFormat.NCW,getLayerName()); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SeparableConvolution2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SeparableConvolution2D.java index 133c14869..f9ae11b49 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SeparableConvolution2D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SeparableConvolution2D.java @@ -20,6 +20,7 @@ import lombok.*; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.api.layers.LayerConstraint; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.layers.convolution.SeparableConvolution2DLayer; @@ -85,6 +86,8 @@ public class SeparableConvolution2D extends ConvolutionLayer { this.cudnnFwdAlgo = builder.cudnnFwdAlgo; this.cudnnBwdFilterAlgo = builder.cudnnBwdFilterAlgo; this.cudnnBwdDataAlgo = builder.cudnnBwdDataAlgo; + this.cnn2dDataFormat = builder.dataFormat; + initializeConstraints(builder); } @@ -153,8 +156,10 @@ public class SeparableConvolution2D extends ConvolutionLayer { + "\"): Expected CNN input, got " + inputType); } + CNN2DFormat format = ((InputType.InputTypeConvolutional)inputType).getFormat(); + return InputTypeUtil.getOutputTypeCnnLayers(inputType, kernelSize, stride, padding, dilation, convolutionMode, - nOut, layerIndex, getLayerName(), SeparableConvolution2DLayer.class); + nOut, layerIndex, getLayerName(), format, SeparableConvolution2DLayer.class); } @@ -166,7 +171,8 @@ public class SeparableConvolution2D extends ConvolutionLayer { * Set channels multiplier of channels-wise step in separable convolution * */ - public int depthMultiplier = 1; + protected int depthMultiplier = 1; + protected CNN2DFormat dataFormat = CNN2DFormat.NCHW; public Builder(int[] kernelSize, int[] stride, int[] padding) { super(kernelSize, stride, padding); @@ -190,6 +196,17 @@ public class SeparableConvolution2D extends ConvolutionLayer { return false; } + /** + * Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last). + * See {@link CNN2DFormat} for more details.
+ * Default: NCHW + * @param format Format for activations (in and out) + */ + public Builder dataFormat(CNN2DFormat format){ + this.dataFormat = format; + return this; + } + /** * Set channels multiplier of channels-wise step in separable convolution * diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToBatchLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToBatchLayer.java index 5d946f2c7..cd7db60ab 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToBatchLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToBatchLayer.java @@ -18,6 +18,7 @@ package org.deeplearning4j.nn.conf.layers; import lombok.*; import org.deeplearning4j.nn.api.ParamInitializer; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -26,6 +27,7 @@ import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.params.EmptyParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.ValidationUtils; +import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -65,12 +67,14 @@ public class SpaceToBatchLayer extends NoParamLayer { protected int[] blocks; protected int[][] padding; + protected CNN2DFormat format = CNN2DFormat.NCHW; protected SpaceToBatchLayer(Builder builder) { super(builder); this.blocks = builder.blocks; this.padding = builder.padding; + this.format = builder.format; } @Override @@ -112,7 +116,7 @@ public class SpaceToBatchLayer extends NoParamLayer { } InputType.InputTypeConvolutional i = (InputType.InputTypeConvolutional) inputType; return InputType.convolutional((i.getHeight() + padding[0][0] + padding[0][1]) / blocks[0], - (i.getWidth() + padding[1][0] + padding[1][1]) / blocks[1], i.getChannels()); + (i.getWidth() + padding[1][0] + padding[1][1]) / blocks[1], i.getChannels(), i.getFormat()); } @Override @@ -123,7 +127,8 @@ public class SpaceToBatchLayer extends NoParamLayer { @Override public void setNIn(InputType inputType, boolean override) { - //No op: space to batch layer doesn't have nIn value + Preconditions.checkState(inputType.getType() == InputType.Type.CNN, "Only CNN input types can be used with SpaceToBatchLayer, got %s", inputType); + this.format = ((InputType.InputTypeConvolutional)inputType).getFormat(); } @Override @@ -158,6 +163,8 @@ public class SpaceToBatchLayer extends NoParamLayer { */ protected int[][] padding; + protected CNN2DFormat format = CNN2DFormat.NCHW; + /** * @param blocks Block size for SpaceToBatch layer. Should be a length 2 array for the height and width * dimensions @@ -193,6 +200,17 @@ public class SpaceToBatchLayer extends NoParamLayer { this.setPadding(padding); } + /** + * Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last). + * See {@link CNN2DFormat} for more details.
+ * Default: NCHW + * @param format Format for activations (in and out) + */ + public T dataFormat(CNN2DFormat format){ + this.format = format; + return (T)this; + } + /** * @param blocks Block size for SpaceToBatch layer. Should be a length 2 array for the height and width * dimensions diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToDepthLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToDepthLayer.java index 44f8bb666..53d9007be 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToDepthLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToDepthLayer.java @@ -18,6 +18,7 @@ package org.deeplearning4j.nn.conf.layers; import lombok.*; import org.deeplearning4j.nn.api.ParamInitializer; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -56,12 +57,20 @@ import java.util.Map; @EqualsAndHashCode(callSuper = true) public class SpaceToDepthLayer extends NoParamLayer { + /** + * @deprecated Use {@link CNN2DFormat} instead + */ + @Deprecated public enum DataFormat { - NCHW, NHWC + NCHW, NHWC; + + public CNN2DFormat toFormat(){ + return this == NCHW ? CNN2DFormat.NCHW : CNN2DFormat.NHWC; + } } protected int blockSize; - protected DataFormat dataFormat; + protected CNN2DFormat dataFormat; protected SpaceToDepthLayer(Builder builder) { @@ -108,7 +117,7 @@ public class SpaceToDepthLayer extends NoParamLayer { } InputType.InputTypeConvolutional i = (InputType.InputTypeConvolutional) inputType; return InputType.convolutional(i.getHeight() / blockSize, i.getWidth() / blockSize, - i.getChannels() * blockSize * blockSize); + i.getChannels() * blockSize * blockSize, i.getFormat()); } @Override @@ -119,7 +128,7 @@ public class SpaceToDepthLayer extends NoParamLayer { @Override public void setNIn(InputType inputType, boolean override) { - //No op: space to batch layer doesn't have nIn value + this.dataFormat = ((InputType.InputTypeConvolutional)inputType).getFormat(); } @Override @@ -147,7 +156,7 @@ public class SpaceToDepthLayer extends NoParamLayer { /** * Data format for input activations. Note DL4J uses NCHW in most cases */ - protected DataFormat dataFormat = DataFormat.NCHW; + protected CNN2DFormat dataFormat = CNN2DFormat.NCHW; /** * @param blockSize Block size @@ -160,7 +169,12 @@ public class SpaceToDepthLayer extends NoParamLayer { * @param blockSize Block size * @param dataFormat Data format for input activations. Note DL4J uses NCHW in most cases */ + @Deprecated public Builder(int blockSize, DataFormat dataFormat) { + this(blockSize, dataFormat.toFormat()); + } + + public Builder(int blockSize, CNN2DFormat dataFormat) { this.setBlockSize(blockSize); this.setDataFormat(dataFormat); } @@ -175,8 +189,20 @@ public class SpaceToDepthLayer extends NoParamLayer { /** * @param dataFormat Data format for input activations. Note DL4J uses NCHW in most cases + * @deprecated Use {@link #dataFormat(CNN2DFormat)} */ + @Deprecated public T dataFormat(DataFormat dataFormat) { + return dataFormat(dataFormat.toFormat()); + } + + /** + * Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last). + * See {@link CNN2DFormat} for more details.
+ * Default: NCHW + * @param dataFormat Format for activations (in and out) + */ + public T dataFormat(CNN2DFormat dataFormat) { this.setDataFormat(dataFormat); return (T) this; } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling1DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling1DLayer.java index 9f3162374..5d2a55994 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling1DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling1DLayer.java @@ -22,6 +22,7 @@ import lombok.NoArgsConstructor; import lombok.ToString; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.Convolution1DUtils; @@ -90,7 +91,7 @@ public class Subsampling1DLayer extends SubsamplingLayer { outLength = Convolution1DUtils.getOutputSize(inputTsLength, kernelSize[0], stride[0], padding[0], convolutionMode, dilation[0]); } - return InputType.recurrent(r.getSize(), outLength); + return InputType.recurrent(r.getSize(), outLength, r.getFormat()); } @Override @@ -105,7 +106,7 @@ public class Subsampling1DLayer extends SubsamplingLayer { + "\"): input is null"); } - return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, getLayerName()); + return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, RNNFormat.NCW, getLayerName()); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SubsamplingLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SubsamplingLayer.java index be6764e9a..8b09aedf1 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SubsamplingLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SubsamplingLayer.java @@ -18,6 +18,7 @@ package org.deeplearning4j.nn.conf.layers; import lombok.*; import org.deeplearning4j.nn.api.ParamInitializer; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; @@ -58,6 +59,7 @@ public class SubsamplingLayer extends NoParamLayer { protected int pnorm; protected double eps; protected boolean cudnnAllowFallback = true; + protected CNN2DFormat cnn2dDataFormat = CNN2DFormat.NCHW; /* Default here for JSON deserialization of 1.0.0-beta4 and earlier models. New models default to false via builder. This impacts average pooling only - whether the divisor should include or exclude padding along image edges. @@ -121,6 +123,7 @@ public class SubsamplingLayer extends NoParamLayer { if (clone.dilation != null) { clone.dilation = clone.dilation.clone(); } + return clone; } @@ -153,12 +156,13 @@ public class SubsamplingLayer extends NoParamLayer { return InputTypeUtil.getOutputTypeCnnLayers(inputType, kernelSize, stride, padding, dilation, convolutionMode, ((InputType.InputTypeConvolutional) inputType).getChannels(), layerIndex, getLayerName(), - SubsamplingLayer.class); + cnn2dDataFormat, SubsamplingLayer.class); } @Override public void setNIn(InputType inputType, boolean override) { //No op: subsampling layer doesn't have nIn value + this.cnn2dDataFormat = ((InputType.InputTypeConvolutional)inputType).getFormat(); } @Override @@ -229,6 +233,7 @@ public class SubsamplingLayer extends NoParamLayer { * Dilation for kernel */ private int[] dilation = new int[] {1, 1}; + protected CNN2DFormat dataFormat = CNN2DFormat.NCHW; public Builder(PoolingType poolingType, int[] kernelSize, int[] stride) { super(poolingType, kernelSize, stride); @@ -307,6 +312,17 @@ public class SubsamplingLayer extends NoParamLayer { return this; } + /** + * Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last). + * See {@link CNN2DFormat} for more details.
+ * Default: NCHW + * @param format Format for activations (in and out) + */ + public Builder dataFormat(CNN2DFormat format){ + this.dataFormat = format; + return this; + } + /** * Kernel dilation. Default: {1, 1}, which is standard convolutions. Used for implementing dilated convolutions, * which are also known as atrous convolutions.
NOTE: Kernel dilation is less common in practice for @@ -358,6 +374,10 @@ public class SubsamplingLayer extends NoParamLayer { public void setDilation(int[] dilation) { this.dilation = ValidationUtils.validate2NonNegative(dilation, false, "dilation"); } + + public void setDataFormat(CNN2DFormat format){ + this.dataFormat = format; + } } @NoArgsConstructor diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling2D.java index 0f1a770a8..0357c3e7b 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling2D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling2D.java @@ -17,6 +17,7 @@ package org.deeplearning4j.nn.conf.layers; import lombok.*; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -59,10 +60,12 @@ public class Upsampling2D extends BaseUpsamplingLayer { @JsonDeserialize(using = LegacyIntArrayDeserializer.class) protected int[] size; + protected CNN2DFormat format = CNN2DFormat.NCHW; protected Upsampling2D(UpsamplingBuilder builder) { super(builder); this.size = builder.size; + this.format = ((Builder)builder).format; } @Override @@ -97,7 +100,7 @@ public class Upsampling2D extends BaseUpsamplingLayer { val inWidth = i.getWidth(); val inDepth = i.getChannels(); - return InputType.convolutional(size[0] * inHeight, size[1] * inWidth, inDepth); + return InputType.convolutional(size[0] * inHeight, size[1] * inWidth, inDepth, i.getFormat()); } @Override @@ -131,14 +134,35 @@ public class Upsampling2D extends BaseUpsamplingLayer { .build(); } + @Override + public void setNIn(InputType inputType, boolean override) { + if (inputType == null || inputType.getType() != InputType.Type.CNN) { + throw new IllegalStateException("Invalid input for Upsampling 2D layer (layer name=\"" + getLayerName() + + "\"): Expected CNN input, got " + inputType); + } + this.format = ((InputType.InputTypeConvolutional)inputType).getFormat(); + } @NoArgsConstructor public static class Builder extends UpsamplingBuilder { + protected CNN2DFormat format = CNN2DFormat.NCHW; + public Builder(int size) { super(new int[] {size, size}); } + /** + * Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last). + * See {@link CNN2DFormat} for more details.
+ * Default: NCHW + * @param format Format for activations (in and out) + */ + public Builder dataFormat(CNN2DFormat format){ + this.format = format; + return this; + } + /** * Upsampling size int, used for both height and width * @@ -146,7 +170,7 @@ public class Upsampling2D extends BaseUpsamplingLayer { */ public Builder size(int size) { - this.setSize(new int[] {size, size}); + this.setSize(size, size); return this; } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding1DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding1DLayer.java index e888a2904..a3345fde9 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding1DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding1DLayer.java @@ -20,6 +20,7 @@ import lombok.*; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; @@ -104,7 +105,7 @@ public class ZeroPadding1DLayer extends NoParamLayer { + "\"): input is null"); } - return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, getLayerName()); + return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, RNNFormat.NCW, getLayerName()); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPaddingLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPaddingLayer.java index 48463b76b..30e46edab 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPaddingLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPaddingLayer.java @@ -17,6 +17,7 @@ package org.deeplearning4j.nn.conf.layers; import lombok.*; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -45,6 +46,7 @@ import java.util.Map; public class ZeroPaddingLayer extends NoParamLayer { private int[] padding; + private CNN2DFormat dataFormat = CNN2DFormat.NCHW; public ZeroPaddingLayer(int padTopBottom, int padLeftRight) { this(new Builder(padTopBottom, padLeftRight)); @@ -63,6 +65,7 @@ public class ZeroPaddingLayer extends NoParamLayer { } this.padding = builder.padding; + this.dataFormat = builder.cnn2DFormat; } @Override @@ -85,7 +88,9 @@ public class ZeroPaddingLayer extends NoParamLayer { int outH = hwd[0] + padding[0] + padding[1]; int outW = hwd[1] + padding[2] + padding[3]; - return InputType.convolutional(outH, outW, hwd[2]); + InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional)inputType; + + return InputType.convolutional(outH, outW, hwd[2], c.getFormat()); } @Override @@ -107,6 +112,12 @@ public class ZeroPaddingLayer extends NoParamLayer { .build(); } + @Override + public void setNIn(InputType inputType, boolean override) { + InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional)inputType; + this.dataFormat = c.getFormat(); + } + @Getter @Setter public static class Builder extends Layer.Builder { @@ -117,6 +128,19 @@ public class ZeroPaddingLayer extends NoParamLayer { @Setter(AccessLevel.NONE) private int[] padding = new int[] {0, 0, 0, 0}; //Padding: top, bottom, left, right + private CNN2DFormat cnn2DFormat = CNN2DFormat.NCHW; + + /** + * Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last). + * See {@link CNN2DFormat} for more details.
+ * Default: NCHW + * @param format Format for activations (in and out) + */ + public Builder dataFormat(CNN2DFormat format){ + this.cnn2DFormat = format; + return this; + } + /** * @param padding Padding value for top, bottom, left, and right. Must be length 4 array */ diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping2D.java index 497bb9a06..7b8852dc0 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping2D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping2D.java @@ -17,12 +17,14 @@ package org.deeplearning4j.nn.conf.layers.convolutional; import lombok.*; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.InputTypeUtil; import org.deeplearning4j.nn.conf.layers.Layer; import org.deeplearning4j.nn.conf.layers.NoParamLayer; +import org.deeplearning4j.nn.conf.layers.ZeroPaddingLayer; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.layers.convolution.Cropping2DLayer; import org.deeplearning4j.optimize.api.TrainingListener; @@ -47,6 +49,7 @@ import java.util.Map; public class Cropping2D extends NoParamLayer { private int[] cropping; + private CNN2DFormat dataFormat = CNN2DFormat.NCHW; /** * @param cropTopBottom Amount of cropping to apply to both the top and the bottom of the input activations @@ -56,6 +59,10 @@ public class Cropping2D extends NoParamLayer { this(cropTopBottom, cropTopBottom, cropLeftRight, cropLeftRight); } + public Cropping2D(CNN2DFormat dataFormat, int cropTopBottom, int cropLeftRight) { + this(dataFormat, cropTopBottom, cropTopBottom, cropLeftRight, cropLeftRight); + } + /** * @param cropTop Amount of cropping to apply to the top of the input activations * @param cropBottom Amount of cropping to apply to the bottom of the input activations @@ -63,7 +70,11 @@ public class Cropping2D extends NoParamLayer { * @param cropRight Amount of cropping to apply to the right of the input activations */ public Cropping2D(int cropTop, int cropBottom, int cropLeft, int cropRight) { - this(new Builder(cropTop, cropBottom, cropLeft, cropRight)); + this(CNN2DFormat.NCHW, cropTop, cropBottom, cropLeft, cropRight); + } + + public Cropping2D(CNN2DFormat format, int cropTop, int cropBottom, int cropLeft, int cropRight) { + this(new Builder(cropTop, cropBottom, cropLeft, cropRight).dataFormat(format)); } /** @@ -77,6 +88,7 @@ public class Cropping2D extends NoParamLayer { protected Cropping2D(Builder builder) { super(builder); this.cropping = builder.cropping; + this.dataFormat = builder.cnn2DFormat; } @Override @@ -98,7 +110,9 @@ public class Cropping2D extends NoParamLayer { int outH = hwd[0] - cropping[0] - cropping[1]; int outW = hwd[1] - cropping[2] - cropping[3]; - return InputType.convolutional(outH, outW, hwd[2]); + InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional)inputType; + + return InputType.convolutional(outH, outW, hwd[2], c.getFormat()); } @Override @@ -113,6 +127,10 @@ public class Cropping2D extends NoParamLayer { return null; } + @Override + public void setNIn(InputType inputType, boolean override) { + this.dataFormat = ((InputType.InputTypeConvolutional)inputType).getFormat(); + } @Getter @Setter @@ -124,6 +142,19 @@ public class Cropping2D extends NoParamLayer { @Setter(AccessLevel.NONE) private int[] cropping = new int[] {0, 0, 0, 0}; + private CNN2DFormat cnn2DFormat = CNN2DFormat.NCHW; + + /** + * Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last). + * See {@link CNN2DFormat} for more details.
+ * Default: NCHW + * @param format Format for activations (in and out) + */ + public Builder dataFormat(CNN2DFormat format){ + this.cnn2DFormat = format; + return this; + } + /** * @param cropping Cropping amount for top/bottom/left/right (in that order). Must be length 1, 2, or 4 array. */ diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/RepeatVector.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/RepeatVector.java index 424130f17..e7f252c4b 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/RepeatVector.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/RepeatVector.java @@ -19,6 +19,7 @@ package org.deeplearning4j.nn.conf.layers.misc; import lombok.*; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; @@ -46,10 +47,12 @@ import java.util.Map; public class RepeatVector extends FeedForwardLayer { private int n = 1; + private RNNFormat dataFormat = RNNFormat.NCW; protected RepeatVector(Builder builder) { super(builder); this.n = builder.n; + this.dataFormat = builder.dataFormat; } @Override @@ -83,7 +86,7 @@ public class RepeatVector extends FeedForwardLayer { + "\"): Expected FF input, got " + inputType); } InputType.InputTypeFeedForward ffInput = (InputType.InputTypeFeedForward) inputType; - return InputType.recurrent(ffInput.getSize(), n); + return InputType.recurrent(ffInput.getSize(), n, this.dataFormat); } @Override @@ -101,13 +104,14 @@ public class RepeatVector extends FeedForwardLayer { } + @NoArgsConstructor @Getter @Setter public static class Builder> extends FeedForwardLayer.Builder { private int n = 1; // no repetition by default - + private RNNFormat dataFormat = RNNFormat.NCW; /** * Set repetition factor for RepeatVector layer */ @@ -115,6 +119,15 @@ public class RepeatVector extends FeedForwardLayer { return n; } + public RNNFormat getDataFormat(){ + return dataFormat; + } + + public Builder dataFormat(RNNFormat dataFormat){ + this.dataFormat = dataFormat; + return this; + } + /** * Set repetition factor for RepeatVector layer * diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/Bidirectional.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/Bidirectional.java index 4d32f22e5..792e5633b 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/Bidirectional.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/Bidirectional.java @@ -21,6 +21,7 @@ import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer; import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; @@ -30,6 +31,7 @@ import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.layers.recurrent.BidirectionalLayer; import org.deeplearning4j.nn.params.BidirectionalParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.deeplearning4j.util.TimeSeriesUtils; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.learning.config.IUpdater; @@ -124,6 +126,10 @@ public class Bidirectional extends Layer { } } + public RNNFormat getRNNDataFormat(){ + return TimeSeriesUtils.getFormatFromRnnLayer(fwd); + } + @Override public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, @@ -170,7 +176,7 @@ public class Bidirectional extends Layer { } else { InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) outOrig; if (mode == Mode.CONCAT) { - return InputType.recurrent(2 * r.getSize()); + return InputType.recurrent(2 * r.getSize(), getRNNDataFormat()); } else { return r; } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/TimeDistributed.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/TimeDistributed.java index bd9685ef9..5489ccc78 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/TimeDistributed.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/TimeDistributed.java @@ -5,6 +5,7 @@ import lombok.EqualsAndHashCode; import lombok.NonNull; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.Layer; import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; @@ -29,17 +30,19 @@ import java.util.Collection; @EqualsAndHashCode(callSuper = true) public class TimeDistributed extends BaseWrapperLayer { - private final int timeAxis; + private RNNFormat rnnDataFormat = RNNFormat.NCW; /** * @param underlying Underlying (internal) layer - should be a feed forward type such as DenseLayer - * @param timeAxis Time axis, should be 2 for DL4J RNN activations (shape [minibatch, size, sequenceLength]) */ - public TimeDistributed(@JsonProperty("underlying") @NonNull Layer underlying, @JsonProperty("timeAxis") int timeAxis) { + public TimeDistributed(@JsonProperty("underlying") @NonNull Layer underlying, @JsonProperty("rnnDataFormat") RNNFormat rnnDataFormat) { super(underlying); - this.timeAxis = timeAxis; + this.rnnDataFormat = rnnDataFormat; } + public TimeDistributed(Layer underlying){ + super(underlying); + } @Override public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, @@ -47,7 +50,7 @@ public class TimeDistributed extends BaseWrapperLayer { NeuralNetConfiguration conf2 = conf.clone(); conf2.setLayer(((TimeDistributed) conf2.getLayer()).getUnderlying()); return new TimeDistributedLayer(underlying.instantiate(conf2, trainingListeners, layerIndex, layerParamsView, - initializeParams, networkDataType), timeAxis); + initializeParams, networkDataType), rnnDataFormat); } @Override @@ -59,7 +62,7 @@ public class TimeDistributed extends BaseWrapperLayer { InputType.InputTypeRecurrent rnn = (InputType.InputTypeRecurrent) inputType; InputType ff = InputType.feedForward(rnn.getSize()); InputType ffOut = underlying.getOutputType(layerIndex, ff); - return InputType.recurrent(ffOut.arrayElementsPerExample(), rnn.getTimeSeriesLength()); + return InputType.recurrent(ffOut.arrayElementsPerExample(), rnn.getTimeSeriesLength(), rnnDataFormat); } @Override @@ -70,6 +73,7 @@ public class TimeDistributed extends BaseWrapperLayer { InputType.InputTypeRecurrent rnn = (InputType.InputTypeRecurrent) inputType; InputType ff = InputType.feedForward(rnn.getSize()); + this.rnnDataFormat = rnn.getFormat(); underlying.setNIn(ff, override); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/CnnToFeedForwardPreProcessor.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/CnnToFeedForwardPreProcessor.java index 81d37a067..a9f82ba32 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/CnnToFeedForwardPreProcessor.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/CnnToFeedForwardPreProcessor.java @@ -19,6 +19,7 @@ package org.deeplearning4j.nn.conf.preprocessor; import lombok.Data; import lombok.val; import org.deeplearning4j.nn.api.MaskState; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.inputs.InputType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -38,11 +39,13 @@ import java.util.Arrays; * For example, CNN -> Denselayer
* This does two things:
* (b) Reshapes 4d activations out of CNN layer, with shape - * [numExamples, numChannels, inputHeight, inputWidth]) into 2d activations (with shape - * [numExamples, inputHeight*inputWidth*numChannels]) for use in feed forward layer + * [numExamples, numChannels, inputHeight, inputWidth]) (for {@link CNN2DFormat#NCHW} format activations) or shape + * [numExamples, inputHeight, inputWidth, numChannels] (for {@link CNN2DFormat#NHWC}) format activations) into 2d activations + * (with shape [numExamples, inputHeight*inputWidth*numChannels]) for use in feed forward layer. * (a) Reshapes epsilons (weights*deltas) out of FeedFoward layer (which is 2D or 3D with shape * [numExamples, inputHeight*inputWidth*numChannels]) into 4d epsilons (with shape - * [numExamples, numChannels, inputHeight, inputWidth]) suitable to feed into CNN layers.
+ * [numExamples, numChannels, inputHeight, inputWidth] or [numExamples, inputHeight, inputWidth, numChannels]) suitable to + * feed into CNN layers.
* Note: numChannels is equivalent to channels or featureMaps referenced in different literature * @author Adam Gibson * @see FeedForwardToCnnPreProcessor for opposite case (i.e., DenseLayer -> CNNetc) @@ -52,6 +55,7 @@ public class CnnToFeedForwardPreProcessor implements InputPreProcessor { protected long inputHeight; protected long inputWidth; protected long numChannels; + protected CNN2DFormat format = CNN2DFormat.NCHW; //Default for legacy JSON deserialization /** * @param inputHeight the columns @@ -61,16 +65,21 @@ public class CnnToFeedForwardPreProcessor implements InputPreProcessor { @JsonCreator public CnnToFeedForwardPreProcessor(@JsonProperty("inputHeight") long inputHeight, - @JsonProperty("inputWidth") long inputWidth, @JsonProperty("numChannels") long numChannels) { + @JsonProperty("inputWidth") long inputWidth, @JsonProperty("numChannels") long numChannels, + @JsonProperty("format") CNN2DFormat format) { this.inputHeight = inputHeight; this.inputWidth = inputWidth; this.numChannels = numChannels; + if(format != null) + this.format = format; } public CnnToFeedForwardPreProcessor(long inputHeight, long inputWidth) { - this.inputHeight = inputHeight; - this.inputWidth = inputWidth; - this.numChannels = 1; + this(inputHeight, inputWidth, 1, CNN2DFormat.NCHW); + } + + public CnnToFeedForwardPreProcessor(long inputHeight, long inputWidth, long numChannels) { + this(inputHeight, inputWidth, numChannels, CNN2DFormat.NCHW); } public CnnToFeedForwardPreProcessor() {} @@ -80,15 +89,32 @@ public class CnnToFeedForwardPreProcessor implements InputPreProcessor { public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) { if (input.rank() == 2) return input; //Should usually never happen - if(input.size(1) != numChannels || input.size(2) != inputHeight || input.size(3) != inputWidth){ - throw new IllegalStateException("Invalid input, does not match configuration: expected [minibatch, numChannels=" - + numChannels + ", inputHeight=" + inputHeight + ", inputWidth=" + inputWidth + "] but got input array of" + - "shape " + Arrays.toString(input.shape())); + + int chDim = 1; + int hDim = 2; + int wDim = 3; + if(format == CNN2DFormat.NHWC){ + chDim = 3; + hDim = 1; + wDim = 2; + } + + if(inputHeight == 0 && inputWidth == 0 && numChannels == 0){ + this.inputHeight = input.size(hDim); + this.inputWidth = input.size(wDim); + this.numChannels = input.size(chDim); + } + + if(input.size(chDim) != numChannels || input.size(hDim) != inputHeight || input.size(wDim) != inputWidth){ + throw new IllegalStateException("Invalid input, does not match configuration: expected " + + (format == CNN2DFormat.NCHW ? "[minibatch, numChannels=" + numChannels + ", inputHeight=" + inputHeight + ", inputWidth=" + inputWidth + "] " : + "[minibatch, inputHeight=" + inputHeight + ", inputWidth=" + inputWidth + ", numChannels=" + numChannels + "]") + + " but got input array of shape " + Arrays.toString(input.shape())); } //Check input: nchw format - if(input.size(1) != numChannels || input.size(2) != inputHeight || - input.size(3) != inputWidth){ + if(input.size(chDim) != numChannels || input.size(hDim) != inputHeight || + input.size(wDim) != inputWidth){ throw new IllegalStateException("Invalid input array: expected shape [minibatch, channels, height, width] = " + "[minibatch, " + numChannels + ", " + inputHeight + ", " + inputWidth + "] - got " + Arrays.toString(input.shape())); @@ -99,6 +125,8 @@ public class CnnToFeedForwardPreProcessor implements InputPreProcessor { if (input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input)) input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c'); + //Note that to match Tensorflow/Keras, we do a simple "c order reshape" for both NCHW and NHWC + val inShape = input.shape(); //[miniBatch,depthOut,outH,outW] val outShape = new long[]{inShape[0], inShape[1] * inShape[2] * inShape[3]}; @@ -119,7 +147,13 @@ public class CnnToFeedForwardPreProcessor implements InputPreProcessor { + inputHeight + " x columns " + inputWidth + " x channels " + numChannels + " but was instead " + Arrays.toString(epsilons.shape())); - INDArray ret = epsilons.reshape('c', epsilons.size(0), numChannels, inputHeight, inputWidth); + INDArray ret; + if(format == CNN2DFormat.NCHW){ + ret = epsilons.reshape('c', epsilons.size(0), numChannels, inputHeight, inputWidth); + } else { + ret = epsilons.reshape('c', epsilons.size(0), inputHeight, inputWidth, numChannels); + } + return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, ret); //Move if required to specified workspace } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/CnnToRnnPreProcessor.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/CnnToRnnPreProcessor.java index 42dca9105..43b1b1e7c 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/CnnToRnnPreProcessor.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/CnnToRnnPreProcessor.java @@ -19,6 +19,7 @@ package org.deeplearning4j.nn.conf.preprocessor; import lombok.*; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.conf.InputPreProcessor; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.util.TimeSeriesUtils; import org.nd4j.base.Preconditions; @@ -38,7 +39,7 @@ import java.util.Arrays; * Functionally equivalent to combining CnnToFeedForwardPreProcessor + FeedForwardToRnnPreProcessor
* Specifically, this does two things:
* (a) Reshape 4d activations out of CNN layer, with shape [timeSeriesLength*miniBatchSize, numChannels, inputHeight, inputWidth]) - * into 3d (time series) activations (with shape [numExamples, inputHeight*inputWidth*numChannels, timeSeriesLength]) + * into 3d (time series) activations (with shape [miniBatchSize, inputHeight*inputWidth*numChannels, timeSeriesLength]) * for use in RNN layers
* (b) Reshapes 3d epsilons (weights.*deltas) out of RNN layer (with shape * [miniBatchSize,inputHeight*inputWidth*numChannels,timeSeriesLength]) into 4d epsilons with shape @@ -52,6 +53,7 @@ public class CnnToRnnPreProcessor implements InputPreProcessor { private long inputHeight; private long inputWidth; private long numChannels; + private RNNFormat rnnDataFormat = RNNFormat.NCW; @Getter(AccessLevel.NONE) @Setter(AccessLevel.NONE) @@ -59,11 +61,20 @@ public class CnnToRnnPreProcessor implements InputPreProcessor { @JsonCreator public CnnToRnnPreProcessor(@JsonProperty("inputHeight") long inputHeight, - @JsonProperty("inputWidth") long inputWidth, @JsonProperty("numChannels") long numChannels) { + @JsonProperty("inputWidth") long inputWidth, + @JsonProperty("numChannels") long numChannels, + @JsonProperty("rnnDataFormat") RNNFormat rnnDataFormat) { this.inputHeight = inputHeight; this.inputWidth = inputWidth; this.numChannels = numChannels; this.product = inputHeight * inputWidth * numChannels; + this.rnnDataFormat = rnnDataFormat; + } + + public CnnToRnnPreProcessor(long inputHeight, + long inputWidth, + long numChannels){ + this(inputHeight, inputWidth, numChannels, RNNFormat.NCW); } @Override @@ -90,14 +101,19 @@ public class CnnToRnnPreProcessor implements InputPreProcessor { //Second: reshape 2d to 3d, as per FeedForwardToRnnPreProcessor INDArray reshaped = workspaceMgr.dup(ArrayType.ACTIVATIONS, twod, 'f'); reshaped = reshaped.reshape('f', miniBatchSize, shape[0] / miniBatchSize, product); - return reshaped.permute(0, 2, 1); + if (rnnDataFormat == RNNFormat.NCW) { + return reshaped.permute(0, 2, 1); + } + return reshaped; } @Override public INDArray backprop(INDArray output, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) { if (output.ordering() == 'c' || !Shape.hasDefaultStridesForShape(output)) output = output.dup('f'); - + if (rnnDataFormat == RNNFormat.NWC){ + output = output.permute(0, 2, 1); + } val shape = output.shape(); INDArray output2d; if (shape[0] == 1) { @@ -122,7 +138,7 @@ public class CnnToRnnPreProcessor implements InputPreProcessor { @Override public CnnToRnnPreProcessor clone() { - return new CnnToRnnPreProcessor(inputHeight, inputWidth, numChannels); + return new CnnToRnnPreProcessor(inputHeight, inputWidth, numChannels, rnnDataFormat); } @Override @@ -133,7 +149,7 @@ public class CnnToRnnPreProcessor implements InputPreProcessor { InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType; val outSize = c.getChannels() * c.getHeight() * c.getWidth(); - return InputType.recurrent(outSize); + return InputType.recurrent(outSize, rnnDataFormat); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/FeedForwardToRnnPreProcessor.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/FeedForwardToRnnPreProcessor.java index aa45b30a5..d0b01698e 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/FeedForwardToRnnPreProcessor.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/FeedForwardToRnnPreProcessor.java @@ -21,6 +21,7 @@ import lombok.NoArgsConstructor; import lombok.val; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.conf.InputPreProcessor; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.util.TimeSeriesUtils; import org.nd4j.linalg.api.ndarray.INDArray; @@ -28,7 +29,7 @@ import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.primitives.Pair; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.ArrayType; - +import org.nd4j.shade.jackson.annotation.JsonProperty; import java.util.Arrays; /** @@ -48,7 +49,12 @@ import java.util.Arrays; @Data @NoArgsConstructor public class FeedForwardToRnnPreProcessor implements InputPreProcessor { + private RNNFormat rnnDataFormat = RNNFormat.NCW; + public FeedForwardToRnnPreProcessor(@JsonProperty("rnnDataFormat") RNNFormat rnnDataFormat){ + if(rnnDataFormat != null) + this.rnnDataFormat = rnnDataFormat; + } @Override public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) { //Need to reshape FF activations (2d) activations to 3d (for input into RNN layer) @@ -60,7 +66,10 @@ public class FeedForwardToRnnPreProcessor implements InputPreProcessor { val shape = input.shape(); INDArray reshaped = input.reshape('f', miniBatchSize, shape[0] / miniBatchSize, shape[1]); - return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, reshaped.permute(0, 2, 1)); + if (rnnDataFormat == RNNFormat.NCW){ + reshaped = reshaped.permute(0, 2, 1); + } + return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, reshaped); } @Override @@ -71,6 +80,9 @@ public class FeedForwardToRnnPreProcessor implements InputPreProcessor { "Invalid input: expect NDArray with rank 3 (i.e., epsilons from RNN layer)"); if (output.ordering() != 'f' || !Shape.hasDefaultStridesForShape(output)) output = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, output, 'f'); + if (rnnDataFormat == RNNFormat.NWC){ + output = output.permute(0, 2, 1); + } val shape = output.shape(); INDArray ret; @@ -87,12 +99,7 @@ public class FeedForwardToRnnPreProcessor implements InputPreProcessor { @Override public FeedForwardToRnnPreProcessor clone() { - try { - FeedForwardToRnnPreProcessor clone = (FeedForwardToRnnPreProcessor) super.clone(); - return clone; - } catch (CloneNotSupportedException e) { - throw new RuntimeException(e); - } + return new FeedForwardToRnnPreProcessor(rnnDataFormat); } @Override @@ -104,10 +111,10 @@ public class FeedForwardToRnnPreProcessor implements InputPreProcessor { if (inputType.getType() == InputType.Type.FF) { InputType.InputTypeFeedForward ff = (InputType.InputTypeFeedForward) inputType; - return InputType.recurrent(ff.getSize()); + return InputType.recurrent(ff.getSize(), rnnDataFormat); } else { InputType.InputTypeConvolutionalFlat cf = (InputType.InputTypeConvolutionalFlat) inputType; - return InputType.recurrent(cf.getFlattenedSize()); + return InputType.recurrent(cf.getFlattenedSize(), rnnDataFormat); } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/RnnToCnnPreProcessor.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/RnnToCnnPreProcessor.java index a3920b061..bcfc92170 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/RnnToCnnPreProcessor.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/RnnToCnnPreProcessor.java @@ -19,8 +19,10 @@ package org.deeplearning4j.nn.conf.preprocessor; import lombok.*; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.conf.InputPreProcessor; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.util.TimeSeriesUtils; +import org.nd4j.enums.RnnDataFormat; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.primitives.Pair; @@ -52,19 +54,27 @@ public class RnnToCnnPreProcessor implements InputPreProcessor { private int inputHeight; private int inputWidth; private int numChannels; - + private RNNFormat rnnDataFormat = RNNFormat.NCW; @Getter(AccessLevel.NONE) @Setter(AccessLevel.NONE) private int product; public RnnToCnnPreProcessor(@JsonProperty("inputHeight") int inputHeight, - @JsonProperty("inputWidth") int inputWidth, @JsonProperty("numChannels") int numChannels) { + @JsonProperty("inputWidth") int inputWidth, + @JsonProperty("numChannels") int numChannels, + @JsonProperty("rnnDataFormat") RNNFormat rnnDataFormat) { this.inputHeight = inputHeight; this.inputWidth = inputWidth; this.numChannels = numChannels; this.product = inputHeight * inputWidth * numChannels; + this.rnnDataFormat = rnnDataFormat; } + public RnnToCnnPreProcessor(int inputHeight, + int inputWidth, + int numChannels){ + this(inputHeight, inputWidth, numChannels, RNNFormat.NCW); + } @Override public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) { @@ -72,6 +82,9 @@ public class RnnToCnnPreProcessor implements InputPreProcessor { input = input.dup('f'); //Input: 3d activations (RNN) //Output: 4d activations (CNN) + if (rnnDataFormat == RNNFormat.NWC){ + input = input.permute(0, 2, 1); + } val shape = input.shape(); INDArray in2d; if (shape[0] == 1) { @@ -98,14 +111,17 @@ public class RnnToCnnPreProcessor implements InputPreProcessor { val shape = output.shape(); //First: reshape 4d to 2d INDArray twod = output.reshape('c', output.size(0), ArrayUtil.prod(output.shape()) / output.size(0)); - //Second: reshape 2d to 4d + //Second: reshape 2d to 3d INDArray reshaped = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, twod, 'f').reshape('f', miniBatchSize, shape[0] / miniBatchSize, product); - return reshaped.permute(0, 2, 1); + if (rnnDataFormat == RNNFormat.NCW) { + reshaped = reshaped.permute(0, 2, 1); + } + return reshaped; } @Override public RnnToCnnPreProcessor clone() { - return new RnnToCnnPreProcessor(inputHeight, inputWidth, numChannels); + return new RnnToCnnPreProcessor(inputHeight, inputWidth, numChannels, rnnDataFormat); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/RnnToFeedForwardPreProcessor.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/RnnToFeedForwardPreProcessor.java index 7c92a7eaf..f3e9323a5 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/RnnToFeedForwardPreProcessor.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/RnnToFeedForwardPreProcessor.java @@ -16,11 +16,14 @@ package org.deeplearning4j.nn.conf.preprocessor; +import lombok.AllArgsConstructor; import lombok.Data; +import lombok.NoArgsConstructor; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.conf.InputPreProcessor; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.util.TimeSeriesUtils; import org.nd4j.linalg.api.ndarray.INDArray; @@ -28,6 +31,7 @@ import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.primitives.Pair; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.ArrayType; +import org.nd4j.shade.jackson.annotation.JsonProperty; import java.util.Arrays; @@ -47,8 +51,15 @@ import java.util.Arrays; */ @Data @Slf4j +@NoArgsConstructor public class RnnToFeedForwardPreProcessor implements InputPreProcessor { + private RNNFormat rnnDataFormat = RNNFormat.NCW; + + public RnnToFeedForwardPreProcessor(@JsonProperty("rnnDataFormat") RNNFormat rnnDataFormat){ + if(rnnDataFormat != null) + this.rnnDataFormat = rnnDataFormat; + } @Override public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) { //Need to reshape RNN activations (3d) activations to 2d (for input into feed forward layer) @@ -59,10 +70,13 @@ public class RnnToFeedForwardPreProcessor implements InputPreProcessor { if (input.ordering() != 'f' || !Shape.hasDefaultStridesForShape(input)) input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'f'); + if (rnnDataFormat == RNNFormat.NWC){ + input = input.permute(0, 2, 1); + } val shape = input.shape(); INDArray ret; if (shape[0] == 1) { - ret = input.tensorAlongDimension(0, 1, 2).permutei(1, 0); //Edge case: miniBatchSize==1 + ret = input.tensorAlongDimension(0, 1, 2).permute(1, 0); //Edge case: miniBatchSize==1 } else if (shape[2] == 1) { ret = input.tensorAlongDimension(0, 1, 0); //Edge case: timeSeriesLength=1 } else { @@ -85,17 +99,15 @@ public class RnnToFeedForwardPreProcessor implements InputPreProcessor { val shape = output.shape(); INDArray reshaped = output.reshape('f', miniBatchSize, shape[0] / miniBatchSize, shape[1]); - return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, reshaped.permute(0, 2, 1)); + if (rnnDataFormat == RNNFormat.NCW){ + reshaped = reshaped.permute(0, 2, 1); + } + return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, reshaped); } @Override public RnnToFeedForwardPreProcessor clone() { - try { - RnnToFeedForwardPreProcessor clone = (RnnToFeedForwardPreProcessor) super.clone(); - return clone; - } catch (CloneNotSupportedException e) { - throw new RuntimeException(e); - } + return new RnnToFeedForwardPreProcessor(rnnDataFormat); } @Override @@ -105,7 +117,7 @@ public class RnnToFeedForwardPreProcessor implements InputPreProcessor { } InputType.InputTypeRecurrent rnn = (InputType.InputTypeRecurrent) inputType; - return InputType.feedForward(rnn.getSize()); + return InputType.feedForward(rnn.getSize(), rnn.getFormat()); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/format/DataFormatDeserializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/format/DataFormatDeserializer.java new file mode 100644 index 000000000..ec50a5edb --- /dev/null +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/format/DataFormatDeserializer.java @@ -0,0 +1,52 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.nn.conf.serde.format; + +import org.deeplearning4j.nn.conf.CNN2DFormat; +import org.deeplearning4j.nn.conf.DataFormat; +import org.deeplearning4j.nn.conf.RNNFormat; +import org.nd4j.shade.jackson.core.JsonParser; +import org.nd4j.shade.jackson.core.JsonProcessingException; +import org.nd4j.shade.jackson.databind.DeserializationContext; +import org.nd4j.shade.jackson.databind.JsonDeserializer; +import org.nd4j.shade.jackson.databind.JsonNode; + +import java.io.IOException; + +/** + * Simple JSON deserializer for {@link DataFormat} instances - {@link CNN2DFormat} and {@link RNNFormat} + * + * @author Alex Black + */ +public class DataFormatDeserializer extends JsonDeserializer { + @Override + public DataFormat deserialize(JsonParser jp, DeserializationContext deserializationContext) throws IOException, JsonProcessingException { + JsonNode node = jp.getCodec().readTree(jp); + String text = node.textValue(); + switch (text){ + case "NCHW": + return CNN2DFormat.NCHW; + case "NHWC": + return CNN2DFormat.NHWC; + case "NCW": + return RNNFormat.NCW; + case "NWC": + return RNNFormat.NWC; + default: + return null; + } + } +} diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/format/DataFormatSerializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/format/DataFormatSerializer.java new file mode 100644 index 000000000..9abe90d38 --- /dev/null +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/serde/format/DataFormatSerializer.java @@ -0,0 +1,37 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.nn.conf.serde.format; + +import org.deeplearning4j.nn.conf.CNN2DFormat; +import org.deeplearning4j.nn.conf.DataFormat; +import org.deeplearning4j.nn.conf.RNNFormat; +import org.nd4j.shade.jackson.core.JsonGenerator; +import org.nd4j.shade.jackson.databind.JsonSerializer; +import org.nd4j.shade.jackson.databind.SerializerProvider; + +import java.io.IOException; + +/** + * Simple JSON deserializer for {@link DataFormat} instances - {@link CNN2DFormat} and {@link RNNFormat} + * + * @author Alex Black + */ +public class DataFormatSerializer extends JsonSerializer { + @Override + public void serialize(DataFormat dataFormat, JsonGenerator jsonGenerator, SerializerProvider serializerProvider) throws IOException { + jsonGenerator.writeString(dataFormat.toString()); + } +} diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/MergeVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/MergeVertex.java index cb58a9813..4f864d0cc 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/MergeVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/MergeVertex.java @@ -28,6 +28,7 @@ import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or; import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.primitives.Pair; import org.deeplearning4j.nn.workspace.ArrayType; @@ -48,14 +49,16 @@ public class MergeVertex extends BaseGraphVertex { private long[][] forwardPassShapes; private int fwdPassRank; + private int mergeAxis; - public MergeVertex(ComputationGraph graph, String name, int vertexIndex, DataType dataType) { - this(graph, name, vertexIndex, null, null, dataType); + public MergeVertex(ComputationGraph graph, String name, int vertexIndex, DataType dataType, int mergeAxis) { + this(graph, name, vertexIndex, null, null, dataType, mergeAxis); } public MergeVertex(ComputationGraph graph, String name, int vertexIndex, VertexIndices[] inputVertices, - VertexIndices[] outputVertices, DataType dataType) { + VertexIndices[] outputVertices, DataType dataType, int mergeAxis) { super(graph, name, vertexIndex, inputVertices, outputVertices, dataType); + this.mergeAxis = mergeAxis; } @Override @@ -92,7 +95,6 @@ public class MergeVertex extends BaseGraphVertex { forwardPassShapes = new long[in.length][0]; val nExamples = in[0].size(0); - int nOut = 0; fwdPassRank = in[0].rank(); for (int i = 0; i < in.length; i++) { val currShape = in[i].shape(); @@ -109,12 +111,11 @@ public class MergeVertex extends BaseGraphVertex { + Arrays.toString(in[0].shape()) + ", activations[" + i + "] shape: " + Arrays.toString(in[i].shape())); } - - nOut += currShape[1]; //Same dimension for all of CNNs, FF, RNNs } try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATIONS)){ - return Nd4j.concat(1, in); + INDArray out = Nd4j.concat(mergeAxis, in); + return out; } } @@ -145,20 +146,16 @@ public class MergeVertex extends BaseGraphVertex { break; case 3: for (int i = 0; i < forwardPassShapes.length; i++) { - out[i].assign(epsilon.get(NDArrayIndex.all(), //All rows - NDArrayIndex.interval(cumulative, cumulative + forwardPassShapes[i][1]), //subset of columns - NDArrayIndex.all())); //All time steps + out[i].assign(epsilon.get(indices(3, mergeAxis, cumulative, cumulative + forwardPassShapes[i][mergeAxis]))); //All time steps - cumulative += forwardPassShapes[i][1]; + cumulative += forwardPassShapes[i][mergeAxis]; } break; case 4: for (int i = 0; i < forwardPassShapes.length; i++) { - out[i].assign(epsilon.get(NDArrayIndex.all(), - NDArrayIndex.interval(cumulative, cumulative + forwardPassShapes[i][1]), //Subset of depth - NDArrayIndex.all(), //Width - NDArrayIndex.all())); //height - cumulative += forwardPassShapes[i][1]; + out[i].assign(epsilon.get(indices(4, mergeAxis, cumulative, cumulative + forwardPassShapes[i][mergeAxis]))); //height + + cumulative += forwardPassShapes[i][mergeAxis]; } break; default: @@ -168,6 +165,19 @@ public class MergeVertex extends BaseGraphVertex { return new Pair<>(null, out); } + private INDArrayIndex[] indices(int num, int axis, long from, long to){ + INDArrayIndex[] out = new INDArrayIndex[num]; + for( int i=0; i extends AbstractLayer { @@ -371,7 +373,7 @@ public abstract class BaseLayer(retGradient, epsOut); } @@ -140,7 +150,10 @@ public class Convolution1DLayer extends ConvolutionLayer { // remove singleton fourth dimension from input and current epsilon epsNext = epsNext.reshape(epsNext.size(0), epsNext.size(1), epsNext.size(2)); input = origInput; - + if (getRnnDataFormat() == RNNFormat.NWC){ + epsNext = epsNext.permute(0, 2, 1); + this.input = input.permute(0, 2, 1); + } return new Pair<>(gradientEpsNext.getFirst(), epsNext); } @@ -185,7 +198,8 @@ public class Convolution1DLayer extends ConvolutionLayer { .s(c.getStride()[0]) .d(c.getDilation()[0]) .p(c.getPadding()[0]) - .dataFormat(Conv1DConfig.NCW) + .dataFormat((((org.deeplearning4j.nn.conf.layers.Convolution1DLayer) + layerConf()).getRnnDataFormat()== RNNFormat.NCW)?Conv1DConfig.NCW: Conv1DConfig.NCW) .paddingMode(PaddingMode.CAUSAL) .build(); INDArray w = getParam(ConvolutionParamInitializer.WEIGHT_KEY); @@ -209,6 +223,9 @@ public class Convolution1DLayer extends ConvolutionLayer { @Override public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr){ + if (getRnnDataFormat() == RNNFormat.NWC){ + this.input = input.permute(0, 2, 1); + } INDArray act4d = super.activate(training, workspaceMgr); INDArray act3d = act4d.reshape(act4d.size(0), act4d.size(1), act4d.size(2)); @@ -219,6 +236,10 @@ public class Convolution1DLayer extends ConvolutionLayer { act3d.shape(), maskOut.shape()); Broadcast.mul(act3d, maskOut, act3d, 0, 2); } + if (getRnnDataFormat() == RNNFormat.NWC){ + this.input = input.permute(0, 2, 1); + act3d = act3d.permute(0, 2, 1); + } return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, act3d); //Should be zero copy most of the time } @@ -231,4 +252,8 @@ public class Convolution1DLayer extends ConvolutionLayer { layerConf().getConvolutionMode()); return new Pair<>(reduced, currentMaskState); } + + private RNNFormat getRnnDataFormat(){ + return ((org.deeplearning4j.nn.conf.layers.Convolution1DLayer) layerConf()).getRnnDataFormat(); + } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ConvolutionHelper.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ConvolutionHelper.java index 8b2dd6940..c54812fbc 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ConvolutionHelper.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ConvolutionHelper.java @@ -16,6 +16,7 @@ package org.deeplearning4j.nn.layers.convolution; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer.AlgoMode; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer.BwdDataAlgo; @@ -39,10 +40,10 @@ public interface ConvolutionHelper extends LayerHelper { Pair backpropGradient(INDArray input, INDArray weights, INDArray bias, INDArray delta, int[] kernel, int[] strides, int[] pad, INDArray biasGradView, INDArray weightGradView, IActivation afn, AlgoMode mode, BwdFilterAlgo bwdFilterAlgo, BwdDataAlgo bwdDataAlgo, - ConvolutionMode convolutionMode, int[] dilation, LayerWorkspaceMgr workspaceMgr); + ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr); INDArray preOutput(INDArray input, INDArray weights, INDArray bias, int[] kernel, int[] strides, int[] pad, - AlgoMode mode, FwdAlgo fwdAlgo, ConvolutionMode convolutionMode, int[] dilation, LayerWorkspaceMgr workspaceMgr); + AlgoMode mode, FwdAlgo fwdAlgo, ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr); INDArray activate(INDArray z, IActivation afn, boolean training); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayer.java index 8ae1a8531..2b8a5148b 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayer.java @@ -20,6 +20,7 @@ package org.deeplearning4j.nn.layers.convolution; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.nn.api.MaskState; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.CacheMode; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; @@ -43,8 +44,6 @@ import org.nd4j.linalg.primitives.Pair; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.ArrayType; import org.nd4j.util.OneTimeLogger; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import java.util.Arrays; @@ -115,6 +114,14 @@ public class ConvolutionLayer extends BaseLayer p = preOutput4d(true, true, workspaceMgr); - delta = afn.backprop(p.getFirst(), epsilon).getFirst(); //TODO handle activation function params + INDArray z = p.getFirst(); + CNN2DFormat f = layerConf().getCnn2dDataFormat(); + if(f != CNN2DFormat.NCHW){ + z = z.permute(0,3,1,2); //NHWC to NCHW + } + delta = afn.backprop(z, epsilon).getFirst(); //TODO handle activation function params if (helper != null && (helperCountFail == 0 || !layerConf().isCudnnAllowFallback())) { + INDArray helperDelta = delta; + if(layerConf().getCnn2dDataFormat() == CNN2DFormat.NHWC) + helperDelta = delta.permute(0,2,3,1); //NCHW to NHWC if(!hasBias() && !(helper instanceof MKLDNNConvHelper)){ //MKL-DNN supports no bias, CuDNN doesn't @@ -168,10 +183,10 @@ public class ConvolutionLayer extends BaseLayer ret = null; try { - ret = helper.backpropGradient(input, weights, bias, delta, kernel, strides, + ret = helper.backpropGradient(origInput, weights, bias, helperDelta, kernel, strides, pad, biasGradView, weightGradView, afn, layerConf().getCudnnAlgoMode(), layerConf().getCudnnBwdFilterAlgo(), layerConf().getCudnnBwdDataAlgo(), - convolutionMode, dilation, workspaceMgr); + convolutionMode, dilation, layerConf().getCnn2dDataFormat(), workspaceMgr); } catch (ND4JOpProfilerException e){ throw e; //NaN panic etc for debugging } catch (Exception e){ @@ -254,6 +269,11 @@ public class ConvolutionLayer extends BaseLayer(retGradient, epsNext); } @@ -284,14 +304,16 @@ public class ConvolutionLayer extends BaseLayer Integer.MAX_VALUE || input.size(3) > Integer.MAX_VALUE) throw new ND4JArraySizeException(); @@ -337,7 +364,7 @@ public class ConvolutionLayer extends BaseLayer(z, forBackprop ? im2col2d : null); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cropping2DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cropping2DLayer.java index da2cf1629..aaa34e20f 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cropping2DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cropping2DLayer.java @@ -18,6 +18,7 @@ package org.deeplearning4j.nn.layers.convolution; import lombok.val; import org.deeplearning4j.nn.api.Layer; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; @@ -91,9 +92,19 @@ public class Cropping2DLayer extends AbstractLayer backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { assertInputSet(true); + CNN2DFormat format = layerConf().getCnn2dDataFormat(); + boolean nchw = format == CNN2DFormat.NCHW; if (input.rank() != 4) { throw new DL4JInvalidInputException("Got rank " + input.rank() + " array as input to Convolution layer with shape " + Arrays.toString(input.shape()) - + ". Expected rank 4 array with shape [miniBatchSize, channels, inputHeight, inputWidth]. " + + ". Expected rank 4 array with shape " + layerConf().getCnn2dDataFormat().dimensionNames() + ". " + layerId()); } INDArray bias; @@ -77,8 +80,8 @@ public class DepthwiseConvolution2DLayer extends ConvolutionLayer { INDArray input = this.input.castTo(dataType); //No-op if correct type long miniBatch = input.size(0); - int inH = (int)input.size(2); - int inW = (int)input.size(3); + int inH = (int)input.size(nchw ? 2 : 1); + int inW = (int)input.size(nchw ? 3 : 2); long inDepth = depthWiseWeights.size(2); int kH = (int) depthWiseWeights.size(0); @@ -90,25 +93,25 @@ public class DepthwiseConvolution2DLayer extends ConvolutionLayer { int[] pad; if (convolutionMode == ConvolutionMode.Same) { int[] outSize = ConvolutionUtils.getOutputSize( - input, kernel, strides, null, convolutionMode, dilation); + input, kernel, strides, null, convolutionMode, dilation, format); pad = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[]{inH, inW}, kernel, strides, dilation); } else { pad = layerConf().getPadding(); - ConvolutionUtils.getOutputSize(input, kernel, strides, pad, convolutionMode, dilation); + ConvolutionUtils.getOutputSize(input, kernel, strides, pad, convolutionMode, dilation, format); } INDArray biasGradView = gradientViews.get(DepthwiseConvolutionParamInitializer.BIAS_KEY); INDArray weightGradView = gradientViews.get(DepthwiseConvolutionParamInitializer.WEIGHT_KEY); - INDArray outEpsilon = workspaceMgr.create( - ArrayType.ACTIVATION_GRAD, depthWiseWeights.dataType(), new long[]{miniBatch, inDepth, inH, inW}, 'c'); + long[] epsShape = nchw ? new long[]{miniBatch, inDepth, inH, inW} : new long[]{miniBatch, inH, inW, inDepth}; + INDArray outEpsilon = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, depthWiseWeights.dataType(), epsShape, 'c'); - Integer sameMode = (convolutionMode == ConvolutionMode.Same) ? 1 : 0; + int sameMode = (convolutionMode == ConvolutionMode.Same) ? 1 : 0; int[] args = new int[]{ kH, kW, strides[0], strides[1], pad[0], pad[1], dilation[0], dilation[1], - sameMode + sameMode, (nchw ? 0 : 1) }; INDArray delta; @@ -161,7 +164,7 @@ public class DepthwiseConvolution2DLayer extends ConvolutionLayer { throw new DL4JInvalidInputException("Got rank " + input.rank() + " array as input to DepthwiseConvolution2D (layer name = " + layerName + ", layer index = " + index + ") with shape " + Arrays.toString(input.shape()) + ". " - + "Expected rank 4 array with shape [miniBatchSize, layerInputDepth, inputHeight, inputWidth]." + + "Expected rank 4 array with shape " + layerConf().getCnn2dDataFormat().dimensionNames() + "." + (input.rank() == 2 ? " (Wrong input type (see InputType.convolutionalFlat()) or wrong data type?)" : "") + " " + layerId()); @@ -169,18 +172,22 @@ public class DepthwiseConvolution2DLayer extends ConvolutionLayer { INDArray input = this.input.castTo(dataType); //no-op if correct dtype + CNN2DFormat format = layerConf().getCnn2dDataFormat(); + boolean nchw = format == CNN2DFormat.NCHW; + long inDepth = depthWiseWeights.size(2); long depthMultiplier = depthWiseWeights.size(3); long outDepth = depthMultiplier * inDepth; - if (input.size(1) != inDepth) { + if (input.size(nchw ? 1 : 3) != inDepth) { String layerName = conf.getLayer().getLayerName(); if (layerName == null) layerName = "(not named)"; throw new DL4JInvalidInputException("Cannot do forward pass in DepthwiseConvolution2D layer " + "(layer name = " + layerName + ", layer index = " + index + "): input array channels does not match CNN layer configuration" - + " (data input channels = " + input.size(1) + ", [minibatch,inputDepth,height,width]=" + + " (data input channels = " + input.size(1) + ", " + + (nchw ? "[minibatch,inputDepth,height,width]=" : "[minibatch,height,width,inputDepth]=") + Arrays.toString(input.shape()) + "; expected" + " input channels = " + inDepth + ") " + layerId()); } @@ -194,30 +201,30 @@ public class DepthwiseConvolution2DLayer extends ConvolutionLayer { int[] pad; int[] outSize; if (convolutionMode == ConvolutionMode.Same) { - outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation); + outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation, format); if (input.size(2) > Integer.MAX_VALUE || input.size(3) > Integer.MAX_VALUE) { throw new ND4JArraySizeException(); } pad = ConvolutionUtils.getSameModeTopLeftPadding( - outSize, new int[]{(int) input.size(2), (int) input.size(3)}, kernel, strides, dilation); + outSize, new int[]{(int) input.size(nchw ? 2 : 1), (int) input.size(nchw ? 3 : 2)}, kernel, strides, dilation); } else { pad = layerConf().getPadding(); - outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, pad, convolutionMode, dilation); + outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, pad, convolutionMode, dilation, format); } long outH = outSize[0]; long outW = outSize[1]; val miniBatch = input.size(0); - INDArray output = workspaceMgr.create( - ArrayType.ACTIVATIONS, depthWiseWeights.dataType(), new long[]{miniBatch, outDepth, outH, outW}, 'c'); + long[] outShape = nchw ? new long[]{miniBatch, outDepth, outH, outW} : new long[]{miniBatch, outH, outW, outDepth}; + INDArray output = workspaceMgr.create(ArrayType.ACTIVATIONS, depthWiseWeights.dataType(), outShape, 'c'); - Integer sameMode = (convolutionMode == ConvolutionMode.Same) ? 1 : 0; + int sameMode = (convolutionMode == ConvolutionMode.Same) ? 1 : 0; int[] args = new int[]{ kH, kW, strides[0], strides[1], - pad[0], pad[1], dilation[0], dilation[1], sameMode + pad[0], pad[1], dilation[0], dilation[1], sameMode, (nchw ? 0 : 1) }; INDArray[] inputs; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SeparableConvolution2DLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SeparableConvolution2DLayer.java index 9808b3a24..48a9b8cfa 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SeparableConvolution2DLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SeparableConvolution2DLayer.java @@ -18,6 +18,7 @@ package org.deeplearning4j.nn.layers.convolution; import lombok.val; import org.deeplearning4j.exception.DL4JInvalidInputException; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.CacheMode; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; @@ -80,7 +81,7 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer { if (input.rank() != 4) { throw new DL4JInvalidInputException("Got rank " + input.rank() + " array as input to SubsamplingLayer with shape " + Arrays.toString(input.shape()) - + ". Expected rank 4 array with shape [minibatchSize, channels, inputHeight, inputWidth]. " + + ". Expected rank 4 array with shape " + layerConf().getCnn2dDataFormat().dimensionNames() + ". " + layerId()); } INDArray bias; @@ -91,9 +92,12 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer { INDArray input = this.input.castTo(dataType); + CNN2DFormat format = layerConf().getCnn2dDataFormat(); + boolean nchw = format == CNN2DFormat.NCHW; + long miniBatch = input.size(0); - int inH = (int)input.size(2); - int inW = (int)input.size(3); + int inH = (int)input.size(nchw ? 2 : 1); + int inW = (int)input.size(nchw ? 3 : 2); int inDepth = (int) depthWiseWeights.size(1); int kH = (int) depthWiseWeights.size(2); @@ -104,24 +108,26 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer { int[] strides = layerConf().getStride(); int[] pad; if (convolutionMode == ConvolutionMode.Same) { - int[] outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation); //Also performs validation + int[] outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation, format); //Also performs validation pad = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {inH, inW}, kernel, strides, dilation); } else { pad = layerConf().getPadding(); - ConvolutionUtils.getOutputSize(input, kernel, strides, pad, convolutionMode, dilation); //Also performs validation + ConvolutionUtils.getOutputSize(input, kernel, strides, pad, convolutionMode, dilation, format); //Also performs validation } INDArray biasGradView = gradientViews.get(SeparableConvolutionParamInitializer.BIAS_KEY); INDArray depthWiseWeightGradView = gradientViews.get(SeparableConvolutionParamInitializer.DEPTH_WISE_WEIGHT_KEY); INDArray pointWiseWeightGradView = gradientViews.get(SeparableConvolutionParamInitializer.POINT_WISE_WEIGHT_KEY); - INDArray outEpsilon = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, depthWiseWeights.dataType(), new long[]{miniBatch, inDepth, inH, inW}, 'c'); + long[] epsShape = nchw ? new long[]{miniBatch, inDepth, inH, inW} : new long[]{miniBatch, inH, inW, inDepth}; + INDArray outEpsilon = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, depthWiseWeights.dataType(), epsShape, 'c'); - Integer sameMode = (convolutionMode == ConvolutionMode.Same) ? 1 : 0; + int sameMode = (convolutionMode == ConvolutionMode.Same) ? 1 : 0; int[] args = new int[] { kH, kW, strides[0], strides[1], - pad[0], pad[1], dilation[0], dilation[1], sameMode + pad[0], pad[1], dilation[0], dilation[1], sameMode, + nchw ? 0 : 1 }; INDArray delta; @@ -180,6 +186,12 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer { INDArray input = this.input.castTo(dataType); + CNN2DFormat format = layerConf().getCnn2dDataFormat(); + boolean nchw = format == CNN2DFormat.NCHW; + int chIdx = nchw ? 1 : 3; + int hIdx = nchw ? 2 : 1; + int wIdx = nchw ? 3 : 2; + if (input.rank() != 4) { String layerName = conf.getLayer().getLayerName(); if (layerName == null) @@ -187,7 +199,7 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer { throw new DL4JInvalidInputException("Got rank " + input.rank() + " array as input to SeparableConvolution2D (layer name = " + layerName + ", layer index = " + index + ") with shape " + Arrays.toString(input.shape()) + ". " - + "Expected rank 4 array with shape [minibatchSize, layerInputDepth, inputHeight, inputWidth]." + + "Expected rank 4 array with shape " + format.dimensionNames() + "." + (input.rank() == 2 ? " (Wrong input type (see InputType.convolutionalFlat()) or wrong data type?)" : "") @@ -197,7 +209,7 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer { long inDepth = depthWiseWeights.size(1); long outDepth = pointWiseWeights.size(0); - if (input.size(1) != inDepth) { + if (input.size(nchw ? 1 : 3) != inDepth) { String layerName = conf.getLayer().getLayerName(); if (layerName == null) layerName = "(not named)"; @@ -217,29 +229,31 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer { int[] pad; int[] outSize; if (convolutionMode == ConvolutionMode.Same) { - outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation); //Also performs validation + outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation, format); //Also performs validation if (input.size(2) > Integer.MAX_VALUE || input.size(3) > Integer.MAX_VALUE) { throw new ND4JArraySizeException(); } - pad = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {(int) input.size(2), (int) input.size(3)}, kernel, + pad = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {(int) input.size(hIdx), (int) input.size(wIdx)}, kernel, strides, dilation ); } else { pad = layerConf().getPadding(); - outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, pad, convolutionMode, dilation); //Also performs validation + outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, pad, convolutionMode, dilation, format); //Also performs validation } int outH = outSize[0]; int outW = outSize[1]; val miniBatch = input.size(0); - INDArray output = workspaceMgr.create(ArrayType.ACTIVATIONS, depthWiseWeights.dataType(), new long[]{miniBatch, outDepth, outH, outW}, 'c'); + long[] outShape = nchw ? new long[]{miniBatch, outDepth, outH, outW} : new long[]{miniBatch, outH, outW, outDepth}; + INDArray output = workspaceMgr.create(ArrayType.ACTIVATIONS, depthWiseWeights.dataType(), outShape, 'c'); Integer sameMode = (convolutionMode == ConvolutionMode.Same) ? 1 : 0; int[] args = new int[] { kH, kW, strides[0], strides[1], - pad[0], pad[1], dilation[0], dilation[1], sameMode + pad[0], pad[1], dilation[0], dilation[1], sameMode, + nchw ? 0 : 1 }; //dl4j weights: depth [depthMultiplier, nIn, kH, kW], point [nOut, nIn * depthMultiplier, 1, 1] diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SpaceToBatch.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SpaceToBatch.java index 586464716..720e756fd 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SpaceToBatch.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SpaceToBatch.java @@ -19,6 +19,7 @@ package org.deeplearning4j.nn.layers.convolution; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.nn.api.Layer; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; @@ -91,17 +92,14 @@ public class SpaceToBatch extends AbstractLayer backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { assertInputSet(true); + INDArray input = this.input.castTo(epsilon.dataType()); + + boolean nchw = layerConf().getDataFormat() == CNN2DFormat.NCHW; long miniBatch = input.size(0); - long inDepth = input.size(1); - long inH = input.size(2); - long inW = input.size(3); + long inDepth = input.size(nchw ? 1 : 3); + long inH = input.size(nchw ? 2 : 1); + long inW = input.size(nchw ? 3 : 2); - INDArray input = this.input.castTo(dataType); //No-op if already correct type - - INDArray outEpsilon = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, input.dataType(), new long[]{1, miniBatch * inDepth * inH * inW}, 'c'); - INDArray reshapedEpsilon; - - if (isNHWC() == 1) { - reshapedEpsilon = outEpsilon.reshape('c', miniBatch, inH, inW, inDepth); - } else { - reshapedEpsilon = outEpsilon.reshape('c', miniBatch, inDepth, inH, inW); - } + long[] epsShape = nchw ? new long[]{miniBatch, inDepth, inH, inW} : new long[]{miniBatch, inH, inW, inDepth}; + INDArray outEpsilon = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, input.dataType(), epsShape, 'c'); Gradient gradient = new DefaultGradient(); int blockSize = getBlockSize(); + //Workaround for issue: https://github.com/eclipse/deeplearning4j/issues/8859 + if(!Shape.hasDefaultStridesForShape(epsilon)) + epsilon = epsilon.dup('c'); + CustomOp op = DynamicCustomOp.builder("depth_to_space") .addInputs(epsilon) - .addIntegerArguments(blockSize, isNHWC()) - .addOutputs(reshapedEpsilon) + .addIntegerArguments(blockSize, nchw ? 0 : 1) //nchw = 0, nhwc = 1 + .addOutputs(outEpsilon) .build(); Nd4j.getExecutioner().exec(op); - reshapedEpsilon = backpropDropOutIfPresent(reshapedEpsilon); - return new Pair<>(gradient, reshapedEpsilon); + return new Pair<>(gradient, outEpsilon); } protected INDArray preOutput(boolean training, boolean forBackprop, LayerWorkspaceMgr workspaceMgr) { @@ -113,7 +111,7 @@ public class SpaceToDepth extends AbstractLayer { - private int[] padding; //[padTop, padBottom, padLeft, padRight] - public ZeroPaddingLayer(NeuralNetConfiguration conf, DataType dataType) { super(conf, dataType); - this.padding = ((org.deeplearning4j.nn.conf.layers.ZeroPaddingLayer) conf.getLayer()).getPadding(); } @Override @@ -65,9 +63,23 @@ public class ZeroPaddingLayer extends AbstractLayer((Gradient) new DefaultGradient(), epsNext); @@ -77,16 +89,28 @@ public class ZeroPaddingLayer extends AbstractLayer backpropGradient(INDArray input, INDArray epsilon, int[] kernel, int[] strides, int[] pad, - PoolingType poolingType, ConvolutionMode convolutionMode, int[] dilation, LayerWorkspaceMgr workspaceMgr); + PoolingType poolingType, ConvolutionMode convolutionMode, int[] dilation, + CNN2DFormat format, LayerWorkspaceMgr workspaceMgr); INDArray activate(INDArray input, boolean training, int[] kernel, int[] strides, int[] pad, PoolingType poolingType, - ConvolutionMode convolutionMode, int[] dilation, LayerWorkspaceMgr workspaceMgr); + ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/subsampling/SubsamplingLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/subsampling/SubsamplingLayer.java index b38945e95..85c3723e4 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/subsampling/SubsamplingLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/subsampling/SubsamplingLayer.java @@ -19,6 +19,7 @@ package org.deeplearning4j.nn.layers.convolution.subsampling; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.nn.api.MaskState; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; @@ -108,15 +109,23 @@ public class SubsamplingLayer extends AbstractLayer ret = null; try{ ret = helper.backpropGradient(input, epsilon, kernel, strides, pad, - layerConf().getPoolingType(), convolutionMode, dilation, workspaceMgr); + layerConf().getPoolingType(), convolutionMode, dilation, dataFormat, workspaceMgr); } catch (ND4JOpProfilerException e){ throw e; //NaN panic etc for debugging } catch (Exception e){ @@ -188,26 +197,14 @@ public class SubsamplingLayer extends AbstractLayer(retGradient, epsAtInput); } - private static double minValue(){ - switch (Nd4j.dataType()){ - case DOUBLE: - return -Double.MAX_VALUE; - case FLOAT: - return -Float.MAX_VALUE; - case HALF: - return -65504.0; - default: - throw new IllegalStateException("Unexpected data type: " + Nd4j.dataType()); - } - } - @Override public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) { @@ -219,16 +216,26 @@ public class SubsamplingLayer extends AbstractLayer backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/upsampling/Upsampling2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/upsampling/Upsampling2D.java index 795e1f8af..efbe90aab 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/upsampling/Upsampling2D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/upsampling/Upsampling2D.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -18,7 +19,7 @@ package org.deeplearning4j.nn.layers.convolution.upsampling; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.exception.DL4JInvalidInputException; -import org.deeplearning4j.nn.api.Layer; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.CacheMode; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; @@ -62,34 +63,41 @@ public class Upsampling2D extends AbstractLayer backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { assertInputSet(true); - long miniBatch = (int) input.size(0); - long inDepth = (int) input.size(1); - long inH = (int) input.size(2); - long inW = (int) input.size(3); + CNN2DFormat format = getFormat(); + boolean nchw = format == CNN2DFormat.NCHW; - INDArray reshapedEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, epsilon.dataType(), new long[]{miniBatch, inDepth, inH, inW}, 'c'); + long miniBatch = (int) input.size(0); + long inDepth = (int) input.size(nchw ? 1 : 3); + long inH = (int) input.size(nchw ? 2 : 1); + long inW = (int) input.size(nchw ? 3 : 2); + + long[] epsShape = nchw ? new long[]{miniBatch, inDepth, inH, inW} : new long[]{miniBatch, inH, inW, inDepth}; + INDArray epsOut = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, epsilon.dataType(), epsShape, 'c'); Gradient gradient = new DefaultGradient(); - int[] intArgs = new int[] {1}; // 1 is for NCHW - - CustomOp op = DynamicCustomOp.builder("upsampling_bp") - .addIntegerArguments(intArgs) + .addIntegerArguments(nchw ? 1 : 0) //1=NCHW, 0=NHWC .addInputs(input, epsilon) - .addOutputs(reshapedEpsilon) + .addOutputs(epsOut) .callInplace(false) .build(); Nd4j.getExecutioner().exec(op); - reshapedEpsilon = backpropDropOutIfPresent(reshapedEpsilon); - return new Pair<>(gradient, reshapedEpsilon); + epsOut = backpropDropOutIfPresent(epsOut); + + return new Pair<>(gradient, epsOut); } protected int[] getSize(){ return layerConf().getSize(); } + protected CNN2DFormat getFormat(){ + //Here so it can be overridden by Upsampling1D + return layerConf().getFormat(); + } + protected INDArray preOutput(boolean training, boolean forBackprop, LayerWorkspaceMgr workspaceMgr) { assertInputSet(false); applyDropOutIfNecessary(training, workspaceMgr); @@ -97,7 +105,7 @@ public class Upsampling2D extends AbstractLayer [minibatch, nOut, seqLen] i.e., NWC -> NCW + } return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, ret); } @@ -177,8 +190,14 @@ public class EmbeddingSequenceLayer extends BaseLayer c = Class.forName("org.nd4j.nativeblas.Nd4jCpu$Environment"); Method m = c.getMethod("getInstance"); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNBatchNormHelper.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNBatchNormHelper.java index 027f9d80d..6f825e3d8 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNBatchNormHelper.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNBatchNormHelper.java @@ -16,6 +16,7 @@ package org.deeplearning4j.nn.layers.mkldnn; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.normalization.BatchNormalizationHelper; @@ -28,9 +29,8 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm; -import org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNormDerivative; import org.nd4j.linalg.api.ops.impl.summarystats.Variance; -import org.nd4j.linalg.api.shape.LongShapeDescriptor; +import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.util.ArrayUtil; @@ -47,7 +47,8 @@ import java.util.Map; */ public class MKLDNNBatchNormHelper implements BatchNormalizationHelper { private static final int[] RANK2_DIMS = {0}; - private static final int[] RANK4_DIMS = {0,2,3}; + private static final int[] RANK4_DIMS_NCHW = {0,2,3}; + private static final int[] RANK4_DIMS_NHWC = {0,1,2}; protected OpContext context; private INDArray meanCache; @@ -64,11 +65,18 @@ public class MKLDNNBatchNormHelper implements BatchNormalizationHelper { @Override public Pair backpropGradient(INDArray input, INDArray epsilon, long[] shape, INDArray gamma, - INDArray beta, INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr workspaceMgr) { + INDArray beta, INDArray dGammaView, INDArray dBetaView, double eps, + CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) { + + //Workaround for: https://github.com/eclipse/deeplearning4j/issues/8860 + if(!Shape.hasDefaultStridesForShape(epsilon)) + epsilon = epsilon.dup('c'); + if(input.dataType() != DataType.FLOAT) return null; //MKL-DNN only supports float - //TODO FIXME - AB 2019/11/01 - https://github.com/eclipse/deeplearning4j/issues/8335 + int axis = (input.rank() != 4 || format == CNN2DFormat.NCHW) ? 1 : 3; + List args = new ArrayList<>(); args.add(input); args.add(meanCache); @@ -85,7 +93,7 @@ public class MKLDNNBatchNormHelper implements BatchNormalizationHelper { .addIntegerArguments( gamma == null ? 0 : 1, //Apply scale beta == null ? 0 : 1, //Apply beta - 1) //Axis (NCHW) + axis) //Axis (NCHW) - 1=NCHW, 3=NHWC .addFloatingPointArguments(eps) .build(); @@ -114,16 +122,18 @@ public class MKLDNNBatchNormHelper implements BatchNormalizationHelper { @Override public INDArray preOutput(INDArray x, boolean training, long[] shape, INDArray gamma, INDArray beta, INDArray mean, INDArray var, - double decay, double eps, LayerWorkspaceMgr workspaceMgr) { + double decay, double eps, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) { if(x.dataType() != DataType.FLOAT) return null; //MKL-DNN only supports float + int axis = (x.rank() != 4 || format == CNN2DFormat.NCHW) ? 1 : 3; + if(context == null){ context = Nd4j.getExecutioner().buildContext(); context.setIArguments( ArrayUtil.fromBoolean(gamma != null), ArrayUtil.fromBoolean(beta != null), - 1); //Axis + axis); //Axis - 1 = NCHW, 3 = NHWC context.setTArguments(eps); } @@ -132,12 +142,22 @@ public class MKLDNNBatchNormHelper implements BatchNormalizationHelper { if(training){ if(meanCache == null){ try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - meanCache = Nd4j.createUninitialized(x.dataType(), x.size(1)); - varCache = Nd4j.createUninitialized(x.dataType(), x.size(1)); + meanCache = Nd4j.createUninitialized(x.dataType(), x.size(axis)); + varCache = Nd4j.createUninitialized(x.dataType(), x.size(axis)); } } - x.mean(meanCache, x.rank() == 2 ? RANK2_DIMS : RANK4_DIMS); - Nd4j.exec(new Variance(x, varCache, false, x.rank() == 2 ? RANK2_DIMS : RANK4_DIMS)); + + int[] dims; + if(x.rank() == 2){ + dims = RANK2_DIMS; + } else if(format == CNN2DFormat.NCHW){ + dims = RANK4_DIMS_NCHW; + } else { + dims = RANK4_DIMS_NHWC; + } + + x.mean(meanCache, dims); + Nd4j.exec(new Variance(x, varCache, false, dims)); m = meanCache; v = varCache; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNConvHelper.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNConvHelper.java index 9bbf4deae..5b360b349 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNConvHelper.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNConvHelper.java @@ -16,6 +16,7 @@ package org.deeplearning4j.nn.layers.mkldnn; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.gradient.DefaultGradient; @@ -61,16 +62,19 @@ public class MKLDNNConvHelper implements ConvolutionHelper { public Pair backpropGradient(INDArray input, INDArray weights, INDArray bias, INDArray delta, int[] kernel, int[] strides, int[] pad, INDArray biasGradView, INDArray weightGradView, IActivation afn, ConvolutionLayer.AlgoMode mode, ConvolutionLayer.BwdFilterAlgo bwdFilterAlgo, ConvolutionLayer.BwdDataAlgo bwdDataAlgo, ConvolutionMode convolutionMode, - int[] dilation, LayerWorkspaceMgr workspaceMgr) { + int[] dilation, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) { if(input.dataType() != DataType.FLOAT || weights.dataType() != DataType.FLOAT) return null; //MKL-DNN only supports floating point dtype - //Note: conv2d op expects [kH, kW, iC, oC] weights... DL4J conv uses [oC, iC, kH, kW] - INDArray weightsPermute = weights.permute(2,3,1,0); - INDArray weightGradViewPermute = weightGradView.permute(2,3,1,0); + int hDim = 2; + int wDim = 3; + if(format == CNN2DFormat.NHWC){ + hDim = 1; + wDim = 2; + } if (convolutionMode == ConvolutionMode.Same) { - pad = ConvolutionUtils.getSameModeTopLeftPadding(new int[]{(int)delta.size(2), (int)delta.size(3)}, new int[] {(int) input.size(2), (int) input.size(3)}, + pad = ConvolutionUtils.getSameModeTopLeftPadding(new int[]{(int)delta.size(hDim), (int)delta.size(wDim)}, new int[] {(int) input.size(hDim), (int) input.size(wDim)}, kernel, strides, dilation); } @@ -81,14 +85,15 @@ public class MKLDNNConvHelper implements ConvolutionHelper { pad[0], pad[1], dilation[0], dilation[1], ArrayUtil.fromBoolean(convolutionMode == ConvolutionMode.Same), - 0 //0=NCHW + format == CNN2DFormat.NCHW ? 0 : 1, //0=NCHW, 1=NHWC + 1 //Weight format: 1 - [oC, iC, kH, kW] ); }; INDArray gradAtInput = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), input.shape()); - INDArray[] inputsArr = biasGradView == null ? new INDArray[]{input, weightsPermute, delta} : new INDArray[]{input, weightsPermute, bias, delta}; - INDArray[] outputArr = biasGradView == null ? new INDArray[]{gradAtInput, weightGradViewPermute} : new INDArray[]{gradAtInput, weightGradViewPermute, biasGradView}; + INDArray[] inputsArr = biasGradView == null ? new INDArray[]{input, weights, delta} : new INDArray[]{input, weights, bias, delta}; + INDArray[] outputArr = biasGradView == null ? new INDArray[]{gradAtInput, weightGradView} : new INDArray[]{gradAtInput, weightGradView, biasGradView}; contextBwd.purge(); for( int i=0; i backpropGradient(INDArray input, INDArray epsilon, int[] kernel, int[] strides, int[] pad, PoolingType poolingType, ConvolutionMode convolutionMode, int[] dilation, LayerWorkspaceMgr workspaceMgr) { + public Pair backpropGradient(INDArray input, INDArray epsilon, int[] kernel, int[] strides, int[] pad, + PoolingType poolingType, ConvolutionMode convolutionMode, int[] dilation, + CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) { if(poolingType == PoolingType.SUM || poolingType == PoolingType.PNORM) return null; INDArray gradAtInput = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), input.shape()); + int hIdx = 2; + int wIdx = 3; + if(format == CNN2DFormat.NHWC){ + hIdx = 1; + wIdx = 2; + } + if (convolutionMode == ConvolutionMode.Same) { - pad = ConvolutionUtils.getSameModeTopLeftPadding(new int[]{(int)epsilon.size(2), (int)epsilon.size(3)}, new int[] {(int)input.size(2), (int)input.size(3)}, kernel, strides, dilation); + pad = ConvolutionUtils.getSameModeTopLeftPadding(new int[]{(int)epsilon.size(hIdx), (int)epsilon.size(wIdx)}, new int[] {(int)input.size(hIdx), (int)input.size(wIdx)}, kernel, strides, dilation); } Pooling2DConfig conf = Pooling2DConfig.builder() @@ -75,7 +85,7 @@ public class MKLDNNSubsamplingHelper implements SubsamplingHelper { .sH(strides[0]).sW(strides[1]) .dH(dilation[0]).dW(dilation[1]) .pH(pad[0]).pW(pad[1]) - .isNHWC(false) + .isNHWC(format == CNN2DFormat.NHWC) .build(); switch (poolingType){ @@ -94,16 +104,26 @@ public class MKLDNNSubsamplingHelper implements SubsamplingHelper { } @Override - public INDArray activate(INDArray input, boolean training, int[] kernel, int[] strides, int[] pad, PoolingType poolingType, ConvolutionMode convolutionMode, int[] dilation, LayerWorkspaceMgr workspaceMgr) { - int[] outSize; - if (convolutionMode == ConvolutionMode.Same) { - outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation); //Also performs validation - pad = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {(int)input.size(2), (int)input.size(3)}, kernel, strides, dilation); - } else { - outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, pad, convolutionMode, dilation); //Also performs validation + public INDArray activate(INDArray input, boolean training, int[] kernel, int[] strides, int[] pad, PoolingType poolingType, + ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) { + + int hIdx = 2; + int wIdx = 3; + if(format == CNN2DFormat.NHWC){ + hIdx = 1; + wIdx = 2; } - long[] outShape = new long[]{input.size(0), input.size(1), outSize[0], outSize[1]}; + int[] outSize; + if (convolutionMode == ConvolutionMode.Same) { + outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation, format); //Also performs validation + pad = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {(int)input.size(hIdx), (int)input.size(wIdx)}, kernel, strides, dilation); + } else { + outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, pad, convolutionMode, dilation, format); //Also performs validation + } + + long[] outShape = format == CNN2DFormat.NCHW ? new long[]{input.size(0), input.size(1), outSize[0], outSize[1]} : + new long[]{input.size(0), outSize[0], outSize[1], input.size(3)}; INDArray output = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), outShape); if(context == null){ @@ -115,7 +135,7 @@ public class MKLDNNSubsamplingHelper implements SubsamplingHelper { dilation[0], dilation[1], ArrayUtil.fromBoolean(convolutionMode == ConvolutionMode.Same), 0, //Extra - not used? - 0); //0 = NCHW + format == CNN2DFormat.NCHW ? 0 : 1); //0 = NCHW, 1=NHWC } DynamicCustomOp op; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/BatchNormalization.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/BatchNormalization.java index cd070185c..21362a0ae 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/BatchNormalization.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/BatchNormalization.java @@ -18,6 +18,7 @@ package org.deeplearning4j.nn.layers.normalization; import lombok.extern.slf4j.Slf4j; import lombok.val; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; @@ -112,6 +113,10 @@ public class BatchNormalization extends BaseLayer ret = null; try { ret = helper.backpropGradient(in, eps, shape, gamma, beta, dGammaView, dBetaView, - layerConf.getEps(), workspaceMgr); + layerConf.getEps(), format, workspaceMgr); } catch (ND4JOpProfilerException e){ throw e; //NaN panic etc for debugging } catch (Throwable t){ @@ -282,39 +288,43 @@ public class BatchNormalization extends BaseLayer backpropGradient(INDArray input, INDArray epsilon, long[] shape, INDArray gamma, INDArray beta, - INDArray dGammaView, INDArray dBetaView, double eps, LayerWorkspaceMgr workspaceMgr); + INDArray dGammaView, INDArray dBetaView, double eps, CNN2DFormat format, + LayerWorkspaceMgr workspaceMgr); INDArray preOutput(INDArray x, boolean training, long[] shape, INDArray gamma, INDArray beta, INDArray mean, - INDArray var, double decay, double eps, LayerWorkspaceMgr workspaceMgr); + INDArray var, double decay, double eps, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr); INDArray getMeanCache(DataType dataType); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/LocalResponseNormalization.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/LocalResponseNormalization.java index fe482ad62..3250176e9 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/LocalResponseNormalization.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/LocalResponseNormalization.java @@ -18,6 +18,7 @@ package org.deeplearning4j.nn.layers.normalization; import lombok.val; import org.deeplearning4j.nn.api.Layer; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; @@ -160,12 +161,17 @@ public class LocalResponseNormalization } } + boolean nchw = layerConf().getDataFormat() == CNN2DFormat.NCHW; + int chDim = nchw ? 1 : 3; + int hDim = nchw ? 2 : 1; + int wDim = nchw ? 3 : 2; + Triple triple = activateHelper(true, workspaceMgr, true); INDArray activations = triple.getFirst(); INDArray unitScale = triple.getSecond(); INDArray scale = triple.getThird(); - val channel = input.size(1); + val channel = input.size(chDim); INDArray tmp, addVal; Gradient retGradient = new DefaultGradient(); INDArray reverse = activations.mul(epsilon); @@ -173,15 +179,25 @@ public class LocalResponseNormalization // sumPart = sum(a^j_{x,y} * gb^j_{x,y}) for (int i = 1; i < halfN + 1; i++) { - tmp = sumPart.get(NDArrayIndex.all(), interval(i, channel), NDArrayIndex.all(), NDArrayIndex.all()); - addVal = reverse.get(NDArrayIndex.all(), interval(0, channel - i), NDArrayIndex.all(), NDArrayIndex.all()); - sumPart.put(new INDArrayIndex[] {NDArrayIndex.all(), interval(i, channel), NDArrayIndex.all(), - NDArrayIndex.all()}, tmp.addi(addVal)); + if(nchw) { + tmp = sumPart.get(NDArrayIndex.all(), interval(i, channel), NDArrayIndex.all(), NDArrayIndex.all()); + addVal = reverse.get(NDArrayIndex.all(), interval(0, channel - i), NDArrayIndex.all(), NDArrayIndex.all()); + sumPart.put(new INDArrayIndex[]{NDArrayIndex.all(), interval(i, channel), NDArrayIndex.all(), + NDArrayIndex.all()}, tmp.addi(addVal)); - tmp = sumPart.get(NDArrayIndex.all(), interval(0, channel - i), NDArrayIndex.all(), NDArrayIndex.all()); - addVal = reverse.get(NDArrayIndex.all(), interval(i, channel), NDArrayIndex.all(), NDArrayIndex.all()); - sumPart.put(new INDArrayIndex[] {NDArrayIndex.all(), interval(0, channel - i), NDArrayIndex.all(), - NDArrayIndex.all()}, tmp.addi(addVal)); + tmp = sumPart.get(NDArrayIndex.all(), interval(0, channel - i), NDArrayIndex.all(), NDArrayIndex.all()); + addVal = reverse.get(NDArrayIndex.all(), interval(i, channel), NDArrayIndex.all(), NDArrayIndex.all()); + sumPart.put(new INDArrayIndex[]{NDArrayIndex.all(), interval(0, channel - i), NDArrayIndex.all(), + NDArrayIndex.all()}, tmp.addi(addVal)); + } else { + tmp = sumPart.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), interval(i, channel)); + addVal = reverse.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), interval(0, channel - i)); + sumPart.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), interval(i, channel)}, tmp.addi(addVal)); + + tmp = sumPart.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), interval(0, channel - i)); + addVal = reverse.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), interval(i, channel)); + sumPart.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), interval(0, channel - i)}, tmp.addi(addVal)); + } } // gx = gy * unitScale**-beta - 2 * alpha * beta * sumPart/unitScale * a^i_{x,y} - rearranged for more in-place ops @@ -228,7 +244,10 @@ public class LocalResponseNormalization } } - val channel = input.size(1); + boolean nchw = layerConf().getDataFormat() == CNN2DFormat.NCHW; + int chDim = nchw ? 1 : 3; + + val channel = input.size(chDim); INDArray tmp, addVal; // x^2 = (a^j_{x,y})^2 INDArray activitySqr = input.mul(input); @@ -236,16 +255,27 @@ public class LocalResponseNormalization //sum_{j=max(0, i - n/2)}^{max(N-1, i + n/2)} (a^j_{x,y})^2 ) for (int i = 1; i < halfN + 1; i++) { - tmp = sumPart.get(NDArrayIndex.all(), interval(i, channel), NDArrayIndex.all(), NDArrayIndex.all()); - addVal = activitySqr.get(NDArrayIndex.all(), interval(0, channel - i), NDArrayIndex.all(), - NDArrayIndex.all()); - sumPart.put(new INDArrayIndex[] {NDArrayIndex.all(), interval(i, channel), NDArrayIndex.all(), - NDArrayIndex.all()}, tmp.addi(addVal)); - tmp = sumPart.get(NDArrayIndex.all(), interval(0, channel - i), NDArrayIndex.all(), NDArrayIndex.all()); - addVal = activitySqr.get(NDArrayIndex.all(), interval(i, channel), NDArrayIndex.all(), NDArrayIndex.all()); - sumPart.put(new INDArrayIndex[] {NDArrayIndex.all(), interval(0, channel - i), NDArrayIndex.all(), - NDArrayIndex.all()}, tmp.addi(addVal)); + if(nchw) { + tmp = sumPart.get(NDArrayIndex.all(), interval(i, channel), NDArrayIndex.all(), NDArrayIndex.all()); + addVal = activitySqr.get(NDArrayIndex.all(), interval(0, channel - i), NDArrayIndex.all(), + NDArrayIndex.all()); + sumPart.put(new INDArrayIndex[]{NDArrayIndex.all(), interval(i, channel), NDArrayIndex.all(), + NDArrayIndex.all()}, tmp.addi(addVal)); + + tmp = sumPart.get(NDArrayIndex.all(), interval(0, channel - i), NDArrayIndex.all(), NDArrayIndex.all()); + addVal = activitySqr.get(NDArrayIndex.all(), interval(i, channel), NDArrayIndex.all(), NDArrayIndex.all()); + sumPart.put(new INDArrayIndex[]{NDArrayIndex.all(), interval(0, channel - i), NDArrayIndex.all(), + NDArrayIndex.all()}, tmp.addi(addVal)); + } else { + tmp = sumPart.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), interval(i, channel)); + addVal = activitySqr.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), interval(0, channel - i)); + sumPart.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), interval(i, channel)}, tmp.addi(addVal)); + + tmp = sumPart.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), interval(0, channel - i)); + addVal = activitySqr.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), interval(i, channel)); + sumPart.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), interval(0, channel - i)}, tmp.addi(addVal)); + } } INDArray unitScale = null; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/BaseRecurrentLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/BaseRecurrentLayer.java index 265946bd8..5aa5bc88c 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/BaseRecurrentLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/BaseRecurrentLayer.java @@ -18,7 +18,10 @@ package org.deeplearning4j.nn.layers.recurrent; import org.deeplearning4j.nn.api.layers.RecurrentLayer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; +import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.layers.BaseLayer; +import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -26,7 +29,7 @@ import java.util.HashMap; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -public abstract class BaseRecurrentLayer +public abstract class BaseRecurrentLayer extends BaseLayer implements RecurrentLayer { /** @@ -85,4 +88,19 @@ public abstract class BaseRecurrentLayer backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { INDArray eFwd; INDArray eBwd; - + boolean permute = getRNNDataFormat() == RNNFormat.NWC && epsilon.rank() == 3; + if (permute){ + epsilon = epsilon.permute(0, 2, 1); + } val n = epsilon.size(1)/2; switch (layerConf.getMode()){ case ADD: @@ -165,6 +172,10 @@ public class BidirectionalLayer implements RecurrentLayer { eBwd = TimeSeriesUtils.reverseTimeSeries(eBwd, workspaceMgr, ArrayType.BP_WORKING_MEM); + if (permute){ + eFwd = eFwd.permute(0, 2, 1); + eBwd = eBwd.permute(0, 2, 1); + } Pair g1 = fwd.backpropGradient(eFwd, workspaceMgr); Pair g2 = bwd.backpropGradient(eBwd, workspaceMgr); @@ -176,7 +187,9 @@ public class BidirectionalLayer implements RecurrentLayer { g.gradientForVariable().put(BidirectionalParamInitializer.BACKWARD_PREFIX + e.getKey(), e.getValue()); } - INDArray g2Reversed = TimeSeriesUtils.reverseTimeSeries(g2.getRight(), workspaceMgr, ArrayType.BP_WORKING_MEM); + INDArray g2Right = permute ? g2.getRight().permute(0, 2, 1): g2.getRight(); + INDArray g2Reversed = TimeSeriesUtils.reverseTimeSeries(g2Right, workspaceMgr, ArrayType.BP_WORKING_MEM); + g2Reversed = permute? g2Reversed.permute(0, 2, 1): g2Reversed; INDArray epsOut = g1.getRight().addi(g2Reversed); return new Pair<>(g, epsOut); @@ -186,25 +199,38 @@ public class BidirectionalLayer implements RecurrentLayer { public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) { INDArray out1 = fwd.activate(training, workspaceMgr); INDArray out2 = bwd.activate(training, workspaceMgr); + boolean permute = getRNNDataFormat() == RNNFormat.NWC && out1.rank() == 3; + if(permute){ + out1 = out1.permute(0, 2, 1); + out2 = out2.permute(0, 2, 1); + } //Reverse the output time series. Note: when using LastTimeStepLayer, output can be rank 2 out2 = out2.rank() == 2 ? out2 : TimeSeriesUtils.reverseTimeSeries(out2, workspaceMgr, ArrayType.FF_WORKING_MEM); - + INDArray ret; switch (layerConf.getMode()){ case ADD: - return out1.addi(out2); + ret = out1.addi(out2); + break; case MUL: //TODO may be more efficient ways than this... this.outFwd = out1.detach(); this.outBwd = out2.detach(); - return workspaceMgr.dup(ArrayType.ACTIVATIONS, out1).muli(out2); + ret = workspaceMgr.dup(ArrayType.ACTIVATIONS, out1).muli(out2); + break; case AVERAGE: - return out1.addi(out2).muli(0.5); + ret = out1.addi(out2).muli(0.5); + break; case CONCAT: - INDArray ret = Nd4j.concat(1, out1, out2); - return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, ret); + ret = Nd4j.concat(1, out1, out2); + ret = workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, ret); + break; default: throw new RuntimeException("Unknown mode: " + layerConf.getMode()); } + if (permute){ + ret = ret.permute(0, 2, 1); + } + return ret; } @Override @@ -465,7 +491,9 @@ public class BidirectionalLayer implements RecurrentLayer { public void setInput(INDArray input, LayerWorkspaceMgr layerWorkspaceMgr) { this.input = input; fwd.setInput(input, layerWorkspaceMgr); - + if (getRNNDataFormat() == RNNFormat.NWC){ + input = input.permute(0, 2, 1); + } INDArray reversed; if(!input.isAttached()){ try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { @@ -478,6 +506,9 @@ public class BidirectionalLayer implements RecurrentLayer { reversed = TimeSeriesUtils.reverseTimeSeries(input); } } + if (getRNNDataFormat() == RNNFormat.NWC){ + reversed = reversed.permute(0, 2, 1); + } bwd.setInput(reversed, layerWorkspaceMgr); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTM.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTM.java index e0fd80842..dc155bff3 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTM.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTM.java @@ -88,12 +88,12 @@ public class GravesBidirectionalLSTM } final FwdPassReturn fwdPass = activateHelperDirectional(true, null, null, true, true, workspaceMgr); - + fwdPass.fwdPassOutput = permuteIfNWC(fwdPass.fwdPassOutput); final Pair forwardsGradient = LSTMHelpers.backpropGradientHelper(this, this.conf, - this.layerConf().getGateActivationFn(), this.input, + this.layerConf().getGateActivationFn(), permuteIfNWC(this.input), getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS), - getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS), epsilon, + getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS), permuteIfNWC(epsilon), truncatedBPTT, tbpttBackwardLength, fwdPass, true, GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS, GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS, @@ -106,16 +106,17 @@ public class GravesBidirectionalLSTM final Pair backwardsGradient = LSTMHelpers.backpropGradientHelper(this, this.conf, - this.layerConf().getGateActivationFn(), this.input, + this.layerConf().getGateActivationFn(), permuteIfNWC(this.input), getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_BACKWARDS), - getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_BACKWARDS), epsilon, + getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_BACKWARDS), permuteIfNWC(epsilon), truncatedBPTT, tbpttBackwardLength, backPass, false, GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_BACKWARDS, GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_BACKWARDS, GravesBidirectionalLSTMParamInitializer.BIAS_KEY_BACKWARDS, gradientViews, maskArray, true, null, workspaceMgr, layerConf().isHelperAllowFallback()); - + forwardsGradient.setSecond(permuteIfNWC(forwardsGradient.getSecond())); + backwardsGradient.setSecond(permuteIfNWC(backwardsGradient.getSecond())); //merge the gradient, which is key value pair of String,INDArray //the keys for forwards and backwards should be different @@ -171,7 +172,7 @@ public class GravesBidirectionalLSTM } else { forwardsEval = LSTMHelpers.activateHelper(this, this.conf, this.layerConf().getGateActivationFn(), - this.input, getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS), + permuteIfNWC(this.input), getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS), getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS), getParam(GravesBidirectionalLSTMParamInitializer.BIAS_KEY_FORWARDS), training, null, null, forBackprop || (cacheMode != CacheMode.NONE && training), true, @@ -179,7 +180,7 @@ public class GravesBidirectionalLSTM forBackprop ? cacheMode : CacheMode.NONE, workspaceMgr, layerConf().isHelperAllowFallback()); backwardsEval = LSTMHelpers.activateHelper(this, this.conf, this.layerConf().getGateActivationFn(), - this.input, + permuteIfNWC(this.input), getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_BACKWARDS), getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_BACKWARDS), getParam(GravesBidirectionalLSTMParamInitializer.BIAS_KEY_BACKWARDS), training, null, null, @@ -187,6 +188,8 @@ public class GravesBidirectionalLSTM GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_BACKWARDS, maskArray, true, null, forBackprop ? cacheMode : CacheMode.NONE, workspaceMgr, layerConf().isHelperAllowFallback()); + forwardsEval.fwdPassOutput = permuteIfNWC(forwardsEval.fwdPassOutput); + backwardsEval.fwdPassOutput = permuteIfNWC(backwardsEval.fwdPassOutput); cachedPassForward = forwardsEval; cachedPassBackward = backwardsEval; } @@ -228,10 +231,12 @@ public class GravesBidirectionalLSTM biasKey = GravesBidirectionalLSTMParamInitializer.BIAS_KEY_BACKWARDS; } - return LSTMHelpers.activateHelper(this, this.conf, this.layerConf().getGateActivationFn(), this.input, + FwdPassReturn ret = LSTMHelpers.activateHelper(this, this.conf, this.layerConf().getGateActivationFn(), permuteIfNWC(this.input), getParam(recurrentKey), getParam(inputKey), getParam(biasKey), training, prevOutputActivations, prevMemCellState, forBackprop, forwards, inputKey, maskArray, true, null, forBackprop ? cacheMode : CacheMode.NONE, workspaceMgr, layerConf().isHelperAllowFallback()); + ret.fwdPassOutput = permuteIfNWC(ret.fwdPassOutput); + return ret; } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTM.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTM.java index b112672f9..551d4ff67 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTM.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTM.java @@ -20,6 +20,7 @@ import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.conf.CacheMode; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.params.GravesLSTMParamInitializer; import org.nd4j.base.Preconditions; @@ -89,17 +90,17 @@ public class GravesLSTM extends BaseRecurrentLayer p = LSTMHelpers.backpropGradientHelper(this, - this.conf, this.layerConf().getGateActivationFn(), this.input, - recurrentWeights, inputWeights, epsilon, truncatedBPTT, tbpttBackwardLength, fwdPass, true, + this.conf, this.layerConf().getGateActivationFn(), permuteIfNWC(this.input), + recurrentWeights, inputWeights, permuteIfNWC(epsilon), truncatedBPTT, tbpttBackwardLength, fwdPass, true, GravesLSTMParamInitializer.INPUT_WEIGHT_KEY, GravesLSTMParamInitializer.RECURRENT_WEIGHT_KEY, GravesLSTMParamInitializer.BIAS_KEY, gradientViews, maskArray, true, null, workspaceMgr, layerConf().isHelperAllowFallback()); weightNoiseParams.clear(); - p.setSecond(backpropDropOutIfPresent(p.getSecond())); + p.setSecond(permuteIfNWC(backpropDropOutIfPresent(p.getSecond()))); return p; } @@ -117,8 +118,8 @@ public class GravesLSTM extends BaseRecurrentLayer p = LSTMHelpers.backpropGradientHelper(this, - this.conf, this.layerConf().getGateActivationFn(), this.input, - recurrentWeights, inputWeights, epsilon, truncatedBPTT, tbpttBackwardLength, fwdPass, true, + this.conf, this.layerConf().getGateActivationFn(), permuteIfNWC(this.input), + recurrentWeights, inputWeights, permuteIfNWC(epsilon), truncatedBPTT, tbpttBackwardLength, fwdPass, true, LSTMParamInitializer.INPUT_WEIGHT_KEY, LSTMParamInitializer.RECURRENT_WEIGHT_KEY, LSTMParamInitializer.BIAS_KEY, gradientViews, null, false, helper, workspaceMgr, layerConf().isHelperAllowFallback()); weightNoiseParams.clear(); - p.setSecond(backpropDropOutIfPresent(p.getSecond())); + p.setSecond(permuteIfNWC(backpropDropOutIfPresent(p.getSecond()))); return p; } @@ -153,6 +154,14 @@ public class LSTM extends BaseRecurrentLayer need to zero out these errors to avoid using errors from a masked time step // to calculate the parameter gradients. Mask array has shape [minibatch, timeSeriesLength] -> get column(this time step) timeStepMaskColumn = maskArray.getColumn(time, true); - deltaifogNext.muliColumnVector(timeStepMaskColumn); + deltaifogNext.muli(timeStepMaskColumn); //Later, the deltaifogNext is used to calculate: input weight gradients, recurrent weight gradients, bias gradients } @@ -737,7 +736,7 @@ public class LSTMHelpers { if (maskArray != null) { //Mask array is present: bidirectional RNN -> need to zero out these errors to avoid sending anything // but 0s to the layer below at this time step (for the given example) - epsilonNextSlice.muliColumnVector(timeStepMaskColumn); + epsilonNextSlice.muli(timeStepMaskColumn); } } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/LastTimeStepLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/LastTimeStepLayer.java index 8e9f7c8f1..57b2659a6 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/LastTimeStepLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/LastTimeStepLayer.java @@ -19,6 +19,7 @@ package org.deeplearning4j.nn.layers.recurrent; import lombok.NonNull; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.MaskState; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer; import org.deeplearning4j.util.TimeSeriesUtils; @@ -59,18 +60,38 @@ public class LastTimeStepLayer extends BaseWrapperLayer { @Override public Pair backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { - INDArray newEps = Nd4j.create(epsilon.dataType(), origOutputShape, 'f'); + long[] newEpsShape = origOutputShape; + + boolean nwc = TimeSeriesUtils.getFormatFromRnnLayer(underlying.conf().getLayer()) == RNNFormat.NWC; + INDArray newEps = Nd4j.create(epsilon.dataType(), newEpsShape, 'f'); if(lastTimeStepIdxs == null){ //no mask case - newEps.put(new INDArrayIndex[]{all(), all(), point(origOutputShape[2]-1)}, epsilon); - } else { - INDArrayIndex[] arr = new INDArrayIndex[]{null, all(), null}; - //TODO probably possible to optimize this with reshape + scatter ops... - for( int i=0; i p = TimeSeriesUtils.pullLastTimeSteps(in, mask, workspaceMgr, arrayType); lastTimeStepIdxs = p.getSecond(); + return p.getFirst(); } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayer.java index 4e01ea084..a4f53fa7f 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayer.java @@ -30,6 +30,9 @@ import org.nd4j.linalg.primitives.Pair; import lombok.NonNull; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import static org.deeplearning4j.nn.conf.RNNFormat.NCW; +import static org.deeplearning4j.nn.conf.RNNFormat.NWC; + /** * Masks timesteps with activation equal to the specified masking value, defaulting to 0.0. * Assumes that the input shape is [batch_size, input_size, timesteps]. @@ -76,7 +79,11 @@ public class MaskZeroLayer extends BaseWrapperLayer { throw new IllegalArgumentException("Expected input of shape [batch_size, timestep_input_size, timestep], " + "got shape "+Arrays.toString(input.shape()) + " instead"); } - INDArray mask = input.eq(maskingValue).castTo(input.dataType()).sum(1).neq(input.shape()[1]); + if ((underlying instanceof BaseRecurrentLayer && + ((BaseRecurrentLayer)underlying).getDataFormat() == NWC)){ + input = input.permute(0, 2, 1); + } + INDArray mask = input.eq(maskingValue).castTo(input.dataType()).sum(1).neq(input.shape()[1]).castTo(input.dataType()); underlying.setMaskArray(mask.detach()); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnLossLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnLossLayer.java index dd1b03d63..4d06ae755 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnLossLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnLossLayer.java @@ -22,6 +22,7 @@ import org.deeplearning4j.eval.Evaluation; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.api.layers.IOutputLayer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.BaseLayer; @@ -60,6 +61,8 @@ public class RnnLossLayer extends BaseLayer backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { assertInputSet(true); + INDArray input = this.input; + INDArray labels = this.labels; if (input.rank() != 3) throw new UnsupportedOperationException( "Input is not rank 3. Expected rank 3 input of shape [minibatch, size, sequenceLength]. Got input with rank " + @@ -67,6 +70,10 @@ public class RnnLossLayer extends BaseLayer(gradient, delta3d); @@ -159,13 +168,21 @@ public class RnnLossLayer extends BaseLayer { public static final String STATE_KEY_PREV_ACTIVATION = "prevAct"; + public SimpleRnn(NeuralNetConfiguration conf, DataType dataType) { super(conf, dataType); } @@ -92,6 +94,7 @@ public class SimpleRnn extends BaseRecurrentLayer p = activateHelper(null, true, true, workspaceMgr); @@ -125,8 +128,9 @@ public class SimpleRnn extends BaseRecurrentLayer= end; i--){ - INDArray dldaCurrent = epsilon.get(all(), all(), point(i)); + INDArray dldaCurrent = epsilon.get(all(), all(), point(i)).dup(); INDArray aCurrent = p.getFirst().get(all(), all(), point(i)); INDArray zCurrent = p.getSecond().get(all(), all(), point(i)); INDArray nCurrent = (hasLayerNorm() ? p.getThird().get(all(), all(), point(i)) : null); @@ -141,7 +145,7 @@ public class SimpleRnn extends BaseRecurrentLayer(grad, epsOut); } @@ -224,6 +229,7 @@ public class SimpleRnn extends BaseRecurrentLayer(out, outZ, outPreNorm, recPreNorm); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/TimeDistributedLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/TimeDistributedLayer.java index 727d19eae..93eb8b9fc 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/TimeDistributedLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/TimeDistributedLayer.java @@ -2,6 +2,7 @@ package org.deeplearning4j.nn.layers.recurrent; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.MaskState; +import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer; import org.deeplearning4j.nn.workspace.ArrayType; @@ -22,11 +23,11 @@ import org.nd4j.linalg.util.ArrayUtil; */ public class TimeDistributedLayer extends BaseWrapperLayer { - private final int timeAxis; + private RNNFormat rnnDataFormat; - public TimeDistributedLayer(Layer underlying, int timeAxis) { + public TimeDistributedLayer(Layer underlying, RNNFormat rnnDataFormat) { super(underlying); - this.timeAxis = timeAxis; + this.rnnDataFormat = rnnDataFormat; } @@ -56,7 +57,7 @@ public class TimeDistributedLayer extends BaseWrapperLayer { protected INDArray reshape(INDArray array){ //Reshape the time axis to the minibatch axis //For example, for RNN -> FF (dense time distributed): [mb, size, seqLen] -> [mb x seqLen, size] - int axis = timeAxis; + int axis = (rnnDataFormat == RNNFormat.NCW) ? 2 : 1; if(axis < 0) axis += array.rank(); @@ -91,7 +92,7 @@ public class TimeDistributedLayer extends BaseWrapperLayer { protected INDArray revertReshape(INDArray toRevert, long minibatch){ - int axis = timeAxis; + int axis = (rnnDataFormat == RNNFormat.NCW)? 2 : 1; if(axis < 0) axis += (toRevert.rank()+1); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/ConvolutionParamInitializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/ConvolutionParamInitializer.java index a133fc843..78d36cdeb 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/ConvolutionParamInitializer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/ConvolutionParamInitializer.java @@ -118,6 +118,7 @@ public class ConvolutionParamInitializer implements ParamInitializer { params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams)); conf.addVariable(WEIGHT_KEY); conf.addVariable(BIAS_KEY); + conf.addVariable(BIAS_KEY); } else { INDArray weightView = paramsView; params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams)); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ConvolutionUtils.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ConvolutionUtils.java index 399af4b2d..359b1913b 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ConvolutionUtils.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ConvolutionUtils.java @@ -22,6 +22,7 @@ import lombok.NonNull; import lombok.val; import org.deeplearning4j.exception.DL4JInvalidConfigException; import org.deeplearning4j.exception.DL4JInvalidInputException; +import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -56,6 +57,10 @@ public class ConvolutionUtils { private ConvolutionUtils() { } + /** + * Use {@link #getOutputSize(INDArray, int[], int[], int[], ConvolutionMode, int[], CNN2DFormat)} + */ + @Deprecated public static int[] getOutputSize(INDArray inputData, int[] kernel, int[] strides, int[] padding, ConvolutionMode convolutionMode) { return getOutputSize(inputData, kernel, strides, padding, convolutionMode, ONES); @@ -74,12 +79,15 @@ public class ConvolutionUtils { * @return Output size: int[2] with output height/width */ public static int[] getDeconvolutionOutputSize(INDArray inputData, int[] kernel, int[] strides, int[] padding, - ConvolutionMode convolutionMode, int[] dilation) { + ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format) { + boolean nchw = format == CNN2DFormat.NCHW; + int hDim = nchw ? 2 : 1; + int wDim = nchw ? 3 : 2; - if (inputData.size(2) > Integer.MAX_VALUE || inputData.size(3) > Integer.MAX_VALUE) + if (inputData.size(hDim) > Integer.MAX_VALUE || inputData.size(wDim) > Integer.MAX_VALUE) throw new ND4JArraySizeException(); - int hIn = (int) inputData.size(2); - int wIn = (int) inputData.size(3); + int hIn = (int) inputData.size(hDim); + int wIn = (int) inputData.size(wDim); int[] eKernel = effectiveKernelSize(kernel, dilation); if (convolutionMode == ConvolutionMode.Same) { @@ -138,6 +146,15 @@ public class ConvolutionUtils { } + /** + * @deprecated Use {@link #getOutputSize(INDArray, int[], int[], int[], ConvolutionMode, int[], CNN2DFormat)} + */ + @Deprecated + public static int[] getOutputSize(INDArray inputData, int[] kernel, int[] strides, int[] padding, + ConvolutionMode convolutionMode, int[] dilation) { + return getOutputSize(inputData, kernel, strides, padding, convolutionMode, dilation, CNN2DFormat.NCHW); + } + /** * Get the output size (height/width) for the given input data and CNN configuration * @@ -147,14 +164,22 @@ public class ConvolutionUtils { * @param padding Padding (height/width) * @param convolutionMode Convolution mode (Same, Strict, Truncate) * @param dilation Kernel dilation (height/width) + * @param format Format for input activations * @return Output size: int[2] with output height/width */ public static int[] getOutputSize(INDArray inputData, int[] kernel, int[] strides, int[] padding, - ConvolutionMode convolutionMode, int[] dilation) { - if (inputData.size(2) > Integer.MAX_VALUE || inputData.size(3) > Integer.MAX_VALUE) + ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format) { + int hDim = 2; + int wDim = 3; + if(format == CNN2DFormat.NHWC){ + hDim = 1; + wDim = 2; + } + + if (inputData.size(hDim) > Integer.MAX_VALUE || inputData.size(wDim) > Integer.MAX_VALUE) throw new ND4JArraySizeException(); - int inH = (int) inputData.size(2); - int inW = (int) inputData.size(3); + int inH = (int) inputData.size(hDim); + int inW = (int) inputData.size(wDim); //Determine the effective kernel size, accounting for dilation //http://deeplearning.net/software/theano/tutorial/conv_arithmetic.html#dilated-convolutions @@ -491,18 +516,28 @@ public class ConvolutionUtils { } - public static INDArray reshape4dTo2d(INDArray in, LayerWorkspaceMgr workspaceMgr, ArrayType type){ + public static INDArray reshape4dTo2d(INDArray in, LayerWorkspaceMgr workspaceMgr, ArrayType type) { + return reshape4dTo2d(in, CNN2DFormat.NCHW, workspaceMgr, type); + } + + public static INDArray reshape4dTo2d(INDArray in, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr, ArrayType type){ if (in.rank() != 4) throw new IllegalArgumentException("Invalid input: expect NDArray with rank 4, got rank " + in.rank() + " with shape " + Arrays.toString(in.shape())); val shape = in.shape(); - //Reshape: from [n,c,h,w] to [n*h*w,c] - - INDArray out = in.permute(0, 2, 3, 1); - if (out.ordering() != 'c' || !Shape.strideDescendingCAscendingF(out)) - out = out.dup('c'); - return out.reshape('c', shape[0] * shape[2] * shape[3], shape[1]); + if(format == CNN2DFormat.NCHW){ + //Reshape: from [n,c,h,w] to [n*h*w,c] + INDArray out = in.permute(0, 2, 3, 1); + if (out.ordering() != 'c' || !Shape.strideDescendingCAscendingF(out)) + out = workspaceMgr.dup(type, out, 'c'); + return workspaceMgr.leverageTo(type, out.reshape('c', shape[0] * shape[2] * shape[3], shape[1])); + } else { + //Reshape: from [n,h,w,c] to [n*h*w,c] + if (in.ordering() != 'c' || !Shape.strideDescendingCAscendingF(in)) + in = workspaceMgr.dup(type, in, 'c'); + return workspaceMgr.leverageTo(type, in.reshape('c', shape[0] * shape[1] * shape[2], shape[3])); + } } public static INDArray reshape5dTo2d(@NonNull Convolution3D.DataFormat format, INDArray in, LayerWorkspaceMgr workspaceMgr, ArrayType type){ @@ -541,18 +576,23 @@ public class ConvolutionUtils { } } - public static INDArray reshape2dTo4d(INDArray in2d, long[] toShape, LayerWorkspaceMgr workspaceMgr, ArrayType type){ + public static INDArray reshape2dTo4d(INDArray in2d, long[] toShape, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr, ArrayType type){ if(in2d.rank() != 2) throw new IllegalArgumentException("Invalid input: expect NDArray with rank 2"); if (toShape.length != 4) throw new IllegalArgumentException("Invalid input: expect toShape with 4 elements: got " + Arrays.toString(toShape)); - //Reshape: from [n*h*w,c] to [n,h,w,c] to [n,c,h,w] - if(in2d.ordering() != 'c' || !Shape.hasDefaultStridesForShape(in2d)) + if (in2d.ordering() != 'c' || !Shape.hasDefaultStridesForShape(in2d)) in2d = workspaceMgr.dup(type, in2d, 'c'); - INDArray out = in2d.reshape('c', toShape[0], toShape[2], toShape[3], toShape[1]); - return workspaceMgr.leverageTo(type, out.permute(0, 3, 1, 2)); + if(format == CNN2DFormat.NCHW) { + //Reshape: from [n*h*w,c] to [n,h,w,c] to [n,c,h,w] + INDArray out = in2d.reshape('c', toShape[0], toShape[2], toShape[3], toShape[1]); + return workspaceMgr.leverageTo(type, out.permute(0, 3, 1, 2)); + } else { + //Reshape: from [n*h*w,c] to [n,h,w,c] + return workspaceMgr.leverageTo(type, in2d.reshape('c', toShape)); + } } public static INDArray reshape2dTo5d(Convolution3D.DataFormat format, INDArray in2d, long n, long d, long h, long w, long ch, LayerWorkspaceMgr workspaceMgr, ArrayType type){ @@ -563,7 +603,6 @@ public class ConvolutionUtils { if(in2d.ordering() != 'c' || !Shape.hasDefaultStridesForShape(in2d)) in2d = workspaceMgr.dup(type, in2d, 'c'); -// INDArray ndhwc = in2d.reshape('c', toShape[0], toShape[2], toShape[3], toShape[4], toShape[1]); INDArray ndhwc = in2d.reshape('c', n, d, h, w, ch); if(format == Convolution3D.DataFormat.NDHWC){ return workspaceMgr.leverageTo(type, ndhwc); @@ -572,11 +611,19 @@ public class ConvolutionUtils { } } - public static INDArray reshapeMaskIfRequired(INDArray mask, INDArray output, LayerWorkspaceMgr workspaceMgr, ArrayType type){ + /** + * @deprecated Use {@link #reshapeMaskIfRequired(INDArray, INDArray, CNN2DFormat, LayerWorkspaceMgr, ArrayType)} + */ + @Deprecated + public static INDArray reshapeMaskIfRequired(INDArray mask, INDArray output, LayerWorkspaceMgr workspaceMgr, ArrayType type) { + return reshapeMaskIfRequired(mask, output, null, workspaceMgr, type); + } + + public static INDArray reshapeMaskIfRequired(INDArray mask, INDArray output, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr, ArrayType type){ if (mask == null) return null; if (mask.rank() == 2) { - return adapt2dMask(mask, output, workspaceMgr, type); + return adapt2dMask(mask, output, format, workspaceMgr, type); } else if (mask.rank() == 3) { return reshape3dMask(mask, workspaceMgr, type); } else { @@ -584,19 +631,30 @@ public class ConvolutionUtils { } } - public static INDArray adapt2dMask(INDArray mask, INDArray output, LayerWorkspaceMgr workspaceMgr, ArrayType type){ - //Input in [n,c,h,w] which is reshaped to [n*h*w,c], mask is [n,1] - //So: We'll broadcast to [n,1,h,w] then reshape to [n*h*w,1] required for the current DL4J loss functions... + public static INDArray adapt2dMask(INDArray mask, INDArray output, @NonNull CNN2DFormat format, LayerWorkspaceMgr workspaceMgr, ArrayType type){ - //Use workaround for: https://github.com/deeplearning4j/nd4j/issues/2066 + if(format == CNN2DFormat.NCHW){ + //Input in [n,c,h,w] which is reshaped to [n*h*w,c], mask is [n,1] + //So: We'll broadcast to [n,1,h,w] then reshape to [n*h*w,1] required for the current DL4J loss functions... - val s = output.shape(); - INDArray bMask = workspaceMgr.create(type, mask.dataType(), new long[]{s[0], 1, s[2], s[3]}, 'c'); - Nd4j.getExecutioner().exec(new BroadcastCopyOp(bMask, mask, bMask, 0, 1)); + //Use workaround for: https://github.com/deeplearning4j/nd4j/issues/2066 - INDArray bMaskPermute = bMask.permute(0, 2, 3, 1).dup('c'); //Not sure if dup is strictly necessary... + val s = output.shape(); + INDArray bMask = workspaceMgr.create(type, mask.dataType(), new long[]{s[0], 1, s[2], s[3]}, 'c'); + Nd4j.getExecutioner().exec(new BroadcastCopyOp(bMask, mask, bMask, 0, 1)); - return workspaceMgr.leverageTo(type, bMaskPermute.reshape('c', s[0] * s[2] * s[3], 1)); + INDArray bMaskPermute = bMask.permute(0, 2, 3, 1).dup('c'); //Not sure if dup is strictly necessary... + + return workspaceMgr.leverageTo(type, bMaskPermute.reshape('c', s[0] * s[2] * s[3], 1)); + } else { + //Input in [n,h,w,c] which is reshaped to [n*h*w,c], mask is [n,1] + //So: We'll broadcast to [n,h,w,1] then reshape to [n*h*w,1] required for the current DL4J loss functions... + val s = output.shape(); + INDArray bMask = workspaceMgr.create(type, mask.dataType(), new long[]{s[0], s[2], s[3], 1}, 'c'); + Nd4j.getExecutioner().exec(new BroadcastCopyOp(bMask, mask, bMask, 0, 3)); + + return workspaceMgr.leverageTo(type, bMask.reshape('c', s[0] * s[2] * s[3], 1)); + } } public static INDArray reshape3dMask(INDArray mask, LayerWorkspaceMgr workspaceMgr, ArrayType type){ @@ -679,10 +737,10 @@ public class ConvolutionUtils { int[] s = new int[]{stride, 1}; int[] d = new int[]{dilation, 1}; if (cm == ConvolutionMode.Same || cm == ConvolutionMode.Causal) { - outSize = ConvolutionUtils.getOutputSize(reshaped4d, k, s, null, cm, d); //Also performs validation + outSize = ConvolutionUtils.getOutputSize(reshaped4d, k, s, null, cm, d, CNN2DFormat.NCHW); //Also performs validation } else { pad = new int[]{padding, 0}; - outSize = ConvolutionUtils.getOutputSize(reshaped4d, k, s, pad, cm, d); //Also performs validation + outSize = ConvolutionUtils.getOutputSize(reshaped4d, k, s, pad, cm, d, CNN2DFormat.NCHW); //Also performs validation } int outH = outSize[0]; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/TimeSeriesUtils.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/TimeSeriesUtils.java index 80383698b..2a58387fe 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/TimeSeriesUtils.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/TimeSeriesUtils.java @@ -17,6 +17,13 @@ package org.deeplearning4j.util; import lombok.val; +import org.deeplearning4j.nn.conf.RNNFormat; +import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer; +import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional; +import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep; +import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed; +import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; @@ -233,6 +240,12 @@ public class TimeSeriesUtils { return outReshape.reshape('f', in.size(0), in.size(1), in.size(2)); } + public static INDArray reverseTimeSeries(INDArray in, LayerWorkspaceMgr workspaceMgr, ArrayType arrayType, RNNFormat dataFormat){ + if (dataFormat == RNNFormat.NCW){ + return reverseTimeSeries(in, workspaceMgr, arrayType); + } + return reverseTimeSeries(in.permute(0, 2, 1), workspaceMgr, arrayType).permute(0, 2, 1); + } /** * Reverse an input time series along the time dimension * @@ -423,4 +436,25 @@ public class TimeSeriesUtils { return new Pair<>(workspaceMgr.leverageTo(arrayType, out), fwdPassTimeSteps); } + + /** + * Get the {@link RNNFormat} from the RNN layer, accounting for the presence of wrapper layers like Bidirectional, + * LastTimeStep, etc + * @param layer Layer to get the RNNFormat from + */ + public static RNNFormat getFormatFromRnnLayer(Layer layer){ + if(layer instanceof BaseRecurrentLayer){ + return ((BaseRecurrentLayer) layer).getRnnDataFormat(); + } else if(layer instanceof MaskZeroLayer){ + return getFormatFromRnnLayer(((MaskZeroLayer) layer).getUnderlying()); + } else if(layer instanceof Bidirectional){ + return getFormatFromRnnLayer(((Bidirectional) layer).getFwd()); + } else if(layer instanceof LastTimeStep){ + return getFormatFromRnnLayer(((LastTimeStep) layer).getUnderlying()); + } else if(layer instanceof TimeDistributed){ + return ((TimeDistributed) layer).getRnnDataFormat(); + } else { + throw new IllegalStateException("Unable to get RNNFormat from layer of type: " + layer); + } + } } diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/BinaryModelServerTest.java b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/BinaryModelServerTest.java index 8e109689a..c0652d4e6 100644 --- a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/BinaryModelServerTest.java +++ b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/BinaryModelServerTest.java @@ -1,5 +1,6 @@ package org.deeplearning4j.remote; +import lombok.extern.slf4j.Slf4j; import lombok.val; import org.datavec.image.loader.Java2DNativeImageLoader; import org.deeplearning4j.BaseDL4JTest; @@ -32,6 +33,7 @@ import java.util.concurrent.TimeUnit; import static org.deeplearning4j.parallelism.inference.InferenceMode.SEQUENTIAL; import static org.junit.Assert.*; +@Slf4j public class BinaryModelServerTest extends BaseDL4JTest { private final int PORT = 18080; @@ -120,7 +122,7 @@ public class BinaryModelServerTest extends BaseDL4JTest { assertEquals(new Integer(1), result); } catch (Exception e){ - e.printStackTrace(); + log.error("",e); throw e; } finally { server.stop(); @@ -189,7 +191,7 @@ public class BinaryModelServerTest extends BaseDL4JTest { assertEquals(new Integer(1), results[2].get()); } catch (Exception e){ - e.printStackTrace(); + log.error("",e); throw e; } finally { server.stop(); @@ -244,7 +246,7 @@ public class BinaryModelServerTest extends BaseDL4JTest { assertEquals(28, result.getWidth()); } catch (Exception e){ - e.printStackTrace(); + log.error("",e); throw e; } finally { server.stop(); diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/JsonModelServerTest.java b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/JsonModelServerTest.java index e94ffbb40..5646ee558 100644 --- a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/JsonModelServerTest.java +++ b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/JsonModelServerTest.java @@ -585,7 +585,7 @@ public class JsonModelServerTest extends BaseDL4JTest { assertEquals(exp.argMax().getInt(0), out); } } catch (Exception e){ - e.printStackTrace(); + log.error("",e); throw e; } finally { server.stop(); @@ -640,7 +640,7 @@ public class JsonModelServerTest extends BaseDL4JTest { server.start(); //client.predict(new float[]{0.0f, 1.0f, 2.0f}); } catch (Exception e){ - e.printStackTrace(); + log.error("",e); throw e; } finally { server.stop(); @@ -700,7 +700,7 @@ public class JsonModelServerTest extends BaseDL4JTest { val result = client.predict(new float[]{0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}); assertNotNull(result); } catch (Exception e){ - e.printStackTrace(); + log.error("",e); throw e; } finally { server.stop(); diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelInferenceTest.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelInferenceTest.java index 219573f0a..b149c102f 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelInferenceTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelInferenceTest.java @@ -660,7 +660,7 @@ public class ParallelInferenceTest extends BaseDL4JTest { //OK System.out.println("Expected exception: " + e.getMessage()); } catch (Exception e){ - e.printStackTrace(); + log.error("",e); fail("Expected other exception type"); } @@ -903,7 +903,7 @@ public class ParallelInferenceTest extends BaseDL4JTest { int idx = t.getRight(); act[idx] = inf.output(t.getFirst(), t.getSecond()); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); failedCount.incrementAndGet(); } } @@ -955,7 +955,7 @@ public class ParallelInferenceTest extends BaseDL4JTest { act[j] = inf.output(in.get(j), inMask); counter.incrementAndGet(); } catch (Exception e){ - e.printStackTrace(); + log.error("",e); failedCount.incrementAndGet(); } } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/TokenizerFunction.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/TokenizerFunction.java index 919d98e97..41c9163e1 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/TokenizerFunction.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/TokenizerFunction.java @@ -16,6 +16,7 @@ package org.deeplearning4j.spark.text.functions; +import lombok.extern.slf4j.Slf4j; import org.apache.spark.api.java.function.Function; import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess; import org.deeplearning4j.text.tokenization.tokenizerfactory.NGramTokenizerFactory; @@ -29,6 +30,7 @@ import java.util.List; * @author Adam Gibson */ @SuppressWarnings("unchecked") +@Slf4j public class TokenizerFunction implements Function> { private String tokenizerFactoryClazz; private String tokenizerPreprocessorClazz; @@ -69,7 +71,7 @@ public class TokenizerFunction implements Function> { tokenizerFactory = new NGramTokenizerFactory(tokenizerFactory, nGrams, nGrams); } } catch (Exception e) { - e.printStackTrace(); + log.error("",e); } return tokenizerFactory; } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TextPipelineTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TextPipelineTest.java index 9e5ad1d67..de141c061 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TextPipelineTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TextPipelineTest.java @@ -326,7 +326,7 @@ public class TextPipelineTest extends BaseSparkTest { sc.stop(); } - @Test + @Test @Ignore //AB 2020/04/20 https://github.com/eclipse/deeplearning4j/issues/8849 public void testCountCumSum() throws Exception { JavaSparkContext sc = getContext(); JavaRDD corpusRDD = getCorpusRDD(sc); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v2/UpdatesConsumer.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v2/UpdatesConsumer.java index 2cf251cda..d3a406ea8 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v2/UpdatesConsumer.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v2/UpdatesConsumer.java @@ -120,7 +120,7 @@ public class UpdatesConsumer implements UpdatesHandler { //log.info("Putting update to the queue, current size: [{}]", updatesBuffer.size()); updatesBuffer.put(array); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); throw new RuntimeException(e); } } else if (params != null && stepFunction != null) { diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/DataVecByteDataSetFunction.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/DataVecByteDataSetFunction.java index 9058dd5af..9627c82bd 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/DataVecByteDataSetFunction.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/DataVecByteDataSetFunction.java @@ -16,6 +16,7 @@ package org.deeplearning4j.spark.datavec; +import lombok.extern.slf4j.Slf4j; import org.apache.hadoop.io.BytesWritable; import org.apache.hadoop.io.Text; import org.apache.spark.api.java.function.PairFunction; @@ -35,6 +36,7 @@ import java.util.List; /** */ +@Slf4j public class DataVecByteDataSetFunction implements PairFunction, Double, DataSet> { private int labelIndex = 0; @@ -104,7 +106,7 @@ public class DataVecByteDataSetFunction implements PairFunction inputs = new ArrayList<>(); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/DataVecDataSetFunction.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/DataVecDataSetFunction.java index 3f111911c..801d2b420 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/DataVecDataSetFunction.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/DataVecDataSetFunction.java @@ -16,6 +16,7 @@ package org.deeplearning4j.spark.datavec; +import lombok.extern.slf4j.Slf4j; import org.apache.spark.api.java.function.Function; import org.datavec.api.io.WritableConverter; import org.datavec.api.io.converters.WritableConverterException; @@ -36,6 +37,7 @@ import java.util.List; * Analogous to {@link RecordReaderDataSetIterator}, but in the context of Spark. * @author Alex Black */ +@Slf4j public class DataVecDataSetFunction implements Function, DataSet>, Serializable { private final int labelIndex; @@ -129,7 +131,8 @@ public class DataVecDataSetFunction implements Function, DataSet> try { current = converter.convert(current); } catch (WritableConverterException e) { - e.printStackTrace(); + + log.error("",e); } } if (regression) { diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/common/repartition/BalancedPartitionerTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/common/repartition/BalancedPartitionerTest.java index c93952a4e..b3be0029c 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/common/repartition/BalancedPartitionerTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/common/repartition/BalancedPartitionerTest.java @@ -33,7 +33,7 @@ public class BalancedPartitionerTest { // the 10 first elements should go in the 1st partition for (int i = 0; i < 10; i++) { int p = bp.getPartition(i); - assertEquals("Found wrong partition output " + p + ", not 0", p, 0); + assertEquals("Found wrong partition output " + p + ", not 0", 0, p); } } @@ -43,7 +43,7 @@ public class BalancedPartitionerTest { // the 10 first elements should go in the 1st partition for (int i = 0; i < 10; i++) { int p = bp.getPartition(i); - assertEquals("Found wrong partition output " + p + ", not 0", p, 0); + assertEquals("Found wrong partition output " + p + ", not 0", 0, p); } } @@ -56,7 +56,7 @@ public class BalancedPartitionerTest { countPerPartition[p] += 1; } for (int i = 0; i < 10; i++) { - assertEquals(countPerPartition[i], 10); + assertEquals(10, countPerPartition[i]); } } @@ -70,9 +70,9 @@ public class BalancedPartitionerTest { } for (int i = 0; i < 10; i++) { if (i < 7) - assertEquals(countPerPartition[i], 10 + 1); + assertEquals(10 + 1, countPerPartition[i]); else - assertEquals(countPerPartition[i], 10); + assertEquals(10, countPerPartition[i]); } } diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/main/java/org/deeplearning4j/ui/stats/BaseStatsListener.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/main/java/org/deeplearning4j/ui/stats/BaseStatsListener.java index 233633ab9..979f811e5 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/main/java/org/deeplearning4j/ui/stats/BaseStatsListener.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/main/java/org/deeplearning4j/ui/stats/BaseStatsListener.java @@ -385,7 +385,7 @@ public abstract class BaseStatsListener implements RoutingIterationListener { gpuMaxBytes[i] = nativeOps.getDeviceTotalMemory(0); gpuCurrentBytes[i] = gpuMaxBytes[i] - nativeOps.getDeviceFreeMemory(0); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); } } } diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/test/java/org/deeplearning4j/ui/stats/TestTransferStatsCollection.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/test/java/org/deeplearning4j/ui/stats/TestTransferStatsCollection.java index 4c1347e91..07916643a 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/test/java/org/deeplearning4j/ui/stats/TestTransferStatsCollection.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/test/java/org/deeplearning4j/ui/stats/TestTransferStatsCollection.java @@ -62,13 +62,11 @@ public class TestTransferStatsCollection extends BaseDL4JTest { new FineTuneConfiguration.Builder().updater(new Sgd(0.01)).build()) .setFeatureExtractor(0).build(); - File f = testDir.newFile("dl4jTestTransferStatsCollection.bin"); - f.delete(); + File dir = testDir.newFolder(); + File f = new File(dir, "dl4jTestTransferStatsCollection.bin"); net2.setListeners(new StatsListener(new FileStatsStorage(f))); //Previosuly: failed on frozen layers net2.fit(new DataSet(Nd4j.rand(8, 10), Nd4j.rand(8, 10))); - - f.deleteOnExit(); } } diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/src/main/java/org/deeplearning4j/ui/weights/ConvolutionalIterationListener.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/src/main/java/org/deeplearning4j/ui/weights/ConvolutionalIterationListener.java index 4aae630ca..1c12af527 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/src/main/java/org/deeplearning4j/ui/weights/ConvolutionalIterationListener.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/src/main/java/org/deeplearning4j/ui/weights/ConvolutionalIterationListener.java @@ -17,6 +17,7 @@ package org.deeplearning4j.ui.weights; import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; import lombok.val; import org.datavec.image.loader.ImageLoader; import org.deeplearning4j.api.storage.Persistable; @@ -49,6 +50,7 @@ import java.util.List; /** * @author raver119@gmail.com */ +@Slf4j public class ConvolutionalIterationListener extends BaseTrainingListener { private enum Orientation { @@ -661,7 +663,7 @@ public class ConvolutionalIterationListener extends BaseTrainingListener { try { ImageIO.write(renderImageGrayscale(array), "png", file); } catch (IOException e) { - e.printStackTrace(); + log.error("",e); } } @@ -670,7 +672,7 @@ public class ConvolutionalIterationListener extends BaseTrainingListener { try { ImageIO.write(image, "png", file); } catch (IOException e) { - e.printStackTrace(); + log.error("",e); } } diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/src/test/java/org/deeplearning4j/ui/ManualTests.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/src/test/java/org/deeplearning4j/ui/ManualTests.java index 7621faa3b..a489dc854 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/src/test/java/org/deeplearning4j/ui/ManualTests.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/src/test/java/org/deeplearning4j/ui/ManualTests.java @@ -16,6 +16,7 @@ package org.deeplearning4j.ui; +import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.IOUtils; import org.datavec.image.loader.LFWLoader; import org.deeplearning4j.datasets.iterator.impl.LFWDataSetIterator; @@ -82,6 +83,7 @@ import static org.junit.Assert.fail; * @author raver119@gmail.com */ @Ignore +@Slf4j public class ManualTests { private static Logger log = LoggerFactory.getLogger(ManualTests.class); @@ -258,7 +260,7 @@ public class ManualTests { try { ImageIO.write(imageToRender, "png", file); } catch (IOException e) { - e.printStackTrace(); + log.error("",e); } } diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java index cadc5050c..f90360d26 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java @@ -929,7 +929,7 @@ public class TrainModule implements UIModule { NeuralNetConfiguration.mapper().readValue(config, NeuralNetConfiguration.class); return new Triple<>(null, null, layer); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); } } return null; diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIMultiSession.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIMultiSession.java index deeb4a4bd..e4545ce55 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIMultiSession.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIMultiSession.java @@ -34,6 +34,7 @@ import org.deeplearning4j.ui.api.UIServer; import org.deeplearning4j.ui.stats.StatsListener; import org.deeplearning4j.ui.storage.InMemoryStatsStorage; import org.junit.Before; +import org.junit.Ignore; import org.junit.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; @@ -53,7 +54,7 @@ import static org.junit.Assert.*; /** * @author Tamas Fenyvesi */ -@Slf4j +@Slf4j @Ignore //https://github.com/eclipse/deeplearning4j/issues/8891 public class TestVertxUIMultiSession extends BaseDL4JTest { @Before @@ -121,7 +122,7 @@ public class TestVertxUIMultiSession extends BaseDL4JTest { assertEquals(HttpResponseStatus.OK.code(), conn.getResponseCode()); assertTrue(uIServer.isAttached(ss)); } catch (IOException e) { - e.printStackTrace(); + log.error("",e); fail(e.getMessage()); } finally { uIServer.detach(ss); @@ -206,5 +207,4 @@ public class TestVertxUIMultiSession extends BaseDL4JTest { throws UnsupportedEncodingException { return String.format("%s/train/%s", serverAddress, URLEncoder.encode(sessionId, "UTF-8")); } - } diff --git a/deeplearning4j/deeplearning4j-zoo/src/main/java/org/deeplearning4j/zoo/model/AlexNet.java b/deeplearning4j/deeplearning4j-zoo/src/main/java/org/deeplearning4j/zoo/model/AlexNet.java index e6461c121..41ddb7874 100644 --- a/deeplearning4j/deeplearning4j-zoo/src/main/java/org/deeplearning4j/zoo/model/AlexNet.java +++ b/deeplearning4j/deeplearning4j-zoo/src/main/java/org/deeplearning4j/zoo/model/AlexNet.java @@ -21,7 +21,6 @@ import lombok.Builder; import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.*; -import org.deeplearning4j.nn.conf.distribution.GaussianDistribution; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java index de1fb07b3..cd1a36493 100644 --- a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java +++ b/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java @@ -1045,7 +1045,7 @@ public class IntegrationTestRunner { act[j] = inf.output(in.get(j).getFirst(), inMask); counter.incrementAndGet(); } catch (Exception e){ - e.printStackTrace(); + log.error("",e); failedCount.incrementAndGet(); } } diff --git a/libnd4j/include/helpers/mman.h b/libnd4j/include/helpers/mman.h index 21f3fcceb..618ee23c3 100644 --- a/libnd4j/include/helpers/mman.h +++ b/libnd4j/include/helpers/mman.h @@ -138,15 +138,6 @@ void _mmap(Nd4jLong* result, size_t length, const char *fileName) { OffsetType off = 0; int prot = PROT_READ | PROT_WRITE; - // we need to convert long path (probably) to short pat (actually) - // it's Windows API, in the middle of 2018! - auto sz = GetShortPathName(fileName, nullptr, 0); - - auto shortName = new TCHAR[sz]; - GetShortPathName(fileName, shortName, sz); - - delete[] shortName; - #ifdef _MSC_VER #pragma warning(push) #pragma warning(disable: 4293) @@ -170,7 +161,7 @@ void _mmap(Nd4jLong* result, size_t length, const char *fileName) { #pragma warning(pop) #endif - h = CreateFile(shortName, GENERIC_READ | GENERIC_WRITE, FILE_SHARE_WRITE | FILE_SHARE_READ, nullptr, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, nullptr); + h = CreateFileA(fileName, GENERIC_READ | GENERIC_WRITE, FILE_SHARE_WRITE | FILE_SHARE_READ, nullptr, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, nullptr); if (h == INVALID_HANDLE_VALUE) { errno = __map_mman_error(GetLastError(), EPERM); diff --git a/libnd4j/include/memory/impl/MemoryTracker.cpp b/libnd4j/include/memory/impl/MemoryTracker.cpp index be3019b08..5ebb4fd16 100644 --- a/libnd4j/include/memory/impl/MemoryTracker.cpp +++ b/libnd4j/include/memory/impl/MemoryTracker.cpp @@ -90,6 +90,9 @@ namespace sd { return result; } } + + // safe return + return std::string(""); } #endif diff --git a/libnd4j/include/ops/declarable/generic/linalg/tri.cpp b/libnd4j/include/ops/declarable/generic/linalg/tri.cpp index 19144e2fb..c7e1a125b 100644 --- a/libnd4j/include/ops/declarable/generic/linalg/tri.cpp +++ b/libnd4j/include/ops/declarable/generic/linalg/tri.cpp @@ -51,7 +51,9 @@ DECLARE_SHAPE_FN(tri) { const int rows = INT_ARG(0); const int cols = block.numI() > 1 ? INT_ARG(1) : rows; - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(block.dataType(), 'c', {rows, cols})); + auto dtype = block.numD() ? D_ARG(0) : DataType::FLOAT32; + + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(dtype, 'c', {rows, cols})); } diff --git a/libnd4j/include/ops/declarable/generic/random/exponential.cpp b/libnd4j/include/ops/declarable/generic/random/exponential.cpp index 8605ffafe..cac3d1a88 100644 --- a/libnd4j/include/ops/declarable/generic/random/exponential.cpp +++ b/libnd4j/include/ops/declarable/generic/random/exponential.cpp @@ -27,29 +27,8 @@ namespace sd { namespace ops { CUSTOM_OP_IMPL(random_exponential, 1, 1, true, 1, 0) { - // uniform distribution + // random generator for distribution auto rng = block.randomGenerator(); - - // FIXME: to be implemented - /* - if (rng == nullptr) - return Status::THROW("RNG is null, aborting..."); - - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); - - if (block.width() == 1) - functions::random::RandomFunction::template execTransform>(block.getRNG(), z->getBuffer(), z->getShapeInfo(), block.getTArguments()->data()); - else { - auto y = INPUT_VARIABLE(1); - REQUIRE_TRUE(y->isSameShape(z), 0, "ExponentialDistribution: Y shape should be equal to Z shape"); - - functions::random::RandomFunction::template execTransform>(block.getRNG(), y->getBuffer(), y->getShapeInfo(), z->getBuffer(), z->getShapeInfo(), block.getTArguments()->data()); - } - - STORE_RESULT(*z); -*/ - auto z = OUTPUT_VARIABLE(0); auto lambda = T_ARG(0); diff --git a/libnd4j/include/ops/declarable/generic/transforms/batch_to_space.cpp b/libnd4j/include/ops/declarable/generic/transforms/batch_to_space.cpp index e915df7f0..607980f0d 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/batch_to_space.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/batch_to_space.cpp @@ -75,7 +75,10 @@ CUSTOM_OP_IMPL(batch_to_space, 2, 1, false, 0, 1) { REQUIRE_TRUE(oH >= 0, 0, "BatchToSpace: crop top/bottom values are too big and cause negative output height dimension !"); REQUIRE_TRUE(oW >= 0, 0, "BatchToSpace: crop left/right values are too big and cause negative output width dimension !"); - helpers::batchToSpace(block.launchContext(), *input, *output, cropBottom, cropTop, cropLeft, cropRight, blockSize); + if (shape::strideDescendingCAscendingF(input->shapeInfo())) + helpers::batchToSpace(block.launchContext(), *input, *output, cropBottom, cropTop, cropLeft, cropRight, blockSize); + else + helpers::batchToSpace(block.launchContext(), input->dup(), *output, cropBottom, cropTop, cropLeft, cropRight, blockSize); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/transforms/batch_to_space_nd.cpp b/libnd4j/include/ops/declarable/generic/transforms/batch_to_space_nd.cpp index 312fff7ec..f62921cc2 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/batch_to_space_nd.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/batch_to_space_nd.cpp @@ -74,7 +74,10 @@ CUSTOM_OP_IMPL(batch_to_space_nd, 3, 1, false, 0, 0) { REQUIRE_TRUE(outSpatialDim >= 0, 0, "BatchToSpaceND: crop left/right values are too big and cause negative output spatial dimension/dimensions !"); } - helpers::batchToSpaceND(block.launchContext(), *input, *blockShape, *crop, *output); + if (shape::strideDescendingCAscendingF(input->shapeInfo())) + helpers::batchToSpaceND(block.launchContext(), *input, *blockShape, *crop, *output); + else + helpers::batchToSpaceND(block.launchContext(), input->dup(), *blockShape, *crop, *output); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/transforms/depth_to_space.cpp b/libnd4j/include/ops/declarable/generic/transforms/depth_to_space.cpp index 63c351c34..dcf827eb1 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/depth_to_space.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/depth_to_space.cpp @@ -44,7 +44,10 @@ namespace ops { auto output = OUTPUT_VARIABLE(0); - helpers::_depthToSpace(block.launchContext(), input, output, block_size, isNHWC); + if (shape::strideDescendingCAscendingF(input->shapeInfo())) + helpers::_depthToSpace(block.launchContext(), *input, output, block_size, isNHWC); + else + helpers::_depthToSpace(block.launchContext(), input->dup(), output, block_size, isNHWC); STORE_RESULT(output); diff --git a/libnd4j/include/ops/declarable/generic/transforms/reverse.cpp b/libnd4j/include/ops/declarable/generic/transforms/reverse.cpp index ceb953979..401b68d00 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/reverse.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/reverse.cpp @@ -85,7 +85,7 @@ namespace ops { // check the consistency of input dimensions to reverse along shape::checkDimensions(input->rankOf(), axis); // we just reverse back original array - helpers::reverse(block.launchContext(), eps, output, &axis, true); + helpers::reverse(block.launchContext(), eps, output, &axis, false); } return Status::OK(); diff --git a/libnd4j/include/ops/declarable/generic/transforms/space_to_batch.cpp b/libnd4j/include/ops/declarable/generic/transforms/space_to_batch.cpp index 12b981ac2..9a1683818 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/space_to_batch.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/space_to_batch.cpp @@ -36,6 +36,7 @@ CUSTOM_OP_IMPL(space_to_batch, 2, 1, false, 0, 1) { auto output = OUTPUT_VARIABLE(0); + const uint blockSize = INT_ARG(0); REQUIRE_TRUE(blockSize >= 2, 0, "SpaceToBatch: integer parameter block_size must be >= 2, but got %i instead", blockSize); @@ -52,7 +53,10 @@ CUSTOM_OP_IMPL(space_to_batch, 2, 1, false, 0, 1) { REQUIRE_TRUE((input->sizeAt(1) + padBottom + padTop) % blockSize == 0 && (input->sizeAt(2) + padLeft + padRight) % blockSize == 0, 0, "SpaceToBatch: after padding, second and third dimensions of input array must be divisible by blockSize !"); - helpers::spaceToBatch(block.launchContext(), *input, *output, padBottom, padTop, padLeft, padRight, blockSize); + if (shape::strideDescendingCAscendingF(input->shapeInfo())) + helpers::spaceToBatch(block.launchContext(), *input, *output, padBottom, padTop, padLeft, padRight, blockSize); + else + helpers::spaceToBatch(block.launchContext(), input->dup(), *output, padBottom, padTop, padLeft, padRight, blockSize); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/transforms/space_to_batch_nd.cpp b/libnd4j/include/ops/declarable/generic/transforms/space_to_batch_nd.cpp index a782f5b02..0b8c4152d 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/space_to_batch_nd.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/space_to_batch_nd.cpp @@ -56,7 +56,10 @@ CUSTOM_OP_IMPL(space_to_batch_nd, 3, 1, false, 0, 0) { REQUIRE_TRUE((input->sizeAt(i + 1) + padLeft + padRight) % blockSize == 0, 0, "SpaceToBatchND: after padding, spatial dimensions of input array must be divisible by blockSize !"); } - helpers::spaceToBatchND(block.launchContext(), *input, *blockShape, *padding, *output); + if (shape::strideDescendingCAscendingF(input->shapeInfo())) + helpers::spaceToBatchND(block.launchContext(), *input, *blockShape, *padding, *output); + else + helpers::spaceToBatchND(block.launchContext(), input->dup(), *blockShape, *padding, *output); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/transforms/space_to_depth.cpp b/libnd4j/include/ops/declarable/generic/transforms/space_to_depth.cpp index 3daf62ccd..b831dce2f 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/space_to_depth.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/space_to_depth.cpp @@ -51,7 +51,10 @@ namespace ops { auto output = OUTPUT_VARIABLE(0); - helpers::_spaceTodepth(block.launchContext(), input, output, block_size, isNHWC); + if (shape::strideDescendingCAscendingF(input->shapeInfo())) + helpers::_spaceTodepth(block.launchContext(), *input, output, block_size, isNHWC); + else + helpers::_spaceTodepth(block.launchContext(), input->dup(), output, block_size, isNHWC); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/d_t_s.cpp b/libnd4j/include/ops/declarable/helpers/cpu/d_t_s.cpp index 700e5b8dd..598b3dc30 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/d_t_s.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/d_t_s.cpp @@ -26,14 +26,14 @@ namespace ops { namespace helpers { template - static void __depthToSpace(NDArray *input, NDArray *output, int block_size, bool isNHWC) { - T *input_ptr = reinterpret_cast(input->buffer()); + static void __depthToSpace(const NDArray &input, NDArray *output, int block_size, bool isNHWC) { + T *input_ptr = reinterpret_cast(input.getBuffer()); T *output_ptr = reinterpret_cast(output->buffer()); - const int batch_size = input->sizeAt(0); - const int input_depth = isNHWC ? input->sizeAt(3) : input->sizeAt(1); - const int input_height = isNHWC ? input->sizeAt(1) : input->sizeAt(2); - const int input_width = isNHWC ? input->sizeAt(2) : input->sizeAt(3); + const int batch_size = input.sizeAt(0); + const int input_depth = isNHWC ? input.sizeAt(3) : input.sizeAt(1); + const int input_height = isNHWC ? input.sizeAt(1) : input.sizeAt(2); + const int input_width = isNHWC ? input.sizeAt(2) : input.sizeAt(3); const int output_depth = isNHWC ? output->sizeAt(3) : output->sizeAt(1); const int output_height = isNHWC ? output->sizeAt(1) : output->sizeAt(2); @@ -93,13 +93,13 @@ namespace helpers { } } - void _depthToSpace(sd::LaunchContext * context, NDArray *input, NDArray *output, int block_size, bool isNHWC) { - auto xType = input->dataType(); + void _depthToSpace(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC) { + auto xType = input.dataType(); BUILD_SINGLE_SELECTOR(xType, __depthToSpace, (input, output, block_size, isNHWC), LIBND4J_TYPES); } - BUILD_SINGLE_TEMPLATE(template void __depthToSpace, (NDArray *input, NDArray *output, int block_size, bool isNHWC);, LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template void __depthToSpace, (const NDArray &input, NDArray *output, int block_size, bool isNHWC);, LIBND4J_TYPES); } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/s_t_d.cpp b/libnd4j/include/ops/declarable/helpers/cpu/s_t_d.cpp index 32968b486..5668ea422 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/s_t_d.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/s_t_d.cpp @@ -25,14 +25,14 @@ namespace sd { namespace ops { namespace helpers { template - static void _spaceTodepth_(NDArray *input, NDArray *output, int block_size, bool isNHWC) { - auto input_ptr = reinterpret_cast(input->buffer()); + static void _spaceTodepth_(const NDArray &input, NDArray *output, int block_size, bool isNHWC) { + auto input_ptr = reinterpret_cast(input.getBuffer()); auto output_ptr = reinterpret_cast(output->buffer()); - const int batch_size = input->sizeAt(0); - const int input_depth = isNHWC ? input->sizeAt(3) : input->sizeAt(1); - const int input_height = isNHWC ? input->sizeAt(1) : input->sizeAt(2); - const int input_width = isNHWC ? input->sizeAt(2) : input->sizeAt(3); + const int batch_size = input.sizeAt(0); + const int input_depth = isNHWC ? input.sizeAt(3) : input.sizeAt(1); + const int input_height = isNHWC ? input.sizeAt(1) : input.sizeAt(2); + const int input_width = isNHWC ? input.sizeAt(2) : input.sizeAt(3); const int output_depth = isNHWC ? output->sizeAt(3) : output->sizeAt(1); const int output_height = isNHWC ? output->sizeAt(1) : output->sizeAt(2); @@ -97,11 +97,11 @@ namespace helpers { } } - void _spaceTodepth(sd::LaunchContext * context, NDArray *input, NDArray *output, int block_size, bool isNHWC) { - BUILD_SINGLE_SELECTOR(input->dataType(), _spaceTodepth_, (input, output, block_size, isNHWC), LIBND4J_TYPES); + void _spaceTodepth(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC) { + BUILD_SINGLE_SELECTOR(input.dataType(), _spaceTodepth_, (input, output, block_size, isNHWC), LIBND4J_TYPES); } - BUILD_SINGLE_TEMPLATE(template void _spaceTodepth_, (NDArray *input, NDArray *output, int block_size, bool isNHWC), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template void _spaceTodepth_, (const NDArray &input, NDArray *output, int block_size, bool isNHWC), LIBND4J_TYPES); } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/d_t_s.cu b/libnd4j/include/ops/declarable/helpers/cuda/d_t_s.cu index 35103d18b..fc3b04ee8 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/d_t_s.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/d_t_s.cu @@ -88,20 +88,20 @@ namespace helpers { template - static void __depthToSpace(sd::LaunchContext * context, NDArray *input, NDArray *output, int block_size, bool isNHWC) { - depthToSpaceKernel<<<512, 512, 1024, *context->getCudaStream()>>>(input->specialBuffer(), input->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), block_size, isNHWC); + static void __depthToSpace(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC) { + depthToSpaceKernel<<<512, 512, 1024, *context->getCudaStream()>>>(input.getSpecialBuffer(), input.getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), block_size, isNHWC); } - void _depthToSpace(sd::LaunchContext * context, NDArray *input, NDArray *output, int block_size, bool isNHWC) { - auto xType = input->dataType(); + void _depthToSpace(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC) { + auto xType = input.dataType(); - NDArray::prepareSpecialUse({output}, {input}); + NDArray::prepareSpecialUse({output}, {&input}); BUILD_SINGLE_SELECTOR(xType, __depthToSpace, (context, input, output, block_size, isNHWC), LIBND4J_TYPES); - NDArray::registerSpecialUse({output}, {input}); + NDArray::registerSpecialUse({output}, {&input}); } - BUILD_SINGLE_TEMPLATE(template void __depthToSpace, (sd::LaunchContext * context, NDArray *input, NDArray *output, int block_size, bool isNHWC);, LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template void __depthToSpace, (sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC);, LIBND4J_TYPES); } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/s_t_d.cu b/libnd4j/include/ops/declarable/helpers/cuda/s_t_d.cu index a5ae42e78..4290a57c6 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/s_t_d.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/s_t_d.cu @@ -90,17 +90,17 @@ namespace helpers { } template - static void _spaceTodepth_(sd::LaunchContext * context, NDArray *input, NDArray *output, int block_size, bool isNHWC) { - spaceToDepthKernel<<<512, 512, 1024, *context->getCudaStream()>>>(input->specialBuffer(), input->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), block_size, isNHWC); + static void _spaceTodepth_(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC) { + spaceToDepthKernel<<<512, 512, 1024, *context->getCudaStream()>>>(input.getSpecialBuffer(), input.getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), block_size, isNHWC); } - void _spaceTodepth(sd::LaunchContext * context, NDArray *input, NDArray *output, int block_size, bool isNHWC) { - NDArray::prepareSpecialUse({output}, {input}); - BUILD_SINGLE_SELECTOR(input->dataType(), _spaceTodepth_, (context, input, output, block_size, isNHWC), LIBND4J_TYPES); - NDArray::registerSpecialUse({output}, {input}); + void _spaceTodepth(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC) { + NDArray::prepareSpecialUse({output}, {&input}); + BUILD_SINGLE_SELECTOR(input.dataType(), _spaceTodepth_, (context, input, output, block_size, isNHWC), LIBND4J_TYPES); + NDArray::registerSpecialUse({output}, {&input}); } - BUILD_SINGLE_TEMPLATE(template void _spaceTodepth_, (sd::LaunchContext * context, NDArray *input, NDArray *output, int block_size, bool isNHWC), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template void _spaceTodepth_, (sd::LaunchContext *context, const NDArray &input, NDArray *output, int block_size, bool isNHWC), LIBND4J_TYPES); } } diff --git a/libnd4j/include/ops/declarable/helpers/d_t_s.h b/libnd4j/include/ops/declarable/helpers/d_t_s.h index 20c11ec24..e5ac58e5a 100644 --- a/libnd4j/include/ops/declarable/helpers/d_t_s.h +++ b/libnd4j/include/ops/declarable/helpers/d_t_s.h @@ -25,7 +25,7 @@ namespace sd { namespace ops { namespace helpers { - void _depthToSpace(sd::LaunchContext * context, NDArray *input, NDArray *output, int block_size, bool isNHWC); + void _depthToSpace(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC); } } } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/s_t_d.h b/libnd4j/include/ops/declarable/helpers/s_t_d.h index 6dbc64f21..7ef500f03 100644 --- a/libnd4j/include/ops/declarable/helpers/s_t_d.h +++ b/libnd4j/include/ops/declarable/helpers/s_t_d.h @@ -24,7 +24,7 @@ namespace sd { namespace ops { namespace helpers { - void _spaceTodepth(sd::LaunchContext * context, NDArray *input, NDArray *output, int block_size, bool isNHWC); + void _spaceTodepth(sd::LaunchContext * context, const NDArray &input, NDArray *output, int block_size, bool isNHWC); } } } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp index 173880e63..6ae27b42a 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp @@ -151,7 +151,7 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray ////////////////////////////////////////////////////////////////////////// -static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray* dLdO, const NDArray* weights, +static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray &dLdO, const NDArray* weights, NDArray* dLdI, NDArray* dLdW, const float epsilon, const bool isNCHW) { // unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any) for x @@ -213,7 +213,7 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const dnnl::memory::desc dLdO_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any); dnnl::memory::desc dLdO_user_md = dnnl::memory::desc(dims, type, format); - mkldnnUtils::setBlockStrides(dLdO, dLdO_user_md); + mkldnnUtils::setBlockStrides(&dLdO, dLdO_user_md); // dLdI dnnl::memory::desc dLdI_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any); @@ -242,7 +242,7 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]); // dLdO - mkldnnUtils::loadDataToMklStream(dLdO, engine, stream, dLdO_user_md, op_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]); + mkldnnUtils::loadDataToMklStream(&dLdO, engine, stream, dLdO_user_md, op_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]); // mean auto mean_mkl_mem = dnnl::memory(op_bp_prim_desc.mean_desc(), engine, mean->getBuffer()); @@ -316,7 +316,7 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const stdInv.applyTransform(transform::Sqrt, stdInv); // 1 / (variance + epsilon)^0.5 // dfdm / N - auto dfdm = dLdO->reduceAlongDimension(sd::reduce::Sum, excludedAxes); + auto dfdm = dLdO.reduceAlongDimension(sd::reduce::Sum, excludedAxes); dfdm *= stdInv; dfdm *= -Ninv; @@ -327,7 +327,7 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const // (2/N)*dfdv NDArray dfdv(variance); // empty array with same shape as variance - (xMinusMean * *dLdO).reduceAlongDimension(sd::reduce::Sum, dfdv, excludedAxes); + (xMinusMean * dLdO).reduceAlongDimension(sd::reduce::Sum, dfdv, excludedAxes); dfdv *= stdInv*stdInv*stdInv; dfdv *= -Ninv; @@ -661,7 +661,10 @@ PLATFORM_IMPL(batchnorm_bp, ENGINE_CPU) { const bool isNCHW = !(axes[0] == inRank - 1 && inRank > 2); - batchnormBackPropMKLDNN(input, mean, variance, dLdO, weights, dLdI, dLdW, epsilon, isNCHW); + if (shape::strideDescendingCAscendingF(dLdO->shapeInfo())) + batchnormBackPropMKLDNN(input, mean, variance, *dLdO, weights, dLdI, dLdW, epsilon, isNCHW); + else + batchnormBackPropMKLDNN(input, mean, variance, dLdO->dup(), weights, dLdI, dLdW, epsilon, isNCHW); *dLdM = 0; *dLdV = 0; diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp index f3ef84e2f..0dd3b21f7 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp @@ -31,6 +31,20 @@ namespace sd { namespace ops { namespace platforms { + dnnl::memory::format_tag get_format_tag(const sd::NDArray &array) { + switch (array.rankOf()) { + case 1: + return dnnl::memory::format_tag::ab; + case 2: + return array.ordering() == 'c' ? dnnl::memory::format_tag::ab : dnnl::memory::format_tag::ba; + case 3: + return array.ordering() == 'c' ? dnnl::memory::format_tag::abc : dnnl::memory::format_tag::cba; + default: + throw std::runtime_error("MKLDNN matmul only supports 2D/3D arrays"); + } + } + + ////////////////////////////////////////////////////////////////////////// static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const bool transX, const bool transY, float alpha = 1.f, float beta = 0.f) { @@ -69,17 +83,15 @@ static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const b NDArray* zR = xRank <= 3 ? z : new NDArray(z->reshape(z->ordering(), {z->lengthOf() / (z->sizeAt(-2) * z->sizeAt(-1)), z->sizeAt(-2), z->sizeAt(-1)})/*, false*/); // [M,K] x [K,N] = [M,N] - const int M = (xRank > 1) ? xTR->sizeAt(-2) : 1; - const int K = (xRank > 1) ? xTR->sizeAt(-1) : xTR->lengthOf(); - const int N = (yRank > 1) ? yTR->sizeAt(-1) : 1; - const int bS = (xRank > 2) ? xTR->sizeAt(0) : 1; // [bS, M,K] x [bS, K,N] = [bS, M,N] + const int64_t M = (xRank > 1) ? xTR->sizeAt(-2) : 1; + const int64_t K = (xRank > 1) ? xTR->sizeAt(-1) : xTR->lengthOf(); + const int64_t N = (yRank > 1) ? yTR->sizeAt(-1) : 1; + const int64_t bS = (xRank > 2) ? xTR->sizeAt(0) : 1; // [bS, M,K] x [bS, K,N] = [bS, M,N] dnnl::memory::dims xShape = xRank < 3 ? dnnl::memory::dims({M, K}) : dnnl::memory::dims({bS, M, K}); dnnl::memory::dims yShape = xRank < 3 ? dnnl::memory::dims({K, N}) : dnnl::memory::dims({bS, K, N}); dnnl::memory::dims zShape = xRank < 3 ? dnnl::memory::dims({M, N}) : dnnl::memory::dims({bS, M, N}); - dnnl::memory::format_tag format = xRank < 3 ? dnnl::memory::format_tag::ab : dnnl::memory::format_tag::abc; - // x type dnnl::memory::data_type xType; if(x->dataType() == DataType::FLOAT32) @@ -114,9 +126,9 @@ static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const b // memory descriptors for arrays // x - dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, xType, dnnl::memory::format_tag::any); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, format); - if(xTR->ews() != 1 || xTR->ordering() != 'c') { + dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, xType, get_format_tag(*xTR)); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, get_format_tag(*xTR)); + if(xTR->ews() != 1) { x_user_md.data.format_kind = dnnl_blocked; // overrides format x_user_md.data.format_desc.blocking.strides[0] = xRank == 1 ? 1 : xTR->strideAt(0); x_user_md.data.format_desc.blocking.strides[1] = xRank == 1 ? xTR->strideAt(0) : xTR->strideAt(1); @@ -125,9 +137,9 @@ static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const b } // y - dnnl::memory::desc y_mkl_md = dnnl::memory::desc(yShape, yType, dnnl::memory::format_tag::any); - dnnl::memory::desc y_user_md = dnnl::memory::desc(yShape, yType, format); - if(yTR->ews() != 1 || yTR->ordering() != 'c') { + dnnl::memory::desc y_mkl_md = dnnl::memory::desc(yShape, yType, get_format_tag(*yTR)); + dnnl::memory::desc y_user_md = dnnl::memory::desc(yShape, yType, get_format_tag(*yTR)); + if(yTR->ews() != 1) { y_user_md.data.format_kind = dnnl_blocked; // overrides format y_user_md.data.format_desc.blocking.strides[0] = yRank == 1 ? 1 : yTR->strideAt(0); y_user_md.data.format_desc.blocking.strides[1] = yRank == 1 ? yTR->strideAt(0) : yTR->strideAt(1); @@ -136,9 +148,9 @@ static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const b } // z - dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, zType, dnnl::memory::format_tag::any); - dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, zType, format); - if(zR->ews() != 1 || zR->ordering() != 'c') { + dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, zType, get_format_tag(*zR)); + dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, zType, get_format_tag(*zR)); + if(zR->ews() != 1) { z_user_md.data.format_kind = dnnl_blocked; // overrides format z_user_md.data.format_desc.blocking.strides[0] = zRank == 1 ? 1 : zR->strideAt(0); z_user_md.data.format_desc.blocking.strides[1] = zRank == 1 ? zR->strideAt(0) : zR->strideAt(1); @@ -289,14 +301,20 @@ PLATFORM_CHECK(matmul, ENGINE_CPU) { auto z = OUTPUT_VARIABLE(0); - const DataType xType = x->dataType(); - const DataType yType = y->dataType(); - const DataType zType = z->dataType(); + const auto xType = x->dataType(); + const auto yType = y->dataType(); + const auto zType = z->dataType(); - float alpha = block.numT() > 0 ? T_ARG(0) : 1.0; - float beta = block.numT() > 1 ? T_ARG(1) : 0.0; + float alpha = block.numT() > 0 ? T_ARG(0) : 1.0f; + float beta = block.numT() > 1 ? T_ARG(1) : 0.0f; - return !(z->ordering() == 'f' && beta != 0.f) && block.isUseMKLDNN() && x->rankOf() < 3 && + // we're skipping if result order is F or arrays are not continuous + bool skip2D = z->rankOf() == 2 && (z->ordering() == 'f' || x->ews() != 1 || y->ews() != 1 || z->ews() != 1); + + // we're skipping 3D cases if they are not C continuoys + bool skip3D = z->rankOf() == 3 && (x->ordering() == 'f' || y->ordering() == 'f' || z->ordering() == 'f' || x->ews() != 1 || y->ews() != 1 || z->ews() != 1); + + return !skip2D && !skip3D && block.isUseMKLDNN() && x->rankOf() < 3 && ( (xType==DataType::FLOAT32 && yType==DataType::FLOAT32 && zType==DataType::FLOAT32) || (xType==DataType::HALF && yType==DataType::HALF && zType==DataType::FLOAT32) || diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/tanh.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/tanh.cpp index a82bc2706..fab32f280 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/tanh.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/tanh.cpp @@ -109,7 +109,7 @@ namespace sd { const DataType zType = z->dataType(); const int xRank = x->rankOf(); - bool bSupportedRanks = !x->isEmpty() && xRank < 7 && (xType == DataType::FLOAT32 && zType == DataType::FLOAT32); + bool bSupportedRanks = !x->isEmpty() && xRank < 7 && xRank > 0 && (xType == DataType::FLOAT32 && zType == DataType::FLOAT32); /* Source Destination f32 f32 @@ -214,7 +214,7 @@ namespace sd { const int xRank = x->rankOf(); const int dLdzRank = dLdz->rankOf(); - bool bSupportedRanks = xRank < 7 && dLdzRank == xRank && (!x->isEmpty() && !dLdz->isEmpty()); + bool bSupportedRanks = xRank < 7 && xRank > 0 && dLdzRank == xRank && (!x->isEmpty() && !dLdz->isEmpty()); bSupportedRanks &= (xType == DataType::FLOAT32 && dLdzType == DataType::FLOAT32 && dLdxType == DataType::FLOAT32); if (bSupportedRanks) { diff --git a/libnd4j/include/ops/random_ops.h b/libnd4j/include/ops/random_ops.h index d16b4f68a..939ffa975 100644 --- a/libnd4j/include/ops/random_ops.h +++ b/libnd4j/include/ops/random_ops.h @@ -119,8 +119,8 @@ namespace randomOps { random_def T op(Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) { T lambda = extraParams[0]; - T x = helper->relativeT(idx); //, T(0.f) , max); - T xVal = -sd::math::nd4j_log(T(1.f) - x); + T x = helper->relativeT(idx, sd::DataTypeUtils::min(), T(1.f) - sd::DataTypeUtils::template min()); // x from (0, 1) without bounds + T xVal = -sd::math::nd4j_log(x); return xVal <= (T)0.f ? (T)0.f : xVal / lambda; //pow((T) M_E, -(lambda * x)); } @@ -270,7 +270,7 @@ namespace randomOps { random_def T op(Nd4jLong idx, Nd4jLong length, sd::graph::RandomGenerator *helper, T *extraParams) { T lambda = extraParams[0]; - T x = helper->relativeT(idx, sd::DataTypeUtils::template min(), (T)1.f); + T x = helper->relativeT(idx, sd::DataTypeUtils::template min(), (T)1.f - sd::DataTypeUtils::template min()); return -sd::math::nd4j_log((T)1.f - x) / lambda; } diff --git a/libnd4j/tests_cpu/layers_tests/MklDnnTests.cpp b/libnd4j/tests_cpu/layers_tests/MklDnnTests.cpp index bb3934994..c91c1c5c7 100644 --- a/libnd4j/tests_cpu/layers_tests/MklDnnTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/MklDnnTests.cpp @@ -19,15 +19,17 @@ // @author raver119@gmail.com // +#ifdef HAVE_MKLDNN + #include "testlayers.h" #include #include - -#ifdef HAVE_MKLDNN - #include +#include +#include -#endif + +using namespace sd; class MklDnnTests : public testing::Test { public: @@ -44,7 +46,6 @@ static void printer(std::initializer_list h TEST_F(MklDnnTests, helpers_includer) { // we need this block, to make sure all helpers are still available within binary, and not optimized out by linker -#ifdef HAVE_MKLDNN sd::ops::platforms::PLATFORM_conv2d_ENGINE_CPU conv2d; sd::ops::platforms::PLATFORM_conv2d_bp_ENGINE_CPU conv2d_bp; @@ -83,6 +84,26 @@ TEST_F(MklDnnTests, helpers_includer) { printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, &lrn, &batchnorm, &matmul, &softmax, &softmax_bp, &tanh, &tanh_bp, &xw_plus_b, &xw_plus_b_bp }); - -#endif } + +TEST_F(MklDnnTests, test_tanh_1) { + auto x = NDArrayFactory::create(1.0f); + auto z = NDArrayFactory::create(0.0f); + + sd::ops::tanh op; + auto status = op.execute({&x}, {&z}); + + ASSERT_EQ(Status::OK(), status); +} + +TEST_F(MklDnnTests, test_tanh_2) { + auto x = NDArrayFactory::create('c', {1}, {1.0f}); + auto z = NDArrayFactory::create('c', {1}, {0.0f}); + + sd::ops::tanh op; + auto status = op.execute({&x}, {&z}); + + ASSERT_EQ(Status::OK(), status); +} + +#endif \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/PerformanceTests.cpp b/libnd4j/tests_cpu/layers_tests/PerformanceTests.cpp index f00536e58..c6155eb0c 100644 --- a/libnd4j/tests_cpu/layers_tests/PerformanceTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/PerformanceTests.cpp @@ -60,6 +60,52 @@ public: #ifdef RELEASE_BUILD +TEST_F(PerformanceTests, test_matmul_c_f_1) { + int iterations = 500; + std::vector valuesC, valuesF; + for (int e = 0; e < iterations; e++) { + auto xc = NDArrayFactory::create('c', {512, 2048}); + auto yc = NDArrayFactory::create('c', {2048, 512}); + auto zc = NDArrayFactory::create('c', {512, 512}); + + auto xf = NDArrayFactory::create('f', {512, 2048}); + auto yf = NDArrayFactory::create('f', {2048, 512}); + auto zf = NDArrayFactory::create('f', {512, 512}); + + auto warm = xc.like(); + warm.linspace(1.0); + + //zc.linspace(1.0); + //zf.linspace(1.0); + + sd::ops::matmul op; + + auto timeStartF = std::chrono::system_clock::now(); + + op.execute({&xf, &yf}, {&zf}); + + auto timeEndF = std::chrono::system_clock::now(); + auto outerTimeF = std::chrono::duration_cast(timeEndF - timeStartF).count(); + + + auto timeStartC = std::chrono::system_clock::now(); + + op.execute({&xc, &yc}, {&zc}); + + auto timeEndC = std::chrono::system_clock::now(); + auto outerTimeC = std::chrono::duration_cast(timeEndC - timeStartC).count(); + + valuesF.emplace_back(outerTimeF); + valuesC.emplace_back(outerTimeC); + } + + std::sort(valuesC.begin(), valuesC.end()); + std::sort(valuesF.begin(), valuesF.end()); + + + nd4j_printf("Median time C: [%lld]; Median time F: [%lld];", valuesC[valuesC.size() / 2], valuesF[valuesF.size() / 2]); +} + TEST_F(PerformanceTests, test_maxpooling2d_1) { std::vector valuesX; // auto x = NDArrayFactory::create('c', {32, 3, 224, 224}); diff --git a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp index 889e194a6..56ca6b95e 100644 --- a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp @@ -790,6 +790,75 @@ TEST_F(RNGTests, Test_ExponentialDistribution_1_SGA) { } +TEST_F(RNGTests, Test_ExponentialDistribution_2_SGA) { + auto x = NDArrayFactory::create('c', {2}, {10, 10}); + auto exp0 = NDArrayFactory::create('c', {10, 10}); + RandomGenerator oc(2716049175077475646L, -6182841917129177862L); + + sd::ops::random_exponential op; + RandomLauncher::fillExponential(x.getContext(), oc, &exp0, 2.f); + auto result = op.evaluate({&x}, {1.f}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + ASSERT_TRUE(exp0.isSameShape(z)); + ASSERT_FALSE(exp0.equalsTo(z)); + // +// z->printBuffer("\nExponential2+"); + auto mean = z->reduceNumber(reduce::Mean); + auto variance = z->varianceNumber(variance::SummaryStatsVariance, false); + mean.printBuffer("Mean for exponential with param 1.0 (1 exp) is"); + variance.printBuffer("Variance for exponential with param 1. (1 exp) is"); + ASSERT_FALSE(nexp0->equalsTo(z)); + ASSERT_FALSE(nexp1->equalsTo(z)); + ASSERT_FALSE(nexp2->equalsTo(z)); + mean = exp0.reduceNumber(reduce::Mean); + variance = exp0.varianceNumber(variance::SummaryStatsVariance, false); + mean.printBuffer("Mean for exponential with param 2.0 (1/2 exp) is"); + variance.printBuffer("Variance for exponential with param 2. (1/2 exp) is"); +} + +TEST_F(RNGTests, Test_ExponentialDistribution_3_SGA) { + auto x = NDArrayFactory::create('c', {2}, {1000, 1000}); + auto exp0 = NDArrayFactory::create('c', {1000, 1000}); + RandomGenerator oc(2716049175077475646L, -6182841917129177862L); + auto expMean = NDArrayFactory::create(0.5f); + auto expVar = NDArrayFactory::create(0.25f); + sd::ops::random_exponential op; + RandomLauncher::fillExponential(exp0.getContext(), oc, &exp0, 2.f); + + auto result = op.evaluate({&x}, {1.}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + //ASSERT_TRUE(exp0.isSameShape(z)); + //ASSERT_FALSE(exp0.equalsTo(z)); + // +// z->printBuffer("\nExponential2+"); + auto mean = z->reduceNumber(reduce::Mean); + auto variance = z->varianceNumber(variance::SummaryStatsVariance, false); + mean.printBuffer("Mean"); + variance.printBuffer("Variance"); + ASSERT_NEAR(mean.e(0), 1.f, 1.e-2f); + ASSERT_NEAR(variance.e(0), 1.f, 1.e-2f); +// mean.printBuffer("Mean for exponential with param 1.0 (1 exp) is"); +// variance.printBuffer("Variance for exponential with param 1. (1 exp) is"); +// ASSERT_FALSE(nexp0->equalsTo(z)); +// ASSERT_FALSE(nexp1->equalsTo(z)); +// ASSERT_FALSE(nexp2->equalsTo(z)); + mean = exp0.reduceNumber(reduce::Mean); + variance = exp0.varianceNumber(variance::SummaryStatsVariance, false); + mean.printBuffer("Mean for exponential with param 2.0 (1/2 exp) is"); + variance.printBuffer("Variance for exponential with param 2. (1/4 exp) is"); + ASSERT_TRUE(mean.equalsTo(expMean, 1.e-3)); + ASSERT_TRUE(variance.equalsTo(expVar, 1.e-3)); + RandomLauncher::fillExponential(exp0.getContext(), oc, &exp0, 1.f); + mean = exp0.reduceNumber(reduce::Mean); + variance = exp0.varianceNumber(variance::SummaryStatsVariance, false); + mean.printBuffer("Mean for exponential with param 1.0 (1 exp) is"); + variance.printBuffer("Variance for exponential with param 1.0 (1 exp) is"); +} + TEST_F(RNGTests, Test_ExponentialDistribution_2) { auto x = NDArrayFactory::create('c', {2}, {10, 10}); auto y = NDArrayFactory::create('c', {10, 10}); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java index e21b2d270..a4a2111a5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java @@ -189,7 +189,7 @@ public abstract class DifferentialFunction { try { return property.get(this); } catch (IllegalAccessException e) { - e.printStackTrace(); + log.error("",e); } return null; @@ -447,7 +447,7 @@ public abstract class DifferentialFunction { this.sameDiff = sameDiff; this.inPlace = inPlace; setInstanceId(); - if(sameDiff != null) { + if(sameDiff != null && args != null) { sameDiff.addArgsFor(args, this); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/Loss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/Loss.java index e85a93739..c67f8bd45 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/Loss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/Loss.java @@ -80,7 +80,7 @@ public class Loss { public static Loss sum(List losses) { - if (losses.size() == 0) + if (losses.isEmpty()) return new Loss(Collections.emptyList(), new double[0]); double[] lossValues = new double[losses.get(0).losses.length]; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/impl/HistoryListener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/impl/HistoryListener.java index eb0675da5..b337c0656 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/impl/HistoryListener.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/impl/HistoryListener.java @@ -17,10 +17,7 @@ package org.nd4j.autodiff.listeners.impl; import java.util.ArrayList; -import java.util.HashMap; import java.util.List; -import java.util.Map; -import java.util.Set; import lombok.Getter; import lombok.Setter; @@ -34,9 +31,6 @@ import org.nd4j.autodiff.listeners.records.LossCurve; import org.nd4j.autodiff.listeners.Operation; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.TrainingConfig; -import org.nd4j.autodiff.samediff.internal.SameDiffOp; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.dataset.api.MultiDataSet; /** * HistoryListener is mainly used internally to collect information such as the loss curve and evaluations, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractDependencyTracker.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractDependencyTracker.java index 776d26794..1faea9c66 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractDependencyTracker.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractDependencyTracker.java @@ -154,11 +154,9 @@ public abstract class AbstractDependencyTracker { } } - if (allSatisfied) { - if (!this.allSatisfied.contains(t)) { - this.allSatisfied.add(t); - this.allSatisfiedQueue.add(t); - } + if (allSatisfied && !this.allSatisfied.contains(t)) { + this.allSatisfied.add(t); + this.allSatisfiedQueue.add(t); } } } @@ -278,25 +276,25 @@ public abstract class AbstractDependencyTracker { protected boolean isAllSatisfied(@NonNull T y) { Set set1 = dependencies.get(y); - boolean allSatisfied = true; + boolean retVal = true; if (set1 != null) { for (D d : set1) { - allSatisfied = isSatisfied(d); - if (!allSatisfied) + retVal = isSatisfied(d); + if (!retVal) break; } } - if (allSatisfied) { + if (retVal) { Set> set2 = orDependencies.get(y); if (set2 != null) { for (Pair p : set2) { - allSatisfied = isSatisfied(p.getFirst()) || isSatisfied(p.getSecond()); - if (!allSatisfied) + retVal = isSatisfied(p.getFirst()) || isSatisfied(p.getSecond()); + if (!retVal) break; } } } - return allSatisfied; + return retVal; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java index d89fe05a5..a6273e0b3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java @@ -132,7 +132,7 @@ public abstract class AbstractSession { Preconditions.checkState(!variables.isEmpty() || !requiredActivations.isEmpty(), "Variables to perform forward pass for must not be empty"); if (requiredActivations == null) - requiredActivations = Collections.emptyList(); + requiredActivations = Collections.emptySet(); if (at == null) at = At.defaultAt(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/TrainingSession.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/TrainingSession.java index e683acc47..d0d0b14cf 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/TrainingSession.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/TrainingSession.java @@ -75,7 +75,7 @@ public class TrainingSession extends InferenceSession { this.listeners = filtered.isEmpty() ? null : filtered; } - List requiredActivations = new ArrayList<>(); + Set requiredActivations = new HashSet<>(); gradVarToVarMap = new HashMap<>(); //Key: gradient variable. Value: variable that the key is gradient for for (String s : paramsToTrain) { Preconditions.checkState(sameDiff.hasVariable(s), "SameDiff instance does not have a variable with name \"%s\"", s); @@ -95,6 +95,12 @@ public class TrainingSession extends InferenceSession { gradVarToVarMap.put(grad.name(), s); } + //Also add evaluations - in case we want to evaluate something that isn't required to determine loss + // (hence wouldn't normally be calculated) + if(config.getTrainEvaluations() != null){ + requiredActivations.addAll(config.getTrainEvaluations().keySet()); + } + //Set up losses lossVarsToLossIdx = new LinkedHashMap<>(); List lossVars; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDCNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDCNN.java index bb9f027c0..9d45aa76e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDCNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDCNN.java @@ -1052,4 +1052,38 @@ public class SDCNN extends SDOps { SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling2d(sd,input, scaleH, scaleW, nchw).outputVariable(); return sd.updateVariableNameAndReference(out, name); } + + /** + * 3D Convolution layer operation - Upsampling 3d
+ * + * @param input Input in NCHW format (NUMERIC type) + * @param ncdhw If true: input is in NCDHW (minibatch, channels, depth, height, width) format. False: NDHWC format + * @param scaleD Scale to upsample in depth dimension + * @param scaleH Scale to upsample in height dimension + * @param scaleW Scale to upsample in width dimension + * @return output Upsampled input (NUMERIC type) + */ + public SDVariable upsampling3d(SDVariable input, boolean ncdhw, int scaleD, int scaleH, + int scaleW) { + SDValidation.validateNumerical("upsampling3d", "input", input); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling3d(sd,input, ncdhw, scaleD, scaleH, scaleW).outputVariable(); + } + + /** + * 3D Convolution layer operation - Upsampling 3d
+ * + * @param name name May be null. Name for the output variable + * @param input Input in NCHW format (NUMERIC type) + * @param ncdhw If true: input is in NCDHW (minibatch, channels, depth, height, width) format. False: NDHWC format + * @param scaleD Scale to upsample in depth dimension + * @param scaleH Scale to upsample in height dimension + * @param scaleW Scale to upsample in width dimension + * @return output Upsampled input (NUMERIC type) + */ + public SDVariable upsampling3d(String name, SDVariable input, boolean ncdhw, int scaleD, + int scaleH, int scaleW) { + SDValidation.validateNumerical("upsampling3d", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling3d(sd,input, ncdhw, scaleD, scaleH, scaleW).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java index ef030e952..a58d4d180 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java @@ -258,7 +258,7 @@ public class SDImage extends SDOps { /** * Resize images to size using the specified method.
* - * @param input 4D image [NCHW] (NUMERIC type) + * @param input 4D image [NHWC] (NUMERIC type) * @param size new height and width (INT type) * @param preserveAspectRatio Whether to preserve the aspect ratio. If this is set, then images will be resized to a size that fits in size while preserving the aspect ratio of the original image. Scales up the image if size is bigger than the current size of the image. Defaults to False. * @param antialis Whether to use an anti-aliasing filter when downsampling an image @@ -282,7 +282,7 @@ public class SDImage extends SDOps { * Resize images to size using the specified method.
* * @param name name May be null. Name for the output variable - * @param input 4D image [NCHW] (NUMERIC type) + * @param input 4D image [NHWC] (NUMERIC type) * @param size new height and width (INT type) * @param preserveAspectRatio Whether to preserve the aspect ratio. If this is set, then images will be resized to a size that fits in size while preserving the aspect ratio of the original image. Scales up the image if size is bigger than the current size of the image. Defaults to False. * @param antialis Whether to use an anti-aliasing filter when downsampling an image @@ -306,7 +306,7 @@ public class SDImage extends SDOps { /** * Resize images to size using the specified method.
* - * @param input 4D image [NCHW] (NUMERIC type) + * @param input 4D image [NHWC] (NUMERIC type) * @param size new height and width (INT type) * @param ImageResizeMethod ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling. * ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing. @@ -328,7 +328,7 @@ public class SDImage extends SDOps { * Resize images to size using the specified method.
* * @param name name May be null. Name for the output variable - * @param input 4D image [NCHW] (NUMERIC type) + * @param input 4D image [NHWC] (NUMERIC type) * @param size new height and width (INT type) * @param ImageResizeMethod ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling. * ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing. diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLinalg.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLinalg.java index 8dbb9d3b3..ae97b2c4a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLinalg.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLinalg.java @@ -23,6 +23,7 @@ import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType; import java.lang.String; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; public class SDLinalg extends SDOps { public SDLinalg(SameDiff sameDiff) { @@ -558,4 +559,106 @@ public class SDLinalg extends SDOps { SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Svd(sd,input, fullUV, computeUV, 16).outputVariable(); return sd.updateVariableNameAndReference(out, name); } + + /** + * An array with ones at and below the given diagonal and zeros elsewhere.
+ * + * @param dataType Data type + * @param row + * @param column + * @param diagonal + * @return output (FLOATING_POINT type) + */ + public SDVariable tri(DataType dataType, int row, int column, int diagonal) { + return new org.nd4j.linalg.api.ops.custom.Tri(sd,dataType, row, column, diagonal).outputVariable(); + } + + /** + * An array with ones at and below the given diagonal and zeros elsewhere.
+ * + * @param name name May be null. Name for the output variable + * @param dataType Data type + * @param row + * @param column + * @param diagonal + * @return output (FLOATING_POINT type) + */ + public SDVariable tri(String name, DataType dataType, int row, int column, int diagonal) { + SDVariable out = new org.nd4j.linalg.api.ops.custom.Tri(sd,dataType, row, column, diagonal).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * An array with ones at and below the given diagonal and zeros elsewhere.
+ * + * @param row + * @param column + * @return output (FLOATING_POINT type) + */ + public SDVariable tri(int row, int column) { + return new org.nd4j.linalg.api.ops.custom.Tri(sd,DataType.FLOAT, row, column, 0).outputVariable(); + } + + /** + * An array with ones at and below the given diagonal and zeros elsewhere.
+ * + * @param name name May be null. Name for the output variable + * @param row + * @param column + * @return output (FLOATING_POINT type) + */ + public SDVariable tri(String name, int row, int column) { + SDVariable out = new org.nd4j.linalg.api.ops.custom.Tri(sd,DataType.FLOAT, row, column, 0).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Upper triangle of an array. Return a copy of a input tensor with the elements below the k-th diagonal zeroed.
+ * + * @param input (NUMERIC type) + * @param diag + * @return output (FLOATING_POINT type) + */ + public SDVariable triu(SDVariable input, int diag) { + SDValidation.validateNumerical("triu", "input", input); + return new org.nd4j.linalg.api.ops.custom.Triu(sd,input, diag).outputVariable(); + } + + /** + * Upper triangle of an array. Return a copy of a input tensor with the elements below the k-th diagonal zeroed.
+ * + * @param name name May be null. Name for the output variable + * @param input (NUMERIC type) + * @param diag + * @return output (FLOATING_POINT type) + */ + public SDVariable triu(String name, SDVariable input, int diag) { + SDValidation.validateNumerical("triu", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.custom.Triu(sd,input, diag).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Upper triangle of an array. Return a copy of a input tensor with the elements below the k-th diagonal zeroed.
+ * + * @param input (NUMERIC type) + * @return output (FLOATING_POINT type) + */ + public SDVariable triu(SDVariable input) { + SDValidation.validateNumerical("triu", "input", input); + return new org.nd4j.linalg.api.ops.custom.Triu(sd,input, 0).outputVariable(); + } + + /** + * Upper triangle of an array. Return a copy of a input tensor with the elements below the k-th diagonal zeroed.
+ * + * @param name name May be null. Name for the output variable + * @param input (NUMERIC type) + * @return output (FLOATING_POINT type) + */ + public SDVariable triu(String name, SDVariable input) { + SDValidation.validateNumerical("triu", "input", input); + SDVariable out = new org.nd4j.linalg.api.ops.custom.Triu(sd,input, 0).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java index 66d47f905..601a28a53 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java @@ -67,13 +67,13 @@ public class SDMath extends SDOps { * Looks up ids in a list of embedding tensors.
* * @param x Input tensor (NUMERIC type) - * @param indices A Tensor containing the ids to be looked up. (NUMERIC type) + * @param indices A Tensor containing the ids to be looked up. (INT type) * @param PartitionMode partition_mode == 0 - i.e. 'mod' , 1 - 'div' * @return output Shifted output (NUMERIC type) */ public SDVariable embeddingLookup(SDVariable x, SDVariable indices, PartitionMode PartitionMode) { SDValidation.validateNumerical("EmbeddingLookup", "x", x); - SDValidation.validateNumerical("EmbeddingLookup", "indices", indices); + SDValidation.validateInteger("EmbeddingLookup", "indices", indices); return new org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup(sd,x, indices, PartitionMode).outputVariable(); } @@ -82,18 +82,72 @@ public class SDMath extends SDOps { * * @param name name May be null. Name for the output variable * @param x Input tensor (NUMERIC type) - * @param indices A Tensor containing the ids to be looked up. (NUMERIC type) + * @param indices A Tensor containing the ids to be looked up. (INT type) * @param PartitionMode partition_mode == 0 - i.e. 'mod' , 1 - 'div' * @return output Shifted output (NUMERIC type) */ public SDVariable embeddingLookup(String name, SDVariable x, SDVariable indices, PartitionMode PartitionMode) { SDValidation.validateNumerical("EmbeddingLookup", "x", x); - SDValidation.validateNumerical("EmbeddingLookup", "indices", indices); + SDValidation.validateInteger("EmbeddingLookup", "indices", indices); SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup(sd,x, indices, PartitionMode).outputVariable(); return sd.updateVariableNameAndReference(out, name); } + /** + * Return array of max elements indices with along tensor dimensions
+ * + * @param x Input tensor (NUMERIC type) + * @param dataType Data type + * @return output Array max elements indices with along dimensions. (INT type) + */ + public SDVariable mergeMaxIndex(SDVariable[] x, DataType dataType) { + SDValidation.validateNumerical("MergeMaxIndex", "x", x); + Preconditions.checkArgument(x.length >= 1, "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length); + return new org.nd4j.linalg.api.ops.impl.shape.MergeMaxIndex(sd,x, dataType).outputVariable(); + } + + /** + * Return array of max elements indices with along tensor dimensions
+ * + * @param name name May be null. Name for the output variable + * @param x Input tensor (NUMERIC type) + * @param dataType Data type + * @return output Array max elements indices with along dimensions. (INT type) + */ + public SDVariable mergeMaxIndex(String name, SDVariable[] x, DataType dataType) { + SDValidation.validateNumerical("MergeMaxIndex", "x", x); + Preconditions.checkArgument(x.length >= 1, "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.MergeMaxIndex(sd,x, dataType).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Return array of max elements indices with along tensor dimensions
+ * + * @param x Input tensor (NUMERIC type) + * @return output Array max elements indices with along dimensions. (INT type) + */ + public SDVariable mergeMaxIndex(SDVariable... x) { + SDValidation.validateNumerical("MergeMaxIndex", "x", x); + Preconditions.checkArgument(x.length >= 1, "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length); + return new org.nd4j.linalg.api.ops.impl.shape.MergeMaxIndex(sd,x, DataType.INT).outputVariable(); + } + + /** + * Return array of max elements indices with along tensor dimensions
+ * + * @param name name May be null. Name for the output variable + * @param x Input tensor (NUMERIC type) + * @return output Array max elements indices with along dimensions. (INT type) + */ + public SDVariable mergeMaxIndex(String name, SDVariable... x) { + SDValidation.validateNumerical("MergeMaxIndex", "x", x); + Preconditions.checkArgument(x.length >= 1, "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.MergeMaxIndex(sd,x, DataType.INT).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Elementwise absolute value operation: out = abs(x)
* diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRNN.java index de8148c02..ebb1a025d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRNN.java @@ -35,6 +35,48 @@ public class SDRNN extends SDOps { super(sameDiff); } + /** + * The GRU operation. Gated Recurrent Unit - Cho et al. 2014.
+ * + * @param x input [time, bS, nIn] (NUMERIC type) + * @param hLast initial cell output (at time step = 0) [bS, nOut] (NUMERIC type) + * @param Wx input-to-hidden weights, [nIn, 3*nOut] (NUMERIC type) + * @param Wh hidden-to-hidden weights, [nOut, 3*nOut] (NUMERIC type) + * @param biases biases, [3*nOut] (NUMERIC type) + * @return h cell outputs [time, bS, nOut], that is per each time step (NUMERIC type) + */ + public SDVariable gru(SDVariable x, SDVariable hLast, SDVariable Wx, SDVariable Wh, + SDVariable biases) { + SDValidation.validateNumerical("gru", "x", x); + SDValidation.validateNumerical("gru", "hLast", hLast); + SDValidation.validateNumerical("gru", "Wx", Wx); + SDValidation.validateNumerical("gru", "Wh", Wh); + SDValidation.validateNumerical("gru", "biases", biases); + return new org.nd4j.linalg.api.ops.impl.layers.recurrent.GRU(sd,x, hLast, Wx, Wh, biases).outputVariable(); + } + + /** + * The GRU operation. Gated Recurrent Unit - Cho et al. 2014.
+ * + * @param name name May be null. Name for the output variable + * @param x input [time, bS, nIn] (NUMERIC type) + * @param hLast initial cell output (at time step = 0) [bS, nOut] (NUMERIC type) + * @param Wx input-to-hidden weights, [nIn, 3*nOut] (NUMERIC type) + * @param Wh hidden-to-hidden weights, [nOut, 3*nOut] (NUMERIC type) + * @param biases biases, [3*nOut] (NUMERIC type) + * @return h cell outputs [time, bS, nOut], that is per each time step (NUMERIC type) + */ + public SDVariable gru(String name, SDVariable x, SDVariable hLast, SDVariable Wx, SDVariable Wh, + SDVariable biases) { + SDValidation.validateNumerical("gru", "x", x); + SDValidation.validateNumerical("gru", "hLast", hLast); + SDValidation.validateNumerical("gru", "Wx", Wx); + SDValidation.validateNumerical("gru", "Wh", Wh); + SDValidation.validateNumerical("gru", "biases", biases); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.GRU(sd,x, hLast, Wx, Wh, biases).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * The GRU cell. Does a single time step operation
* @@ -42,9 +84,9 @@ public class SDRNN extends SDOps { * @param hLast Output of the previous cell/time step, with shape [batchSize, numUnits] (NUMERIC type) * @param GRUWeights Configuration Object */ - public SDVariable[] gru(SDVariable x, SDVariable hLast, GRUWeights GRUWeights) { - SDValidation.validateNumerical("gru", "x", x); - SDValidation.validateNumerical("gru", "hLast", hLast); + public SDVariable[] gruCell(SDVariable x, SDVariable hLast, GRUWeights GRUWeights) { + SDValidation.validateNumerical("gruCell", "x", x); + SDValidation.validateNumerical("gruCell", "hLast", hLast); return new org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell(sd,x, hLast, GRUWeights).outputVariables(); } @@ -56,9 +98,10 @@ public class SDRNN extends SDOps { * @param hLast Output of the previous cell/time step, with shape [batchSize, numUnits] (NUMERIC type) * @param GRUWeights Configuration Object */ - public SDVariable[] gru(String[] names, SDVariable x, SDVariable hLast, GRUWeights GRUWeights) { - SDValidation.validateNumerical("gru", "x", x); - SDValidation.validateNumerical("gru", "hLast", hLast); + public SDVariable[] gruCell(String[] names, SDVariable x, SDVariable hLast, + GRUWeights GRUWeights) { + SDValidation.validateNumerical("gruCell", "x", x); + SDValidation.validateNumerical("gruCell", "hLast", hLast); SDVariable[] out = new org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell(sd,x, hLast, GRUWeights).outputVariables(); return sd.updateVariableNamesAndReferences(out, names); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/custom/MergeLambda.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/custom/MergeLambda.java index 079755055..1a6ced62f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/custom/MergeLambda.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/custom/MergeLambda.java @@ -16,8 +16,6 @@ package org.nd4j.evaluation.custom; -import org.nd4j.shade.guava.collect.Lists; - import java.util.List; /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index 6af2d462a..0c6724e9a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -143,7 +143,11 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.layers.convolution.SConv2DDerivative.class, org.nd4j.linalg.api.ops.impl.layers.convolution.SpaceToDepth.class, org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling2d.class, + org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling3d.class, + org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling3dBp.class, org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling2dDerivative.class, + org.nd4j.linalg.api.ops.impl.layers.recurrent.GRU.class, + org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUBp.class, org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell.class, org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell.class, org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMCell.class, @@ -299,6 +303,7 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.shape.Linspace.class, org.nd4j.linalg.api.ops.impl.shape.MergeAvg.class, org.nd4j.linalg.api.ops.impl.shape.MergeMax.class, + org.nd4j.linalg.api.ops.impl.shape.MergeMaxIndex.class, org.nd4j.linalg.api.ops.impl.shape.MergeSum.class, org.nd4j.linalg.api.ops.impl.shape.MeshGrid.class, org.nd4j.linalg.api.ops.impl.shape.OneHot.class, @@ -424,6 +429,7 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.transforms.custom.ParallelConcat.class, org.nd4j.linalg.api.ops.impl.transforms.custom.Pow.class, org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse.class, + org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseBp.class, org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence.class, org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseV2.class, org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits.class, @@ -640,6 +646,9 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.custom.RandomCrop.class, org.nd4j.linalg.api.ops.custom.Roll.class, org.nd4j.linalg.api.ops.custom.ToggleBits.class, + org.nd4j.linalg.api.ops.custom.Tri.class, + org.nd4j.linalg.api.ops.custom.Triu.class, + org.nd4j.linalg.api.ops.custom.TriuBp.class, org.nd4j.linalg.api.ops.custom.Igamma.class, org.nd4j.linalg.api.ops.custom.Igammac.class, org.nd4j.linalg.api.ops.custom.Digamma.class, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationMish.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationMish.java index b789a2925..1d1bba123 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationMish.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationMish.java @@ -21,10 +21,8 @@ import lombok.Getter; import lombok.val; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp; import org.nd4j.linalg.api.ops.impl.transforms.strict.Mish; import org.nd4j.linalg.api.ops.impl.transforms.strict.MishDerivative; -import org.nd4j.linalg.api.ops.impl.transforms.strict.SELU; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.primitives.Pair; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/memory/abstracts/Nd4jWorkspace.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/memory/abstracts/Nd4jWorkspace.java index 823b67bbe..23aa3533a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/memory/abstracts/Nd4jWorkspace.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/memory/abstracts/Nd4jWorkspace.java @@ -297,9 +297,12 @@ public abstract class Nd4jWorkspace implements MemoryWorkspace { } protected void init() { - // we want params validation here + // in case of MMAP we don't want any learning applied + if (workspaceConfiguration.getPolicyLocation() == LocationPolicy.MMAP && workspaceConfiguration.getPolicyLearning() != LearningPolicy.NONE) + throw new IllegalArgumentException("Workspace backed by memory-mapped file can't have LearningPolicy defined"); - if (currentSize.get() > 0) { + // we don't want overallocation in case of MMAP + if (currentSize.get() > 0 && workspaceConfiguration.getPolicyLocation() != LocationPolicy.MMAP) { if (!isOver.get()) { if (workspaceConfiguration.getPolicyAllocation() == AllocationPolicy.OVERALLOCATE && workspaceConfiguration.getOverallocationLimit() > 0) { @@ -310,7 +313,6 @@ public abstract class Nd4jWorkspace implements MemoryWorkspace { if (workspaceConfiguration.getMaxSize() > 0 && currentSize.get() > workspaceConfiguration.getMaxSize()) currentSize.set(workspaceConfiguration.getMaxSize()); - } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/memory/provider/BasicWorkspaceManager.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/memory/provider/BasicWorkspaceManager.java index be054a273..1baea4bc9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/memory/provider/BasicWorkspaceManager.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/memory/provider/BasicWorkspaceManager.java @@ -159,7 +159,7 @@ public abstract class BasicWorkspaceManager implements MemoryWorkspaceManager { if (workspace == null || workspace instanceof DummyWorkspace) return; - //workspace.destroyWorkspace(); + workspace.destroyWorkspace(true); backingMap.get().remove(workspace.getId()); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java index 3fe90bdbb..67dd5701d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java @@ -85,7 +85,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { } public DynamicCustomOp(SameDiff sameDiff, SDVariable arg) { - this(sameDiff, new SDVariable[]{arg}); + this(sameDiff, wrapOrNull(arg)); } public DynamicCustomOp(SameDiff sameDiff, SDVariable[] args) { @@ -655,6 +655,10 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { outputArguments.clear(); } + protected static SDVariable[] wrapOrNull(SDVariable in){ + return in == null ? null : new SDVariable[]{in}; + } + protected static INDArray[] wrapOrNull(INDArray in){ return in == null ? null : new INDArray[]{in}; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/LinearSolve.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/LinearSolve.java index 7c835006e..a107d8c10 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/LinearSolve.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/LinearSolve.java @@ -19,7 +19,6 @@ import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; -import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Lstsq.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Lstsq.java index b7c0e4092..2cd862515 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Lstsq.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Lstsq.java @@ -24,7 +24,6 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; -import java.util.Arrays; import java.util.Collections; import java.util.List; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Tri.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Tri.java new file mode 100644 index 000000000..61187ce13 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Tri.java @@ -0,0 +1,76 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.custom; + +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +@NoArgsConstructor +public class Tri extends DynamicCustomOp { + + private DataType dataType = DataType.FLOAT; + + public Tri(SameDiff sameDiff, int row, int column, int diag) { + super(sameDiff, new SDVariable[]{}); + addIArgument(row,column,diag); + } + + public Tri(SameDiff sameDiff, DataType dataType, int row, int column, int diag) { + super(sameDiff, new SDVariable[]{}); + addIArgument(row,column,diag); + addDArgument(dataType); + this.dataType = dataType; + + + } + + public Tri(int row, int column, int diag) { + super(new INDArray[]{}, null); + addIArgument(row,column,diag); + + } + + public Tri(DataType dataType, int row, int column, int diag) { + super(new INDArray[]{}, null); + addIArgument(row,column,diag); + addDArgument(dataType); + this.dataType = dataType; + + } + + @Override + public String opName() { + return "tri"; + } + + + @Override + public List calculateOutputDataTypes(List inputDataTypes) { + + return Collections.singletonList(this.dataType); + + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Triu.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Triu.java new file mode 100644 index 000000000..9ed9feb01 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Triu.java @@ -0,0 +1,71 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.custom; + +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +@NoArgsConstructor +public class Triu extends DynamicCustomOp { + + private int diag = 0; + + public Triu(SameDiff sameDiff, SDVariable in, int diag) { + super(sameDiff, new SDVariable[]{in}); + addIArgument(diag); + this.diag=diag; + } + + public Triu(SameDiff sameDiff, SDVariable in) { + super(sameDiff, new SDVariable[]{in}); + } + + + + public Triu(INDArray input, int diag) { + super(new INDArray[]{input}, null); + addIArgument(diag); + this.diag=diag; + + } + + + @Override + public String opName() { + return "triu"; + } + + + @Override + public List calculateOutputDataTypes(List inputDataTypes) { + return Collections.singletonList(arg(0).dataType()); + } + + @Override + public List doDiff(List f1) { + + return new TriuBp(sameDiff, arg(0), f1.get(0), diag).outputs(); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/TriuBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/TriuBp.java new file mode 100644 index 000000000..2b7f78922 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/TriuBp.java @@ -0,0 +1,55 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.custom; + +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseBp; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +@NoArgsConstructor +public class TriuBp extends DynamicCustomOp { + + public TriuBp(SameDiff sameDiff, SDVariable in, SDVariable grad, int diag) { + super(sameDiff, new SDVariable[]{in, grad}); + addIArgument(diag); + } + + public TriuBp(SameDiff sameDiff, SDVariable in, SDVariable grad) { + super(sameDiff, new SDVariable[]{in, grad}); + } + + @Override + public String opName() { + return "triu_bp"; + } + + + @Override + public List calculateOutputDataTypes(List inputDataTypes) { + + return Collections.singletonList(arg(0).dataType()); + } + + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMax.java index 181321d4f..b2e0d1192 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMax.java @@ -21,7 +21,6 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseIndexAccumulation; -import org.nd4j.linalg.factory.Nd4j; import java.util.Collections; import java.util.List; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMin.java index 760fca314..f20547c1d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMin.java @@ -21,7 +21,6 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseIndexAccumulation; -import org.nd4j.linalg.factory.Nd4j; import java.util.Collections; import java.util.List; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/AvgPooling3D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/AvgPooling3D.java index 7c6b5186c..c0ec62610 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/AvgPooling3D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/AvgPooling3D.java @@ -68,7 +68,7 @@ public class AvgPooling3D extends Pooling3D { return config.toProperties(); } - + @Override public String getPoolingPrefix() { return "avg"; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPoolWithArgmax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPoolWithArgmax.java index 5f5f6747a..8b151d717 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPoolWithArgmax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPoolWithArgmax.java @@ -82,7 +82,7 @@ public class MaxPoolWithArgmax extends DynamicCustomOp { @Override public Map propertiesForFunction() { - if(config == null && iArguments.size() > 0){ + if(config == null && !iArguments.isEmpty()){ //Perhaps loaded from FlatBuffers - hence we have IArgs but not Config object config = Pooling2DConfig.builder() .kH(iArguments.get(0)) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling2D.java index c54b63aa7..1da1f36bf 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling2D.java @@ -85,7 +85,7 @@ public class MaxPooling2D extends DynamicCustomOp { @Override public Map propertiesForFunction() { - if(config == null && iArguments.size() > 0){ + if(config == null && !iArguments.isEmpty()){ //Perhaps loaded from FlatBuffers - hence we have IArgs but not Config object config = Pooling2DConfig.builder() .kH(iArguments.get(0)) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling3D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling3D.java index 6c7aec888..e931d1583 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling3D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPooling3D.java @@ -75,7 +75,7 @@ public class MaxPooling3D extends Pooling3D { return config.toProperties(); } - + @Override public String getPoolingPrefix() { return "max"; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Upsampling3d.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Upsampling3d.java new file mode 100644 index 000000000..af3358eb7 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Upsampling3d.java @@ -0,0 +1,99 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.layers.convolution; + + +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +/** + * Upsampling3d operation + */ +@Slf4j +@Getter +@NoArgsConstructor +public class Upsampling3d extends DynamicCustomOp { + + + protected boolean ncdhw; + protected int scaleH; + protected int scaleW; + protected int scaleD; + + public Upsampling3d(SameDiff sameDiff, SDVariable input, boolean ncdhw, int scaleD, int scaleH, int scaleW) { + super("upsampling3d",sameDiff, new SDVariable[]{input}); + this.ncdhw = ncdhw; + + this.scaleD = scaleD; + this.scaleH = scaleH; + this.scaleW = scaleW; + + addIArgument(scaleD); + addIArgument(scaleH); + addIArgument(scaleW); + addIArgument(scaleD); + addIArgument(ncdhw ? 1 : 0); + } + + + + + public Upsampling3d(INDArray input, boolean ncdhw, int scaleH, int scaleW, int scaleD) { + super(new INDArray[]{input}, null); + this.ncdhw = ncdhw; + + this.scaleD = scaleD; + this.scaleH = scaleH; + this.scaleW = scaleW; + + addIArgument(scaleD); + addIArgument(scaleH); + addIArgument(scaleW); + addIArgument(scaleD); + addIArgument(ncdhw ? 0 : 1); + } + + + + @Override + public String opName() { + return "upsampling3d"; + } + + + + @Override + public List doDiff(List f1) { + return Arrays.asList(new Upsampling3dBp(sameDiff, arg(0), f1.get(0), this.ncdhw).outputVariables()); + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected 1 input data type for %s, got %s", getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Upsampling3dBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Upsampling3dBp.java new file mode 100644 index 000000000..ddd1147df --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Upsampling3dBp.java @@ -0,0 +1,54 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.layers.convolution; + +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import java.util.Collections; +import java.util.List; + +@NoArgsConstructor +public class Upsampling3dBp extends DynamicCustomOp { + + + + public Upsampling3dBp(SameDiff sameDiff, SDVariable input, SDVariable grad0, boolean ncdhw) { + super("upsampling3d_bp",sameDiff, new SDVariable[]{input, grad0}); + addIArgument(ncdhw ? 1 : 0); + } + + + + + + + @Override + public String opName() { + return "upsampling3d_bp"; + } + + + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 2, "Expected 2 input data type for %s, got %s", getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/BaseConvolutionConfig.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/BaseConvolutionConfig.java index 1fcd49162..71ef8ffcf 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/BaseConvolutionConfig.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/BaseConvolutionConfig.java @@ -18,11 +18,14 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution.config; import java.util.LinkedHashMap; import java.util.Map; + +import lombok.extern.slf4j.Slf4j; import lombok.val; import org.nd4j.linalg.exception.ND4JIllegalStateException; import java.lang.reflect.Field; +@Slf4j public abstract class BaseConvolutionConfig { public abstract Map toProperties(); @@ -61,7 +64,7 @@ public abstract class BaseConvolutionConfig { try { target.set(this, value); } catch (IllegalAccessException e) { - e.printStackTrace(); + log.error("",e); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/GRU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/GRU.java new file mode 100644 index 000000000..0cc62833d --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/GRU.java @@ -0,0 +1,67 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.layers.recurrent; + +import lombok.NoArgsConstructor; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +@NoArgsConstructor +public class GRU extends DynamicCustomOp { + + + public GRU(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable hI, @NonNull SDVariable Wx, @NonNull SDVariable Wh, @NonNull SDVariable biases) { + super(null, sameDiff, new SDVariable[]{x, hI, Wx, Wh, biases}); + + } + + public GRU(@NonNull INDArray x, @NonNull INDArray hI, @NonNull INDArray Wx, @NonNull INDArray Wh, @NonNull INDArray biases) { + super(new INDArray[]{x, hI, Wx, Wh, biases}, null); + + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes) { + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 5, "Expected 5 inputs to GRU: initial cell output, input-to-hidden weights, hidden-to-hidden weights and biases got %s", inputDataTypes); + DataType dt = inputDataTypes.get(1); + for (int i = 0; i < inputDataTypes.size(); i++) { + Preconditions.checkState(inputDataTypes.get(i).isFPType(), "All input types must be a floating point type, got %s", dt); + } + Preconditions.checkState(dt.isFPType(), "Input type 1 must be a floating point type, got %s", dt); + return Collections.singletonList(dt); + } + + @Override + public List doDiff(List grads) { + return Arrays.asList(new GRUBp(sameDiff, arg(0), arg(1), arg(2), arg(3), + arg(4), grads.get(0)).outputVariables()); + } + + + @Override + public String opName() { + return "gru"; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/GRUBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/GRUBp.java new file mode 100644 index 000000000..b667fa811 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/GRUBp.java @@ -0,0 +1,56 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.layers.recurrent; + + +import lombok.NoArgsConstructor; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +@NoArgsConstructor +public class GRUBp extends DynamicCustomOp { + + + public GRUBp(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable hI, @NonNull SDVariable Wx, @NonNull SDVariable Wh, @NonNull SDVariable biases, @NonNull SDVariable dLdh) { + super(null, sameDiff, new SDVariable[]{x, hI, Wx, Wh, biases, dLdh}); + + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes) { + DataType dt = inputDataTypes.get(1); + List list = new ArrayList(); + list.add(dt); + list.add(dt); + list.add(dt); + list.add(dt); + list.add(dt); + return list; + } + + @Override + public String opName() { + return "gru_bp"; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMLayer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMLayer.java index bebbd5f8f..b4567cb0d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMLayer.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMLayer.java @@ -24,12 +24,10 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; -import org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2DBp; import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig; import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights; import org.nd4j.shade.guava.primitives.Booleans; -import javax.xml.crypto.Data; import java.util.ArrayList; import java.util.Arrays; import java.util.List; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMLayerConfig.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMLayerConfig.java index 226150e8b..071b6adf3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMLayerConfig.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMLayerConfig.java @@ -19,8 +19,6 @@ import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; import lombok.NoArgsConstructor; -import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell; -import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer; import java.util.LinkedHashMap; import java.util.Map; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/LSTMLayerOutputs.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/LSTMLayerOutputs.java index d8a2e6e9a..7e2b4fb5a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/LSTMLayerOutputs.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/LSTMLayerOutputs.java @@ -1,19 +1,12 @@ package org.nd4j.linalg.api.ops.impl.layers.recurrent.outputs; -import java.util.Arrays; -import java.util.List; - -import lombok.AccessLevel; import lombok.Getter; -import org.nd4j.autodiff.samediff.SDIndex; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer; import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDataFormat; import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDirectionMode; import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig; -import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.RnnDataFormat; -import org.nd4j.shade.guava.primitives.Booleans; /** * The outputs of a LSTM layer ({@link LSTMLayer}. diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/LSTMLayerWeights.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/LSTMLayerWeights.java index 98985df57..8f0113685 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/LSTMLayerWeights.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/weights/LSTMLayerWeights.java @@ -22,7 +22,6 @@ import lombok.EqualsAndHashCode; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell; import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer; import org.nd4j.linalg.util.ArrayUtil; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/AbsoluteDifferenceLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/AbsoluteDifferenceLoss.java index 4f6539eee..6146630a3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/AbsoluteDifferenceLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/AbsoluteDifferenceLoss.java @@ -19,8 +19,6 @@ package org.nd4j.linalg.api.ops.impl.loss; import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.base.Preconditions; -import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.loss.bp.AbsoluteDifferenceLossBp; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/CosineDistanceLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/CosineDistanceLoss.java index 432910391..21ac3d25c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/CosineDistanceLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/CosineDistanceLoss.java @@ -21,8 +21,6 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.loss.bp.CosineDistanceLossBp; - -import java.util.Arrays; import java.util.List; /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HingeLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HingeLoss.java index d021623d5..3f8613ac5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HingeLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HingeLoss.java @@ -22,7 +22,6 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.loss.bp.HingeLossBp; -import java.util.Arrays; import java.util.List; /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HuberLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HuberLoss.java index acb74c04c..564e726b3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HuberLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HuberLoss.java @@ -23,7 +23,6 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.loss.bp.HuberLossBp; -import java.util.Arrays; import java.util.List; /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanPairwiseSquaredErrorLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanPairwiseSquaredErrorLoss.java index 676eec5e0..b6bfc6748 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanPairwiseSquaredErrorLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanPairwiseSquaredErrorLoss.java @@ -22,7 +22,6 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.loss.bp.MeanPairwiseSquaredErrorLossBp; -import java.util.Arrays; import java.util.List; /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanSquaredErrorLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanSquaredErrorLoss.java index c40d9e432..2bf1c66db 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanSquaredErrorLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanSquaredErrorLoss.java @@ -22,7 +22,6 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.loss.bp.MeanSquaredErrorLossBp; -import java.util.Arrays; import java.util.List; /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/AbsoluteDifferenceLossBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/AbsoluteDifferenceLossBp.java index 54ba231bf..5e3802678 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/AbsoluteDifferenceLossBp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/AbsoluteDifferenceLossBp.java @@ -19,7 +19,6 @@ package org.nd4j.linalg.api.ops.impl.loss.bp; import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.linalg.api.ops.impl.loss.BaseLoss; import java.util.List; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/HingeLossBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/HingeLossBp.java index cdc267f1f..8637c2fc7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/HingeLossBp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/HingeLossBp.java @@ -19,7 +19,6 @@ package org.nd4j.linalg.api.ops.impl.loss.bp; import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.linalg.api.ops.impl.loss.BaseLoss; /** * Hinge loss diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/LogPoissonLossBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/LogPoissonLossBp.java index 5bf5058fd..299aef89b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/LogPoissonLossBp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/LogPoissonLossBp.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.loss.bp; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -25,6 +26,7 @@ import org.nd4j.autodiff.samediff.SameDiff; * * @author Paul Dubs */ +@NoArgsConstructor public class LogPoissonLossBp extends BaseLossBp { private boolean full = false; @@ -39,9 +41,7 @@ public class LogPoissonLossBp extends BaseLossBp { addArgs(); } - public LogPoissonLossBp(){ } - - + @Override protected void addArgs(){ super.addArgs(); if(full){ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/MeanSquaredErrorLossBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/MeanSquaredErrorLossBp.java index 5250f6fb1..9be5038de 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/MeanSquaredErrorLossBp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/bp/MeanSquaredErrorLossBp.java @@ -19,7 +19,6 @@ package org.nd4j.linalg.api.ops.impl.loss.bp; import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.linalg.api.ops.impl.loss.BaseLoss; /** * Mean squared error loss diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsInf.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsInf.java index 5dfc23f8e..4ebb8701d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsInf.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsInf.java @@ -18,7 +18,6 @@ package org.nd4j.linalg.api.ops.impl.reduce.bool; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceBoolOp; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bp/VarianceBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bp/VarianceBp.java index 767392c97..3ab61e96d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bp/VarianceBp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bp/VarianceBp.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.reduce.bp; +import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; @@ -26,7 +27,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; * * @author Alex Black */ - +@NoArgsConstructor public class VarianceBp extends BaseReductionBp { private boolean biasCorrected; @@ -43,8 +44,6 @@ public class VarianceBp extends BaseReductionBp { addTArgument(biasCorrected ? 1.0 : 0.0); } - public VarianceBp(){} - @Override public String opName() { return "reduce_variance_bp"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Mean.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Mean.java index 6309ccf28..a4e27e42f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Mean.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Mean.java @@ -22,7 +22,6 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceFloatOp; import org.nd4j.linalg.api.ops.impl.reduce.bp.MeanBp; -import java.util.Collections; import java.util.List; /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/CountNonZero.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/CountNonZero.java index 7376b0708..cb1de4765 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/CountNonZero.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/CountNonZero.java @@ -19,7 +19,6 @@ package org.nd4j.linalg.api.ops.impl.reduce.longer; import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceLongOp; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/GatherNd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/GatherNd.java index 7ac8429c8..ff38c382e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/GatherNd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/GatherNd.java @@ -17,13 +17,11 @@ package org.nd4j.linalg.api.ops.impl.shape; import lombok.NoArgsConstructor; -import org.apache.commons.lang3.ArrayUtils; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; -import org.nd4j.linalg.util.ArrayUtil; import java.util.Collections; import java.util.List; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java index f83d61c0d..a08b30f74 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java @@ -18,7 +18,6 @@ package org.nd4j.linalg.api.ops.impl.shape; import lombok.NonNull; -import org.apache.commons.lang3.NotImplementedException; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -26,7 +25,6 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; -import org.nd4j.linalg.factory.Nd4j; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeMaxIndex.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeMaxIndex.java new file mode 100644 index 000000000..47f7b606a --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeMaxIndex.java @@ -0,0 +1,85 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.shape; + +import lombok.NoArgsConstructor; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights; +import org.nd4j.linalg.util.ArrayUtil; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + + +@NoArgsConstructor +public class MergeMaxIndex extends DynamicCustomOp { + + private DataType dataType = DataType.INT32; + + public MergeMaxIndex(@NonNull SameDiff sameDiff, @NonNull SDVariable... inputs) { + super("mergemaxindex", sameDiff, inputs); + addIArgument(dataType.toInt()); + + } + + public MergeMaxIndex(@NonNull INDArray... inputs) { + super("mergemaxindex", inputs, null); + Preconditions.checkArgument(areEqualShapes(inputs), "All inputs have to be equal shapes"); + addIArgument(dataType.toInt()); + + } + + public MergeMaxIndex(@NonNull SameDiff sd, @NonNull SDVariable[] x, @NonNull DataType dataType) { + super("mergemaxindex", sd, x); + this.dataType = dataType; + addIArgument(dataType.toInt()); + } + + public MergeMaxIndex(@NonNull INDArray[] x, @NonNull DataType dataType) { + super(x, null); + Preconditions.checkArgument(areEqualShapes(x), "All inputs have to be equal shapes"); + this.dataType = dataType; + addIArgument(dataType.toInt()); + + } + + + protected static boolean areEqualShapes(INDArray... inputs) { + for (INDArray input : inputs) { + if (!inputs[0].equalShapes(input)) { + return false; + } + } + return true; + } + + @Override + public String opName() { + return "mergemaxindex"; + } + + @Override + public List calculateOutputDataTypes(List dataTypes) { + return Collections.singletonList(this.dataType); + } +} \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/Variance.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/Variance.java index 64948880c..9f69f8e6a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/Variance.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/Variance.java @@ -165,10 +165,7 @@ public class Variance extends BaseReduceOp { return false; INDArray z = oc != null ? oc.getOutputArray(0) : z(); - if (z != null && !z.isR()) - return false; - - return true; + return !(z != null && !z.isR()); } @Override @@ -201,6 +198,7 @@ public class Variance extends BaseReduceOp { return Type.VARIANCE; } + @Override public List calculateOutputDataTypes(List dataTypes){ Preconditions.checkState(dataTypes != null && dataTypes.size() == 1, "Expected exactly 1 input datatype for %s, got input %s", getClass(), dataTypes); //Variance and stdev reduction: Always FP out, but if FP in is float/double/half then it's float/double/half out diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Histogram.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Histogram.java index 31937ded7..fdd48e133 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Histogram.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Histogram.java @@ -19,7 +19,6 @@ package org.nd4j.linalg.api.ops.impl.transforms; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; -import org.nd4j.linalg.factory.Nd4j; /** * Histogram op wrapper diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsFinite.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsFinite.java index 8cf81febf..a8be9803a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsFinite.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsFinite.java @@ -21,7 +21,6 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformBoolOp; -import org.nd4j.linalg.api.ops.BaseTransformOp; import java.util.Collections; import java.util.List; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsNaN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsNaN.java index 9f8e9ea74..d5ff3cd6a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsNaN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsNaN.java @@ -17,12 +17,10 @@ package org.nd4j.linalg.api.ops.impl.transforms.bool; import lombok.NoArgsConstructor; -import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformBoolOp; -import org.nd4j.linalg.api.ops.BaseTransformOp; import java.util.Collections; import java.util.List; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/MatchConditionTransform.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/MatchConditionTransform.java index b3810ba15..d0836adc9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/MatchConditionTransform.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/MatchConditionTransform.java @@ -17,13 +17,11 @@ package org.nd4j.linalg.api.ops.impl.transforms.bool; import lombok.NonNull; -import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformBoolOp; -import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.conditions.Condition; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/InvertPermutation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/InvertPermutation.java index 9612c4dea..dbf3bfdd4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/InvertPermutation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/InvertPermutation.java @@ -25,7 +25,6 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; -import java.util.Arrays; import java.util.Collections; import java.util.List; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MaximumBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MaximumBp.java index 92fb3b0eb..75683c72e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MaximumBp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MaximumBp.java @@ -19,7 +19,6 @@ import lombok.NoArgsConstructor; import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ops.DynamicCustomOp; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Reverse.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Reverse.java index 372f96c18..bd35c6813 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Reverse.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Reverse.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -23,6 +24,7 @@ import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.shape.bp.MergeAvgBp; import java.util.Arrays; import java.util.Collections; @@ -30,8 +32,8 @@ import java.util.List; public class Reverse extends DynamicCustomOp { - public Reverse(SameDiff sameDiff, SDVariable i_v, int... dimensions) { - super(null, sameDiff, new SDVariable[]{i_v}, false); + public Reverse(@NonNull SameDiff sameDiff, @NonNull SDVariable i_v, @NonNull int... dimensions) { + super(sameDiff, new SDVariable[]{i_v}); this.dimensions = dimensions; addIArgument(dimensions); } @@ -56,6 +58,7 @@ public class Reverse extends DynamicCustomOp { public Reverse(INDArray x, int... axis){ super(new INDArray[]{x}, new INDArray[0]); this.inPlace = false; + this.dimensions = axis; addIArgument(axis); } @@ -67,6 +70,7 @@ public class Reverse extends DynamicCustomOp { public Reverse(INDArray x, INDArray z, int... axis){ super(new INDArray[]{x}, new INDArray[] {z}); this.inPlace = false; + this.dimensions = axis; addIArgument(axis); } @@ -100,8 +104,7 @@ public class Reverse extends DynamicCustomOp { @Override public List doDiff(List f1) { - SDVariable ret = sameDiff.reverse(f1.get(0), dimensions); - return Collections.singletonList(ret); + return new ReverseBp(sameDiff, arg(0), f1.get(0), dimensions).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ReverseBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ReverseBp.java new file mode 100644 index 000000000..c5a30f7e0 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ReverseBp.java @@ -0,0 +1,47 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import lombok.NoArgsConstructor; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +@NoArgsConstructor +public class ReverseBp extends DynamicCustomOp { + public ReverseBp(@NonNull SameDiff sameDiff, @NonNull SDVariable i_v, @NonNull SDVariable grad, @NonNull int... dimensions) { + super(sameDiff, new SDVariable[]{i_v, grad}); + addIArgument(dimensions); + } + + + @Override + public String opName() { + return "reverse_bp"; + } + + + @Override + public List calculateOutputDataTypes(List inputDataTypes) { + return Collections.singletonList(inputDataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/CubeDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/CubeDerivative.java index af9985e89..60b5d9161 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/CubeDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/CubeDerivative.java @@ -20,7 +20,6 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import java.util.List; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardSigmoidDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardSigmoidDerivative.java index c7328a92b..75c8a25fb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardSigmoidDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardSigmoidDerivative.java @@ -21,7 +21,6 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import java.util.List; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardTanhDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardTanhDerivative.java index 2ca198506..6bf3a95ac 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardTanhDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardTanhDerivative.java @@ -22,10 +22,8 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; -import java.util.Arrays; import java.util.Collections; import java.util.List; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MergeAddOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MergeAddOp.java index 0d634766e..6edf4d002 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MergeAddOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MergeAddOp.java @@ -18,18 +18,14 @@ package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic; import lombok.NoArgsConstructor; import lombok.NonNull; -import org.apache.commons.lang3.ArrayUtils; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; -import org.nd4j.linalg.api.ops.impl.transforms.custom.MaximumBp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.MergeAddBp; -import org.nd4j.linalg.util.ArrayUtil; -import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/bp/MergeAddBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/bp/MergeAddBp.java index b0403ecff..8df28f40b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/bp/MergeAddBp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/bp/MergeAddBp.java @@ -24,7 +24,6 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; @NoArgsConstructor diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Cube.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Cube.java index b4550cb4e..1a7e66dc2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Cube.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Cube.java @@ -16,18 +16,13 @@ package org.nd4j.linalg.api.ops.impl.transforms.same; -import java.util.Collections; - import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformSameOp; import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp; - -import java.util.Arrays; import java.util.List; /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Max.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Max.java index f5aec6b48..9db42dcfc 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Max.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Max.java @@ -21,10 +21,8 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformSameOp; -import org.nd4j.linalg.api.ops.impl.reduce.bp.MaxBp; import org.nd4j.linalg.api.ops.impl.transforms.custom.MaximumBp; -import java.util.Collections; import java.util.List; /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ACos.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ACos.java index 44dca1ed1..88fae5f99 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ACos.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ACos.java @@ -21,8 +21,6 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformFloatOp; -import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import java.util.Arrays; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ACosh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ACosh.java index a8d9f12ad..7273454bc 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ACosh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ACosh.java @@ -21,8 +21,6 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformFloatOp; -import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import java.util.Arrays; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ASin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ASin.java index fc514a415..9174d59a6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ASin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ASin.java @@ -21,11 +21,8 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformFloatOp; -import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; -import java.util.Arrays; import java.util.Collections; import java.util.List; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ASinh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ASinh.java index 25f20e011..f394232b1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ASinh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ASinh.java @@ -20,8 +20,6 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformFloatOp; -import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import java.util.Arrays; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ATan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ATan.java index a7a741759..5d311cb5e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ATan.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ATan.java @@ -21,8 +21,6 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformFloatOp; -import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import java.util.Arrays; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ATanh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ATanh.java index b73884135..df094b295 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ATanh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ATanh.java @@ -20,8 +20,6 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformFloatOp; -import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import java.util.Arrays; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cos.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cos.java index 35ed040c2..2efe4020b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cos.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cos.java @@ -21,8 +21,6 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformFloatOp; -import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import java.util.Arrays; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cosh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cosh.java index 5144315da..dd942d898 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cosh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cosh.java @@ -21,8 +21,6 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformFloatOp; -import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import java.util.Arrays; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java index cc3a5e116..1c3c430d1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java @@ -25,7 +25,6 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp; -import java.util.Collections; import java.util.List; /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardTanh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardTanh.java index fc44f3c22..e0386cb37 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardTanh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardTanh.java @@ -16,18 +16,13 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; -import java.util.Collections; - import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformFloatOp; -import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhBp; -import java.util.Arrays; import java.util.List; /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/LogSigmoid.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/LogSigmoid.java index 6a118a062..33041438c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/LogSigmoid.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/LogSigmoid.java @@ -20,12 +20,9 @@ import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformFloatOp; -import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative; -import java.util.Arrays; import java.util.Collections; import java.util.List; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BinomialDistribution.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BinomialDistribution.java index 93e4e3c66..29a17aa29 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BinomialDistribution.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/BinomialDistribution.java @@ -47,6 +47,7 @@ public class BinomialDistribution extends BaseRandomOp { public BinomialDistribution(SameDiff sd, int trials, double probability, DataType dataType, long[] shape){ this(sd, trials, probability, shape); + super.dataType = dataType; } public BinomialDistribution(int trials, double probability, DataType dt, long[] shape){ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/DropOut.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/DropOut.java index 742e28113..e94e775eb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/DropOut.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/DropOut.java @@ -20,13 +20,9 @@ import lombok.NoArgsConstructor; import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.random.BaseRandomOp; - -import java.util.LinkedHashMap; import java.util.List; -import java.util.Map; /** * DropOut implementation as Op diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/DataSet.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/DataSet.java index 5cfecc6fe..1a4e97d22 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/DataSet.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/DataSet.java @@ -329,7 +329,7 @@ public class DataSet implements org.nd4j.linalg.dataset.api.DataSet { dos.flush(); dos.close(); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/MiniBatchFileDataSetIterator.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/MiniBatchFileDataSetIterator.java index 424119e32..3673e3577 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/MiniBatchFileDataSetIterator.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/MiniBatchFileDataSetIterator.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.dataset; +import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.FileUtils; import org.nd4j.linalg.dataset.api.DataSetPreProcessor; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; @@ -23,6 +24,7 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.util.ND4JFileUtils; +import javax.annotation.processing.SupportedAnnotationTypes; import java.io.*; import java.util.ArrayList; import java.util.List; @@ -32,6 +34,7 @@ import java.util.UUID; * Mini batch file datasetiterator * auto partitions a dataset in to mini batches */ +@Slf4j public class MiniBatchFileDataSetIterator implements DataSetIterator { private int batchSize; private List paths; @@ -75,7 +78,7 @@ public class MiniBatchFileDataSetIterator implements DataSetIterator { try { FileUtils.deleteDirectory(MiniBatchFileDataSetIterator.this.rootDir); } catch (IOException e) { - e.printStackTrace(); + log.error("",e); } } })); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/AbstractDataSetNormalizer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/AbstractDataSetNormalizer.java index 1eacd626e..5dd6bd641 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/AbstractDataSetNormalizer.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/AbstractDataSetNormalizer.java @@ -154,13 +154,13 @@ public abstract class AbstractDataSetNormalizer exten @Override public void transform(INDArray features, INDArray featuresMask) { - S featureStats = getFeatureStats(); + S featureStatsLocal = getFeatureStats(); - if(featureStats == null){ + if(featureStatsLocal == null){ throw new ND4JIllegalStateException("Features statistics were not yet calculated. Make sure to run fit() first."); } - strategy.preProcess(features, featuresMask, featureStats); } + strategy.preProcess(features, featuresMask, featureStatsLocal); } /** * Transform the labels. If {@link #isFitLabel()} == false, this is a no-op diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/CompositeDataSetPreProcessor.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/CompositeDataSetPreProcessor.java index 14a8bf489..7be4d7d3b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/CompositeDataSetPreProcessor.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/CompositeDataSetPreProcessor.java @@ -19,8 +19,6 @@ package org.nd4j.linalg.dataset.api.preprocessor; import org.nd4j.base.Preconditions; import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.dataset.api.DataSetPreProcessor; -import org.nd4j.linalg.dataset.api.MultiDataSet; -import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor; /** * A simple Composite DataSetPreProcessor - allows you to apply multiple DataSetPreProcessors sequentially diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index 2b1021742..6c371fe01 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.factory; +import lombok.extern.slf4j.Slf4j; import org.nd4j.linalg.factory.ops.*; import org.nd4j.shade.guava.primitives.Ints; import org.nd4j.shade.guava.primitives.Longs; @@ -112,6 +113,7 @@ import java.util.logging.Logger; * * @author Adam Gibson */ +@Slf4j public class Nd4j { /** @@ -2460,7 +2462,7 @@ public class Nd4j { //noinspection ConstantConditions newArr.addi((format.parse(entries[0])).doubleValue()); } catch (ParseException e) { - e.printStackTrace(); + log.error("",e); } } else { Preconditions.checkState(entries.length == theShape[rank-1], "Invalid number of entries - format does not match expected shape." + @@ -2470,7 +2472,7 @@ public class Nd4j { BigDecimal number = (BigDecimal) format.parse(entries[i]); subsetArr[i] = number.doubleValue(); } catch (ParseException e) { - e.printStackTrace(); + log.error("",e); } } INDArray subTensor = Nd4j.create(subsetArr, new long[]{subsetArr.length}, Nd4j.defaultFloatingPointType()); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4jBackend.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4jBackend.java index 7575c1238..c3b9ea558 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4jBackend.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4jBackend.java @@ -196,7 +196,7 @@ public abstract class Nd4jBackend { try { Nd4jContext.getInstance().updateProperties(backend.getConfigurationResource().getInputStream()); } catch (IOException e) { - e.printStackTrace(); + log.error("",e); } if(logInit) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDCNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDCNN.java index 1e3c89111..103e761be 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDCNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDCNN.java @@ -508,4 +508,19 @@ public class NDCNN { NDValidation.validateNumerical("upsampling2d", "input", input); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling2d(input, scaleH, scaleW, nchw))[0]; } + + /** + * 3D Convolution layer operation - Upsampling 3d
+ * + * @param input Input in NCHW format (NUMERIC type) + * @param ncdhw If true: input is in NCDHW (minibatch, channels, depth, height, width) format. False: NDHWC format + * @param scaleD Scale to upsample in depth dimension + * @param scaleH Scale to upsample in height dimension + * @param scaleW Scale to upsample in width dimension + * @return output Upsampled input (NUMERIC type) + */ + public INDArray upsampling3d(INDArray input, boolean ncdhw, int scaleD, int scaleH, int scaleW) { + NDValidation.validateNumerical("upsampling3d", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling3d(input, ncdhw, scaleD, scaleH, scaleW))[0]; + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDImage.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDImage.java index 536633cd2..03b9f8571 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDImage.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDImage.java @@ -138,7 +138,7 @@ public class NDImage { /** * Resize images to size using the specified method.
* - * @param input 4D image [NCHW] (NUMERIC type) + * @param input 4D image [NHWC] (NUMERIC type) * @param size new height and width (INT type) * @param preserveAspectRatio Whether to preserve the aspect ratio. If this is set, then images will be resized to a size that fits in size while preserving the aspect ratio of the original image. Scales up the image if size is bigger than the current size of the image. Defaults to False. * @param antialis Whether to use an anti-aliasing filter when downsampling an image @@ -161,7 +161,7 @@ public class NDImage { /** * Resize images to size using the specified method.
* - * @param input 4D image [NCHW] (NUMERIC type) + * @param input 4D image [NHWC] (NUMERIC type) * @param size new height and width (INT type) * @param ImageResizeMethod ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling. * ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing. diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDLinalg.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDLinalg.java index cb80c8092..0112515dd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDLinalg.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDLinalg.java @@ -20,6 +20,7 @@ package org.nd4j.linalg.factory.ops; import static org.nd4j.linalg.factory.NDValidation.isSameType; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.NDValidation; import org.nd4j.linalg.factory.Nd4j; @@ -271,4 +272,51 @@ public class NDLinalg { NDValidation.validateNumerical("svd", "input", input); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Svd(input, fullUV, computeUV, 16))[0]; } + + /** + * An array with ones at and below the given diagonal and zeros elsewhere.
+ * + * @param dataType Data type + * @param row + * @param column + * @param diagonal + * @return output (FLOATING_POINT type) + */ + public INDArray tri(DataType dataType, int row, int column, int diagonal) { + return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.Tri(dataType, row, column, diagonal))[0]; + } + + /** + * An array with ones at and below the given diagonal and zeros elsewhere.
+ * + * @param row + * @param column + * @return output (FLOATING_POINT type) + */ + public INDArray tri(int row, int column) { + return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.Tri(DataType.FLOAT, row, column, 0))[0]; + } + + /** + * Upper triangle of an array. Return a copy of a input tensor with the elements below the k-th diagonal zeroed.
+ * + * @param input (NUMERIC type) + * @param diag + * @return output (FLOATING_POINT type) + */ + public INDArray triu(INDArray input, int diag) { + NDValidation.validateNumerical("triu", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.Triu(input, diag))[0]; + } + + /** + * Upper triangle of an array. Return a copy of a input tensor with the elements below the k-th diagonal zeroed.
+ * + * @param input (NUMERIC type) + * @return output (FLOATING_POINT type) + */ + public INDArray triu(INDArray input) { + NDValidation.validateNumerical("triu", "input", input); + return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.Triu(input, 0))[0]; + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java index 1deddfd0a..b00da5e04 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java @@ -50,16 +50,41 @@ public class NDMath { * Looks up ids in a list of embedding tensors.
* * @param x Input tensor (NUMERIC type) - * @param indices A Tensor containing the ids to be looked up. (NUMERIC type) + * @param indices A Tensor containing the ids to be looked up. (INT type) * @param PartitionMode partition_mode == 0 - i.e. 'mod' , 1 - 'div' * @return output Shifted output (NUMERIC type) */ public INDArray embeddingLookup(INDArray x, INDArray indices, PartitionMode PartitionMode) { NDValidation.validateNumerical("EmbeddingLookup", "x", x); - NDValidation.validateNumerical("EmbeddingLookup", "indices", indices); + NDValidation.validateInteger("EmbeddingLookup", "indices", indices); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup(x, indices, PartitionMode))[0]; } + /** + * Return array of max elements indices with along tensor dimensions
+ * + * @param x Input tensor (NUMERIC type) + * @param dataType Data type + * @return output Array max elements indices with along dimensions. (INT type) + */ + public INDArray mergeMaxIndex(INDArray[] x, DataType dataType) { + NDValidation.validateNumerical("MergeMaxIndex", "x", x); + Preconditions.checkArgument(x.length >= 1, "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.MergeMaxIndex(x, dataType))[0]; + } + + /** + * Return array of max elements indices with along tensor dimensions
+ * + * @param x Input tensor (NUMERIC type) + * @return output Array max elements indices with along dimensions. (INT type) + */ + public INDArray mergeMaxIndex(INDArray... x) { + NDValidation.validateNumerical("MergeMaxIndex", "x", x); + Preconditions.checkArgument(x.length >= 1, "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.MergeMaxIndex(x, DataType.INT))[0]; + } + /** * Elementwise absolute value operation: out = abs(x)
* diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDRNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDRNN.java index 9bb7d9640..6dee1ef7e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDRNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDRNN.java @@ -34,6 +34,25 @@ public class NDRNN { public NDRNN() { } + /** + * The GRU operation. Gated Recurrent Unit - Cho et al. 2014.
+ * + * @param x input [time, bS, nIn] (NUMERIC type) + * @param hLast initial cell output (at time step = 0) [bS, nOut] (NUMERIC type) + * @param Wx input-to-hidden weights, [nIn, 3*nOut] (NUMERIC type) + * @param Wh hidden-to-hidden weights, [nOut, 3*nOut] (NUMERIC type) + * @param biases biases, [3*nOut] (NUMERIC type) + * @return h cell outputs [time, bS, nOut], that is per each time step (NUMERIC type) + */ + public INDArray gru(INDArray x, INDArray hLast, INDArray Wx, INDArray Wh, INDArray biases) { + NDValidation.validateNumerical("gru", "x", x); + NDValidation.validateNumerical("gru", "hLast", hLast); + NDValidation.validateNumerical("gru", "Wx", Wx); + NDValidation.validateNumerical("gru", "Wh", Wh); + NDValidation.validateNumerical("gru", "biases", biases); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.GRU(x, hLast, Wx, Wh, biases))[0]; + } + /** * The GRU cell. Does a single time step operation
* @@ -41,9 +60,9 @@ public class NDRNN { * @param hLast Output of the previous cell/time step, with shape [batchSize, numUnits] (NUMERIC type) * @param GRUWeights Configuration Object */ - public INDArray[] gru(INDArray x, INDArray hLast, GRUWeights GRUWeights) { - NDValidation.validateNumerical("gru", "x", x); - NDValidation.validateNumerical("gru", "hLast", hLast); + public INDArray[] gruCell(INDArray x, INDArray hLast, GRUWeights GRUWeights) { + NDValidation.validateNumerical("gruCell", "x", x); + NDValidation.validateNumerical("gruCell", "hLast", hLast); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell(x, hLast, GRUWeights)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossSparseMCXENT.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossSparseMCXENT.java index f472fae5f..c6879a133 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossSparseMCXENT.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossSparseMCXENT.java @@ -22,15 +22,9 @@ import lombok.Getter; import lombok.Setter; import org.nd4j.base.Preconditions; import org.nd4j.linalg.activations.IActivation; -import org.nd4j.linalg.activations.impl.ActivationSoftmax; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.shape.OneHot; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.indexing.BooleanIndexing; -import org.nd4j.linalg.indexing.conditions.Conditions; -import org.nd4j.linalg.lossfunctions.ILossFunction; -import org.nd4j.linalg.lossfunctions.LossUtil; -import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.linalg.primitives.Pair; import org.nd4j.serde.jackson.shaded.NDArrayTextDeSerializer; import org.nd4j.serde.jackson.shaded.NDArrayTextSerializer; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java index ad887789d..c63289e2c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java @@ -793,8 +793,9 @@ public class Transforms { * @return */ public static INDArray lessThanOrEqual(INDArray first, INDArray ndArray, boolean dup) { - return Nd4j.getExecutioner().exec(new LessThanOrEqual(first, ndArray, Nd4j.createUninitialized(DataType.BOOL, first.shape(), first.ordering())))[0]; - + val op = dup ? new LessThanOrEqual(first, ndArray, Nd4j.createUninitialized(DataType.BOOL, first.shape(), first.ordering())) : + new LessThanOrEqual(first, ndArray); + return Nd4j.getExecutioner().exec(op)[0]; } @@ -805,7 +806,9 @@ public class Transforms { * @return */ public static INDArray greaterThanOrEqual(INDArray first, INDArray ndArray, boolean dup) { - return Nd4j.getExecutioner().exec(new GreaterThanOrEqual(first, ndArray, Nd4j.createUninitialized(DataType.BOOL, first.shape(), first.ordering())))[0]; + val op = dup ? new GreaterThanOrEqual(first, ndArray, Nd4j.createUninitialized(DataType.BOOL, first.shape(), first.ordering())) : + new GreaterThanOrEqual(first, ndArray); + return Nd4j.getExecutioner().exec(op)[0]; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/serde/binary/BinarySerde.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/serde/binary/BinarySerde.java index ecc4cfe18..fc6f13922 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/serde/binary/BinarySerde.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/serde/binary/BinarySerde.java @@ -262,7 +262,7 @@ public class BinarySerde { try (WritableByteChannel channel = Channels.newChannel(outputStream)) { channel.write(buffer); } catch (IOException e) { - e.printStackTrace(); + log.error("",e); } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/constant/ProtectedCudaConstantHandler.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/constant/ProtectedCudaConstantHandler.java index af7e33f9c..0b2f1fca1 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/constant/ProtectedCudaConstantHandler.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/constant/ProtectedCudaConstantHandler.java @@ -39,6 +39,7 @@ import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.nativeblas.NativeOpsHolder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import lombok.extern.slf4j.Slf4j; import java.util.HashMap; import java.util.Map; @@ -49,6 +50,7 @@ import java.util.concurrent.atomic.AtomicLong; /** * Created by raver on 08.06.2016. */ +@Slf4j public class ProtectedCudaConstantHandler implements ConstantHandler { private static ProtectedCudaConstantHandler ourInstance = new ProtectedCudaConstantHandler(); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/workspace/CudaWorkspace.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/workspace/CudaWorkspace.java index 048a6d4c5..e8d9b2d18 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/workspace/CudaWorkspace.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/workspace/CudaWorkspace.java @@ -413,7 +413,7 @@ public class CudaWorkspace extends Nd4jWorkspace { @Override public String getUniqueId() { - return "Workspace_" + getId(); + return "Workspace_" + getId() + "_" + Nd4j.getDeallocatorService().nextValue(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/CublasPointer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/CublasPointer.java index fe5f51080..d7109f5cb 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/CublasPointer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/CublasPointer.java @@ -22,6 +22,7 @@ import org.nd4j.jita.allocator.impl.AtomicAllocator; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.jcublas.buffer.JCudaBuffer; import org.nd4j.linalg.jcublas.context.CudaContext; +import lombok.extern.slf4j.Slf4j; /** * Wraps the allocation @@ -29,6 +30,7 @@ import org.nd4j.linalg.jcublas.context.CudaContext; * @author bam4d * */ +@Slf4j public class CublasPointer implements AutoCloseable { /** @@ -166,7 +168,7 @@ public class CublasPointer implements AutoCloseable { try { pointer.close(); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); } } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaFloatDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaFloatDataBuffer.java index a6aca24c2..cec6ca762 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaFloatDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaFloatDataBuffer.java @@ -22,6 +22,7 @@ import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.util.ArrayUtil; +import lombok.extern.slf4j.Slf4j; import java.io.ByteArrayOutputStream; import java.io.DataOutputStream; @@ -33,6 +34,7 @@ import java.nio.ByteBuffer; * * @author Adam Gibson */ +@Slf4j public class CudaFloatDataBuffer extends BaseCudaDataBuffer { /** * Meant for creating another view of a buffer @@ -169,7 +171,7 @@ public class CudaFloatDataBuffer extends BaseCudaDataBuffer { try { dos.writeFloat(data[i]); } catch (IOException e) { - e.printStackTrace(); + log.error("",e); } return bos.toByteArray(); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/workspace/CpuWorkspace.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/workspace/CpuWorkspace.java index b773b7964..05c09ae85 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/workspace/CpuWorkspace.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/workspace/CpuWorkspace.java @@ -62,7 +62,7 @@ public class CpuWorkspace extends Nd4jWorkspace implements Deallocatable { public String getUniqueId() { - return "Workspace_" + getId(); + return "Workspace_" + getId() + "_" + Nd4j.getDeallocatorService().nextValue(); } @Override @@ -92,7 +92,6 @@ public class CpuWorkspace extends Nd4jWorkspace implements Deallocatable { if (currentSize.get() > 0) { isInit.set(true); - if (isDebug.get()) log.info("Allocating [{}] workspace of {} bytes...", id, currentSize.get()); @@ -139,6 +138,13 @@ public class CpuWorkspace extends Nd4jWorkspace implements Deallocatable { } } + protected long mappedFileSize() { + if (workspaceConfiguration.getPolicyLocation() != LocationPolicy.MMAP) + return 0; + + return tempFile.length(); + } + @Override protected void clearExternalAllocations() { if (isDebug.get()) diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/workspace/CpuWorkspaceDeallocator.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/workspace/CpuWorkspaceDeallocator.java index 40553448a..436244ff8 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/workspace/CpuWorkspaceDeallocator.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/workspace/CpuWorkspaceDeallocator.java @@ -18,11 +18,15 @@ package org.nd4j.linalg.cpu.nativecpu.workspace; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; +import org.bytedeco.javacpp.LongPointer; +import org.bytedeco.javacpp.Pointer; import org.nd4j.linalg.api.memory.Deallocator; import org.nd4j.linalg.api.memory.enums.LocationPolicy; import org.nd4j.linalg.api.memory.enums.MemoryKind; import org.nd4j.linalg.api.memory.pointers.PointersPair; import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.primitives.Pair; +import org.nd4j.nativeblas.NativeOpsHolder; import java.util.List; import java.util.Queue; @@ -37,12 +41,16 @@ public class CpuWorkspaceDeallocator implements Deallocator { private Queue pinnedPointers; private List externalPointers; private LocationPolicy location; + private Pair mmapInfo; public CpuWorkspaceDeallocator(@NonNull CpuWorkspace workspace) { this.pointersPair = workspace.workspace(); this.pinnedPointers = workspace.pinnedPointers(); this.externalPointers = workspace.externalPointers(); this.location = workspace.getWorkspaceConfiguration().getPolicyLocation(); + + if (workspace.mappedFileSize() > 0) + this.mmapInfo = Pair.makePair(workspace.mmap, workspace.mappedFileSize()); } @Override @@ -50,7 +58,7 @@ public class CpuWorkspaceDeallocator implements Deallocator { log.trace("Deallocating CPU workspace"); // purging workspace planes - if (pointersPair != null) { + if (pointersPair != null && (pointersPair.getDevicePointer() != null || pointersPair.getHostPointer() != null)) { if (pointersPair.getDevicePointer() != null) { Nd4j.getMemoryManager().release(pointersPair.getDevicePointer(), MemoryKind.DEVICE); } @@ -58,6 +66,8 @@ public class CpuWorkspaceDeallocator implements Deallocator { if (pointersPair.getHostPointer() != null) { if (location != LocationPolicy.MMAP) Nd4j.getMemoryManager().release(pointersPair.getHostPointer(), MemoryKind.HOST); + else + NativeOpsHolder.getInstance().getDeviceNativeOps().munmapFile(null, mmapInfo.getFirst(), mmapInfo.getSecond()); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestOpMapping.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestOpMapping.java index e88f195c0..adac58697 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestOpMapping.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestOpMapping.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -16,6 +17,8 @@ package org.nd4j.autodiff; +import org.apache.commons.io.FileUtils; +import org.apache.commons.io.FilenameUtils; import org.junit.Ignore; import org.junit.Test; import org.nd4j.autodiff.functions.DifferentialFunction; @@ -23,20 +26,68 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.converters.ImportClassMapping; import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.NoOp; +import org.nd4j.linalg.api.ops.compat.CompatSparseToDense; +import org.nd4j.linalg.api.ops.compat.CompatStringSplit; +import org.nd4j.linalg.api.ops.custom.*; +import org.nd4j.linalg.api.ops.impl.broadcast.*; +import org.nd4j.linalg.api.ops.impl.broadcast.bool.*; +import org.nd4j.linalg.api.ops.impl.controlflow.compat.*; +import org.nd4j.linalg.api.ops.impl.grid.FreeGridOp; +import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax; +import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin; +import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction; +import org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2DTF; +import org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv3DTF; +import org.nd4j.linalg.api.ops.impl.meta.InvertedPredicateMetaOp; +import org.nd4j.linalg.api.ops.impl.meta.PostulateMetaOp; +import org.nd4j.linalg.api.ops.impl.meta.PredicateMetaOp; +import org.nd4j.linalg.api.ops.impl.meta.ReduceMetaOp; +import org.nd4j.linalg.api.ops.impl.nlp.CbowRound; +import org.nd4j.linalg.api.ops.impl.nlp.SkipGramRound; +import org.nd4j.linalg.api.ops.impl.reduce.HashCode; +import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarSetValue; +import org.nd4j.linalg.api.ops.impl.shape.ApplyGradientDescent; +import org.nd4j.linalg.api.ops.impl.shape.Create; +import org.nd4j.linalg.api.ops.impl.shape.ParallelStack; +import org.nd4j.linalg.api.ops.impl.transforms.any.Assign; +import org.nd4j.linalg.api.ops.impl.transforms.custom.ParallelConcat; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.BinaryMinimalRelativeError; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.BinaryRelativeError; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.CopyOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.PowPairwise; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RealDivOp; +import org.nd4j.linalg.api.ops.impl.updaters.*; +import org.nd4j.linalg.api.ops.persistence.RestoreV2; +import org.nd4j.linalg.api.ops.persistence.SaveV2; +import org.nd4j.linalg.api.ops.util.PrintAffinity; +import org.nd4j.linalg.api.ops.util.PrintVariable; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.lossfunctions.ILossFunction; +import org.nd4j.resources.Resources; import org.reflections.Reflections; +import java.io.File; import java.lang.reflect.Modifier; +import java.nio.charset.StandardCharsets; import java.util.*; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; public class TestOpMapping extends BaseNd4jTest { + Set> subTypes; + public TestOpMapping(Nd4jBackend b){ super(b); + + Reflections reflections = new Reflections("org.nd4j"); + subTypes = reflections.getSubTypesOf(DifferentialFunction.class); } @Override @@ -46,14 +97,13 @@ public class TestOpMapping extends BaseNd4jTest { @Override public long getTimeoutMilliseconds() { - return 60000L; + return 90000L; } + + @Test public void testOpMappingCoverage() throws Exception { - Reflections reflections = new Reflections("org.nd4j"); - Set> subTypes = reflections.getSubTypesOf(DifferentialFunction.class); - Map opNameMapping = ImportClassMapping.getOpNameMapping(); Map tfOpNameMapping = ImportClassMapping.getTFOpMappingFunctions(); Map onnxOpNameMapping = ImportClassMapping.getOnnxOpMappingFunctions(); @@ -102,6 +152,167 @@ public class TestOpMapping extends BaseNd4jTest { } } + @Test + public void testOpsInNamespace() throws Exception { + //Ensure that every op is either in a namespace, OR it's explicitly marked as ignored (i.e., an op that we don't + // want to add to a namespace for some reason) + //Note that we ignore "*Bp", "*Gradient", "*Derivative" etc ops + + String path = FilenameUtils.concat(new File("").getAbsolutePath(), "../nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops"); + path = FilenameUtils.normalize(path); + System.out.println(path); + File dir = new File(path); + Collection c = FileUtils.listFiles(dir, new String[]{"java"}, true); + + String strPattern = " org.nd4j.linalg.api.ops(\\.(\\w)+)+"; + Pattern pattern = Pattern.compile(strPattern); + + + Set seenClasses = new HashSet<>(); + for(File f1 : c){ + List lines = FileUtils.readLines(f1, StandardCharsets.UTF_8); + for(String l : lines){ + Matcher matcher = pattern.matcher(l); + while(matcher.find()){ + int s = matcher.start(); + int e = matcher.end(); + + String str = l.substring(s+1,e); //+1 because pattern starts with space + seenClasses.add(str); + } + } + } + + int countNotSeen = 0; + int countSeen = 0; + List notSeen = new ArrayList<>(); + for(Class cl : subTypes){ + String s = cl.getName(); + + //Backprop/gradient ops should not be in namespaces + if(s.endsWith("Bp") || s.endsWith("BpOp") || s.endsWith("Gradient") || s.endsWith("Derivative") || s.endsWith("Grad")) + continue; + + if(Modifier.isAbstract(cl.getModifiers()) || Modifier.isInterface(cl.getModifiers())) //Skip interfaces and abstract methods + continue; + + if(excludedFromNamespaces.contains(cl)) //Explicitly excluded - don't want in namespaces + continue; + + if(!seenClasses.contains(s)){ +// System.out.println("NOT SEEN: " + s); + notSeen.add(s); + countNotSeen++; + } else { + countSeen++; + } + } + + Collections.sort(notSeen); + System.out.println(String.join("\n", notSeen)); + + System.out.println("Not seen ops: " + countNotSeen); + System.out.println("Seen ops: " + countSeen); + System.out.println("Namespace scan count ops: " + seenClasses.size()); + } + + //Set of classes that we explicitly don't want in a namespace for some reason + private static final Set> excludedFromNamespaces = new HashSet<>(); + static { + Set> s = excludedFromNamespaces; + + //Updaters - used via TrainingConfig, not namespaces + s.add(AdaDeltaUpdater.class); + s.add(AdaGradUpdater.class); + s.add(AdaMaxUpdater.class); + s.add(AdamUpdater.class); + s.add(AmsGradUpdater.class); + s.add(NadamUpdater.class); + s.add(NesterovsUpdater.class); + s.add(RmsPropUpdater.class); + s.add(SgdUpdater.class); + + //Legacy broadcast ops + s.add(BroadcastAddOp.class); + s.add(BroadcastAMax.class); + s.add(BroadcastAMin.class); + s.add(BroadcastCopyOp.class); + s.add(BroadcastDivOp.class); + s.add(BroadcastGradientArgs.class); + s.add(BroadcastMax.class); + s.add(BroadcastMin.class); + s.add(BroadcastMulOp.class); + s.add(BroadcastRDivOp.class); + s.add(BroadcastRSubOp.class); + s.add(BroadcastSubOp.class); + s.add(BroadcastTo.class); + s.add(BroadcastEqualTo.class); + s.add(BroadcastGreaterThan.class); + s.add(BroadcastGreaterThanOrEqual.class); + s.add(BroadcastLessThan.class); + s.add(BroadcastLessThanOrEqual.class); + s.add(BroadcastNotEqual.class); + + //Redundant operations + s.add(ArgMax.class); //IMax already in namespace + s.add(ArgMin.class); //IMin already in namespace + + //Various utility methods, used internally + s.add(DynamicCustomOp.class); + s.add(ExternalErrorsFunction.class); + s.add(GradientBackwardsMarker.class); + s.add(KnnMinDistance.class); + s.add(BinaryRelativeError.class); + s.add(SpTreeCell.class); + s.add(BarnesHutGains.class); + s.add(BinaryMinimalRelativeError.class); + s.add(SkipGramRound.class); + s.add(BarnesHutSymmetrize.class); + s.add(BarnesEdgeForces.class); + s.add(CbowRound.class); + + //For TF compatibility only + s.add(NoOp.class); + s.add(RestoreV2.class); + s.add(ParallelConcat.class); + s.add(ParallelStack.class); + s.add(DeConv2DTF.class); + s.add(DeConv3DTF.class); + s.add(CompatSparseToDense.class); + s.add(CompatStringSplit.class); + s.add(ApplyGradientDescent.class); + s.add(RealDivOp.class); + s.add(SaveV2.class); + + //Control ops, used internally as part of loops etc + s.add(Enter.class); + s.add(Exit.class); + s.add(NextIteration.class); + s.add(LoopCond.class); + s.add(Merge.class); + s.add(Switch.class); + + //MetaOps, grid ops etc not part of public API + s.add(InvertedPredicateMetaOp.class); + s.add(PostulateMetaOp.class); + s.add(PredicateMetaOp.class); + s.add(ReduceMetaOp.class); + s.add(FreeGridOp.class); + + //Others that don't relaly make sense as a namespace method + s.add(CopyOp.class); + s.add(org.nd4j.linalg.api.ops.impl.transforms.pairwise.Set.class); + s.add(PowPairwise.class); //We have custom op Pow already used for this + s.add(Create.class); //Already have zeros, ones, etc for this + s.add(HashCode.class); + s.add(ScalarSetValue.class); + s.add(PrintVariable.class); + s.add(PrintAffinity.class); + s.add(Assign.class); + + + + } @Test @Ignore public void generateOpClassList() throws Exception{ diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java index c83a55d08..964c82c18 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java @@ -38,6 +38,7 @@ import org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2D; import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D; import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2DDerivative; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRU; import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMActivations; import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDataFormat; import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDirectionMode; @@ -1701,6 +1702,35 @@ public class LayerOpValidation extends BaseOpValidation { assertNull(err); } + @Test + public void GRUTestCase() { + int bS = 5; + int nIn = 4; + int nOut = 6; + int time = 2; + + SameDiff sd = SameDiff.create(); + + SDVariable in = sd.var("in", Nd4j.randn(DataType.DOUBLE, time, bS, nIn).muli(10)); + SDVariable hLast = sd.var("cLast", Nd4j.zeros(DataType.DOUBLE, bS, nOut)); + SDVariable Wx = sd.var("Wx", Nd4j.randn(DataType.DOUBLE, nIn, 3*nOut)); + SDVariable Wh = sd.var("Wh", Nd4j.randn(DataType.DOUBLE, nOut, 3*nOut)); + SDVariable biases = sd.var("bias", Nd4j.randn(DataType.DOUBLE, 3*nOut)); + + SDVariable out = new GRU(sd, in, hLast, Wx, Wh,biases).outputVariable(); + + long[] outShapes = new long[]{time,bS, nOut}; + assertArrayEquals(new long[]{time,bS, nOut}, out.eval().shape()); + + sd.setLossVariables(out.std(true)); + String err = OpValidation.validate(new TestCase(sd) + .gradientCheck(true) + ); + + assertNull(err); + + } + diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java index c47d02b04..f3c79db65 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RnnOpValidation.java @@ -227,7 +227,7 @@ public class RnnOpValidation extends BaseOpValidation { .cBias(bc) .build(); - SDVariable[] v = sd.rnn().gru(x, hLast, weights); + SDVariable[] v = sd.rnn().gruCell(x, hLast, weights); List toExec = new ArrayList<>(); for(SDVariable sdv : v){ toExec.add(sdv.name()); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java index df801a95a..73aae94bd 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java @@ -33,6 +33,8 @@ import org.nd4j.autodiff.validation.TestCase; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.custom.Tri; +import org.nd4j.linalg.api.ops.custom.Triu; import org.nd4j.linalg.api.ops.impl.shape.*; import org.nd4j.linalg.api.ops.impl.transforms.custom.Fill; import org.nd4j.linalg.api.shape.LongShapeDescriptor; @@ -2525,4 +2527,49 @@ public class ShapeOpValidation extends BaseOpValidation { assertArrayEquals(exp, out.shape()); } } + + + @Test + public void testMergeMaxIndex() { + + Nd4j.getRandom().setSeed(12345); + SameDiff sd = SameDiff.create(); + SDVariable inputX = sd.var(Nd4j.createFromArray(new float[] {1, 0, 0})); + SDVariable inputY = sd.var(Nd4j.createFromArray(new float[] {0, 1, 0})); + SDVariable inputZ = sd.var(Nd4j.createFromArray(new float[] {0, 0, 1})); + SDVariable out = new MergeMaxIndex(sd, new SDVariable[]{inputX, inputY, inputZ},DataType.INT32).outputVariable(); + INDArray expected = Nd4j.createFromArray(0,1,2); + String err = OpValidation.validate(new TestCase(sd) + .expectedOutput("mergemaxindex", expected) + .gradientCheck(false)); + assertNull(err); + + } + + @Test + public void testTriOp() { + + SameDiff sd = SameDiff.create(); + SDVariable out = new Tri(sd, DataType.INT32, 3, 5, 2).outputVariable(); + INDArray expected = Nd4j.createFromArray(new int[][]{{1, 1, 1, 0, 0}, {1, 1, 1, 1, 0}, {1, 1, 1, 1, 1}}); + String err = OpValidation.validate(new TestCase(sd) + .expectedOutput("tri", expected) + .gradientCheck(false)); + assertNull(err); + } + + @Test + public void testTriuOp() { + + SameDiff sd = SameDiff.create(); + SDVariable input = sd.var(Nd4j.createFromArray(new double[][]{{1,2,3}, {4,5,6}, {7,8,9},{10,11,12}})); + SDVariable out = new Triu(sd, input,-1).outputVariable(); + out.markAsLoss(); + INDArray expected = Nd4j.createFromArray(new double[][]{{1,2,3}, {4,5,6}, {0,8,9},{0,0,12}}); + String err = OpValidation.validate(new TestCase(sd) + .expectedOutput("triu", expected) + .gradientCheck(true)); + assertNull(err); + + } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java index 108ccb3e2..a85f73643 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java @@ -40,11 +40,13 @@ import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.image.ImageResize; import org.nd4j.linalg.api.ops.impl.layers.convolution.DepthToSpace; import org.nd4j.linalg.api.ops.impl.layers.convolution.SpaceToDepth; +import org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling3d; import org.nd4j.linalg.api.ops.impl.scalar.ScalarFMod; import org.nd4j.linalg.api.ops.impl.scalar.ScalarMultiplication; import org.nd4j.linalg.api.ops.impl.shape.Cross; import org.nd4j.linalg.api.ops.impl.shape.MergeAvg; import org.nd4j.linalg.api.ops.impl.shape.MergeMax; +import org.nd4j.linalg.api.ops.impl.shape.MergeMaxIndex; import org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup; import org.nd4j.linalg.api.ops.impl.transforms.Pad; import org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByAvgNorm; @@ -2126,7 +2128,7 @@ public class TransformOpValidation extends BaseOpValidation { }; - SDVariable out = new ImageResize(sd, inputImage, requestedSize, preserveAspectRatio, antialias, method).outputVariable(); + SDVariable out = new ImageResize(sd, inputImage, requestedSize, preserveAspectRatio, antialias, method).outputVariable().std(true); String err = OpValidation.validate(new TestCase(sd) .gradientCheck(false) @@ -2150,7 +2152,7 @@ public class TransformOpValidation extends BaseOpValidation { SDVariable inputY = sd.var(Nd4j.rand(2, 3)); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Max(sd, inputX, inputY).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Max(sd, inputX, inputY).outputVariable().std(true); String err = OpValidation.validate(new TestCase(sd) .gradientCheck(true)); assertNull(err); @@ -2166,7 +2168,7 @@ public class TransformOpValidation extends BaseOpValidation { SDVariable inputX = sd.var(Nd4j.rand(2, 3)); SDVariable inputY = sd.var(Nd4j.rand(2, 3)); SDVariable inputZ = sd.var(Nd4j.rand(2, 3)); - SDVariable out = new MergeAddOp(sd, new SDVariable[]{inputX, inputY, inputZ}).outputVariable(); + SDVariable out = new MergeAddOp(sd, new SDVariable[]{inputX, inputY, inputZ}).outputVariable().std(true); out.markAsLoss(); String err = OpValidation.validate(new TestCase(sd) .gradientCheck(true)); @@ -2183,7 +2185,7 @@ public class TransformOpValidation extends BaseOpValidation { SDVariable inputX = sd.var(Nd4j.rand(2, 3)); SDVariable inputY = sd.var(Nd4j.rand(2, 3)); SDVariable inputZ = sd.var(Nd4j.rand(2, 3)); - SDVariable out = new MergeMax(sd, new SDVariable[]{inputX, inputY, inputZ}).outputVariable(); + SDVariable out = new MergeMax(sd, inputX, inputY, inputZ).outputVariable().std(true); out.markAsLoss(); String err = OpValidation.validate(new TestCase(sd) .gradientCheck(true)); @@ -2201,7 +2203,7 @@ public class TransformOpValidation extends BaseOpValidation { SDVariable inputX = sd.var(Nd4j.rand(2, 3)); SDVariable inputY = sd.var(Nd4j.rand(2, 3)); SDVariable inputZ = sd.var(Nd4j.rand(2, 3)); - SDVariable out = new MergeAvg(sd, new SDVariable[]{inputX, inputY, inputZ}).outputVariable(); + SDVariable out = new MergeAvg(sd, inputX, inputY, inputZ).outputVariable().std(true); out.markAsLoss(); String err = OpValidation.validate(new TestCase(sd) .gradientCheck(true)); @@ -2210,6 +2212,44 @@ public class TransformOpValidation extends BaseOpValidation { } + @Test + public void testReverseBp() { + + Nd4j.getRandom().setSeed(12345); + SameDiff sd = SameDiff.create(); + SDVariable input = sd.var(Nd4j.createFromArray(new double[][]{{2,7}, {3,5}, {4,5}})); + SDVariable out = new Reverse(sd, input,0).outputVariable(); + SDVariable loss = out.std(true); + loss.markAsLoss(); + String err = OpValidation.validate(new TestCase(sd) + .gradientCheck(true)); + assertNull(err); + } + + @Test + public void testUpsampling3dBp() { + + Nd4j.getRandom().setSeed(12345); + for (boolean dataformat : new boolean[]{true, false}) { + + SameDiff sd = SameDiff.create(); + + // NCDHW input + SDVariable input = dataformat ? sd.var(Nd4j.rand(DataType.DOUBLE, 2, 1, 5, 5, 5)) : sd.var(Nd4j.rand(DataType.DOUBLE, 2, 5, 5, 5, 1)); + int scaleD = 2; + int scaleH = 2; + int scaleW = 2; + SDVariable out = new Upsampling3d(sd, input, true, scaleD, scaleH, scaleW).outputVariable().std(true); + out.markAsLoss(); + String err = OpValidation.validate(new TestCase(sd) + .gradientCheck(true)); + assertNull(err); + } + + + } + + } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTrainingTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTrainingTest.java index 3ab940937..75dc77cbc 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTrainingTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTrainingTest.java @@ -16,6 +16,7 @@ package org.nd4j.autodiff.samediff; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import java.util.Collections; @@ -35,6 +36,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.IrisDataSetIterator; import org.nd4j.linalg.dataset.MultiDataSet; +import org.nd4j.linalg.dataset.adapter.SingletonDataSetIterator; import org.nd4j.linalg.dataset.adapter.SingletonMultiDataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; @@ -341,6 +343,39 @@ public class SameDiffTrainingTest extends BaseNd4jTest { History history = sd.fit(new SingletonMultiDataSetIterator(mds), 1); } + @Test + public void testTrainingEvalVarNotReqForLoss(){ + //If a variable is not required for the loss - normally it won't be calculated + //But we want to make sure it IS calculated here - so we can perform evaluation on it + + SameDiff sd = SameDiff.create(); + SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 4); + SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 3); + SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 4, 3)); + SDVariable z = in.mmul(w); + SDVariable out = sd.nn.softmax("softmax", z); + SDVariable loss = sd.loss.logLoss("loss", label, out); + SDVariable notRequiredForLoss = sd.nn.softmax("notRequiredForLoss", z); + + sd.setTrainingConfig(TrainingConfig.builder() + .updater(new Adam(0.001)) + .dataSetFeatureMapping("in") + .dataSetLabelMapping("label") + .trainEvaluation("notRequiredForLoss", 0, new Evaluation()) + .build()); + +// sd.setListeners(new ScoreListener(1)); + + DataSet ds = new DataSet(Nd4j.rand(DataType.FLOAT, 3, 4), Nd4j.createFromArray(new float[][]{{1,0,0}, {0,1,0}, {0,0,1}})); + + History h = sd.fit() + .train(new SingletonDataSetIterator(ds), 4) + .exec(); + + List l = h.trainingEval(Evaluation.Metric.ACCURACY); + assertEquals(4, l.size()); + } + @Override public char ordering() { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java index 04ebaa0d7..12658ede8 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java @@ -1826,4 +1826,78 @@ public class CustomOpsTests extends BaseNd4jTest { } + @Test + public void testBatchNormBpNHWC(){ + //Nd4j.getEnvironment().allowHelpers(false); //Passes if helpers/MKLDNN is disabled + + INDArray in = Nd4j.rand(DataType.FLOAT, 2, 4, 4, 3); + INDArray eps = Nd4j.rand(DataType.FLOAT, in.shape()); + INDArray epsStrided = eps.permute(1,0,2,3).dup().permute(1,0,2,3); + INDArray mean = Nd4j.rand(DataType.FLOAT, 3); + INDArray var = Nd4j.rand(DataType.FLOAT, 3); + INDArray gamma = Nd4j.rand(DataType.FLOAT, 3); + INDArray beta = Nd4j.rand(DataType.FLOAT, 3); + + assertEquals(eps, epsStrided); + + INDArray out1eps = in.like(); + INDArray out1m = mean.like(); + INDArray out1v = var.like(); + + INDArray out2eps = in.like(); + INDArray out2m = mean.like(); + INDArray out2v = var.like(); + + DynamicCustomOp op1 = DynamicCustomOp.builder("batchnorm_bp") + .addInputs(in, mean, var, gamma, beta, eps) + .addOutputs(out1eps, out1m, out1v) + .addIntegerArguments(1, 1, 3) + .addFloatingPointArguments(1e-5) + .build(); + + DynamicCustomOp op2 = DynamicCustomOp.builder("batchnorm_bp") + .addInputs(in, mean, var, gamma, beta, epsStrided) + .addOutputs(out2eps, out2m, out2v) + .addIntegerArguments(1, 1, 3) + .addFloatingPointArguments(1e-5) + .build(); + + Nd4j.exec(op1); + Nd4j.exec(op2); + + assertEquals(out1eps, out2eps); //Fails here + assertEquals(out1m, out2m); + assertEquals(out1v, out2v); + } + + @Test + public void testSpaceToDepthBadStrides(){ + INDArray in = Nd4j.rand(DataType.FLOAT, 2, 3, 6, 6); + INDArray inBadStrides = in.permute(1,0,2,3).dup().permute(1,0,2,3); + assertEquals(in, inBadStrides); + + System.out.println("in: " + in.shapeInfoToString()); + System.out.println("inBadStrides: " + inBadStrides.shapeInfoToString()); + + + INDArray out = Nd4j.create(DataType.FLOAT, 2, 12, 3, 3); + INDArray out2 = out.like(); + + + CustomOp op1 = DynamicCustomOp.builder("space_to_depth") + .addInputs(in) + .addIntegerArguments(2, 0) //nchw = 0, nhwc = 1 + .addOutputs(out) + .build(); + Nd4j.exec(op1); + + CustomOp op2 = DynamicCustomOp.builder("space_to_depth") + .addInputs(inBadStrides) + .addIntegerArguments(2, 0) //nchw = 0, nhwc = 1 + .addOutputs(out2) + .build(); + Nd4j.exec(op2); + + assertEquals(out, out2); + } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/OperationProfilerTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/OperationProfilerTests.java index 22c6e3a52..8fff93cad 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/OperationProfilerTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/profiling/OperationProfilerTests.java @@ -476,7 +476,7 @@ public class OperationProfilerTests extends BaseNd4jTest { Nd4j.exec(op); //Should trigger NaN panic fail(); } catch (Exception e){ - e.printStackTrace(); + log.error("",e); assertTrue(e.getMessage(), e.getMessage().contains("Inf")); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java index 0c43ff9ca..5b357d9a4 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java @@ -930,6 +930,7 @@ public class BasicWorkspaceTests extends BaseNd4jTest { WorkspaceConfiguration mmap = WorkspaceConfiguration.builder() .initialSize(1000000) .policyLocation(LocationPolicy.MMAP) + .policyLearning(LearningPolicy.NONE) .build(); MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(mmap, "M2"); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/SpecialWorkspaceTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/SpecialWorkspaceTests.java index 8d7ebb040..aefbafe53 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/SpecialWorkspaceTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/SpecialWorkspaceTests.java @@ -26,16 +26,14 @@ import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; -import org.nd4j.linalg.api.memory.enums.AllocationPolicy; -import org.nd4j.linalg.api.memory.enums.LearningPolicy; -import org.nd4j.linalg.api.memory.enums.ResetPolicy; -import org.nd4j.linalg.api.memory.enums.SpillPolicy; +import org.nd4j.linalg.api.memory.enums.*; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.api.memory.abstracts.Nd4jWorkspace; +import java.nio.file.Files; import java.util.ArrayList; import java.util.Arrays; @@ -335,6 +333,98 @@ public class SpecialWorkspaceTests extends BaseNd4jTest { assertEquals(exp, res); } + + @Test + public void testMmapedWorkspaceLimits_1() throws Exception { + if (!Nd4j.getEnvironment().isCPU()) + return; + + val tmpFile = Files.createTempFile("some", "file"); + val mmap = WorkspaceConfiguration.builder() + .initialSize(200 * 1024L * 1024L) // 200mbs + .tempFilePath(tmpFile.toAbsolutePath().toString()) + .policyLocation(LocationPolicy.MMAP) + .policyLearning(LearningPolicy.NONE) + .build(); + + try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(mmap, "M2")) { + int twoHundredMbsOfFloats = 52_428_800; // 200mbs % 4 + val addMoreFloats = true; + if (addMoreFloats) { + twoHundredMbsOfFloats += 1_000; + } + + val x = Nd4j.rand(DataType.FLOAT, twoHundredMbsOfFloats); + } + } + + @Test + public void testMmapedWorkspace_Path_Limits_1() throws Exception { + if (!Nd4j.getEnvironment().isCPU()) + return; + + // getting very long file name + val builder = new StringBuilder("long_file_name_"); + for (int e = 0; e < 100; e++) + builder.append("9"); + + + val tmpFile = Files.createTempFile("some", builder.toString()); + val mmap = WorkspaceConfiguration.builder() + .initialSize(200 * 1024L * 1024L) // 200mbs + .tempFilePath(tmpFile.toAbsolutePath().toString()) + .policyLocation(LocationPolicy.MMAP) + .policyLearning(LearningPolicy.NONE) + .build(); + + try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(mmap, "M2")) { + val x = Nd4j.rand(DataType.FLOAT, 1024); + } + } + + @Test + public void testDeleteMappedFile_1() throws Exception { + if (!Nd4j.getEnvironment().isCPU()) + return; + + val tmpFile = Files.createTempFile("some", "file"); + val mmap = WorkspaceConfiguration.builder() + .initialSize(200 * 1024L * 1024L) // 200mbs + .tempFilePath(tmpFile.toAbsolutePath().toString()) + .policyLocation(LocationPolicy.MMAP) + .policyLearning(LearningPolicy.NONE) + .build(); + + try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(mmap, "M2")) { + val x = Nd4j.rand(DataType.FLOAT, 1024); + } + + Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); + + Files.delete(tmpFile); + } + + @Test(expected = IllegalArgumentException.class) + public void testDeleteMappedFile_2() throws Exception { + if (!Nd4j.getEnvironment().isCPU()) + throw new IllegalArgumentException("Don't try to run on CUDA"); + + val tmpFile = Files.createTempFile("some", "file"); + val mmap = WorkspaceConfiguration.builder() + .initialSize(200 * 1024L * 1024L) // 200mbs + .tempFilePath(tmpFile.toAbsolutePath().toString()) + .policyLocation(LocationPolicy.MMAP) + .build(); + + try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(mmap, "M2")) { + val x = Nd4j.rand(DataType.FLOAT, 1024); + } + + Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); + + Files.delete(tmpFile); + } + @Override public char ordering() { return 'c'; diff --git a/nd4j/nd4j-common-tests/pom.xml b/nd4j/nd4j-common-tests/pom.xml index 9f1765e21..ab280ae0a 100644 --- a/nd4j/nd4j-common-tests/pom.xml +++ b/nd4j/nd4j-common-tests/pom.xml @@ -35,7 +35,6 @@ org.reflections reflections ${reflections.version} - com.google.code.findbugs @@ -43,6 +42,14 @@ + + + + org.springframework + spring-core + 5.0.2.RELEASE + + diff --git a/nd4j/nd4j-common-tests/src/main/java/org/nd4j/BaseND4JTest.java b/nd4j/nd4j-common-tests/src/main/java/org/nd4j/BaseND4JTest.java index 40b331ad5..34dd0f832 100644 --- a/nd4j/nd4j-common-tests/src/main/java/org/nd4j/BaseND4JTest.java +++ b/nd4j/nd4j-common-tests/src/main/java/org/nd4j/BaseND4JTest.java @@ -55,7 +55,7 @@ public abstract class BaseND4JTest { * Override this method to set the default timeout for methods in the test class */ public long getTimeoutMilliseconds(){ - return 30000; + return 60_000; } /** diff --git a/nd4j/nd4j-common-tests/src/main/java/org/nd4j/common/tests/ResourceUtils.java b/nd4j/nd4j-common-tests/src/main/java/org/nd4j/common/tests/ResourceUtils.java new file mode 100644 index 000000000..95c09c154 --- /dev/null +++ b/nd4j/nd4j-common-tests/src/main/java/org/nd4j/common/tests/ResourceUtils.java @@ -0,0 +1,122 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.common.tests; + +import org.nd4j.linalg.io.ClassPathResource; +import org.nd4j.resources.Resources; +import org.springframework.core.io.Resource; +import org.springframework.core.io.support.PathMatchingResourcePatternResolver; +import org.springframework.core.io.support.ResourcePatternResolver; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +/** + * Utilities for dealing with class path resources such as test files in JARs + * + * @author Alex Black + */ +public class ResourceUtils { + + private ResourceUtils() { + } + + /** + * List all classpath resource files, optionally recursively, inside the specified path/directory + * The path argument should be a directory. + * Returns the path of the resources, normalized by {@link Resources#normalize(String)} + * + * @param path Path in which to list all files + * @param recursive If true: list all files in subdirectories also. If false: only include files in the specified + * directory, but not any files in subdirectories + * @param includeDirectories If true: include any subdirectories in the returned list of files. False: Only return + * files, not directories + * @param extensions Optional - may be null (or length 0). If null/length 0: files with any extension are returned + * If non-null: only files matching one of the specified extensions are included. + * Extensions can we specified with or without "." - i.e., "csv" and ".csv" are the same + * @return List of files (and optionally directories) in the specified path + */ + public static List listClassPathFiles(String path, boolean recursive, boolean includeDirectories, String... extensions) { + try { + return listClassPathFilesHelper(path, recursive, includeDirectories, extensions); + } catch (IOException e) { + throw new RuntimeException("Error listing class path files", e); + } + } + + private static List listClassPathFilesHelper(String path, boolean recursive, boolean includeDirectories, String... extensions) throws IOException { + ResourcePatternResolver resolver = new PathMatchingResourcePatternResolver(new ClassPathResource(path).getClassLoader()); + + StringBuilder sbPattern = new StringBuilder("classpath*:" + path); + if (recursive) { + sbPattern.append("/**/*"); + } else { + sbPattern.append("/*"); + } + + //Normalize extensions so they are all like ".csv" etc - with leading "." + String[] normExt = null; + if (extensions != null && extensions.length > 0) { + normExt = new String[extensions.length]; + for (int i = 0; i < extensions.length; i++) { + if (!extensions[i].startsWith(".")) { + normExt[i] = "." + extensions[i]; + } else { + normExt[i] = extensions[i]; + } + } + } + + String pattern = sbPattern.toString(); + Resource[] resources = resolver.getResources(pattern); + List out = new ArrayList<>(resources.length); + for (Resource r : resources) { + String origPath = r.getURL().toString(); + int idx = origPath.indexOf(path); + String relativePath = origPath.substring(idx); + String resourcePath = Resources.normalizePath(relativePath); + + + if (resourcePath.endsWith("/")) { + if (includeDirectories) { + out.add(resourcePath); + } else { + continue; //Skip directory + } + } + + if (normExt != null) { + //Check if it matches any of the specified extensions + boolean matches = false; + for (String ext : normExt) { + if (resourcePath.endsWith(ext)) { + matches = true; + break; + } + } + if (matches) { + out.add(resourcePath); + } + } else { + //Include all files + out.add(resourcePath); + } + + } + return out; + } +} diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/config/ND4JSystemProperties.java b/nd4j/nd4j-common/src/main/java/org/nd4j/config/ND4JSystemProperties.java index 14401d691..16c78923f 100644 --- a/nd4j/nd4j-common/src/main/java/org/nd4j/config/ND4JSystemProperties.java +++ b/nd4j/nd4j-common/src/main/java/org/nd4j/config/ND4JSystemProperties.java @@ -16,6 +16,9 @@ package org.nd4j.config; +import java.io.File; +import java.net.URL; + public class ND4JSystemProperties { /** @@ -125,6 +128,22 @@ public class ND4JSystemProperties { */ public static final String RESOURCES_CACHE_DIR = "org.nd4j.test.resources.cache.dir"; + /** + * Applicability: nd4j-common {@link org.nd4j.resources.Resources} class (and hence {@link org.nd4j.resources.strumpf.StrumpfResolver})
+ * Description: When resolving resources from a Strumpf resource file (Example: {@code Resources.asFile("myFile.txt")} + * what should be the connection timeout, as used by {@link org.apache.commons.io.FileUtils#copyURLToFile(URL, File, int, int)}
+ * Default: {@link org.nd4j.resources.strumpf.ResourceFile#DEFAULT_CONNECTION_TIMEOUT} + */ + public static final String RESOURCES_CONNECTION_TIMEOUT = "org.nd4j.resources.download.connectiontimeout"; + + /** + * Applicability: nd4j-common {@link org.nd4j.resources.Resources} class (and hence {@link org.nd4j.resources.strumpf.StrumpfResolver})
+ * Description: When resolving resources from a Strumpf resource file (Example: {@code Resources.asFile("myFile.txt")} + * what should be the connection timeout, as used by {@link org.apache.commons.io.FileUtils#copyURLToFile(URL, File, int, int)}
+ * Default: {@link org.nd4j.resources.strumpf.ResourceFile#DEFAULT_READ_TIMEOUT} + */ + public static final String RESOURCES_READ_TIMEOUT = "org.nd4j.resources.download.readtimeout"; + /** * Applicability: nd4j-common {@link org.nd4j.resources.Resources} class (and hence {@link org.nd4j.resources.strumpf.StrumpfResolver})
* Description: When resolving resources, what local directories should be checked (in addition to the classpath) for files? diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/io/AbstractFileResolvingResource.java b/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/io/AbstractFileResolvingResource.java index c0b5aa853..b6dcc6bf8 100644 --- a/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/io/AbstractFileResolvingResource.java +++ b/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/io/AbstractFileResolvingResource.java @@ -28,6 +28,7 @@ import java.net.URLConnection; public abstract class AbstractFileResolvingResource extends AbstractResource { public AbstractFileResolvingResource() {} + @Override public File getFile() throws IOException { URL url = this.getURL(); return url.getProtocol().startsWith("vfs") @@ -35,6 +36,7 @@ public abstract class AbstractFileResolvingResource extends AbstractResource { : ResourceUtils.getFile(url, this.getDescription()); } + @Override protected File getFileForLastModifiedCheck() throws IOException { URL url = this.getURL(); if (ResourceUtils.isJarURL(url)) { @@ -53,6 +55,7 @@ public abstract class AbstractFileResolvingResource extends AbstractResource { : ResourceUtils.getFile(uri, this.getDescription()); } + @Override public boolean exists() { try { URL ex = this.getURL(); @@ -90,6 +93,7 @@ public abstract class AbstractFileResolvingResource extends AbstractResource { } } + @Override public boolean isReadable() { try { URL ex = this.getURL(); @@ -104,6 +108,7 @@ public abstract class AbstractFileResolvingResource extends AbstractResource { } } + @Override public long contentLength() throws IOException { URL url = this.getURL(); if (ResourceUtils.isFileURL(url)) { @@ -119,6 +124,7 @@ public abstract class AbstractFileResolvingResource extends AbstractResource { } } + @Override public long lastModified() throws IOException { URL url = this.getURL(); if (!ResourceUtils.isFileURL(url) && !ResourceUtils.isJarURL(url)) { diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/ArrayUtil.java b/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/ArrayUtil.java index 39c09e627..cead401b3 100644 --- a/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/ArrayUtil.java +++ b/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/ArrayUtil.java @@ -459,7 +459,7 @@ public class ArrayUtil { * @return the sum of this array */ public static int sum(List add) { - if (add.size() < 1) + if (add.isEmpty()) return 0; int ret = 0; for (int i = 0; i < add.size(); i++) @@ -498,7 +498,7 @@ public class ArrayUtil { * @return the product of this array */ public static int prod(List mult) { - if (mult.size() < 1) + if (mult.isEmpty()) return 0; int ret = 1; for (int i = 0; i < mult.size(); i++) @@ -546,7 +546,7 @@ public class ArrayUtil { * @return the product of this array */ public static long prodLong(List mult) { - if (mult.size() < 1) + if (mult.isEmpty()) return 0; long ret = 1; for (int i = 0; i < mult.size(); i++) @@ -1383,7 +1383,7 @@ public class ArrayUtil { long[] oldShapeB; - if (listB.size() == 0) { + if (listB.isEmpty()) { oldShapeB = new long[] {1}; } else { oldShapeB = Longs.toArray(listB); @@ -2965,7 +2965,7 @@ public class ArrayUtil { // now all even elements will be interleaved with odd elements for (int i = 0; i < result.length; i++) { - if (i % 2 == 0 && indexes.size() >= 1) { + if (i % 2 == 0 && !indexes.isEmpty()) { int idx = indexes.get(0); indexes.remove(0); result[i] = idx; @@ -3004,7 +3004,7 @@ public class ArrayUtil { // now all even elements will be interleaved with odd elements for (int i = 0; i < result.length; i++) { - if (i % 2 == 0 && indexes.size() >= 1) { + if (i % 2 == 0 && !indexes.isEmpty()) { int idx = indexes.get(0); indexes.remove(0); result[i] = idx; @@ -3132,7 +3132,7 @@ public class ArrayUtil { public static T getRandomElement(List list) { - if (list.size() < 1) + if (list.isEmpty()) return null; return list.get(RandomUtils.nextInt(0, list.size())); diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/MathUtils.java b/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/MathUtils.java index c32b43669..51ef18253 100644 --- a/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/MathUtils.java +++ b/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/MathUtils.java @@ -1326,11 +1326,11 @@ public class MathUtils { public static float randomFloatBetween(float begin, float end) { float rand = (float) Math.random(); - return begin + (rand * ((end - begin))); + return begin + (rand * (end - begin)); } public static double randomDoubleBetween(double begin, double end) { - return begin + (Math.random() * ((end - begin))); + return begin + (Math.random() * (end - begin)); } /** diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/resources/Downloader.java b/nd4j/nd4j-common/src/main/java/org/nd4j/resources/Downloader.java index 05c44c29e..19352fc7c 100644 --- a/nd4j/nd4j-common/src/main/java/org/nd4j/resources/Downloader.java +++ b/nd4j/nd4j-common/src/main/java/org/nd4j/resources/Downloader.java @@ -34,30 +34,49 @@ import java.net.URL; */ @Slf4j public class Downloader { + /** + * Default connection timeout in milliseconds when using {@link FileUtils#copyURLToFile(URL, File, int, int)} + */ + public static final int DEFAULT_CONNECTION_TIMEOUT = 60000; + /** + * Default read timeout in milliseconds when using {@link FileUtils#copyURLToFile(URL, File, int, int)} + */ + public static final int DEFAULT_READ_TIMEOUT = 60000; private Downloader(){ } /** - * Download the specified URL to the specified file, and verify that the target MD5 matches - * @param name Name (mainly for providing useful exceptions) - * @param url URL to download - * @param f Destination file - * @param targetMD5 Expected MD5 for file - * @param maxTries Maximum number of download attempts before failing and throwing an exception - * @throws IOException If an error occurs during downloading + * As per {@link #download(String, URL, File, String, int, int, int)} with the connection and read timeouts + * set to their default values - {@link #DEFAULT_CONNECTION_TIMEOUT} and {@link #DEFAULT_READ_TIMEOUT} respectively */ public static void download(String name, URL url, File f, String targetMD5, int maxTries) throws IOException { - download(name, url, f, targetMD5, maxTries, 0); + download(name, url, f, targetMD5, maxTries, DEFAULT_CONNECTION_TIMEOUT, DEFAULT_READ_TIMEOUT); } - private static void download(String name, URL url, File f, String targetMD5, int maxTries, int attempt) throws IOException { + /** + * Download the specified URL to the specified file, and verify that the target MD5 matches + * + * @param name Name (mainly for providing useful exceptions) + * @param url URL to download + * @param f Destination file + * @param targetMD5 Expected MD5 for file + * @param maxTries Maximum number of download attempts before failing and throwing an exception + * @param connectionTimeout connection timeout in milliseconds, as used by {@link org.apache.commons.io.FileUtils#copyURLToFile(URL, File, int, int)} + * @param readTimeout read timeout in milliseconds, as used by {@link org.apache.commons.io.FileUtils#copyURLToFile(URL, File, int, int)} + * @throws IOException If an error occurs during downloading + */ + public static void download(String name, URL url, File f, String targetMD5, int maxTries, int connectionTimeout, int readTimeout) throws IOException { + download(name, url, f, targetMD5, maxTries, 0, connectionTimeout, readTimeout); + } + + private static void download(String name, URL url, File f, String targetMD5, int maxTries, int attempt, int connectionTimeout, int readTimeout) throws IOException { boolean isCorrectFile = f.exists() && f.isFile() && checkMD5OfFile(targetMD5, f); if (attempt < maxTries) { if(!isCorrectFile) { - FileUtils.copyURLToFile(url, f); + FileUtils.copyURLToFile(url, f, connectionTimeout, readTimeout); if (!checkMD5OfFile(targetMD5, f)) { f.delete(); - download(name, url, f, targetMD5, maxTries, attempt + 1); + download(name, url, f, targetMD5, maxTries, attempt + 1, connectionTimeout, readTimeout); } } } else if (!isCorrectFile) { @@ -67,6 +86,14 @@ public class Downloader { } } + /** + * As per {@link #downloadAndExtract(String, URL, File, File, String, int, int, int)} with the connection and read timeouts + * * set to their default values - {@link #DEFAULT_CONNECTION_TIMEOUT} and {@link #DEFAULT_READ_TIMEOUT} respectively + */ + public static void downloadAndExtract(String name, URL url, File f, File extractToDir, String targetMD5, int maxTries) throws IOException { + downloadAndExtract(name, url, f, extractToDir, targetMD5, maxTries, DEFAULT_CONNECTION_TIMEOUT, DEFAULT_READ_TIMEOUT); + } + /** * Download the specified URL to the specified file, verify that the MD5 matches, and then extract it to the specified directory.
* Note that the file must be an archive, with the correct file extension: .zip, .jar, .tar.gz, .tgz or .gz @@ -77,20 +104,24 @@ public class Downloader { * @param extractToDir Destination directory to extract all files * @param targetMD5 Expected MD5 for file * @param maxTries Maximum number of download attempts before failing and throwing an exception + * @param connectionTimeout connection timeout in milliseconds, as used by {@link org.apache.commons.io.FileUtils#copyURLToFile(URL, File, int, int)} + * @param readTimeout read timeout in milliseconds, as used by {@link org.apache.commons.io.FileUtils#copyURLToFile(URL, File, int, int)} * @throws IOException If an error occurs during downloading */ - public static void downloadAndExtract(String name, URL url, File f, File extractToDir, String targetMD5, int maxTries) throws IOException { - downloadAndExtract(0, maxTries, name, url, f, extractToDir, targetMD5); + public static void downloadAndExtract(String name, URL url, File f, File extractToDir, String targetMD5, int maxTries, + int connectionTimeout, int readTimeout) throws IOException { + downloadAndExtract(0, maxTries, name, url, f, extractToDir, targetMD5, connectionTimeout, readTimeout); } - private static void downloadAndExtract(int attempt, int maxTries, String name, URL url, File f, File extractToDir, String targetMD5) throws IOException { + private static void downloadAndExtract(int attempt, int maxTries, String name, URL url, File f, File extractToDir, + String targetMD5, int connectionTimeout, int readTimeout) throws IOException { boolean isCorrectFile = f.exists() && f.isFile() && checkMD5OfFile(targetMD5, f); if (attempt < maxTries) { if(!isCorrectFile) { - FileUtils.copyURLToFile(url, f); + FileUtils.copyURLToFile(url, f, connectionTimeout, readTimeout); if (!checkMD5OfFile(targetMD5, f)) { f.delete(); - downloadAndExtract(attempt + 1, maxTries, name, url, f, extractToDir, targetMD5); + downloadAndExtract(attempt + 1, maxTries, name, url, f, extractToDir, targetMD5, connectionTimeout, readTimeout); } } // try extracting @@ -99,7 +130,7 @@ public class Downloader { } catch (Throwable t){ log.warn("Error extracting {} files from file {} - retrying...", name, f.getAbsolutePath(), t); f.delete(); - downloadAndExtract(attempt + 1, maxTries, name, url, f, extractToDir, targetMD5); + downloadAndExtract(attempt + 1, maxTries, name, url, f, extractToDir, targetMD5, connectionTimeout, readTimeout); } } else if (!isCorrectFile) { //Too many attempts diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/resources/Resolver.java b/nd4j/nd4j-common/src/main/java/org/nd4j/resources/Resolver.java index 51e75dc5d..c25b7d123 100644 --- a/nd4j/nd4j-common/src/main/java/org/nd4j/resources/Resolver.java +++ b/nd4j/nd4j-common/src/main/java/org/nd4j/resources/Resolver.java @@ -69,4 +69,11 @@ public interface Resolver { */ File localCacheRoot(); + /** + * Normalize the path that may be a resource reference. + * For example: "someDir/myFile.zip.resource_reference" --> "someDir/myFile.zip" + * Returns null if the file cannot be resolved. + * If the file is not a reference, the original path is returned + */ + String normalizePath(String path); } diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/resources/Resources.java b/nd4j/nd4j-common/src/main/java/org/nd4j/resources/Resources.java index b9e8253e3..14088f7c3 100644 --- a/nd4j/nd4j-common/src/main/java/org/nd4j/resources/Resources.java +++ b/nd4j/nd4j-common/src/main/java/org/nd4j/resources/Resources.java @@ -87,6 +87,15 @@ public class Resources { INSTANCE.copyDir(directoryPath, destinationDir); } + /** + * Normalize the path that may be a resource reference. + * For example: "someDir/myFile.zip.resource_reference" --> "someDir/myFile.zip" + * Returns null if the file cannot be resolved. + * If the file is not a reference, the original path is returned + */ + public static String normalizePath(String path){ + return INSTANCE.normalize(path); + } protected boolean resourceExists(String resourcePath) { for (Resolver r : resolvers) { @@ -128,4 +137,11 @@ public class Resources { } } + public String normalize(String path){ + for(Resolver r : resolvers){ + path = r.normalizePath(path); + } + return path; + } + } diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/resources/strumpf/ResourceFile.java b/nd4j/nd4j-common/src/main/java/org/nd4j/resources/strumpf/ResourceFile.java index f8fca14b5..6a69bd3b9 100644 --- a/nd4j/nd4j-common/src/main/java/org/nd4j/resources/strumpf/ResourceFile.java +++ b/nd4j/nd4j-common/src/main/java/org/nd4j/resources/strumpf/ResourceFile.java @@ -1,5 +1,6 @@ package org.nd4j.resources.strumpf; +import org.nd4j.config.ND4JSystemProperties; import org.nd4j.shade.guava.io.Files; import lombok.AllArgsConstructor; import lombok.Data; @@ -32,6 +33,14 @@ import java.util.Map; @JsonIgnoreProperties("filePath") @Slf4j public class ResourceFile { + /** + * Default value for resource downloading connection timeout - see {@link ND4JSystemProperties#RESOURCES_CONNECTION_TIMEOUT} + */ + public static final int DEFAULT_CONNECTION_TIMEOUT = 60000; //Timeout for connections to be established + /** + * Default value for resource downloading read timeout - see {@link ND4JSystemProperties#RESOURCES_READ_TIMEOUT} + */ + public static final int DEFAULT_READ_TIMEOUT = 60000; //Timeout for amount of time between connection established and data is available protected static final String PATH_KEY = "full_remote_path"; protected static final String HASH = "_hash"; protected static final String COMPRESSED_HASH = "_compressed_hash"; @@ -146,15 +155,20 @@ public class ResourceFile { String sha256PropertyCompressed = relativePath() + COMPRESSED_HASH; - //TODO NEXT LINE IN TEMPORARY UNTIL FIXED IN STRUMPF 0.3.2 -// sha256PropertyCompressed = sha256PropertyCompressed.replaceAll("/", "\\\\"); - String sha256Compressed = v1.get(sha256PropertyCompressed); Preconditions.checkState(sha256Compressed != null, "Expected JSON property %s was not found in resource reference file %s", sha256PropertyCompressed, filePath); String sha256Property = relativePath() + HASH; String sha256Uncompressed = v1.get(sha256Property); + String connTimeoutStr = System.getProperty(ND4JSystemProperties.RESOURCES_CONNECTION_TIMEOUT); + String readTimeoutStr = System.getProperty(ND4JSystemProperties.RESOURCES_READ_TIMEOUT); + boolean validCTimeout = connTimeoutStr != null && connTimeoutStr.matches("\\d+"); + boolean validRTimeout = readTimeoutStr != null && readTimeoutStr.matches("\\d+"); + + int connectTimeout = validCTimeout ? Integer.parseInt(connTimeoutStr) : DEFAULT_CONNECTION_TIMEOUT; + int readTimeout = validRTimeout ? Integer.parseInt(readTimeoutStr) : DEFAULT_READ_TIMEOUT; + try { boolean correctHash = false; for (int tryCount = 0; tryCount < MAX_DOWNLOAD_ATTEMPTS; tryCount++) { @@ -162,7 +176,7 @@ public class ResourceFile { if (tempFile.exists()) tempFile.delete(); log.info("Downloading remote resource {} to {}", remotePath, tempFile); - FileUtils.copyURLToFile(new URL(remotePath), tempFile); + FileUtils.copyURLToFile(new URL(remotePath), tempFile, connectTimeout, readTimeout); //Now: check if downloaded archive hash is OK String hash = sha256(tempFile); correctHash = sha256Compressed.equals(hash); diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/resources/strumpf/StrumpfResolver.java b/nd4j/nd4j-common/src/main/java/org/nd4j/resources/strumpf/StrumpfResolver.java index a9d2904ed..07ab95676 100644 --- a/nd4j/nd4j-common/src/main/java/org/nd4j/resources/strumpf/StrumpfResolver.java +++ b/nd4j/nd4j-common/src/main/java/org/nd4j/resources/strumpf/StrumpfResolver.java @@ -258,6 +258,14 @@ public class StrumpfResolver implements Resolver { return cacheDir; } + @Override + public String normalizePath(@NonNull String path) { + if(path.endsWith(REF)){ + return path.substring(0, path.length()-REF.length()); + } + return path; + } + protected void assertExists(String resourcePath) { if (!exists(resourcePath)) { diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/util/ArchiveUtils.java b/nd4j/nd4j-common/src/main/java/org/nd4j/util/ArchiveUtils.java index d51d9ca9b..9883ee5c3 100644 --- a/nd4j/nd4j-common/src/main/java/org/nd4j/util/ArchiveUtils.java +++ b/nd4j/nd4j-common/src/main/java/org/nd4j/util/ArchiveUtils.java @@ -44,16 +44,30 @@ public class ArchiveUtils { } /** - * Extracts files to the specified destination + * Extracts all files from the archive to the specified destination.
+ * Note: Logs the path of all extracted files by default. Use {@link #unzipFileTo(String, String, boolean)} if + * logging is not desired.
+ * Can handle .zip, .jar, .tar.gz, .tgz, .tar, and .gz formats. + * Format is interpreted from the filename * - * @param file the file to extract to - * @param dest the destination directory - * @throws IOException + * @param file the file to extract the files from + * @param dest the destination directory. Will be created if it does not exist + * @throws IOException If an error occurs accessing the files or extracting */ public static void unzipFileTo(String file, String dest) throws IOException { unzipFileTo(file, dest, true); } + /** + * Extracts all files from the archive to the specified destination, optionally logging the extracted file path.
+ * Can handle .zip, .jar, .tar.gz, .tgz, .tar, and .gz formats. + * Format is interpreted from the filename + * + * @param file the file to extract the files from + * @param dest the destination directory. Will be created if it does not exist + * @param logFiles If true: log the path of every extracted file; if false do not log + * @throws IOException If an error occurs accessing the files or extracting + */ public static void unzipFileTo(String file, String dest, boolean logFiles) throws IOException { File target = new File(file); if (!target.exists()) diff --git a/nd4j/nd4j-jdbc/nd4j-jdbc-api/pom.xml b/nd4j/nd4j-jdbc/nd4j-jdbc-api/pom.xml index fdaf7cc89..618b44c9e 100644 --- a/nd4j/nd4j-jdbc/nd4j-jdbc-api/pom.xml +++ b/nd4j/nd4j-jdbc/nd4j-jdbc-api/pom.xml @@ -40,7 +40,7 @@ com.mchange c3p0 - 0.9.5-pre5 + 0.9.5.4 org.nd4j diff --git a/nd4j/nd4j-jdbc/nd4j-jdbc-api/src/main/java/org/nd4j/jdbc/driverfinder/DriverFinder.java b/nd4j/nd4j-jdbc/nd4j-jdbc-api/src/main/java/org/nd4j/jdbc/driverfinder/DriverFinder.java index e467035e8..4fa9cc794 100644 --- a/nd4j/nd4j-jdbc/nd4j-jdbc-api/src/main/java/org/nd4j/jdbc/driverfinder/DriverFinder.java +++ b/nd4j/nd4j-jdbc/nd4j-jdbc-api/src/main/java/org/nd4j/jdbc/driverfinder/DriverFinder.java @@ -16,6 +16,8 @@ package org.nd4j.jdbc.driverfinder; +import lombok.extern.slf4j.Slf4j; + import java.io.IOException; import java.io.InputStream; import java.sql.Driver; @@ -29,6 +31,7 @@ import java.util.Set; * * @author Adam Gibson */ +@Slf4j public class DriverFinder { public final static String ND4j_JDBC_PROPERTIES = "nd4j.jdbc.properties"; @@ -43,9 +46,9 @@ public class DriverFinder { try { driver = clazz.newInstance(); } catch (InstantiationException e) { - e.printStackTrace(); + log.error("",e); } catch (IllegalAccessException e) { - e.printStackTrace(); + log.error("",e); } } return driver; diff --git a/nd4j/nd4j-jdbc/nd4j-jdbc-hsql/src/test/java/org/nd4j/jdbc/hsql/HSqlLoaderTest.java b/nd4j/nd4j-jdbc/nd4j-jdbc-hsql/src/test/java/org/nd4j/jdbc/hsql/HSqlLoaderTest.java index e4340c3a6..a546e8da9 100644 --- a/nd4j/nd4j-jdbc/nd4j-jdbc-hsql/src/test/java/org/nd4j/jdbc/hsql/HSqlLoaderTest.java +++ b/nd4j/nd4j-jdbc/nd4j-jdbc-hsql/src/test/java/org/nd4j/jdbc/hsql/HSqlLoaderTest.java @@ -16,6 +16,7 @@ package org.nd4j.jdbc.hsql; +import lombok.extern.slf4j.Slf4j; import org.hsqldb.jdbc.JDBCDataSource; import org.junit.AfterClass; import org.junit.BeforeClass; @@ -32,6 +33,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertThat; +@Slf4j public class HSqlLoaderTest extends BaseND4JTest { private static HsqlLoader hsqlLoader; private static DataSource dataSource; @@ -114,7 +116,7 @@ public class HSqlLoaderTest extends BaseND4JTest { return result.getInt("total"); } } catch (SQLException e) { - e.printStackTrace(); + log.error("",e); } return 0; } diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/main/java/org/nd4j/parameterserver/client/ParameterServerClient.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/main/java/org/nd4j/parameterserver/client/ParameterServerClient.java index b41e0389c..d962bbfcf 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/main/java/org/nd4j/parameterserver/client/ParameterServerClient.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/main/java/org/nd4j/parameterserver/client/ParameterServerClient.java @@ -96,7 +96,7 @@ public class ParameterServerClient implements NDArrayCallback { .asJson().getBody().toString(), MasterStatus.class).getResponderN(); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); } return 0; } @@ -135,7 +135,7 @@ public class ParameterServerClient implements NDArrayCallback { .asJson().getBody().toString(), SubscriberState.class); return subscriberState.isReady(); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); } return false; } @@ -163,7 +163,7 @@ public class ParameterServerClient implements NDArrayCallback { .asJson().getBody().toString(), MasterStatus.class).started(); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); } return false; } diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/background/BackgroundDaemonStarter.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/background/BackgroundDaemonStarter.java index aa32ba514..6fd33fe75 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/background/BackgroundDaemonStarter.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/background/BackgroundDaemonStarter.java @@ -123,7 +123,7 @@ public class BackgroundDaemonStarter { .redirectOutput(System.out).destroyOnExit().redirectError(System.err).execute() .getExitValue(); } catch (TimeoutException e) { - e.printStackTrace(); + log.error("",e); } } else { List args2 = new ArrayList<>( @@ -133,7 +133,7 @@ public class BackgroundDaemonStarter { new ProcessExecutor().command(args2).destroyOnExit().readOutput(true).redirectOutput(System.out) .redirectError(System.err).execute().getExitValue(); } catch (TimeoutException e) { - e.printStackTrace(); + log.error("",e); } } diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/background/RemoteParameterServerClientTests.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/background/RemoteParameterServerClientTests.java index 6398cbab1..d94b1384a 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/background/RemoteParameterServerClientTests.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/background/RemoteParameterServerClientTests.java @@ -63,7 +63,7 @@ public class RemoteParameterServerClientTests extends BaseND4JTest { masterStatus.set( BackgroundDaemonStarter.startMaster(parameterLength, mediaDriver.aeronDirectoryName())); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); } }); @@ -73,7 +73,7 @@ public class RemoteParameterServerClientTests extends BaseND4JTest { try { slaveStatus.set(BackgroundDaemonStarter.startSlave(parameterLength, mediaDriver.aeronDirectoryName())); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); } }); t2.start(); diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/messages/intercom/DistributedAssignMessage.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/messages/intercom/DistributedAssignMessage.java index 6a96fcc45..1a4942f18 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/messages/intercom/DistributedAssignMessage.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/messages/intercom/DistributedAssignMessage.java @@ -20,7 +20,6 @@ import lombok.*; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.parameterserver.distributed.messages.BaseVoidMessage; import org.nd4j.parameterserver.distributed.messages.DistributedMessage; -import org.nd4j.parameterserver.distributed.messages.RequestMessage; /** * Assign target row to specified value diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/chunks/impl/FileChunksTracker.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/chunks/impl/FileChunksTracker.java index fc19e271f..9b37d9bb9 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/chunks/impl/FileChunksTracker.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/chunks/impl/FileChunksTracker.java @@ -18,6 +18,7 @@ package org.nd4j.parameterserver.distributed.v2.chunks.impl; import lombok.Getter; import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; import lombok.val; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.primitives.AtomicBoolean; @@ -34,6 +35,7 @@ import java.util.concurrent.ConcurrentHashMap; /** * File-based implementation of ChunksTracker */ +@Slf4j public class FileChunksTracker implements ChunksTracker { @Getter private final String originId; @@ -114,7 +116,7 @@ public class FileChunksTracker implements ChunksTracker print the usage info jcmdr.usage(); try { diff --git a/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/DoubleArraySerde.java b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/DoubleArraySerde.java index fbe44a101..3ada3fc08 100644 --- a/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/DoubleArraySerde.java +++ b/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/DoubleArraySerde.java @@ -17,13 +17,6 @@ package org.nd4j.remote.clients.serde.impl; import lombok.*; -import org.nd4j.remote.clients.serde.JsonDeserializer; -import org.nd4j.remote.clients.serde.JsonSerializer; -import org.nd4j.shade.jackson.core.JsonProcessingException; -import org.nd4j.shade.jackson.databind.ObjectMapper; - -import java.io.IOException; - /** * This class provides JSON ser/de for Java double[] */ diff --git a/nd4j/nd4j-remote/nd4j-json-server/pom.xml b/nd4j/nd4j-remote/nd4j-json-server/pom.xml index 5537216ca..004145101 100644 --- a/nd4j/nd4j-remote/nd4j-json-server/pom.xml +++ b/nd4j/nd4j-remote/nd4j-json-server/pom.xml @@ -121,7 +121,7 @@ javax.activation activation - 1.1 + 1.1.1 diff --git a/nd4j/nd4j-serde/nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/AeronNDArraySubscriber.java b/nd4j/nd4j-serde/nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/AeronNDArraySubscriber.java index f5786b9e9..04cb3a880 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/AeronNDArraySubscriber.java +++ b/nd4j/nd4j-serde/nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/AeronNDArraySubscriber.java @@ -174,7 +174,7 @@ public class AeronNDArraySubscriber implements AutoCloseable { try { subscriber.launch(); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); } }); @@ -206,7 +206,7 @@ public class AeronNDArraySubscriber implements AutoCloseable { try { subscriber.launch(); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); } }); diff --git a/nd4j/nd4j-serde/nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/response/AeronNDArrayResponder.java b/nd4j/nd4j-serde/nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/response/AeronNDArrayResponder.java index 482b269a5..ff9e826ba 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/response/AeronNDArrayResponder.java +++ b/nd4j/nd4j-serde/nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/response/AeronNDArrayResponder.java @@ -172,7 +172,7 @@ public class AeronNDArrayResponder implements AutoCloseable { try { subscriber.launch(); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); } }); @@ -210,7 +210,7 @@ public class AeronNDArrayResponder implements AutoCloseable { try { subscriber.launch(); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); } }); diff --git a/nd4j/nd4j-serde/nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/response/NDArrayResponseFragmentHandler.java b/nd4j/nd4j-serde/nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/response/NDArrayResponseFragmentHandler.java index d259530a0..f42bfe9f4 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/response/NDArrayResponseFragmentHandler.java +++ b/nd4j/nd4j-serde/nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/response/NDArrayResponseFragmentHandler.java @@ -21,6 +21,7 @@ import io.aeron.logbuffer.FragmentHandler; import io.aeron.logbuffer.Header; import lombok.AllArgsConstructor; import lombok.Builder; +import lombok.extern.slf4j.Slf4j; import org.agrona.DirectBuffer; import org.nd4j.aeron.ipc.AeronNDArrayPublisher; import org.nd4j.aeron.ipc.AeronUtil; @@ -42,6 +43,7 @@ import java.nio.ByteOrder; */ @AllArgsConstructor @Builder +@Slf4j public class NDArrayResponseFragmentHandler implements FragmentHandler { private NDArrayHolder holder; private Aeron.Context context; @@ -80,13 +82,13 @@ public class NDArrayResponseFragmentHandler implements FragmentHandler { try { publisher.publish(arrGet); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); } try { publisher.close(); } catch (Exception e) { - + log.error("",e); } } } diff --git a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/LargeNdArrayIpcTest.java b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/LargeNdArrayIpcTest.java index bd3ac9b9c..bb572a348 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/LargeNdArrayIpcTest.java +++ b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/LargeNdArrayIpcTest.java @@ -110,7 +110,7 @@ public class LargeNdArrayIpcTest extends BaseND4JTest { try { subscriber.launch(); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); } }); diff --git a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/NdArrayIpcTest.java b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/NdArrayIpcTest.java index d1a95ffde..ad8b8c16e 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/NdArrayIpcTest.java +++ b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/NdArrayIpcTest.java @@ -109,7 +109,7 @@ public class NdArrayIpcTest extends BaseND4JTest { try { subscriber.launch(); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); } }); @@ -133,7 +133,7 @@ public class NdArrayIpcTest extends BaseND4JTest { publisher.publish(arr); log.info("Sent array"); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); } }); @@ -189,7 +189,7 @@ public class NdArrayIpcTest extends BaseND4JTest { try { subscriber.launch(); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); } }); diff --git a/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/pom.xml b/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/pom.xml deleted file mode 100644 index 453867c26..000000000 --- a/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/pom.xml +++ /dev/null @@ -1,95 +0,0 @@ - - - - - nd4j-camel-routes - org.nd4j - 1.0.0-SNAPSHOT - - 4.0.0 - - nd4j-kafka_2.11 - jar - - nd4j-kafka - https://deeplearning4j.org - - - - 2.11.12 - 2.11 - - - - - org.nd4j - nd4j-api - ${project.version} - - - org.scala-lang - scala-library - ${scala.version} - - - - io.netty - netty - ${netty.version} - - - net.jpountz.lz4 - lz4 - ${lz4.version} - - - org.xerial.snappy - snappy-java - ${snappy.version} - - - org.apache.camel - camel-kafka - ${camel.version} - - - org.slf4j - slf4j-log4j12 - - - log4j - log4j - - - - - - - org.nd4j - nd4j-common-tests - ${project.version} - test - - - - - - testresources - - - diff --git a/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/test/java/org/nd4j/kafka/Nd4jKafkaRouteTest.java b/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/test/java/org/nd4j/kafka/Nd4jKafkaRouteTest.java deleted file mode 100644 index 3b40661bb..000000000 --- a/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/test/java/org/nd4j/kafka/Nd4jKafkaRouteTest.java +++ /dev/null @@ -1,83 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.kafka; - -import org.apache.camel.CamelContext; -import org.apache.camel.impl.DefaultCamelContext; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; -import org.nd4j.BaseND4JTest; -import org.nd4j.camel.kafka.KafkaConnectionInformation; -import org.nd4j.camel.kafka.Nd4jKafkaConsumer; -import org.nd4j.camel.kafka.Nd4jKafkaProducer; -import org.nd4j.camel.kafka.Nd4jKafkaRoute; -import org.nd4j.linalg.factory.Nd4j; - -import static org.junit.Assert.assertEquals; - -/** - * Created by agibsonccc on 7/19/16. - */ -public class Nd4jKafkaRouteTest extends BaseND4JTest { - private EmbeddedKafkaCluster kafka; - private EmbeddedZookeeper zk; - private CamelContext camelContext; - public final static String TOPIC = "nd4jtest"; - public final static String GROUP_ID = "nd4j"; - private KafkaConnectionInformation connectionInformation; - - - - @Before - public void before() throws Exception { - zk = new EmbeddedZookeeper(TestUtils.getAvailablePort()); - zk.startup(); - kafka = new EmbeddedKafkaCluster(zk.getConnection()); - kafka.startup(); - kafka.createTopics(TOPIC); - camelContext = new DefaultCamelContext(); - camelContext.start(); - connectionInformation = KafkaConnectionInformation.builder().groupId(GROUP_ID).topicName(TOPIC) - .zookeeperHost("localhost").zookeeperPort(zk.getPort()).kafkaBrokerList(kafka.getBrokerList()) - .build(); - camelContext.addRoutes(Nd4jKafkaRoute.builder().kafkaConnectionInformation(connectionInformation).build()); - } - - @After - public void after() throws Exception { - if (kafka != null) - kafka.shutdown(); - if (zk != null) - zk.shutdown(); - if (camelContext != null) - camelContext.stop(); - } - - - @Test - public void testKafkaRoute() throws Exception { - Nd4jKafkaProducer kafkaProducer = Nd4jKafkaProducer.builder().camelContext(camelContext) - .connectionInformation(connectionInformation).build(); - kafkaProducer.publish(Nd4j.create(4)); - Nd4jKafkaConsumer consumer = Nd4jKafkaConsumer.builder().camelContext(camelContext) - .connectionInformation(connectionInformation).build(); - assertEquals(Nd4j.create(4), consumer.receive()); - } - - -} diff --git a/nd4j/nd4j-serde/nd4j-gson/pom.xml b/nd4j/nd4j-serde/nd4j-gson/pom.xml deleted file mode 100644 index 60de01b6e..000000000 --- a/nd4j/nd4j-serde/nd4j-gson/pom.xml +++ /dev/null @@ -1,171 +0,0 @@ - - - - - - nd4j-serde - org.nd4j - 1.0.0-SNAPSHOT - - 4.0.0 - - nd4j-gson - 1.0.0-SNAPSHOT - - - org.nd4j - nd4j-api - ${project.version} - provided - - - - com.google.code.gson - gson - ${gson.version} - provided - - - - junit - junit - test - - - - - org.nd4j - nd4j-common-tests - ${project.version} - test - - - - - - testresources - - - - nd4j-tests-cpu - - false - - - - org.nd4j - nd4j-native - ${project.version} - - - - - - org.apache.maven.plugins - maven-surefire-plugin - - - ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/ - - - src/test/java - - *.java - **/*.java - **/Test*.java - **/*Test.java - **/*TestCase.java - - junit:junit - - org.nd4j.linalg.cpu.nativecpu.CpuBackend - - org.nd4j.linalg.cpu.nativecpu.CpuBackend - - - - -Ddtype=float -Dfile.encoding=UTF-8 -Xmx8g - - - - - - - - nd4j-tests-cuda - - false - - - - org.nd4j - nd4j-cuda-10.2 - ${project.version} - - - - - - org.apache.maven.plugins - maven-surefire-plugin - - - org.apache.maven.surefire - surefire-junit47 - 2.19.1 - - - - - - ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cuda/blas/ - - - src/test/java - - *.java - **/*.java - **/Test*.java - **/*Test.java - **/*TestCase.java - - junit:junit - - org.nd4j.linalg.jcublas.JCublasBackend - - org.nd4j.linalg.jcublas.JCublasBackend - - - - -Ddtype=float -Xmx6g - - - - - - - - diff --git a/nd4j/nd4j-serde/nd4j-gson/src/test/java/org/nd4j/serde/gson/GsonDeserializationUtilsTest.java b/nd4j/nd4j-serde/nd4j-gson/src/test/java/org/nd4j/serde/gson/GsonDeserializationUtilsTest.java deleted file mode 100644 index a64641fe3..000000000 --- a/nd4j/nd4j-serde/nd4j-gson/src/test/java/org/nd4j/serde/gson/GsonDeserializationUtilsTest.java +++ /dev/null @@ -1,101 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.serde.gson; - -import org.junit.Test; -import org.nd4j.BaseND4JTest; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; - -import static org.junit.Assert.assertEquals; - -public class GsonDeserializationUtilsTest extends BaseND4JTest { - @Test - public void deserializeRawJson_PassInInRank3Array_ExpectCorrectDeserialization() { - String serializedRawArray = "[[[1.00, 11.00, 3.00],\n" + "[13.00, 5.00, 15.00],\n" + "[7.00, 17.00, 9.00]]]"; - INDArray expectedArray = buildExpectedArray(1, 1, 3, 3); - - INDArray indArray = GsonDeserializationUtils.deserializeRawJson(serializedRawArray); - - assertEquals(expectedArray, indArray); - } - - @Test - public void deserializeRawJson_ArrayHasOnlyOneRowWithColumns_ExpectCorrectDeserialization() { - String serializedRawArray = "[1.00, 11.00, 3.00]"; - INDArray expectedArray = Nd4j.create(new double[] {1, 11, 3}).castTo(Nd4j.defaultFloatingPointType()); - - INDArray indArray = GsonDeserializationUtils.deserializeRawJson(serializedRawArray); - - assertEquals(expectedArray, indArray); - } - - @Test - public void deserializeRawJson_ArrayIsRankFive_ExpectCorrectDeserialization() { - String serializedRawArray = "[[[[[1.00, 11.00],\n" + " [3.00, 13.00]],\n" + " [[5.00, 15.00],\n" - + " [7.00, 17.00]]],\n" + " [[[9.00, 1.00],\n" + " [11.00, 3.00]],\n" - + " [[13.00, 5.00],\n" + " [15.00, 7.00]]],\n" + " [[[17.00, 9.00],\n" - + " [1.00, 11.00]],\n" + " [[3.00, 13.00],\n" + " [5.00, 15.00]]]],\n" - + " [[[[7.00, 17.00],\n" + " [9.00, 1.00]],\n" + " [[11.00, 3.00],\n" - + " [13.00, 5.00]]],\n" + " [[[15.00, 7.00],\n" + " [17.00, 9.00]],\n" - + " [[1.00, 11.00],\n" + " [3.00, 13.00]]],\n" + " [[[5.00, 15.00],\n" - + " [7.00, 17.00]],\n" + " [[9.00, 1.00],\n" + " [11.00, 3.00]]]],\n" - + " [[[[13.00, 5.00],\n" + " [15.00, 7.00]],\n" + " [[17.00, 9.00],\n" - + " [1.00, 11.00]]],\n" + " [[[3.00, 13.00],\n" + " [5.00, 15.00]],\n" - + " [[7.00, 17.00],\n" + " [9.00, 1.00]]],\n" + " [[[11.00, 3.00],\n" - + " [13.00, 5.00]],\n" + " [[15.00, 7.00],\n" + " [17.00, 9.00]]]]]"; - INDArray expectedArray = buildExpectedArray(8, 3, 3, 2, 2, 2); - - INDArray array = GsonDeserializationUtils.deserializeRawJson(serializedRawArray); - - assertEquals(expectedArray, array); - } - - - @Test - public void testSimpleVector() { - INDArray arr = Nd4j.linspace(1, 4, 4, Nd4j.defaultFloatingPointType()).reshape(4); - INDArray out = GsonDeserializationUtils.deserializeRawJson(arr.toString()); - assertEquals(arr, out); - } - - - - @Test - public void deserializeRawJson_HaveCommaInsideNumbers_ExpectCorrectDeserialization() { - String serializedRawArray = - "[[1.00, 1100.00, 3.00],\n" + "[13.00, 5.00, 15591.00],\n" + "[7000.00, 17.00, 9.00]]"; - INDArray expectedArray = Nd4j.create(new double[] {1, 1100, 3, 13, 5, 15591, 7000, 17, 9}, new int[] {3, 3}).castTo(Nd4j.defaultFloatingPointType()); - - INDArray indArray = GsonDeserializationUtils.deserializeRawJson(serializedRawArray); - - assertEquals(expectedArray, indArray); - } - - private INDArray buildExpectedArray(int numberOfTripletRows, int... shape) { - INDArray expectedArray = Nd4j.create(3 * numberOfTripletRows, 3); - for (int i = 0; i < numberOfTripletRows; i++) { - int index = 3 * i; - expectedArray.putRow(index, Nd4j.create(new double[] {1, 11, 3})); - expectedArray.putRow(index + 1, Nd4j.create(new double[] {13, 5, 15})); - expectedArray.putRow(index + 2, Nd4j.create(new double[] {7, 17, 9})); - } - - return expectedArray.reshape(shape); - } -} diff --git a/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/graphrunner/GraphRunner.java b/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/graphrunner/GraphRunner.java index 49861e3fe..f3cc72135 100644 --- a/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/graphrunner/GraphRunner.java +++ b/nd4j/nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/graphrunner/GraphRunner.java @@ -540,7 +540,7 @@ public class GraphRunner implements Closeable { org.tensorflow.framework.GraphDef graphDef1 = org.tensorflow.framework.GraphDef.parseFrom(graphToUse); initSessionAndStatusIfNeeded(graphDef1); } catch (org.nd4j.shade.protobuf.InvalidProtocolBufferException e) { - e.printStackTrace(); + log.error("",e); } } @@ -562,7 +562,7 @@ public class GraphRunner implements Closeable { org.tensorflow.framework.ConfigProto configProto = org.tensorflow.framework.ConfigProto.parseFrom(binaryString); return configProto; } catch (Exception e) { - e.printStackTrace(); + log.error("",e); } return null; @@ -641,7 +641,7 @@ public class GraphRunner implements Closeable { try { return org.nd4j.shade.protobuf.util.JsonFormat.printer().print(sessionOptionsConfigProto); } catch (Exception e) { - e.printStackTrace(); + log.error("",e); } return null; @@ -681,7 +681,7 @@ public class GraphRunner implements Closeable { } } catch (Exception e) { - e.printStackTrace(); + log.error("",e); } return builder1.build(); diff --git a/nd4j/nd4j-uberjar/pom.xml b/nd4j/nd4j-uberjar/pom.xml index bb0d45b5e..388e48fc5 100644 --- a/nd4j/nd4j-uberjar/pom.xml +++ b/nd4j/nd4j-uberjar/pom.xml @@ -185,16 +185,6 @@ nd4j-kryo_2.11 ${project.version} - - org.nd4j - nd4j-kafka_2.11 - ${project.version} - - - org.nd4j - nd4j-gson - ${project.version} - org.nd4j nd4j-common diff --git a/pom.xml b/pom.xml index 15af4658d..b34148656 100644 --- a/pom.xml +++ b/pom.xml @@ -271,7 +271,7 @@ 0.5.0 2.3.23 2.8.1 - 2.7.0 + 2.9.8 2.3 3.2 3.1 @@ -325,14 +325,14 @@ 3.2.2 4.1 - 2.4.3 + 2.4.5 2 2.0.29 1.7.21 4.12 1.2.3 2.10.1 - 2.10.1 + 2.10.3 1.24 2.8.7 1.18.12 @@ -353,7 +353,7 @@ 2.8.0 1.2.0-3f79e055 4.10.0 - 3.8.3 + 3.9.0 1.10.0 1.14.0 @@ -368,7 +368,7 @@ 3.7.0 3.3.1 3.0.1 - 1.0.0-beta8 + 1.0.0 2.2.2 ${maven-git-commit-plugin.version} diff --git a/rl4j/rl4j-ale/src/main/java/org/deeplearning4j/rl4j/mdp/ale/ALEMDP.java b/rl4j/rl4j-ale/src/main/java/org/deeplearning4j/rl4j/mdp/ale/ALEMDP.java index 66606de69..7ff58ef30 100644 --- a/rl4j/rl4j-ale/src/main/java/org/deeplearning4j/rl4j/mdp/ale/ALEMDP.java +++ b/rl4j/rl4j-ale/src/main/java/org/deeplearning4j/rl4j/mdp/ale/ALEMDP.java @@ -25,9 +25,13 @@ import org.bytedeco.javacpp.IntPointer; import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.space.ArrayObservationSpace; +import org.deeplearning4j.rl4j.space.Box; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.ObservationSpace; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; /** * @author saudet @@ -70,10 +74,14 @@ public class ALEMDP implements MDP { actions = new int[(int)a.limit()]; a.get(actions); + int height = (int)ale.getScreen().height(); + int width = (int)(int)ale.getScreen().width(); + discreteSpace = new DiscreteSpace(actions.length); - int[] shape = {(int)ale.getScreen().height(), (int)ale.getScreen().width(), 3}; + int[] shape = {3, height, width}; observationSpace = new ArrayObservationSpace<>(shape); screenBuffer = new byte[shape[0] * shape[1] * shape[2]]; + } public void setupGame() { @@ -103,7 +111,7 @@ public class ALEMDP implements MDP { public GameScreen reset() { ale.reset_game(); ale.getScreenRGB(screenBuffer); - return new GameScreen(screenBuffer); + return new GameScreen(observationSpace.getShape(), screenBuffer); } @@ -115,7 +123,8 @@ public class ALEMDP implements MDP { double r = ale.act(actions[action]) * scaleFactor; log.info(ale.getEpisodeFrameNumber() + " " + r + " " + action + " "); ale.getScreenRGB(screenBuffer); - return new StepReply(new GameScreen(screenBuffer), r, ale.game_over(), null); + + return new StepReply(new GameScreen(observationSpace.getShape(), screenBuffer), r, ale.game_over(), null); } public ObservationSpace getObservationSpace() { @@ -140,17 +149,35 @@ public class ALEMDP implements MDP { } public static class GameScreen implements Encodable { - double[] array; - public GameScreen(byte[] screen) { - array = new double[screen.length]; - for (int i = 0; i < screen.length; i++) { - array[i] = (screen[i] & 0xFF) / 255.0; - } + final INDArray data; + public GameScreen(int[] shape, byte[] screen) { + + data = Nd4j.create(screen, new long[] {shape[1], shape[2], 3}, DataType.UINT8).permute(2,0,1); } + private GameScreen(INDArray toDup) { + data = toDup.dup(); + } + + @Override public double[] toArray() { - return array; + return data.data().asDouble(); + } + + @Override + public boolean isSkipped() { + return false; + } + + @Override + public INDArray getData() { + return data; + } + + @Override + public GameScreen dup() { + return new GameScreen(data); } } } diff --git a/rl4j/rl4j-api/src/main/java/org/deeplearning4j/gym/StepReply.java b/rl4j/rl4j-api/src/main/java/org/deeplearning4j/gym/StepReply.java index ab054689a..e37750d72 100644 --- a/rl4j/rl4j-api/src/main/java/org/deeplearning4j/gym/StepReply.java +++ b/rl4j/rl4j-api/src/main/java/org/deeplearning4j/gym/StepReply.java @@ -19,15 +19,15 @@ package org.deeplearning4j.gym; import lombok.Value; /** - * @param type of observation + * @param type of observation * @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/6/16. * * StepReply is the container for the data returned after each step(action). */ @Value -public class StepReply { +public class StepReply { - T observation; + OBSERVATION observation; double reward; boolean done; Object info; diff --git a/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/mdp/MDP.java b/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/mdp/MDP.java index 37b097dbf..e911a7acc 100644 --- a/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/mdp/MDP.java +++ b/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/mdp/MDP.java @@ -32,7 +32,7 @@ import org.deeplearning4j.rl4j.space.ObservationSpace; * in a "functionnal manner" if step return a mdp * */ -public interface MDP> { +public interface MDP> { ObservationSpace getObservationSpace(); diff --git a/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/Box.java b/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/Box.java index e90601fda..3bc242fea 100644 --- a/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/Box.java +++ b/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/Box.java @@ -16,6 +16,9 @@ package org.deeplearning4j.rl4j.space; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/8/16. * @@ -25,13 +28,37 @@ package org.deeplearning4j.rl4j.space; */ public class Box implements Encodable { - private final double[] array; + private final INDArray data; - public Box(double[] arr) { - this.array = arr; + public Box(double... arr) { + this.data = Nd4j.create(arr); } + public Box(int[] shape, double... arr) { + this.data = Nd4j.create(arr).reshape(shape); + } + + private Box(INDArray toDup) { + data = toDup.dup(); + } + + @Override public double[] toArray() { - return array; + return data.data().asDouble(); + } + + @Override + public boolean isSkipped() { + return false; + } + + @Override + public INDArray getData() { + return data; + } + + @Override + public Encodable dup() { + return new Box(data); } } diff --git a/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/Encodable.java b/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/Encodable.java index 04b0c22af..bfec24f68 100644 --- a/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/Encodable.java +++ b/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/Encodable.java @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K. K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -16,17 +16,19 @@ package org.deeplearning4j.rl4j.space; -/** - * @author rubenfiszel (ruben.fiszel@epfl.ch) on 7/19/16. - * Encodable is an interface that ensure that the state is convertible to a double array - */ +import org.nd4j.linalg.api.ndarray.INDArray; + public interface Encodable { - /** - * $ - * encodes all the information of an Observation in an array double and can be used as input of a DQN directly - * - * @return the encoded informations - */ + @Deprecated double[] toArray(); + + boolean isSkipped(); + + /** + * Any image data should be in CHW format. + */ + INDArray getData(); + + Encodable dup(); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/Agent.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/Agent.java new file mode 100644 index 000000000..1b4a2699d --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/Agent.java @@ -0,0 +1,210 @@ +package org.deeplearning4j.rl4j.agent; + +import lombok.AccessLevel; +import lombok.Getter; +import lombok.NonNull; +import org.deeplearning4j.rl4j.agent.listener.AgentListener; +import org.deeplearning4j.rl4j.agent.listener.AgentListenerList; +import org.deeplearning4j.rl4j.environment.Environment; +import org.deeplearning4j.rl4j.environment.StepResult; +import org.deeplearning4j.rl4j.observation.Observation; +import org.deeplearning4j.rl4j.observation.transform.TransformProcess; +import org.deeplearning4j.rl4j.policy.IPolicy; +import org.nd4j.base.Preconditions; + +import java.util.Map; + +public class Agent { + @Getter + private final String id; + + @Getter + private final Environment environment; + + @Getter + private final IPolicy policy; + + private final TransformProcess transformProcess; + + protected final AgentListenerList listeners; + + private final Integer maxEpisodeSteps; + + @Getter(AccessLevel.PROTECTED) + private Observation observation; + + @Getter(AccessLevel.PROTECTED) + private ACTION lastAction; + + @Getter + private int episodeStepNumber; + + @Getter + private double reward; + + protected boolean canContinue; + + private Agent(Builder builder) { + this.environment = builder.environment; + this.transformProcess = builder.transformProcess; + this.policy = builder.policy; + this.maxEpisodeSteps = builder.maxEpisodeSteps; + this.id = builder.id; + + listeners = buildListenerList(); + } + + protected AgentListenerList buildListenerList() { + return new AgentListenerList(); + } + + public void addListener(AgentListener listener) { + listeners.add(listener); + } + + public void run() { + runEpisode(); + } + + protected void onBeforeEpisode() { + // Do Nothing + } + + protected void onAfterEpisode() { + // Do Nothing + } + + protected void runEpisode() { + reset(); + onBeforeEpisode(); + + canContinue = listeners.notifyBeforeEpisode(this); + + while (canContinue && !environment.isEpisodeFinished() && (maxEpisodeSteps == null || episodeStepNumber < maxEpisodeSteps)) { + performStep(); + } + + if(!canContinue) { + return; + } + + onAfterEpisode(); + } + + protected void reset() { + resetEnvironment(); + resetPolicy(); + reward = 0; + lastAction = getInitialAction(); + canContinue = true; + } + + protected void resetEnvironment() { + episodeStepNumber = 0; + Map channelsData = environment.reset(); + this.observation = transformProcess.transform(channelsData, episodeStepNumber, false); + } + + protected void resetPolicy() { + policy.reset(); + } + + protected ACTION getInitialAction() { + return environment.getSchema().getActionSchema().getNoOp(); + } + + protected void performStep() { + + onBeforeStep(); + + ACTION action = decideAction(observation); + + canContinue = listeners.notifyBeforeStep(this, observation, action); + if(!canContinue) { + return; + } + + StepResult stepResult = act(action); + handleStepResult(stepResult); + + onAfterStep(stepResult); + + canContinue = listeners.notifyAfterStep(this, stepResult); + if(!canContinue) { + return; + } + + incrementEpisodeStepNumber(); + } + + protected void incrementEpisodeStepNumber() { + ++episodeStepNumber; + } + + protected ACTION decideAction(Observation observation) { + if (!observation.isSkipped()) { + lastAction = policy.nextAction(observation); + } + + return lastAction; + } + + protected StepResult act(ACTION action) { + return environment.step(action); + } + + protected void handleStepResult(StepResult stepResult) { + observation = convertChannelDataToObservation(stepResult, episodeStepNumber + 1); + reward +=computeReward(stepResult); + } + + protected Observation convertChannelDataToObservation(StepResult stepResult, int episodeStepNumberOfObs) { + return transformProcess.transform(stepResult.getChannelsData(), episodeStepNumberOfObs, stepResult.isTerminal()); + } + + protected double computeReward(StepResult stepResult) { + return stepResult.getReward(); + } + + protected void onAfterStep(StepResult stepResult) { + // Do Nothing + } + + protected void onBeforeStep() { + // Do Nothing + } + + public static Builder builder(@NonNull Environment environment, @NonNull TransformProcess transformProcess, @NonNull IPolicy policy) { + return new Builder<>(environment, transformProcess, policy); + } + + public static class Builder { + private final Environment environment; + private final TransformProcess transformProcess; + private final IPolicy policy; + private Integer maxEpisodeSteps = null; // Default, no max + private String id; + + public Builder(@NonNull Environment environment, @NonNull TransformProcess transformProcess, @NonNull IPolicy policy) { + this.environment = environment; + this.transformProcess = transformProcess; + this.policy = policy; + } + + public Builder maxEpisodeSteps(int maxEpisodeSteps) { + Preconditions.checkArgument(maxEpisodeSteps > 0, "maxEpisodeSteps must be greater than 0, got", maxEpisodeSteps); + this.maxEpisodeSteps = maxEpisodeSteps; + + return this; + } + + public Builder id(String id) { + this.id = id; + return this; + } + + public Agent build() { + return new Agent(this); + } + } +} \ No newline at end of file diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/listener/AgentListener.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/listener/AgentListener.java new file mode 100644 index 000000000..898f89241 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/listener/AgentListener.java @@ -0,0 +1,23 @@ +package org.deeplearning4j.rl4j.agent.listener; + +import org.deeplearning4j.rl4j.agent.Agent; +import org.deeplearning4j.rl4j.environment.StepResult; +import org.deeplearning4j.rl4j.observation.Observation; + +public interface AgentListener { + enum ListenerResponse { + /** + * Tell the learning process to continue calling the listeners and the training. + */ + CONTINUE, + + /** + * Tell the learning process to stop calling the listeners and terminate the training. + */ + STOP, + } + + AgentListener.ListenerResponse onBeforeEpisode(Agent agent); + AgentListener.ListenerResponse onBeforeStep(Agent agent, Observation observation, ACTION action); + AgentListener.ListenerResponse onAfterStep(Agent agent, StepResult stepResult); +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/listener/AgentListenerList.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/listener/AgentListenerList.java new file mode 100644 index 000000000..e003934d4 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/listener/AgentListenerList.java @@ -0,0 +1,50 @@ +package org.deeplearning4j.rl4j.agent.listener; + +import org.deeplearning4j.rl4j.agent.Agent; +import org.deeplearning4j.rl4j.environment.StepResult; +import org.deeplearning4j.rl4j.observation.Observation; + +import java.util.ArrayList; +import java.util.List; + +public class AgentListenerList { + protected final List> listeners = new ArrayList<>(); + + /** + * Add a listener at the end of the list + * @param listener The listener to be added + */ + public void add(AgentListener listener) { + listeners.add(listener); + } + + public boolean notifyBeforeEpisode(Agent agent) { + for (AgentListener listener : listeners) { + if (listener.onBeforeEpisode(agent) == AgentListener.ListenerResponse.STOP) { + return false; + } + } + + return true; + } + + public boolean notifyBeforeStep(Agent agent, Observation observation, ACTION action) { + for (AgentListener listener : listeners) { + if (listener.onBeforeStep(agent, observation, action) == AgentListener.ListenerResponse.STOP) { + return false; + } + } + + return true; + } + + public boolean notifyAfterStep(Agent agent, StepResult stepResult) { + for (AgentListener listener : listeners) { + if (listener.onAfterStep(agent, stepResult) == AgentListener.ListenerResponse.STOP) { + return false; + } + } + + return true; + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/ActionSchema.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/ActionSchema.java new file mode 100644 index 000000000..f6521e734 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/ActionSchema.java @@ -0,0 +1,9 @@ +package org.deeplearning4j.rl4j.environment; + +import lombok.Value; + +@Value +public class ActionSchema { + private ACTION noOp; + //FIXME ACTION randomAction(); +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/Environment.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/Environment.java new file mode 100644 index 000000000..95ff7d2b6 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/Environment.java @@ -0,0 +1,11 @@ +package org.deeplearning4j.rl4j.environment; + +import java.util.Map; + +public interface Environment { + Schema getSchema(); + Map reset(); + StepResult step(ACTION action); + boolean isEpisodeFinished(); + void close(); +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/Schema.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/Schema.java new file mode 100644 index 000000000..5ddea24cd --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/Schema.java @@ -0,0 +1,8 @@ +package org.deeplearning4j.rl4j.environment; + +import lombok.Value; + +@Value +public class Schema { + private ActionSchema actionSchema; +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/StepResult.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/StepResult.java new file mode 100644 index 000000000..b64dd08f5 --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/StepResult.java @@ -0,0 +1,12 @@ +package org.deeplearning4j.rl4j.environment; + +import lombok.Value; + +import java.util.Map; + +@Value +public class StepResult { + private Map channelsData; + private double reward; + private boolean terminal; +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/helper/INDArrayHelper.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/helper/INDArrayHelper.java index 2e608db19..b42a7c503 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/helper/INDArrayHelper.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/helper/INDArrayHelper.java @@ -24,16 +24,17 @@ import org.nd4j.linalg.factory.Nd4j; * @author Alexandre Boulanger */ public class INDArrayHelper { + /** - * MultiLayerNetwork and ComputationGraph expect the first dimension to be the number of examples in the INDArray. - * In the case of RL4J, it must be 1. This method will return a INDArray with the correct shape. + * MultiLayerNetwork and ComputationGraph expects input data to be in NCHW in the case of pixels and NS in case of other data types. * - * @param source A INDArray - * @return The source INDArray with the correct shape + * We must have either shape 2 (NK) or shape 4 (NCHW) */ public static INDArray forceCorrectShape(INDArray source) { + return source.shape()[0] == 1 && source.shape().length > 1 ? source : Nd4j.expandDims(source, 0); + } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/HistoryProcessor.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/HistoryProcessor.java index 550c6eb70..f3516af50 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/HistoryProcessor.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/HistoryProcessor.java @@ -46,7 +46,6 @@ public class HistoryProcessor implements IHistoryProcessor { @Getter final private Configuration conf; - final private OpenCVFrameConverter openCVFrameConverter = new OpenCVFrameConverter.ToMat(); private CircularFifoQueue history; private VideoRecorder videoRecorder; @@ -63,8 +62,7 @@ public class HistoryProcessor implements IHistoryProcessor { public void startMonitor(String filename, int[] shape) { if(videoRecorder == null) { - videoRecorder = VideoRecorder.builder(shape[0], shape[1]) - .frameInputType(VideoRecorder.FrameInputTypes.Float) + videoRecorder = VideoRecorder.builder(shape[1], shape[2]) .build(); } @@ -89,14 +87,13 @@ public class HistoryProcessor implements IHistoryProcessor { return videoRecorder != null && videoRecorder.isRecording(); } - public void record(INDArray raw) { + public void record(INDArray pixelArray) { if(isMonitoring()) { // before accessing the raw pointer, we need to make sure that array is actual on the host side - Nd4j.getAffinityManager().ensureLocation(raw, AffinityManager.Location.HOST); + Nd4j.getAffinityManager().ensureLocation(pixelArray, AffinityManager.Location.HOST); - VideoRecorder.VideoFrame frame = videoRecorder.createFrame(raw.data().pointer()); try { - videoRecorder.record(frame); + videoRecorder.record(pixelArray); } catch (Exception e) { e.printStackTrace(); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/IHistoryProcessor.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/IHistoryProcessor.java index a8a09bc0b..6bd74fd28 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/IHistoryProcessor.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/IHistoryProcessor.java @@ -64,7 +64,7 @@ public interface IHistoryProcessor { @Builder.Default int skipFrame = 4; public int[] getShape() { - return new int[] {getHistoryLength(), getCroppingHeight(), getCroppingWidth()}; + return new int[] {getHistoryLength(), getRescaledHeight(), getRescaledWidth()}; } } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java index 0d1f0ae20..0d1b5bea2 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java @@ -28,9 +28,9 @@ import org.deeplearning4j.rl4j.space.Encodable; * * A common interface that any training method should implement */ -public interface ILearning> { +public interface ILearning> { - IPolicy getPolicy(); + IPolicy getPolicy(); void train(); @@ -38,7 +38,7 @@ public interface ILearning> { ILearningConfiguration getConfiguration(); - MDP getMdp(); + MDP getMdp(); IHistoryProcessor getHistoryProcessor(); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java index ca9451ea2..ba88454a7 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java @@ -21,7 +21,6 @@ import lombok.Getter; import lombok.Setter; import lombok.Value; import lombok.extern.slf4j.Slf4j; -import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.Encodable; @@ -38,8 +37,8 @@ import org.nd4j.linalg.factory.Nd4j; * */ @Slf4j -public abstract class Learning, NN extends NeuralNet> - implements ILearning, NeuralNetFetchable { +public abstract class Learning, NN extends NeuralNet> + implements ILearning, NeuralNetFetchable { @Getter @Setter protected int stepCount = 0; diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThread.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThread.java index 864683d79..26d8d5e02 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThread.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThread.java @@ -29,10 +29,10 @@ import org.deeplearning4j.rl4j.learning.listener.TrainingListener; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.NeuralNet; +import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.space.ActionSpace; -import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.util.IDataManager; import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; import org.nd4j.linalg.factory.Nd4j; @@ -188,7 +188,7 @@ public abstract class AsyncThread getAsyncGlobal(); - protected abstract IAsyncLearningConfiguration getConf(); + protected abstract IAsyncLearningConfiguration getConfiguration(); - protected abstract IPolicy getPolicy(NN net); + protected abstract IPolicy getPolicy(NN net); protected abstract SubEpochReturn trainSubEpoch(Observation obs, int nstep); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java index fcce92a4a..c32be6906 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java @@ -24,29 +24,22 @@ import lombok.Setter; import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.rl4j.experience.ExperienceHandler; import org.deeplearning4j.rl4j.experience.StateActionExperienceHandler; -import org.deeplearning4j.rl4j.experience.ExperienceHandler; -import org.deeplearning4j.rl4j.experience.StateActionExperienceHandler; -import org.deeplearning4j.rl4j.experience.StateActionPair; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.NeuralNet; +import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.space.DiscreteSpace; -import org.deeplearning4j.rl4j.space.Encodable; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; - -import java.util.Stack; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. *

* Async Learning specialized for the Discrete Domain */ -public abstract class AsyncThreadDiscrete - extends AsyncThread { +public abstract class AsyncThreadDiscrete + extends AsyncThread { @Getter private NN current; @@ -59,7 +52,7 @@ public abstract class AsyncThreadDiscrete asyncGlobal, - MDP mdp, + MDP mdp, TrainingListenerList listeners, int threadNumber, int deviceNum) { @@ -97,7 +90,7 @@ public abstract class AsyncThreadDiscrete policy = getPolicy(current); + IPolicy policy = getPolicy(current); Integer action = getMdp().getActionSpace().noOp(); @@ -112,7 +105,7 @@ public abstract class AsyncThreadDiscrete stepReply = getLegacyMDPWrapper().step(action); - accuReward += stepReply.getReward() * getConf().getRewardFactor(); + accuReward += stepReply.getReward() * getConfiguration().getRewardFactor(); if (!obs.isSkipped()) { experienceHandler.addExperience(obs, action, accuReward, stepReply.isDone()); @@ -126,7 +119,7 @@ public abstract class AsyncThreadDiscrete extends AsyncLearning { +public abstract class A3CDiscrete extends AsyncLearning { @Getter final public A3CLearningConfiguration configuration; @Getter - final protected MDP mdp; + final protected MDP mdp; final private IActorCritic iActorCritic; @Getter final private AsyncGlobal asyncGlobal; @Getter - final private ACPolicy policy; + final private ACPolicy policy; - public A3CDiscrete(MDP mdp, IActorCritic iActorCritic, A3CLearningConfiguration conf) { + public A3CDiscrete(MDP mdp, IActorCritic iActorCritic, A3CLearningConfiguration conf) { this.iActorCritic = iActorCritic; this.mdp = mdp; this.configuration = conf; diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteConv.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteConv.java index 17c6b8da8..08fec8a94 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteConv.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteConv.java @@ -25,8 +25,8 @@ import org.deeplearning4j.rl4j.network.ac.ActorCriticFactoryCompGraph; import org.deeplearning4j.rl4j.network.ac.ActorCriticFactoryCompGraphStdConv; import org.deeplearning4j.rl4j.network.ac.IActorCritic; import org.deeplearning4j.rl4j.network.configuration.ActorCriticNetworkConfiguration; -import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; +import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.util.DataManagerTrainingListener; import org.deeplearning4j.rl4j.util.IDataManager; @@ -42,19 +42,19 @@ import org.deeplearning4j.rl4j.util.IDataManager; * first layers since they're essentially doing the same dimension * reduction task **/ -public class A3CDiscreteConv extends A3CDiscrete { +public class A3CDiscreteConv extends A3CDiscrete { final private HistoryProcessor.Configuration hpconf; @Deprecated - public A3CDiscreteConv(MDP mdp, IActorCritic actorCritic, + public A3CDiscreteConv(MDP mdp, IActorCritic actorCritic, HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) { this(mdp, actorCritic, hpconf, conf); addListener(new DataManagerTrainingListener(dataManager)); } @Deprecated - public A3CDiscreteConv(MDP mdp, IActorCritic IActorCritic, + public A3CDiscreteConv(MDP mdp, IActorCritic IActorCritic, HistoryProcessor.Configuration hpconf, A3CConfiguration conf) { super(mdp, IActorCritic, conf.toLearningConfiguration()); @@ -62,7 +62,7 @@ public class A3CDiscreteConv extends A3CDiscrete { setHistoryProcessor(hpconf); } - public A3CDiscreteConv(MDP mdp, IActorCritic IActorCritic, + public A3CDiscreteConv(MDP mdp, IActorCritic IActorCritic, HistoryProcessor.Configuration hpconf, A3CLearningConfiguration conf) { super(mdp, IActorCritic, conf); this.hpconf = hpconf; @@ -70,35 +70,35 @@ public class A3CDiscreteConv extends A3CDiscrete { } @Deprecated - public A3CDiscreteConv(MDP mdp, ActorCriticFactoryCompGraph factory, + public A3CDiscreteConv(MDP mdp, ActorCriticFactoryCompGraph factory, HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) { this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager); } @Deprecated - public A3CDiscreteConv(MDP mdp, ActorCriticFactoryCompGraph factory, + public A3CDiscreteConv(MDP mdp, ActorCriticFactoryCompGraph factory, HistoryProcessor.Configuration hpconf, A3CConfiguration conf) { this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf); } - public A3CDiscreteConv(MDP mdp, ActorCriticFactoryCompGraph factory, + public A3CDiscreteConv(MDP mdp, ActorCriticFactoryCompGraph factory, HistoryProcessor.Configuration hpconf, A3CLearningConfiguration conf) { this(mdp, factory.buildActorCritic(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf); } @Deprecated - public A3CDiscreteConv(MDP mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf, + public A3CDiscreteConv(MDP mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf, HistoryProcessor.Configuration hpconf, A3CConfiguration conf, IDataManager dataManager) { this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf.toNetworkConfiguration()), hpconf, conf, dataManager); } @Deprecated - public A3CDiscreteConv(MDP mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf, + public A3CDiscreteConv(MDP mdp, ActorCriticFactoryCompGraphStdConv.Configuration netConf, HistoryProcessor.Configuration hpconf, A3CConfiguration conf) { this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf.toNetworkConfiguration()), hpconf, conf); } - public A3CDiscreteConv(MDP mdp, ActorCriticNetworkConfiguration netConf, + public A3CDiscreteConv(MDP mdp, ActorCriticNetworkConfiguration netConf, HistoryProcessor.Configuration hpconf, A3CLearningConfiguration conf) { this(mdp, new ActorCriticFactoryCompGraphStdConv(netConf), hpconf, conf); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteDense.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteDense.java index 74332bf3a..5fd68f571 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteDense.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteDense.java @@ -21,8 +21,8 @@ import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.ac.*; import org.deeplearning4j.rl4j.network.configuration.ActorCriticDenseNetworkConfiguration; -import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; +import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.util.DataManagerTrainingListener; import org.deeplearning4j.rl4j.util.IDataManager; @@ -34,74 +34,74 @@ import org.deeplearning4j.rl4j.util.IDataManager; * We use specifically the Separate version because * the model is too small to have enough benefit by sharing layers */ -public class A3CDiscreteDense extends A3CDiscrete { +public class A3CDiscreteDense extends A3CDiscrete { @Deprecated - public A3CDiscreteDense(MDP mdp, IActorCritic IActorCritic, A3CConfiguration conf, + public A3CDiscreteDense(MDP mdp, IActorCritic IActorCritic, A3CConfiguration conf, IDataManager dataManager) { this(mdp, IActorCritic, conf); addListener(new DataManagerTrainingListener(dataManager)); } @Deprecated - public A3CDiscreteDense(MDP mdp, IActorCritic actorCritic, A3CConfiguration conf) { + public A3CDiscreteDense(MDP mdp, IActorCritic actorCritic, A3CConfiguration conf) { super(mdp, actorCritic, conf.toLearningConfiguration()); } - public A3CDiscreteDense(MDP mdp, IActorCritic actorCritic, A3CLearningConfiguration conf) { + public A3CDiscreteDense(MDP mdp, IActorCritic actorCritic, A3CLearningConfiguration conf) { super(mdp, actorCritic, conf); } @Deprecated - public A3CDiscreteDense(MDP mdp, ActorCriticFactorySeparate factory, + public A3CDiscreteDense(MDP mdp, ActorCriticFactorySeparate factory, A3CConfiguration conf, IDataManager dataManager) { this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, dataManager); } @Deprecated - public A3CDiscreteDense(MDP mdp, ActorCriticFactorySeparate factory, + public A3CDiscreteDense(MDP mdp, ActorCriticFactorySeparate factory, A3CConfiguration conf) { this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); } - public A3CDiscreteDense(MDP mdp, ActorCriticFactorySeparate factory, + public A3CDiscreteDense(MDP mdp, ActorCriticFactorySeparate factory, A3CLearningConfiguration conf) { this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); } @Deprecated - public A3CDiscreteDense(MDP mdp, + public A3CDiscreteDense(MDP mdp, ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf, IDataManager dataManager) { this(mdp, new ActorCriticFactorySeparateStdDense(netConf.toNetworkConfiguration()), conf, dataManager); } @Deprecated - public A3CDiscreteDense(MDP mdp, + public A3CDiscreteDense(MDP mdp, ActorCriticFactorySeparateStdDense.Configuration netConf, A3CConfiguration conf) { this(mdp, new ActorCriticFactorySeparateStdDense(netConf.toNetworkConfiguration()), conf); } - public A3CDiscreteDense(MDP mdp, + public A3CDiscreteDense(MDP mdp, ActorCriticDenseNetworkConfiguration netConf, A3CLearningConfiguration conf) { this(mdp, new ActorCriticFactorySeparateStdDense(netConf), conf); } @Deprecated - public A3CDiscreteDense(MDP mdp, ActorCriticFactoryCompGraph factory, + public A3CDiscreteDense(MDP mdp, ActorCriticFactoryCompGraph factory, A3CConfiguration conf, IDataManager dataManager) { this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, dataManager); } @Deprecated - public A3CDiscreteDense(MDP mdp, ActorCriticFactoryCompGraph factory, + public A3CDiscreteDense(MDP mdp, ActorCriticFactoryCompGraph factory, A3CConfiguration conf) { this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); } - public A3CDiscreteDense(MDP mdp, ActorCriticFactoryCompGraph factory, + public A3CDiscreteDense(MDP mdp, ActorCriticFactoryCompGraph factory, A3CLearningConfiguration conf) { this(mdp, factory.buildActorCritic(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java index adf68489e..123680a38 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java @@ -23,23 +23,23 @@ import org.deeplearning4j.rl4j.learning.configuration.A3CLearningConfiguration; import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.ac.IActorCritic; +import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.policy.ACPolicy; import org.deeplearning4j.rl4j.policy.Policy; import org.deeplearning4j.rl4j.space.DiscreteSpace; -import org.deeplearning4j.rl4j.space.Encodable; import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.api.rng.Random; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/23/16. - * + *

* Local thread as described in the https://arxiv.org/abs/1602.01783 paper. */ -public class A3CThreadDiscrete extends AsyncThreadDiscrete { +public class A3CThreadDiscrete extends AsyncThreadDiscrete { @Getter - final protected A3CLearningConfiguration conf; + final protected A3CLearningConfiguration configuration; @Getter final protected IAsyncGlobal asyncGlobal; @Getter @@ -47,17 +47,17 @@ public class A3CThreadDiscrete extends AsyncThreadDiscrete< final private Random rnd; - public A3CThreadDiscrete(MDP mdp, IAsyncGlobal asyncGlobal, + public A3CThreadDiscrete(MDP mdp, IAsyncGlobal asyncGlobal, A3CLearningConfiguration a3cc, int deviceNum, TrainingListenerList listeners, int threadNumber) { super(asyncGlobal, mdp, listeners, threadNumber, deviceNum); - this.conf = a3cc; + this.configuration = a3cc; this.asyncGlobal = asyncGlobal; this.threadNumber = threadNumber; - Long seed = conf.getSeed(); + Long seed = configuration.getSeed(); rnd = Nd4j.getRandom(); - if(seed != null) { + if (seed != null) { rnd.setSeed(seed + threadNumber); } @@ -65,13 +65,16 @@ public class A3CThreadDiscrete extends AsyncThreadDiscrete< } @Override - protected Policy getPolicy(IActorCritic net) { + protected Policy getPolicy(IActorCritic net) { return new ACPolicy(net, rnd); } + /** + * calc the gradients based on the n-step rewards + */ @Override protected UpdateAlgorithm buildUpdateAlgorithm() { int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape(); - return new AdvantageActorCriticUpdateAlgorithm(asyncGlobal.getTarget().isRecurrent(), shape, getMdp().getActionSpace().getSize(), conf.getGamma()); + return new AdvantageActorCriticUpdateAlgorithm(asyncGlobal.getTarget().isRecurrent(), shape, getMdp().getActionSpace().getSize(), configuration.getGamma()); } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java index a4c0b643b..8a302d2d9 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java @@ -28,26 +28,26 @@ import org.deeplearning4j.rl4j.learning.async.AsyncThread; import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguration; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.dqn.IDQN; +import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.policy.DQNPolicy; import org.deeplearning4j.rl4j.policy.IPolicy; import org.deeplearning4j.rl4j.space.DiscreteSpace; -import org.deeplearning4j.rl4j.space.Encodable; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. */ -public abstract class AsyncNStepQLearningDiscrete - extends AsyncLearning { +public abstract class AsyncNStepQLearningDiscrete + extends AsyncLearning { @Getter final public AsyncQLearningConfiguration configuration; @Getter - final private MDP mdp; + final private MDP mdp; @Getter final private AsyncGlobal asyncGlobal; - public AsyncNStepQLearningDiscrete(MDP mdp, IDQN dqn, AsyncQLearningConfiguration conf) { + public AsyncNStepQLearningDiscrete(MDP mdp, IDQN dqn, AsyncQLearningConfiguration conf) { this.mdp = mdp; this.configuration = conf; this.asyncGlobal = new AsyncGlobal<>(dqn, conf); @@ -62,8 +62,8 @@ public abstract class AsyncNStepQLearningDiscrete return asyncGlobal.getTarget(); } - public IPolicy getPolicy() { - return new DQNPolicy(getNeuralNet()); + public IPolicy getPolicy() { + return new DQNPolicy(getNeuralNet()); } @Data diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteConv.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteConv.java index f92b704b6..3f12a60ad 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteConv.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteConv.java @@ -25,8 +25,8 @@ import org.deeplearning4j.rl4j.network.configuration.NetworkConfiguration; import org.deeplearning4j.rl4j.network.dqn.DQNFactory; import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv; import org.deeplearning4j.rl4j.network.dqn.IDQN; -import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; +import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.util.DataManagerTrainingListener; import org.deeplearning4j.rl4j.util.IDataManager; @@ -35,17 +35,17 @@ import org.deeplearning4j.rl4j.util.IDataManager; * Specialized constructors for the Conv (pixels input) case * Specialized conf + provide additional type safety */ -public class AsyncNStepQLearningDiscreteConv extends AsyncNStepQLearningDiscrete { +public class AsyncNStepQLearningDiscreteConv extends AsyncNStepQLearningDiscrete { final private HistoryProcessor.Configuration hpconf; @Deprecated - public AsyncNStepQLearningDiscreteConv(MDP mdp, IDQN dqn, + public AsyncNStepQLearningDiscreteConv(MDP mdp, IDQN dqn, HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) { this(mdp, dqn, hpconf, conf); addListener(new DataManagerTrainingListener(dataManager)); } - public AsyncNStepQLearningDiscreteConv(MDP mdp, IDQN dqn, + public AsyncNStepQLearningDiscreteConv(MDP mdp, IDQN dqn, HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf) { super(mdp, dqn, conf); this.hpconf = hpconf; @@ -53,21 +53,21 @@ public class AsyncNStepQLearningDiscreteConv extends AsyncN } @Deprecated - public AsyncNStepQLearningDiscreteConv(MDP mdp, DQNFactory factory, - HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) { + public AsyncNStepQLearningDiscreteConv(MDP mdp, DQNFactory factory, + HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) { this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager); } - public AsyncNStepQLearningDiscreteConv(MDP mdp, DQNFactory factory, + public AsyncNStepQLearningDiscreteConv(MDP mdp, DQNFactory factory, HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf) { this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf); } @Deprecated - public AsyncNStepQLearningDiscreteConv(MDP mdp, NetworkConfiguration netConf, - HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) { + public AsyncNStepQLearningDiscreteConv(MDP mdp, NetworkConfiguration netConf, + HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf, IDataManager dataManager) { this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf, dataManager); } - public AsyncNStepQLearningDiscreteConv(MDP mdp, NetworkConfiguration netConf, + public AsyncNStepQLearningDiscreteConv(MDP mdp, NetworkConfiguration netConf, HistoryProcessor.Configuration hpconf, AsyncQLearningConfiguration conf) { this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteDense.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteDense.java index b6216e849..a94eba7a4 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteDense.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteDense.java @@ -24,65 +24,65 @@ import org.deeplearning4j.rl4j.network.configuration.DQNDenseNetworkConfiguratio import org.deeplearning4j.rl4j.network.dqn.DQNFactory; import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense; import org.deeplearning4j.rl4j.network.dqn.IDQN; -import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; +import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.util.DataManagerTrainingListener; import org.deeplearning4j.rl4j.util.IDataManager; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/7/16. */ -public class AsyncNStepQLearningDiscreteDense extends AsyncNStepQLearningDiscrete { +public class AsyncNStepQLearningDiscreteDense extends AsyncNStepQLearningDiscrete { @Deprecated - public AsyncNStepQLearningDiscreteDense(MDP mdp, IDQN dqn, + public AsyncNStepQLearningDiscreteDense(MDP mdp, IDQN dqn, AsyncNStepQLConfiguration conf, IDataManager dataManager) { super(mdp, dqn, conf.toLearningConfiguration()); addListener(new DataManagerTrainingListener(dataManager)); } @Deprecated - public AsyncNStepQLearningDiscreteDense(MDP mdp, IDQN dqn, + public AsyncNStepQLearningDiscreteDense(MDP mdp, IDQN dqn, AsyncNStepQLConfiguration conf) { super(mdp, dqn, conf.toLearningConfiguration()); } - public AsyncNStepQLearningDiscreteDense(MDP mdp, IDQN dqn, + public AsyncNStepQLearningDiscreteDense(MDP mdp, IDQN dqn, AsyncQLearningConfiguration conf) { super(mdp, dqn, conf); } @Deprecated - public AsyncNStepQLearningDiscreteDense(MDP mdp, DQNFactory factory, + public AsyncNStepQLearningDiscreteDense(MDP mdp, DQNFactory factory, AsyncNStepQLConfiguration conf, IDataManager dataManager) { this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, dataManager); } @Deprecated - public AsyncNStepQLearningDiscreteDense(MDP mdp, DQNFactory factory, + public AsyncNStepQLearningDiscreteDense(MDP mdp, DQNFactory factory, AsyncNStepQLConfiguration conf) { this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); } - public AsyncNStepQLearningDiscreteDense(MDP mdp, DQNFactory factory, + public AsyncNStepQLearningDiscreteDense(MDP mdp, DQNFactory factory, AsyncQLearningConfiguration conf) { this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); } @Deprecated - public AsyncNStepQLearningDiscreteDense(MDP mdp, + public AsyncNStepQLearningDiscreteDense(MDP mdp, DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf, IDataManager dataManager) { this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf, dataManager); } @Deprecated - public AsyncNStepQLearningDiscreteDense(MDP mdp, + public AsyncNStepQLearningDiscreteDense(MDP mdp, DQNFactoryStdDense.Configuration netConf, AsyncNStepQLConfiguration conf) { this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf); } - public AsyncNStepQLearningDiscreteDense(MDP mdp, + public AsyncNStepQLearningDiscreteDense(MDP mdp, DQNDenseNetworkConfiguration netConf, AsyncQLearningConfiguration conf) { this(mdp, new DQNFactoryStdDense(netConf), conf); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java index 34a2c07a4..0b8535f53 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java @@ -25,21 +25,21 @@ import org.deeplearning4j.rl4j.learning.configuration.AsyncQLearningConfiguratio import org.deeplearning4j.rl4j.learning.listener.TrainingListenerList; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.dqn.IDQN; +import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.policy.DQNPolicy; import org.deeplearning4j.rl4j.policy.EpsGreedy; import org.deeplearning4j.rl4j.policy.Policy; import org.deeplearning4j.rl4j.space.DiscreteSpace; -import org.deeplearning4j.rl4j.space.Encodable; import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.factory.Nd4j; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/5/16. */ -public class AsyncNStepQLearningThreadDiscrete extends AsyncThreadDiscrete { +public class AsyncNStepQLearningThreadDiscrete extends AsyncThreadDiscrete { @Getter - final protected AsyncQLearningConfiguration conf; + final protected AsyncQLearningConfiguration configuration; @Getter final protected IAsyncGlobal asyncGlobal; @Getter @@ -47,31 +47,31 @@ public class AsyncNStepQLearningThreadDiscrete extends Asyn final private Random rnd; - public AsyncNStepQLearningThreadDiscrete(MDP mdp, IAsyncGlobal asyncGlobal, - AsyncQLearningConfiguration conf, + public AsyncNStepQLearningThreadDiscrete(MDP mdp, IAsyncGlobal asyncGlobal, + AsyncQLearningConfiguration configuration, TrainingListenerList listeners, int threadNumber, int deviceNum) { super(asyncGlobal, mdp, listeners, threadNumber, deviceNum); - this.conf = conf; + this.configuration = configuration; this.asyncGlobal = asyncGlobal; this.threadNumber = threadNumber; rnd = Nd4j.getRandom(); - Long seed = conf.getSeed(); - if (seed != null) { + Long seed = configuration.getSeed(); + if(seed != null) { rnd.setSeed(seed + threadNumber); } setUpdateAlgorithm(buildUpdateAlgorithm()); } - public Policy getPolicy(IDQN nn) { - return new EpsGreedy(new DQNPolicy(nn), getMdp(), conf.getUpdateStart(), conf.getEpsilonNbStep(), - rnd, conf.getMinEpsilon(), this); + public Policy getPolicy(IDQN nn) { + return new EpsGreedy(new DQNPolicy(nn), getMdp(), configuration.getUpdateStart(), configuration.getEpsilonNbStep(), + rnd, configuration.getMinEpsilon(), this); } @Override protected UpdateAlgorithm buildUpdateAlgorithm() { int[] shape = getHistoryProcessor() == null ? getMdp().getObservationSpace().getShape() : getHistoryProcessor().getConf().getShape(); - return new QLearningUpdateAlgorithm(shape, getMdp().getActionSpace().getSize(), conf.getGamma()); + return new QLearningUpdateAlgorithm(shape, getMdp().getActionSpace().getSize(), configuration.getGamma()); } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java index d12db5d67..b2e06dc9c 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java @@ -32,10 +32,10 @@ import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration; import org.deeplearning4j.rl4j.learning.sync.SyncLearning; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.dqn.IDQN; +import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.policy.EpsGreedy; import org.deeplearning4j.rl4j.space.ActionSpace; -import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.util.IDataManager.StatEntry; import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java index b2ad597d0..771650340 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java @@ -33,11 +33,11 @@ import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorith import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.StandardDQN; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.dqn.IDQN; +import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.policy.DQNPolicy; import org.deeplearning4j.rl4j.policy.EpsGreedy; import org.deeplearning4j.rl4j.space.DiscreteSpace; -import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.rng.Random; diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteConv.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteConv.java index 98c690269..450d0e27e 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteConv.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteConv.java @@ -24,8 +24,8 @@ import org.deeplearning4j.rl4j.network.configuration.NetworkConfiguration; import org.deeplearning4j.rl4j.network.dqn.DQNFactory; import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdConv; import org.deeplearning4j.rl4j.network.dqn.IDQN; -import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; +import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.util.DataManagerTrainingListener; import org.deeplearning4j.rl4j.util.IDataManager; @@ -34,59 +34,59 @@ import org.deeplearning4j.rl4j.util.IDataManager; * Specialized constructors for the Conv (pixels input) case * Specialized conf + provide additional type safety */ -public class QLearningDiscreteConv extends QLearningDiscrete { +public class QLearningDiscreteConv extends QLearningDiscrete { @Deprecated - public QLearningDiscreteConv(MDP mdp, IDQN dqn, HistoryProcessor.Configuration hpconf, + public QLearningDiscreteConv(MDP mdp, IDQN dqn, HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) { this(mdp, dqn, hpconf, conf); addListener(new DataManagerTrainingListener(dataManager)); } @Deprecated - public QLearningDiscreteConv(MDP mdp, IDQN dqn, HistoryProcessor.Configuration hpconf, + public QLearningDiscreteConv(MDP mdp, IDQN dqn, HistoryProcessor.Configuration hpconf, QLConfiguration conf) { super(mdp, dqn, conf.toLearningConfiguration(), conf.getEpsilonNbStep() * hpconf.getSkipFrame()); setHistoryProcessor(hpconf); } - public QLearningDiscreteConv(MDP mdp, IDQN dqn, HistoryProcessor.Configuration hpconf, + public QLearningDiscreteConv(MDP mdp, IDQN dqn, HistoryProcessor.Configuration hpconf, QLearningConfiguration conf) { super(mdp, dqn, conf, conf.getEpsilonNbStep() * hpconf.getSkipFrame()); setHistoryProcessor(hpconf); } @Deprecated - public QLearningDiscreteConv(MDP mdp, DQNFactory factory, + public QLearningDiscreteConv(MDP mdp, DQNFactory factory, HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) { this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf, dataManager); } @Deprecated - public QLearningDiscreteConv(MDP mdp, DQNFactory factory, + public QLearningDiscreteConv(MDP mdp, DQNFactory factory, HistoryProcessor.Configuration hpconf, QLConfiguration conf) { this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf); } - public QLearningDiscreteConv(MDP mdp, DQNFactory factory, + public QLearningDiscreteConv(MDP mdp, DQNFactory factory, HistoryProcessor.Configuration hpconf, QLearningConfiguration conf) { this(mdp, factory.buildDQN(hpconf.getShape(), mdp.getActionSpace().getSize()), hpconf, conf); } @Deprecated - public QLearningDiscreteConv(MDP mdp, DQNFactoryStdConv.Configuration netConf, + public QLearningDiscreteConv(MDP mdp, DQNFactoryStdConv.Configuration netConf, HistoryProcessor.Configuration hpconf, QLConfiguration conf, IDataManager dataManager) { this(mdp, new DQNFactoryStdConv(netConf.toNetworkConfiguration()), hpconf, conf, dataManager); } @Deprecated - public QLearningDiscreteConv(MDP mdp, DQNFactoryStdConv.Configuration netConf, + public QLearningDiscreteConv(MDP mdp, DQNFactoryStdConv.Configuration netConf, HistoryProcessor.Configuration hpconf, QLConfiguration conf) { this(mdp, new DQNFactoryStdConv(netConf.toNetworkConfiguration()), hpconf, conf); } - public QLearningDiscreteConv(MDP mdp, NetworkConfiguration netConf, + public QLearningDiscreteConv(MDP mdp, NetworkConfiguration netConf, HistoryProcessor.Configuration hpconf, QLearningConfiguration conf) { this(mdp, new DQNFactoryStdConv(netConf), hpconf, conf); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteDense.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteDense.java index 5b95cc84e..789e71b42 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteDense.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteDense.java @@ -24,65 +24,65 @@ import org.deeplearning4j.rl4j.network.configuration.DQNDenseNetworkConfiguratio import org.deeplearning4j.rl4j.network.dqn.DQNFactory; import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense; import org.deeplearning4j.rl4j.network.dqn.IDQN; -import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.Encodable; +import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.util.DataManagerTrainingListener; import org.deeplearning4j.rl4j.util.IDataManager; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/6/16. */ -public class QLearningDiscreteDense extends QLearningDiscrete { +public class QLearningDiscreteDense extends QLearningDiscrete { @Deprecated - public QLearningDiscreteDense(MDP mdp, IDQN dqn, QLearning.QLConfiguration conf, - IDataManager dataManager) { + public QLearningDiscreteDense(MDP mdp, IDQN dqn, QLearning.QLConfiguration conf, + IDataManager dataManager) { this(mdp, dqn, conf); addListener(new DataManagerTrainingListener(dataManager)); } @Deprecated - public QLearningDiscreteDense(MDP mdp, IDQN dqn, QLearning.QLConfiguration conf) { + public QLearningDiscreteDense(MDP mdp, IDQN dqn, QLearning.QLConfiguration conf) { super(mdp, dqn, conf.toLearningConfiguration(), conf.getEpsilonNbStep()); } - public QLearningDiscreteDense(MDP mdp, IDQN dqn, QLearningConfiguration conf) { + public QLearningDiscreteDense(MDP mdp, IDQN dqn, QLearningConfiguration conf) { super(mdp, dqn, conf, conf.getEpsilonNbStep()); } @Deprecated - public QLearningDiscreteDense(MDP mdp, DQNFactory factory, - QLearning.QLConfiguration conf, IDataManager dataManager) { + public QLearningDiscreteDense(MDP mdp, DQNFactory factory, + QLearning.QLConfiguration conf, IDataManager dataManager) { this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf, dataManager); } @Deprecated - public QLearningDiscreteDense(MDP mdp, DQNFactory factory, + public QLearningDiscreteDense(MDP mdp, DQNFactory factory, QLearning.QLConfiguration conf) { this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); } - public QLearningDiscreteDense(MDP mdp, DQNFactory factory, + public QLearningDiscreteDense(MDP mdp, DQNFactory factory, QLearningConfiguration conf) { this(mdp, factory.buildDQN(mdp.getObservationSpace().getShape(), mdp.getActionSpace().getSize()), conf); } @Deprecated - public QLearningDiscreteDense(MDP mdp, DQNFactoryStdDense.Configuration netConf, - QLearning.QLConfiguration conf, IDataManager dataManager) { + public QLearningDiscreteDense(MDP mdp, DQNFactoryStdDense.Configuration netConf, + QLearning.QLConfiguration conf, IDataManager dataManager) { this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf, dataManager); } @Deprecated - public QLearningDiscreteDense(MDP mdp, DQNFactoryStdDense.Configuration netConf, + public QLearningDiscreteDense(MDP mdp, DQNFactoryStdDense.Configuration netConf, QLearning.QLConfiguration conf) { this(mdp, new DQNFactoryStdDense(netConf.toNetworkConfiguration()), conf); } - public QLearningDiscreteDense(MDP mdp, DQNDenseNetworkConfiguration netConf, + public QLearningDiscreteDense(MDP mdp, DQNDenseNetworkConfiguration netConf, QLearningConfiguration conf) { this(mdp, new DQNFactoryStdDense(netConf), conf); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/CartpoleEnvironment.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/CartpoleEnvironment.java new file mode 100644 index 000000000..1e1348b4a --- /dev/null +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/CartpoleEnvironment.java @@ -0,0 +1,129 @@ +package org.deeplearning4j.rl4j.mdp; + +import lombok.Getter; +import lombok.Setter; +import org.deeplearning4j.rl4j.environment.ActionSchema; +import org.deeplearning4j.rl4j.environment.Environment; +import org.deeplearning4j.rl4j.environment.Schema; +import org.deeplearning4j.rl4j.environment.StepResult; + +import java.util.HashMap; +import java.util.Map; +import java.util.Random; + +public class CartpoleEnvironment implements Environment { + private static final int NUM_ACTIONS = 2; + private static final int ACTION_LEFT = 0; + private static final int ACTION_RIGHT = 1; + + private static final Schema schema = new Schema<>(new ActionSchema<>(ACTION_LEFT)); + + public enum KinematicsIntegrators { Euler, SemiImplicitEuler }; + + private static final double gravity = 9.8; + private static final double massCart = 1.0; + private static final double massPole = 0.1; + private static final double totalMass = massPole + massCart; + private static final double length = 0.5; // actually half the pole's length + private static final double polemassLength = massPole * length; + private static final double forceMag = 10.0; + private static final double tau = 0.02; // seconds between state updates + + // Angle at which to fail the episode + private static final double thetaThresholdRadians = 12.0 * 2.0 * Math.PI / 360.0; + private static final double xThreshold = 2.4; + + private final Random rnd; + + @Getter @Setter + private KinematicsIntegrators kinematicsIntegrator = KinematicsIntegrators.Euler; + + @Getter + private boolean episodeFinished = false; + + private double x; + private double xDot; + private double theta; + private double thetaDot; + private Integer stepsBeyondDone; + + public CartpoleEnvironment() { + rnd = new Random(); + } + + public CartpoleEnvironment(int seed) { + rnd = new Random(seed); + } + + @Override + public Schema getSchema() { + return schema; + } + + @Override + public Map reset() { + + x = 0.1 * rnd.nextDouble() - 0.05; + xDot = 0.1 * rnd.nextDouble() - 0.05; + theta = 0.1 * rnd.nextDouble() - 0.05; + thetaDot = 0.1 * rnd.nextDouble() - 0.05; + stepsBeyondDone = null; + episodeFinished = false; + + return new HashMap() {{ + put("data", new double[]{x, xDot, theta, thetaDot}); + }}; + } + + @Override + public StepResult step(Integer action) { + double force = action == ACTION_RIGHT ? forceMag : -forceMag; + double cosTheta = Math.cos(theta); + double sinTheta = Math.sin(theta); + double temp = (force + polemassLength * thetaDot * thetaDot * sinTheta) / totalMass; + double thetaAcc = (gravity * sinTheta - cosTheta* temp) / (length * (4.0/3.0 - massPole * cosTheta * cosTheta / totalMass)); + double xAcc = temp - polemassLength * thetaAcc * cosTheta / totalMass; + + switch(kinematicsIntegrator) { + case Euler: + x += tau * xDot; + xDot += tau * xAcc; + theta += tau * thetaDot; + thetaDot += tau * thetaAcc; + break; + + case SemiImplicitEuler: + xDot += tau * xAcc; + x += tau * xDot; + thetaDot += tau * thetaAcc; + theta += tau * thetaDot; + break; + } + + episodeFinished |= x < -xThreshold || x > xThreshold + || theta < -thetaThresholdRadians || theta > thetaThresholdRadians; + + double reward; + if(!episodeFinished) { + reward = 1.0; + } + else if(stepsBeyondDone == null) { + stepsBeyondDone = 0; + reward = 1.0; + } + else { + ++stepsBeyondDone; + reward = 0; + } + + Map channelsData = new HashMap() {{ + put("data", new double[]{x, xDot, theta, thetaDot}); + }}; + return new StepResult(channelsData, reward, episodeFinished); + } + + @Override + public void close() { + // Do nothing + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/CartpoleNative.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/CartpoleNative.java index 94aa79b0b..8b33e54d0 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/CartpoleNative.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/CartpoleNative.java @@ -4,8 +4,8 @@ import lombok.Getter; import lombok.Setter; import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.rl4j.space.ArrayObservationSpace; +import org.deeplearning4j.rl4j.space.Box; import org.deeplearning4j.rl4j.space.DiscreteSpace; -import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.ObservationSpace; import java.util.Random; @@ -36,7 +36,7 @@ import java.util.Random; */ -public class CartpoleNative implements MDP { +public class CartpoleNative implements MDP { public enum KinematicsIntegrators { Euler, SemiImplicitEuler }; private static final int NUM_ACTIONS = 2; @@ -74,7 +74,7 @@ public class CartpoleNative implements MDP observationSpace = new ArrayObservationSpace(new int[] { OBSERVATION_NUM_FEATURES }); + private ObservationSpace observationSpace = new ArrayObservationSpace(new int[] { OBSERVATION_NUM_FEATURES }); public CartpoleNative() { rnd = new Random(); @@ -85,7 +85,7 @@ public class CartpoleNative implements MDP step(Integer action) { + public StepReply step(Integer action) { double force = action == ACTION_RIGHT ? forceMag : -forceMag; double cosTheta = Math.cos(theta); double sinTheta = Math.sin(theta); @@ -143,26 +143,12 @@ public class CartpoleNative implements MDP(new State(new double[] { x, xDot, theta, thetaDot }), reward, done, null); + return new StepReply<>(new Box(x, xDot, theta, thetaDot), reward, done, null); } @Override - public MDP newInstance() { + public MDP newInstance() { return new CartpoleNative(); } - public static class State implements Encodable { - - private final double[] state; - - State(double[] state) { - - this.state = state; - } - - @Override - public double[] toArray() { - return state; - } - } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/HardToyState.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/HardToyState.java index 6fd96b7ea..a357eaeda 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/HardToyState.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/HardToyState.java @@ -18,6 +18,7 @@ package org.deeplearning4j.rl4j.mdp.toy; import lombok.Value; import org.deeplearning4j.rl4j.space.Encodable; +import org.nd4j.linalg.api.ndarray.INDArray; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) on 8/9/16. @@ -31,4 +32,19 @@ public class HardToyState implements Encodable { public double[] toArray() { return values; } + + @Override + public boolean isSkipped() { + return false; + } + + @Override + public INDArray getData() { + return null; + } + + @Override + public Encodable dup() { + return null; + } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/SimpleToy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/SimpleToy.java index 19b07b0b1..933332125 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/SimpleToy.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/SimpleToy.java @@ -24,6 +24,7 @@ import org.deeplearning4j.rl4j.learning.NeuralNetFetchable; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.deeplearning4j.rl4j.space.ArrayObservationSpace; +import org.deeplearning4j.rl4j.space.Box; import org.deeplearning4j.rl4j.space.DiscreteSpace; import org.deeplearning4j.rl4j.space.ObservationSpace; import org.nd4j.linalg.api.ndarray.INDArray; @@ -40,7 +41,6 @@ import org.nd4j.linalg.factory.Nd4j; public class SimpleToy implements MDP { final private int maxStep; - //TODO 10 steps toy (always +1 reward2 actions), toylong (1000 steps), toyhard (7 actions, +1 only if actiion = (step/100+step)%7, and toyStoch (like last but reward has 0.10 odd to be somewhere else). @Getter private DiscreteSpace actionSpace = new DiscreteSpace(2); @Getter diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/SimpleToyState.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/SimpleToyState.java index 1c38cf384..6e41ea414 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/SimpleToyState.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/SimpleToyState.java @@ -18,6 +18,7 @@ package org.deeplearning4j.rl4j.mdp.toy; import lombok.Value; import org.deeplearning4j.rl4j.space.Encodable; +import org.nd4j.linalg.api.ndarray.INDArray; /** * @author rubenfiszel (ruben.fiszel@epfl.ch) 7/18/16. @@ -28,11 +29,24 @@ public class SimpleToyState implements Encodable { int i; int step; - @Override public double[] toArray() { double[] ar = new double[1]; ar[0] = (20 - i); return ar; } + @Override + public boolean isSkipped() { + return false; + } + + @Override + public INDArray getData() { + return null; + } + + @Override + public Encodable dup() { + return null; + } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/Observation.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/Observation.java index 0444aa32d..6603429cc 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/Observation.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/Observation.java @@ -25,7 +25,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; * * @author Alexandre Boulanger */ -public class Observation { +public class Observation implements Encodable { /** * A singleton representing a skipped observation @@ -38,6 +38,11 @@ public class Observation { @Getter private final INDArray data; + @Override + public double[] toArray() { + return data.data().asDouble(); + } + public boolean isSkipped() { return data == null; } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/EncodableToINDArrayTransform.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/EncodableToINDArrayTransform.java similarity index 64% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/EncodableToINDArrayTransform.java rename to rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/EncodableToINDArrayTransform.java index a9214bbff..8be8c7ed9 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/EncodableToINDArrayTransform.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/EncodableToINDArrayTransform.java @@ -1,41 +1,28 @@ -/******************************************************************************* - * Copyright (c) 2020 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ -package org.deeplearning4j.rl4j.observation.transform.legacy; - -import org.bytedeco.javacv.OpenCVFrameConverter; -import org.bytedeco.opencv.opencv_core.Mat; -import org.datavec.api.transform.Operation; -import org.datavec.image.data.ImageWritable; -import org.deeplearning4j.rl4j.space.Encodable; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; - -import static org.bytedeco.opencv.global.opencv_core.CV_32FC; - -public class EncodableToINDArrayTransform implements Operation { - - private final int[] shape; - - public EncodableToINDArrayTransform(int[] shape) { - this.shape = shape; - } - - @Override - public INDArray transform(Encodable encodable) { - return Nd4j.create(encodable.toArray()).reshape(shape); - } - -} +/******************************************************************************* + * Copyright (c) 2020 Konduit K. K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.rl4j.observation.transform; + +import org.datavec.api.transform.Operation; +import org.deeplearning4j.rl4j.space.Encodable; +import org.nd4j.linalg.api.ndarray.INDArray; + +public class EncodableToINDArrayTransform implements Operation { + @Override + public INDArray transform(Encodable encodable) { + return encodable.getData(); + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/EncodableToImageWritableTransform.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/EncodableToImageWritableTransform.java index 133fbdb61..870b366ff 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/EncodableToImageWritableTransform.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/EncodableToImageWritableTransform.java @@ -15,34 +15,32 @@ ******************************************************************************/ package org.deeplearning4j.rl4j.observation.transform.legacy; +import org.bytedeco.javacv.Frame; import org.bytedeco.javacv.OpenCVFrameConverter; import org.bytedeco.opencv.opencv_core.Mat; import org.datavec.api.transform.Operation; import org.datavec.image.data.ImageWritable; +import org.datavec.image.loader.NativeImageLoader; import org.deeplearning4j.rl4j.space.Encodable; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import static org.bytedeco.opencv.global.opencv_core.CV_32FC; +import static org.bytedeco.opencv.global.opencv_core.CV_32FC3; +import static org.bytedeco.opencv.global.opencv_core.CV_32S; +import static org.bytedeco.opencv.global.opencv_core.CV_32SC; +import static org.bytedeco.opencv.global.opencv_core.CV_32SC3; +import static org.bytedeco.opencv.global.opencv_core.CV_64FC; +import static org.bytedeco.opencv.global.opencv_core.CV_8UC3; public class EncodableToImageWritableTransform implements Operation { - private final OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat(); - private final int height; - private final int width; - private final int colorChannels; - - public EncodableToImageWritableTransform(int height, int width, int colorChannels) { - this.height = height; - this.width = width; - this.colorChannels = colorChannels; - } + final static NativeImageLoader nativeImageLoader = new NativeImageLoader(); @Override public ImageWritable transform(Encodable encodable) { - INDArray indArray = Nd4j.create(encodable.toArray()).reshape(height, width, colorChannels); - Mat mat = new Mat(height, width, CV_32FC(3), indArray.data().pointer()); - return new ImageWritable(converter.convert(mat)); + return new ImageWritable(nativeImageLoader.asFrame(encodable.getData(), Frame.DEPTH_UBYTE)); } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/ImageWritableToINDArrayTransform.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/ImageWritableToINDArrayTransform.java index 3a48c128a..88615325d 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/ImageWritableToINDArrayTransform.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/ImageWritableToINDArrayTransform.java @@ -18,34 +18,31 @@ package org.deeplearning4j.rl4j.observation.transform.legacy; import org.datavec.api.transform.Operation; import org.datavec.image.data.ImageWritable; import org.datavec.image.loader.NativeImageLoader; -import org.deeplearning4j.rl4j.space.Encodable; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; import java.io.IOException; public class ImageWritableToINDArrayTransform implements Operation { - private final int height; - private final int width; - private final NativeImageLoader loader; - - public ImageWritableToINDArrayTransform(int height, int width) { - this.height = height; - this.width = width; - this.loader = new NativeImageLoader(height, width); - } + private final NativeImageLoader loader = new NativeImageLoader(); @Override public INDArray transform(ImageWritable imageWritable) { + + int height = imageWritable.getHeight(); + int width = imageWritable.getWidth(); + int channels = imageWritable.getFrame().imageChannels; + INDArray out = null; try { out = loader.asMatrix(imageWritable); } catch (IOException e) { e.printStackTrace(); } - out = out.reshape(1, height, width); + + // Convert back to uint8 and reshape to the number of channels in the image + out = out.reshape(channels, height, width); INDArray compressed = out.castTo(DataType.UINT8); return compressed; } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/HistoryMergeTransform.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/HistoryMergeTransform.java index e27d1134c..e8920bbdd 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/HistoryMergeTransform.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/HistoryMergeTransform.java @@ -46,19 +46,20 @@ public class HistoryMergeTransform implements Operation, Res private final HistoryMergeElementStore historyMergeElementStore; private final HistoryMergeAssembler historyMergeAssembler; private final boolean shouldStoreCopy; - private final boolean isFirstDimenstionBatch; + private final boolean isFirstDimensionBatch; private HistoryMergeTransform(Builder builder) { this.historyMergeElementStore = builder.historyMergeElementStore; this.historyMergeAssembler = builder.historyMergeAssembler; this.shouldStoreCopy = builder.shouldStoreCopy; - this.isFirstDimenstionBatch = builder.isFirstDimenstionBatch; + this.isFirstDimensionBatch = builder.isFirstDimenstionBatch; } @Override public INDArray transform(INDArray input) { + INDArray element; - if(isFirstDimenstionBatch) { + if(isFirstDimensionBatch) { element = input.slice(0, 0); } else { @@ -132,9 +133,9 @@ public class HistoryMergeTransform implements Operation, Res return this; } - public HistoryMergeTransform build() { + public HistoryMergeTransform build(int frameStackLength) { if(historyMergeElementStore == null) { - historyMergeElementStore = new CircularFifoStore(); + historyMergeElementStore = new CircularFifoStore(frameStackLength); } if(historyMergeAssembler == null) { diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/CircularFifoStore.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/CircularFifoStore.java index db1cbb2bd..5b00bba3c 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/CircularFifoStore.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/CircularFifoStore.java @@ -28,14 +28,9 @@ import org.nd4j.linalg.factory.Nd4j; * @author Alexandre Boulanger */ public class CircularFifoStore implements HistoryMergeElementStore { - private static final int DEFAULT_STORE_SIZE = 4; private final CircularFifoQueue queue; - public CircularFifoStore() { - this(DEFAULT_STORE_SIZE); - } - public CircularFifoStore(int size) { Preconditions.checkArgument(size > 0, "The size must be at least 1, got %s", size); queue = new CircularFifoQueue<>(size); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/ACPolicy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/ACPolicy.java index 61ba70825..6824e75cb 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/ACPolicy.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/ACPolicy.java @@ -20,8 +20,8 @@ import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.network.ac.ActorCriticCompGraph; import org.deeplearning4j.rl4j.network.ac.ActorCriticSeparate; import org.deeplearning4j.rl4j.network.ac.IActorCritic; -import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.space.Encodable; +import org.deeplearning4j.rl4j.observation.Observation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.factory.Nd4j; @@ -35,7 +35,7 @@ import java.io.IOException; * the softmax output of the actor critic, but objects constructed * with a {@link Random} argument of null return the max only. */ -public class ACPolicy extends Policy { +public class ACPolicy extends Policy { final private IActorCritic actorCritic; Random rnd; @@ -48,18 +48,18 @@ public class ACPolicy extends Policy { this.rnd = rnd; } - public static ACPolicy load(String path) throws IOException { - return new ACPolicy(ActorCriticCompGraph.load(path)); + public static ACPolicy load(String path) throws IOException { + return new ACPolicy<>(ActorCriticCompGraph.load(path)); } - public static ACPolicy load(String path, Random rnd) throws IOException { - return new ACPolicy(ActorCriticCompGraph.load(path), rnd); + public static ACPolicy load(String path, Random rnd) throws IOException { + return new ACPolicy<>(ActorCriticCompGraph.load(path), rnd); } - public static ACPolicy load(String pathValue, String pathPolicy) throws IOException { - return new ACPolicy(ActorCriticSeparate.load(pathValue, pathPolicy)); + public static ACPolicy load(String pathValue, String pathPolicy) throws IOException { + return new ACPolicy<>(ActorCriticSeparate.load(pathValue, pathPolicy)); } - public static ACPolicy load(String pathValue, String pathPolicy, Random rnd) throws IOException { - return new ACPolicy(ActorCriticSeparate.load(pathValue, pathPolicy), rnd); + public static ACPolicy load(String pathValue, String pathPolicy, Random rnd) throws IOException { + return new ACPolicy<>(ActorCriticSeparate.load(pathValue, pathPolicy), rnd); } public IActorCritic getNeuralNet() { diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/BoltzmannQ.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/BoltzmannQ.java index cf2b60f41..6f2e63620 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/BoltzmannQ.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/BoltzmannQ.java @@ -17,8 +17,8 @@ package org.deeplearning4j.rl4j.policy; import org.deeplearning4j.rl4j.network.dqn.IDQN; -import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.space.Encodable; +import org.deeplearning4j.rl4j.observation.Observation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.rng.Random; @@ -30,7 +30,7 @@ import static org.nd4j.linalg.ops.transforms.Transforms.exp; * Boltzmann exploration is a stochastic policy wrt to the * exponential Q-values as evaluated by the dqn model. */ -public class BoltzmannQ extends Policy { +public class BoltzmannQ extends Policy { final private IDQN dqn; final private Random rnd; diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/DQNPolicy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/DQNPolicy.java index c7ef91665..ed591a1ff 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/DQNPolicy.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/DQNPolicy.java @@ -20,8 +20,8 @@ import lombok.AllArgsConstructor; import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.network.dqn.DQN; import org.deeplearning4j.rl4j.network.dqn.IDQN; -import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.space.Encodable; +import org.deeplearning4j.rl4j.observation.Observation; import org.nd4j.linalg.api.ndarray.INDArray; import java.io.IOException; @@ -32,13 +32,15 @@ import java.io.IOException; * DQN policy returns the action with the maximum Q-value as evaluated * by the dqn model */ + +// FIXME: Should we rename this "GreedyPolicy"? @AllArgsConstructor -public class DQNPolicy extends Policy { +public class DQNPolicy extends Policy { final private IDQN dqn; - public static DQNPolicy load(String path) throws IOException { - return new DQNPolicy(DQN.load(path)); + public static DQNPolicy load(String path) throws IOException { + return new DQNPolicy<>(DQN.load(path)); } public IDQN getNeuralNet() { diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java index 2c7695dc7..a7282f139 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java @@ -20,12 +20,11 @@ package org.deeplearning4j.rl4j.policy; import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.rl4j.learning.IEpochTrainer; -import org.deeplearning4j.rl4j.learning.ILearning; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.NeuralNet; +import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.space.ActionSpace; -import org.deeplearning4j.rl4j.space.Encodable; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.rng.Random; @@ -41,10 +40,10 @@ import org.nd4j.linalg.api.rng.Random; */ @AllArgsConstructor @Slf4j -public class EpsGreedy> extends Policy { +public class EpsGreedy> extends Policy { - final private Policy policy; - final private MDP mdp; + final private Policy policy; + final private MDP mdp; final private int updateStart; final private int epsilonNbStep; final private Random rnd; diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/IPolicy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/IPolicy.java index f87971a89..ffc029835 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/IPolicy.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/IPolicy.java @@ -7,8 +7,14 @@ import org.deeplearning4j.rl4j.space.ActionSpace; import org.deeplearning4j.rl4j.space.Encodable; import org.nd4j.linalg.api.ndarray.INDArray; -public interface IPolicy { - > double play(MDP mdp, IHistoryProcessor hp); - A nextAction(INDArray input); - A nextAction(Observation observation); +public interface IPolicy { + @Deprecated + > double play(MDP mdp, IHistoryProcessor hp); + + @Deprecated + ACTION nextAction(INDArray input); + + ACTION nextAction(Observation observation); + + void reset(); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java index d5fa59766..6a4146c94 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java @@ -22,9 +22,9 @@ import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.learning.Learning; import org.deeplearning4j.rl4j.mdp.MDP; import org.deeplearning4j.rl4j.network.NeuralNet; +import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.space.ActionSpace; -import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; /** @@ -34,22 +34,22 @@ import org.deeplearning4j.rl4j.util.LegacyMDPWrapper; * * A Policy responsability is to choose the next action given a state */ -public abstract class Policy implements IPolicy { +public abstract class Policy implements IPolicy { public abstract NeuralNet getNeuralNet(); public abstract A nextAction(Observation obs); - public > double play(MDP mdp) { + public > double play(MDP mdp) { return play(mdp, (IHistoryProcessor)null); } - public > double play(MDP mdp, HistoryProcessor.Configuration conf) { + public > double play(MDP mdp, HistoryProcessor.Configuration conf) { return play(mdp, new HistoryProcessor(conf)); } @Override - public > double play(MDP mdp, IHistoryProcessor hp) { + public > double play(MDP mdp, IHistoryProcessor hp) { resetNetworks(); LegacyMDPWrapper mdpWrapper = new LegacyMDPWrapper(mdp, hp); @@ -84,8 +84,11 @@ public abstract class Policy implements IPolicy { protected void resetNetworks() { getNeuralNet().reset(); } + public void reset() { + resetNetworks(); + } - protected > Learning.InitMdp refacInitMdp(LegacyMDPWrapper mdpWrapper, IHistoryProcessor hp) { + protected > Learning.InitMdp refacInitMdp(LegacyMDPWrapper mdpWrapper, IHistoryProcessor hp) { double reward = 0; diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java index 981f35379..cc0a12e38 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java @@ -7,22 +7,22 @@ import org.datavec.image.transform.ColorConversionTransform; import org.datavec.image.transform.CropImageTransform; import org.datavec.image.transform.MultiImageTransform; import org.datavec.image.transform.ResizeImageTransform; +import org.datavec.image.transform.ShowImageTransform; import org.deeplearning4j.gym.StepReply; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; import org.deeplearning4j.rl4j.mdp.MDP; +import org.deeplearning4j.rl4j.observation.transform.EncodableToINDArrayTransform; +import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.observation.transform.TransformProcess; import org.deeplearning4j.rl4j.observation.transform.filter.UniformSkippingFilter; -import org.deeplearning4j.rl4j.observation.transform.legacy.EncodableToINDArrayTransform; import org.deeplearning4j.rl4j.observation.transform.legacy.EncodableToImageWritableTransform; import org.deeplearning4j.rl4j.observation.transform.legacy.ImageWritableToINDArrayTransform; import org.deeplearning4j.rl4j.observation.transform.operation.HistoryMergeTransform; import org.deeplearning4j.rl4j.observation.transform.operation.SimpleNormalizationTransform; import org.deeplearning4j.rl4j.space.ActionSpace; -import org.deeplearning4j.rl4j.space.Encodable; import org.deeplearning4j.rl4j.space.ObservationSpace; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; import java.util.HashMap; import java.util.Map; @@ -46,6 +46,7 @@ public class LegacyMDPWrapper wrappedMDP, IHistoryProcessor historyProcessor) { this.wrappedMDP = wrappedMDP; this.shape = wrappedMDP.getObservationSpace().getShape(); @@ -66,28 +67,33 @@ public class LegacyMDPWrapper channelsData = buildChannelsData(rawStepReply.getObservation()); Observation observation = transformProcess.transform(channelsData, stepOfObservation, rawStepReply.isDone()); + return new StepReply(observation, rawStepReply.getReward(), rawStepReply.isDone(), rawStepReply.getInfo()); } @@ -161,12 +168,7 @@ public class LegacyMDPWrapper { diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/VideoRecorder.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/VideoRecorder.java index 5a91b71e4..e7e9fcd4c 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/VideoRecorder.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/VideoRecorder.java @@ -16,26 +16,21 @@ package org.deeplearning4j.rl4j.util; -import lombok.Getter; import lombok.extern.slf4j.Slf4j; -import org.bytedeco.javacpp.BytePointer; -import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacv.FFmpegFrameRecorder; import org.bytedeco.javacv.Frame; -import org.bytedeco.javacv.OpenCVFrameConverter; -import org.bytedeco.opencv.global.opencv_core; -import org.bytedeco.opencv.global.opencv_imgproc; -import org.bytedeco.opencv.opencv_core.Mat; -import org.bytedeco.opencv.opencv_core.Rect; -import org.bytedeco.opencv.opencv_core.Size; -import org.opencv.imgproc.Imgproc; +import org.datavec.image.loader.NativeImageLoader; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; -import static org.bytedeco.ffmpeg.global.avcodec.*; -import static org.bytedeco.opencv.global.opencv_core.*; +import static org.bytedeco.ffmpeg.global.avcodec.AV_CODEC_ID_H264; +import static org.bytedeco.ffmpeg.global.avcodec.AV_CODEC_ID_MPEG4; +import static org.bytedeco.ffmpeg.global.avutil.AV_PIX_FMT_RGB0; +import static org.bytedeco.ffmpeg.global.avutil.AV_PIX_FMT_RGB24; +import static org.bytedeco.ffmpeg.global.avutil.AV_PIX_FMT_RGB8; /** - * VideoRecorder is used to create a video from a sequence of individual frames. If using 3 channels - * images, it expects B-G-R order. A RGB order can be used by calling isRGBOrder(true).
+ * VideoRecorder is used to create a video from a sequence of INDArray frames. INDArrays are assumed to be in CHW format where C=3 and pixels are RGB encoded
* Example:
*

  * {@code
@@ -45,11 +40,8 @@ import static org.bytedeco.opencv.global.opencv_core.*;
  *             .build();
  *         recorder.startRecording("myVideo.mp4");
  *         while(...) {
- *             byte[] data = new byte[160*100*3];
- *             // Todo: Fill data
- *             VideoRecorder.VideoFrame frame = recorder.createFrame(data);
- *             // Todo: Apply cropping or resizing to frame
- *             recorder.record(frame);
+ *             INDArray chwData = Nd4j.create()
+ *             recorder.record(chwData);
  *         }
  *         recorder.stopRecording();
  * }
@@ -60,16 +52,13 @@ import static org.bytedeco.opencv.global.opencv_core.*;
 @Slf4j
 public class VideoRecorder implements AutoCloseable {
 
-    public enum FrameInputTypes { BGR, RGB, Float }
+    private final NativeImageLoader nativeImageLoader = new NativeImageLoader();
 
     private final int height;
     private final int width;
-    private final int imageType;
-    private final OpenCVFrameConverter openCVFrameConverter = new OpenCVFrameConverter.ToMat();
     private final int codec;
     private final double framerate;
     private final int videoQuality;
-    private final FrameInputTypes frameInputType;
 
     private FFmpegFrameRecorder fmpegFrameRecorder = null;
 
@@ -83,11 +72,9 @@ public class VideoRecorder implements AutoCloseable {
     private VideoRecorder(Builder builder) {
         this.height = builder.height;
         this.width = builder.width;
-        imageType = CV_8UC(builder.numChannels);
         codec = builder.codec;
         framerate = builder.frameRate;
         videoQuality = builder.videoQuality;
-        frameInputType = builder.frameInputType;
     }
 
     /**
@@ -119,59 +106,11 @@ public class VideoRecorder implements AutoCloseable {
 
     /**
      * Add a frame to the video
-     * @param frame the VideoFrame to add to the video
+     * @param imageArray the INDArray that contains the data to be recorded, the data must be in CHW format
      * @throws Exception
      */
-    public void record(VideoFrame frame) throws Exception {
-        Size size = frame.getMat().size();
-        if(size.height() != height || size.width() != width) {
-            throw new IllegalArgumentException(String.format("Wrong frame size. Got (%dh x %dw) expected (%dh x %dw)", size.height(), size.width(), height, width));
-        }
-        Frame cvFrame = openCVFrameConverter.convert(frame.getMat());
-        fmpegFrameRecorder.record(cvFrame);
-    }
-
-    /**
-     * Create a VideoFrame from a byte array.
-     * @param data A byte array. Expect the index to be of the form [(Y*Width + X) * NumChannels + channel]
-     * @return An instance of VideoFrame
-     */
-    public VideoFrame createFrame(byte[] data) {
-        return createFrame(new BytePointer(data));
-    }
-
-    /**
-     * Create a VideoFrame from a byte array with different height and width than the video
-     * the frame will need to be cropped or resized before being added to the video)
-     *
-     * @param data A byte array Expect the index to be of the form [(Y*customWidth + X) * NumChannels + channel]
-     * @param customHeight The actual height of the data
-     * @param customWidth The actual width of the data
-     * @return A VideoFrame instance
-     */
-    public VideoFrame createFrame(byte[] data, int customHeight, int customWidth) {
-        return createFrame(new BytePointer(data), customHeight, customWidth);
-    }
-
-    /**
-     * Create a VideoFrame from a Pointer (to use for example with a INDarray).
-     * @param data A Pointer (for example myINDArray.data().pointer())
-     * @return An instance of VideoFrame
-     */
-    public VideoFrame createFrame(Pointer data) {
-        return new VideoFrame(height, width, imageType, frameInputType, data);
-    }
-
-    /**
-     *  Create a VideoFrame from a Pointer with different height and width than the video
-     * the frame will need to be cropped or resized before being added to the video)
-     * @param data
-     * @param customHeight The actual height of the data
-     * @param customWidth The actual width of the data
-     * @return A VideoFrame instance
-     */
-    public VideoFrame createFrame(Pointer data, int customHeight, int customWidth) {
-        return new VideoFrame(customHeight, customWidth, imageType, frameInputType, data);
+    public void record(INDArray imageArray) throws Exception {
+        fmpegFrameRecorder.record(nativeImageLoader.asFrame(imageArray, Frame.DEPTH_UBYTE));
     }
 
     /**
@@ -192,69 +131,12 @@ public class VideoRecorder implements AutoCloseable {
         return new Builder(height, width);
     }
 
-    /**
-     * An individual frame for the video
-     */
-    public static class VideoFrame {
-
-        private final int height;
-        private final int width;
-        private final int imageType;
-        @Getter
-        private Mat mat;
-
-        private VideoFrame(int height, int width, int imageType, FrameInputTypes frameInputType, Pointer data) {
-            this.height = height;
-            this.width = width;
-            this.imageType = imageType;
-
-            switch(frameInputType) {
-                case RGB:
-                    Mat src = new Mat(height, width, imageType, data);
-                    mat = new Mat(height, width, imageType);
-                    opencv_imgproc.cvtColor(src, mat, Imgproc.COLOR_RGB2BGR);
-                    break;
-
-                case BGR:
-                    mat = new Mat(height, width, imageType, data);
-                    break;
-
-                case Float:
-                    Mat tmpMat = new Mat(height, width, CV_32FC(3), data);
-                    mat = new Mat(height, width, imageType);
-                    tmpMat.convertTo(mat, CV_8UC(3), 255.0, 0.0);
-            }
-        }
-
-        /**
-         * Crop the video to a specified size
-         * @param newHeight The new height of the frame
-         * @param newWidth The new width of the frame
-         * @param heightOffset The starting height offset in the uncropped frame
-         * @param widthOffset The starting weight offset in the uncropped frame
-         */
-        public void crop(int newHeight, int newWidth, int heightOffset, int widthOffset) {
-            mat = mat.apply(new Rect(widthOffset, heightOffset, newWidth, newHeight));
-        }
-
-        /**
-         * Resize the frame to a specified size
-         * @param newHeight The new height of the frame
-         * @param newWidth The new width of the frame
-         */
-        public void resize(int newHeight, int newWidth) {
-            mat = new Mat(newHeight, newWidth, imageType);
-        }
-    }
-
     /**
      * A builder class for the VideoRecorder
      */
     public static class Builder {
         private final int height;
         private final int width;
-        private int numChannels = 3;
-        private FrameInputTypes frameInputType = FrameInputTypes.BGR;
         private int codec = AV_CODEC_ID_H264;
         private double frameRate = 30.0;
         private int videoQuality = 30;
@@ -268,24 +150,6 @@ public class VideoRecorder implements AutoCloseable {
             this.width = width;
         }
 
-        /**
-         * Specify the number of channels. Default is 3
-         * @param numChannels
-         */
-        public Builder numChannels(int numChannels) {
-            this.numChannels = numChannels;
-            return this;
-        }
-
-        /**
-         * Tell the VideoRecorder what data it will receive (default is BGR)
-         * @param frameInputType (See {@link FrameInputTypes}}
-         */
-        public Builder frameInputType(FrameInputTypes frameInputType) {
-            this.frameInputType = frameInputType;
-            return this;
-        }
-
         /**
          * The codec to use for the video. Default is AV_CODEC_ID_H264
          * @param codec Code (see {@link org.bytedeco.ffmpeg.global.avcodec codec codes})
diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/AgentTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/AgentTest.java
new file mode 100644
index 000000000..a8beae640
--- /dev/null
+++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/AgentTest.java
@@ -0,0 +1,483 @@
+package org.deeplearning4j.rl4j.agent;
+
+import org.deeplearning4j.rl4j.agent.listener.AgentListener;
+import org.deeplearning4j.rl4j.environment.ActionSchema;
+import org.deeplearning4j.rl4j.environment.Environment;
+import org.deeplearning4j.rl4j.environment.Schema;
+import org.deeplearning4j.rl4j.environment.StepResult;
+import org.deeplearning4j.rl4j.observation.Observation;
+import org.deeplearning4j.rl4j.observation.transform.TransformProcess;
+import org.deeplearning4j.rl4j.policy.IPolicy;
+import org.junit.Rule;
+import org.junit.Test;
+import static org.junit.Assert.*;
+
+import org.mockito.*;
+import org.mockito.junit.*;
+import org.nd4j.linalg.factory.Nd4j;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.mockito.ArgumentMatchers.*;
+import static org.mockito.Mockito.*;
+
+public class AgentTest {
+
+    @Mock Environment environmentMock;
+    @Mock TransformProcess transformProcessMock;
+    @Mock IPolicy policyMock;
+    @Mock AgentListener listenerMock;
+
+    @Rule
+    public MockitoRule mockitoRule = MockitoJUnit.rule();
+
+    @Test
+    public void when_buildingWithNullEnvironment_expect_exception() {
+        try {
+            Agent.builder(null, null, null).build();
+            fail("NullPointerException should have been thrown");
+        } catch (NullPointerException exception) {
+            String expectedMessage = "environment is marked non-null but is null";
+            String actualMessage = exception.getMessage();
+
+            assertTrue(actualMessage.contains(expectedMessage));
+        }
+    }
+
+    @Test
+    public void when_buildingWithNullTransformProcess_expect_exception() {
+        try {
+            Agent.builder(environmentMock, null, null).build();
+            fail("NullPointerException should have been thrown");
+        } catch (NullPointerException exception) {
+            String expectedMessage = "transformProcess is marked non-null but is null";
+            String actualMessage = exception.getMessage();
+
+            assertTrue(actualMessage.contains(expectedMessage));
+        }
+    }
+
+    @Test
+    public void when_buildingWithNullPolicy_expect_exception() {
+        try {
+            Agent.builder(environmentMock, transformProcessMock, null).build();
+            fail("NullPointerException should have been thrown");
+        } catch (NullPointerException exception) {
+            String expectedMessage = "policy is marked non-null but is null";
+            String actualMessage = exception.getMessage();
+
+            assertTrue(actualMessage.contains(expectedMessage));
+        }
+    }
+
+    @Test
+    public void when_buildingWithInvalidMaxSteps_expect_exception() {
+        try {
+            Agent.builder(environmentMock, transformProcessMock, policyMock)
+                    .maxEpisodeSteps(0)
+                    .build();
+            fail("IllegalArgumentException should have been thrown");
+        } catch (IllegalArgumentException exception) {
+            String expectedMessage = "maxEpisodeSteps must be greater than 0, got [0]";
+            String actualMessage = exception.getMessage();
+
+            assertTrue(actualMessage.contains(expectedMessage));
+        }
+    }
+
+    @Test
+    public void when_buildingWithId_expect_idSetInAgent() {
+        // Arrange
+        Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
+                .id("TestAgent")
+                .build();
+
+        // Assert
+        assertEquals("TestAgent", sut.getId());
+    }
+
+    @Test
+    public void when_runIsCalled_expect_agentIsReset() {
+        // Arrange
+        Map envResetResult = new HashMap<>();
+        Schema schema = new Schema(new ActionSchema<>(-1));
+        when(environmentMock.reset()).thenReturn(envResetResult);
+        when(environmentMock.getSchema()).thenReturn(schema);
+
+        when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
+        when(policyMock.nextAction(any(Observation.class))).thenReturn(1);
+
+        Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
+                .build();
+
+        when(listenerMock.onBeforeStep(any(Agent.class), any(Observation.class), anyInt())).thenReturn(AgentListener.ListenerResponse.STOP);
+        sut.addListener(listenerMock);
+
+        // Act
+        sut.run();
+
+        // Assert
+        assertEquals(0, sut.getEpisodeStepNumber());
+        verify(transformProcessMock).transform(envResetResult, 0, false);
+        verify(policyMock, times(1)).reset();
+        assertEquals(0.0, sut.getReward(), 0.00001);
+        verify(environmentMock, times(1)).reset();
+    }
+
+    @Test
+    public void when_runIsCalled_expect_onBeforeAndAfterEpisodeCalled() {
+        // Arrange
+        Map envResetResult = new HashMap<>();
+        Schema schema = new Schema(new ActionSchema<>(-1));
+        when(environmentMock.reset()).thenReturn(envResetResult);
+        when(environmentMock.getSchema()).thenReturn(schema);
+
+        when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
+        when(environmentMock.isEpisodeFinished()).thenReturn(true);
+
+        Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock).build();
+        Agent spy = Mockito.spy(sut);
+
+        // Act
+        spy.run();
+
+        // Assert
+        verify(spy, times(1)).onBeforeEpisode();
+        verify(spy, times(1)).onAfterEpisode();
+    }
+
+    @Test
+    public void when_onBeforeEpisodeReturnsStop_expect_performStepAndOnAfterEpisodeNotCalled() {
+        // Arrange
+        Map envResetResult = new HashMap<>();
+        Schema schema = new Schema(new ActionSchema<>(-1));
+        when(environmentMock.reset()).thenReturn(envResetResult);
+        when(environmentMock.getSchema()).thenReturn(schema);
+
+        when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
+
+        Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock).build();
+
+        when(listenerMock.onBeforeEpisode(any(Agent.class))).thenReturn(AgentListener.ListenerResponse.STOP);
+        sut.addListener(listenerMock);
+
+        Agent spy = Mockito.spy(sut);
+
+        // Act
+        spy.run();
+
+        // Assert
+        verify(spy, times(1)).onBeforeEpisode();
+        verify(spy, never()).performStep();
+        verify(spy, never()).onAfterStep(any(StepResult.class));
+        verify(spy, never()).onAfterEpisode();
+    }
+
+    @Test
+    public void when_runIsCalledWithoutMaxStep_expect_agentRunUntilEpisodeIsFinished() {
+        // Arrange
+        Map envResetResult = new HashMap<>();
+        Schema schema = new Schema(new ActionSchema<>(-1));
+        when(environmentMock.reset()).thenReturn(envResetResult);
+        when(environmentMock.getSchema()).thenReturn(schema);
+
+        when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
+
+        Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
+                .build();
+
+        final Agent spy = Mockito.spy(sut);
+
+        doAnswer(invocation -> {
+            ((Agent)invocation.getMock()).incrementEpisodeStepNumber();
+            return null;
+        }).when(spy).performStep();
+        when(environmentMock.isEpisodeFinished()).thenAnswer(invocation -> spy.getEpisodeStepNumber() >= 5 );
+
+        // Act
+        spy.run();
+
+        // Assert
+        verify(spy, times(1)).onBeforeEpisode();
+        verify(spy, times(5)).performStep();
+        verify(spy, times(1)).onAfterEpisode();
+    }
+
+    @Test
+    public void when_maxStepsIsReachedBeforeEposideEnds_expect_runTerminated() {
+        // Arrange
+        Map envResetResult = new HashMap<>();
+        Schema schema = new Schema(new ActionSchema<>(-1));
+        when(environmentMock.reset()).thenReturn(envResetResult);
+        when(environmentMock.getSchema()).thenReturn(schema);
+
+        when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
+
+        Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
+                .maxEpisodeSteps(3)
+                .build();
+
+        final Agent spy = Mockito.spy(sut);
+
+        doAnswer(invocation -> {
+            ((Agent)invocation.getMock()).incrementEpisodeStepNumber();
+            return null;
+        }).when(spy).performStep();
+
+        // Act
+        spy.run();
+
+        // Assert
+        verify(spy, times(1)).onBeforeEpisode();
+        verify(spy, times(3)).performStep();
+        verify(spy, times(1)).onAfterEpisode();
+    }
+
+    @Test
+    public void when_initialObservationsAreSkipped_expect_performNoOpAction() {
+        // Arrange
+        Map envResetResult = new HashMap<>();
+        Schema schema = new Schema(new ActionSchema<>(-1));
+        when(environmentMock.reset()).thenReturn(envResetResult);
+        when(environmentMock.getSchema()).thenReturn(schema);
+
+        when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(Observation.SkippedObservation);
+
+        Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
+                .build();
+
+        when(listenerMock.onBeforeStep(any(Agent.class), any(Observation.class), any())).thenReturn(AgentListener.ListenerResponse.STOP);
+        sut.addListener(listenerMock);
+
+        Agent spy = Mockito.spy(sut);
+
+        // Act
+        spy.run();
+
+        // Assert
+        verify(listenerMock).onBeforeStep(any(), any(), eq(-1));
+    }
+
+    @Test
+    public void when_initialObservationsAreSkipped_expect_performNoOpActionAnd() {
+        // Arrange
+        Map envResetResult = new HashMap<>();
+        Schema schema = new Schema(new ActionSchema<>(-1));
+        when(environmentMock.reset()).thenReturn(envResetResult);
+        when(environmentMock.getSchema()).thenReturn(schema);
+
+        when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(Observation.SkippedObservation);
+
+        Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
+                .build();
+
+        when(listenerMock.onBeforeStep(any(Agent.class), any(Observation.class), any())).thenReturn(AgentListener.ListenerResponse.STOP);
+        sut.addListener(listenerMock);
+
+        Agent spy = Mockito.spy(sut);
+
+        // Act
+        spy.run();
+
+        // Assert
+        verify(listenerMock).onBeforeStep(any(), any(), eq(-1));
+    }
+
+    @Test
+    public void when_observationsIsSkipped_expect_performLastAction() {
+        // Arrange
+        Map envResetResult = new HashMap<>();
+        Schema schema = new Schema(new ActionSchema<>(-1));
+        when(environmentMock.reset()).thenReturn(envResetResult);
+        when(environmentMock.step(any(Integer.class))).thenReturn(new StepResult(envResetResult, 0.0, false));
+        when(environmentMock.getSchema()).thenReturn(schema);
+
+        when(policyMock.nextAction(any(Observation.class)))
+                .thenAnswer(invocation -> (int)((Observation)invocation.getArgument(0)).getData().getDouble(0));
+
+        Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
+                .maxEpisodeSteps(3)
+                .build();
+
+        Agent spy = Mockito.spy(sut);
+
+        when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean()))
+                .thenAnswer(invocation -> {
+                    int stepNumber = (int)invocation.getArgument(1);
+                    return stepNumber  % 2 == 1 ? Observation.SkippedObservation
+                            : new Observation(Nd4j.create(new double[] {  stepNumber }));
+                });
+
+        sut.addListener(listenerMock);
+
+        // Act
+        spy.run();
+
+        // Assert
+        verify(policyMock, times(2)).nextAction(any(Observation.class));
+
+        ArgumentCaptor agentCaptor = ArgumentCaptor.forClass(Agent.class);
+        ArgumentCaptor observationCaptor = ArgumentCaptor.forClass(Observation.class);
+        ArgumentCaptor actionCaptor = ArgumentCaptor.forClass(Integer.class);
+        verify(listenerMock, times(3)).onBeforeStep(agentCaptor.capture(), observationCaptor.capture(), actionCaptor.capture());
+        List capturedActions = actionCaptor.getAllValues();
+        assertEquals(0, (int)capturedActions.get(0));
+        assertEquals(0, (int)capturedActions.get(1));
+        assertEquals(2, (int)capturedActions.get(2));
+    }
+
+    @Test
+    public void when_onBeforeStepReturnsStop_expect_performStepAndOnAfterEpisodeNotCalled() {
+        // Arrange
+        Schema schema = new Schema(new ActionSchema<>(-1));
+        when(environmentMock.reset()).thenReturn(new HashMap<>());
+        when(environmentMock.getSchema()).thenReturn(schema);
+
+        when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
+
+        Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock).build();
+
+        when(listenerMock.onBeforeStep(any(Agent.class), any(Observation.class), any())).thenReturn(AgentListener.ListenerResponse.STOP);
+        sut.addListener(listenerMock);
+
+        Agent spy = Mockito.spy(sut);
+
+        // Act
+        spy.run();
+
+        // Assert
+        verify(spy, times(1)).onBeforeEpisode();
+        verify(spy, times(1)).onBeforeStep();
+        verify(spy, never()).act(any());
+        verify(spy, never()).onAfterStep(any(StepResult.class));
+        verify(spy, never()).onAfterEpisode();
+    }
+
+    @Test
+    public void when_observationIsNotSkipped_expect_policyActionIsSentToEnvironment() {
+        // Arrange
+        Schema schema = new Schema(new ActionSchema<>(-1));
+        when(environmentMock.reset()).thenReturn(new HashMap<>());
+        when(environmentMock.getSchema()).thenReturn(schema);
+        when(environmentMock.step(any(Integer.class))).thenReturn(new StepResult(new HashMap<>(), 0.0, false));
+
+        when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
+
+        when(policyMock.nextAction(any(Observation.class))).thenReturn(123);
+
+        Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
+                .maxEpisodeSteps(1)
+                .build();
+
+        // Act
+        sut.run();
+
+        // Assert
+        verify(environmentMock, times(1)).step(123);
+    }
+
+    @Test
+    public void when_stepResultIsReceived_expect_observationAndRewardUpdated() {
+        // Arrange
+        Schema schema = new Schema(new ActionSchema<>(-1));
+        when(environmentMock.reset()).thenReturn(new HashMap<>());
+        when(environmentMock.getSchema()).thenReturn(schema);
+        when(environmentMock.step(any(Integer.class))).thenReturn(new StepResult(new HashMap<>(), 234.0, false));
+
+        when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
+
+        when(policyMock.nextAction(any(Observation.class))).thenReturn(123);
+
+        Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
+                .maxEpisodeSteps(1)
+                .build();
+
+        // Act
+        sut.run();
+
+        // Assert
+        assertEquals(123.0, sut.getObservation().getData().getDouble(0), 0.00001);
+        assertEquals(234.0, sut.getReward(), 0.00001);
+    }
+
+    @Test
+    public void when_stepIsDone_expect_onAfterStepAndWithStepResult() {
+        // Arrange
+        Schema schema = new Schema(new ActionSchema<>(-1));
+        when(environmentMock.reset()).thenReturn(new HashMap<>());
+        when(environmentMock.getSchema()).thenReturn(schema);
+        StepResult stepResult = new StepResult(new HashMap<>(), 234.0, false);
+        when(environmentMock.step(any(Integer.class))).thenReturn(stepResult);
+
+        when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
+
+        when(policyMock.nextAction(any(Observation.class))).thenReturn(123);
+
+        Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
+                .maxEpisodeSteps(1)
+                .build();
+        Agent spy = Mockito.spy(sut);
+
+        // Act
+        spy.run();
+
+        // Assert
+        verify(spy).onAfterStep(stepResult);
+    }
+
+    @Test
+    public void when_onAfterStepReturnsStop_expect_onAfterEpisodeNotCalled() {
+        // Arrange
+        Schema schema = new Schema(new ActionSchema<>(-1));
+        when(environmentMock.reset()).thenReturn(new HashMap<>());
+        when(environmentMock.getSchema()).thenReturn(schema);
+        StepResult stepResult = new StepResult(new HashMap<>(), 234.0, false);
+        when(environmentMock.step(any(Integer.class))).thenReturn(stepResult);
+
+        when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
+
+        when(policyMock.nextAction(any(Observation.class))).thenReturn(123);
+
+        Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
+                .maxEpisodeSteps(1)
+                .build();
+        when(listenerMock.onAfterStep(any(Agent.class), any(StepResult.class))).thenReturn(AgentListener.ListenerResponse.STOP);
+        sut.addListener(listenerMock);
+
+        Agent spy = Mockito.spy(sut);
+
+        // Act
+        spy.run();
+
+        // Assert
+        verify(spy, never()).onAfterEpisode();
+    }
+
+    @Test
+    public void when_runIsCalled_expect_onAfterEpisodeIsCalled() {
+        // Arrange
+        Schema schema = new Schema(new ActionSchema<>(-1));
+        when(environmentMock.reset()).thenReturn(new HashMap<>());
+        when(environmentMock.getSchema()).thenReturn(schema);
+        StepResult stepResult = new StepResult(new HashMap<>(), 234.0, false);
+        when(environmentMock.step(any(Integer.class))).thenReturn(stepResult);
+
+        when(transformProcessMock.transform(any(Map.class), anyInt(), anyBoolean())).thenReturn(new Observation(Nd4j.create(new double[] { 123.0 })));
+
+        when(policyMock.nextAction(any(Observation.class))).thenReturn(123);
+
+        Agent sut = Agent.builder(environmentMock, transformProcessMock, policyMock)
+                .maxEpisodeSteps(1)
+                .build();
+
+        Agent spy = Mockito.spy(sut);
+
+        // Act
+        spy.run();
+
+        // Assert
+        verify(spy, times(1)).onAfterEpisode();
+    }
+}
diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java
index 5f2a8ab31..de9778a80 100644
--- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java
+++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java
@@ -62,7 +62,7 @@ public class AsyncThreadDiscreteTest {
     IAsyncGlobal mockAsyncGlobal;
 
     @Mock
-    Policy mockGlobalCurrentPolicy;
+    Policy mockGlobalCurrentPolicy;
 
     @Mock
     NeuralNet mockGlobalTargetNetwork;
@@ -115,7 +115,7 @@ public class AsyncThreadDiscreteTest {
 
         asyncThreadDiscrete.setUpdateAlgorithm(mockUpdateAlgorithm);
 
-        when(asyncThreadDiscrete.getConf()).thenReturn(mockAsyncConfiguration);
+        when(asyncThreadDiscrete.getConfiguration()).thenReturn(mockAsyncConfiguration);
         when(mockAsyncConfiguration.getRewardFactor()).thenReturn(1.0);
         when(asyncThreadDiscrete.getAsyncGlobal()).thenReturn(mockAsyncGlobal);
         when(asyncThreadDiscrete.getPolicy(eq(mockGlobalTargetNetwork))).thenReturn(mockGlobalCurrentPolicy);
diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java
index 117465de3..5b6afbc28 100644
--- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java
+++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java
@@ -39,7 +39,6 @@ import static org.junit.Assert.assertEquals;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyInt;
 import static org.mockito.ArgumentMatchers.eq;
-import static org.mockito.Mockito.clearInvocations;
 import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.times;
@@ -130,7 +129,7 @@ public class AsyncThreadTest {
 
         when(mockAsyncConfiguration.getMaxEpochStep()).thenReturn(maxStepsPerEpisode);
         when(mockAsyncConfiguration.getNStep()).thenReturn(nstep);
-        when(thread.getConf()).thenReturn(mockAsyncConfiguration);
+        when(thread.getConfiguration()).thenReturn(mockAsyncConfiguration);
 
         // if we hit the max step count
         when(mockAsyncGlobal.isTrainingComplete()).thenAnswer(invocation -> thread.getStepCount() >= maxSteps);
diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java
index 82129e0df..e19af338b 100644
--- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java
+++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java
@@ -18,24 +18,16 @@
 package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;
 
 import org.deeplearning4j.gym.StepReply;
-import org.deeplearning4j.rl4j.experience.ExperienceHandler;
-import org.deeplearning4j.rl4j.experience.StateActionPair;
 import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
 import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
-import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
-import org.deeplearning4j.rl4j.learning.sync.Transition;
 import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
 import org.deeplearning4j.rl4j.mdp.MDP;
 import org.deeplearning4j.rl4j.network.dqn.IDQN;
+import org.deeplearning4j.rl4j.space.Encodable;
 import org.deeplearning4j.rl4j.observation.Observation;
 import org.deeplearning4j.rl4j.space.Box;
 import org.deeplearning4j.rl4j.space.DiscreteSpace;
-import org.deeplearning4j.rl4j.space.Encodable;
 import org.deeplearning4j.rl4j.space.ObservationSpace;
-import org.deeplearning4j.rl4j.support.*;
-import org.deeplearning4j.rl4j.util.DataManagerTrainingListener;
-import org.deeplearning4j.rl4j.util.IDataManager;
-import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -43,17 +35,17 @@ import org.mockito.Mock;
 import org.mockito.Mockito;
 import org.mockito.junit.MockitoJUnitRunner;
 import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.linalg.dataset.api.DataSet;
-import org.nd4j.linalg.api.rng.Random;
 import org.nd4j.linalg.factory.Nd4j;
 
-import java.util.ArrayList;
-import java.util.List;
-
-import static org.junit.Assert.*;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyInt;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
 @RunWith(MockitoJUnitRunner.class)
@@ -82,6 +74,7 @@ public class QLearningDiscreteTest {
     @Mock
     QLearningConfiguration mockQlearningConfiguration;
 
+    // HWC
     int[] observationShape = new int[]{3, 10, 10};
     int totalObservationSize = 1;
 
@@ -123,6 +116,7 @@ public class QLearningDiscreteTest {
         when(mockHistoryConfiguration.getCroppingHeight()).thenReturn(observationShape[1]);
         when(mockHistoryConfiguration.getCroppingWidth()).thenReturn(observationShape[2]);
         when(mockHistoryConfiguration.getSkipFrame()).thenReturn(skipFrames);
+        when(mockHistoryConfiguration.getHistoryLength()).thenReturn(1);
         when(mockHistoryProcessor.getConf()).thenReturn(mockHistoryConfiguration);
 
         qLearningDiscrete.setHistoryProcessor(mockHistoryProcessor);
@@ -148,7 +142,7 @@ public class QLearningDiscreteTest {
         Observation observation = new Observation(Nd4j.zeros(observationShape));
         when(mockDQN.output(eq(observation))).thenReturn(Nd4j.create(new float[] {1.0f, 0.5f}));
 
-        when(mockMDP.step(anyInt())).thenReturn(new StepReply<>(new Box(new double[totalObservationSize]), 0, false, null));
+        when(mockMDP.step(anyInt())).thenReturn(new StepReply<>(new Observation(Nd4j.zeros(observationShape)), 0, false, null));
 
         // Act
         QLearning.QLStepReturn stepReturn = qLearningDiscrete.trainStep(observation);
@@ -170,25 +164,26 @@ public class QLearningDiscreteTest {
         // Arrange
         mockTestContext(100,0,2,1.0, 10);
 
-        mockHistoryProcessor(2);
+        Observation skippedObservation = Observation.SkippedObservation;
+        Observation nextObservation = new Observation(Nd4j.zeros(observationShape));
 
-        // An example observation and 2 Q values output (2 actions)
-        Observation observation = new Observation(Nd4j.zeros(observationShape));
-        when(mockDQN.output(eq(observation))).thenReturn(Nd4j.create(new float[] {1.0f, 0.5f}));
-
-        when(mockMDP.step(anyInt())).thenReturn(new StepReply<>(new Box(new double[totalObservationSize]), 0, false, null));
+        when(mockMDP.step(anyInt())).thenReturn(new StepReply<>(nextObservation, 0, false, null));
 
         // Act
-        QLearning.QLStepReturn stepReturn = qLearningDiscrete.trainStep(observation);
+        QLearning.QLStepReturn stepReturn = qLearningDiscrete.trainStep(skippedObservation);
 
         // Assert
-        assertEquals(1.0, stepReturn.getMaxQ(), 1e-5);
+        assertEquals(Double.NaN, stepReturn.getMaxQ(), 1e-5);
 
         StepReply stepReply = stepReturn.getStepReply();
 
         assertEquals(0, stepReply.getReward(), 1e-5);
         assertFalse(stepReply.isDone());
-        assertTrue(stepReply.getObservation().isSkipped());
+        assertFalse(stepReply.getObservation().isSkipped());
+        assertEquals(0, qLearningDiscrete.getExperienceHandler().getTrainingBatchSize());
+
+        verify(mockDQN, never()).output(any(INDArray.class));
+
     }
 
     //TODO: there are much more test cases here that can be improved upon
diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/HistoryMergeTransformTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/HistoryMergeTransformTest.java
index 9c7a172bb..9126ea1fa 100644
--- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/HistoryMergeTransformTest.java
+++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/HistoryMergeTransformTest.java
@@ -17,7 +17,7 @@ public class HistoryMergeTransformTest {
         HistoryMergeTransform sut = HistoryMergeTransform.builder()
                 .isFirstDimenstionBatch(false)
                 .elementStore(store)
-                .build();
+                .build(4);
         INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
 
         // Act
@@ -35,7 +35,7 @@ public class HistoryMergeTransformTest {
         HistoryMergeTransform sut = HistoryMergeTransform.builder()
                 .isFirstDimenstionBatch(true)
                 .elementStore(store)
-                .build();
+                .build(4);
         INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 }).reshape(1, 3);
 
         // Act
@@ -53,7 +53,7 @@ public class HistoryMergeTransformTest {
         HistoryMergeTransform sut = HistoryMergeTransform.builder()
                 .isFirstDimenstionBatch(true)
                 .elementStore(store)
-                .build();
+                .build(4);
         INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
 
         // Act
@@ -70,7 +70,7 @@ public class HistoryMergeTransformTest {
         HistoryMergeTransform sut = HistoryMergeTransform.builder()
                 .shouldStoreCopy(false)
                 .elementStore(store)
-                .build();
+                .build(4);
         INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
 
         // Act
@@ -87,7 +87,7 @@ public class HistoryMergeTransformTest {
         HistoryMergeTransform sut = HistoryMergeTransform.builder()
                 .shouldStoreCopy(true)
                 .elementStore(store)
-                .build();
+                .build(4);
         INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
 
         // Act
@@ -107,7 +107,7 @@ public class HistoryMergeTransformTest {
         HistoryMergeTransform sut = HistoryMergeTransform.builder()
                 .elementStore(store)
                 .assembler(assemble)
-                .build();
+                .build(4);
         INDArray input = Nd4j.create(new double[] { 1.0, 2.0, 3.0 });
 
         // Act
diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java
index 7db92a599..249304afb 100644
--- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java
+++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java
@@ -30,13 +30,8 @@ import org.deeplearning4j.rl4j.network.NeuralNet;
 import org.deeplearning4j.rl4j.network.ac.IActorCritic;
 import org.deeplearning4j.rl4j.observation.Observation;
 import org.deeplearning4j.rl4j.space.ActionSpace;
-import org.deeplearning4j.rl4j.support.MockDQN;
-import org.deeplearning4j.rl4j.support.MockEncodable;
-import org.deeplearning4j.rl4j.support.MockHistoryProcessor;
-import org.deeplearning4j.rl4j.support.MockMDP;
-import org.deeplearning4j.rl4j.support.MockNeuralNet;
-import org.deeplearning4j.rl4j.support.MockObservationSpace;
-import org.deeplearning4j.rl4j.support.MockRandom;
+import org.deeplearning4j.rl4j.space.Encodable;
+import org.deeplearning4j.rl4j.support.*;
 import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
 import org.junit.Test;
 import org.nd4j.linalg.activations.Activation;
@@ -227,7 +222,7 @@ public class PolicyTest {
         assertEquals(0, dqn.outputParams.size());
     }
 
-    public static class MockRefacPolicy extends Policy {
+    public static class MockRefacPolicy extends Policy {
 
         private NeuralNet neuralNet;
         private final int[] shape;
@@ -257,8 +252,8 @@ public class PolicyTest {
         }
 
         @Override
-        protected > Learning.InitMdp refacInitMdp(LegacyMDPWrapper mdpWrapper, IHistoryProcessor hp) {
-            mdpWrapper.setTransformProcess(MockMDP.buildTransformProcess(shape, skipFrame, historyLength));
+        protected > Learning.InitMdp refacInitMdp(LegacyMDPWrapper mdpWrapper, IHistoryProcessor hp) {
+            mdpWrapper.setTransformProcess(MockMDP.buildTransformProcess(skipFrame, historyLength));
             return super.refacInitMdp(mdpWrapper, hp);
         }
     }
diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockEncodable.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockEncodable.java
deleted file mode 100644
index 436205b42..000000000
--- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockEncodable.java
+++ /dev/null
@@ -1,18 +0,0 @@
-package org.deeplearning4j.rl4j.support;
-
-import org.deeplearning4j.rl4j.space.Encodable;
-
-public class MockEncodable implements Encodable {
-
-    private final int value;
-
-    public MockEncodable(int value) {
-
-        this.value = value;
-    }
-
-    @Override
-    public double[] toArray() {
-        return new double[] { value };
-    }
-}
diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockMDP.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockMDP.java
index bbed87624..61db6dd6e 100644
--- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockMDP.java
+++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockMDP.java
@@ -2,9 +2,10 @@ package org.deeplearning4j.rl4j.support;
 
 import org.deeplearning4j.gym.StepReply;
 import org.deeplearning4j.rl4j.mdp.MDP;
+import org.deeplearning4j.rl4j.observation.transform.EncodableToINDArrayTransform;
 import org.deeplearning4j.rl4j.observation.transform.TransformProcess;
 import org.deeplearning4j.rl4j.observation.transform.filter.UniformSkippingFilter;
-import org.deeplearning4j.rl4j.observation.transform.legacy.EncodableToINDArrayTransform;
+import org.deeplearning4j.rl4j.observation.transform.legacy.EncodableToImageWritableTransform;
 import org.deeplearning4j.rl4j.observation.transform.operation.HistoryMergeTransform;
 import org.deeplearning4j.rl4j.observation.transform.operation.SimpleNormalizationTransform;
 import org.deeplearning4j.rl4j.observation.transform.operation.historymerge.CircularFifoStore;
@@ -15,7 +16,7 @@ import org.nd4j.linalg.api.rng.Random;
 import java.util.ArrayList;
 import java.util.List;
 
-public class MockMDP implements MDP {
+public class MockMDP implements MDP {
 
     private final DiscreteSpace actionSpace;
     private final int stepsUntilDone;
@@ -55,11 +56,11 @@ public class MockMDP implements MDP {
     }
 
     @Override
-    public MockEncodable reset() {
+    public MockObservation reset() {
         ++resetCount;
         currentObsValue = 0;
         step = 0;
-        return new MockEncodable(currentObsValue++);
+        return new MockObservation(currentObsValue++);
     }
 
     @Override
@@ -68,10 +69,10 @@ public class MockMDP implements MDP {
     }
 
     @Override
-    public StepReply step(Integer action) {
+    public StepReply step(Integer action) {
         actions.add(action);
         ++step;
-        return new StepReply<>(new MockEncodable(currentObsValue), (double) currentObsValue++, isDone(), null);
+        return new StepReply<>(new MockObservation(currentObsValue), (double) currentObsValue++, isDone(), null);
     }
 
     @Override
@@ -84,14 +85,14 @@ public class MockMDP implements MDP {
         return null;
     }
 
-    public static TransformProcess buildTransformProcess(int[] shape, int skipFrame, int historyLength) {
+    public static TransformProcess buildTransformProcess(int skipFrame, int historyLength) {
         return TransformProcess.builder()
                 .filter(new UniformSkippingFilter(skipFrame))
-                .transform("data", new EncodableToINDArrayTransform(shape))
+                .transform("data", new EncodableToINDArrayTransform())
                 .transform("data", new SimpleNormalizationTransform(0.0, 255.0))
                 .transform("data", HistoryMergeTransform.builder()
                         .elementStore(new CircularFifoStore(historyLength))
-                        .build())
+                        .build(4))
                 .build("data");
     }
 
diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockObservation.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockObservation.java
new file mode 100644
index 000000000..70a3e76c6
--- /dev/null
+++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockObservation.java
@@ -0,0 +1,51 @@
+/*******************************************************************************
+ * Copyright (c) 2020 Konduit K. K.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+package org.deeplearning4j.rl4j.support;
+
+
+import org.deeplearning4j.rl4j.space.Encodable;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
+
+public class MockObservation implements Encodable {
+
+    final INDArray data;
+
+    public MockObservation(int value) {
+        this.data = Nd4j.ones(1).mul(value);
+    }
+
+    @Override
+    public double[] toArray() {
+        return data.data().asDouble();
+    }
+
+    @Override
+    public boolean isSkipped() {
+        return false;
+    }
+
+    @Override
+    public INDArray getData() {
+        return data;
+    }
+
+    @Override
+    public Encodable dup() {
+        return null;
+    }
+}
diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockPolicy.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockPolicy.java
index 4c4f100e9..8786f7d7d 100644
--- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockPolicy.java
+++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockPolicy.java
@@ -5,18 +5,19 @@ import org.deeplearning4j.rl4j.mdp.MDP;
 import org.deeplearning4j.rl4j.observation.Observation;
 import org.deeplearning4j.rl4j.policy.IPolicy;
 import org.deeplearning4j.rl4j.space.ActionSpace;
+import org.deeplearning4j.rl4j.space.Encodable;
 import org.nd4j.linalg.api.ndarray.INDArray;
 
 import java.util.ArrayList;
 import java.util.List;
 
-public class MockPolicy implements IPolicy {
+public class MockPolicy implements IPolicy {
 
     public int playCallCount = 0;
     public List actionInputs = new ArrayList();
 
     @Override
-    public > double play(MDP mdp, IHistoryProcessor hp) {
+    public > double play(MDP mdp, IHistoryProcessor hp) {
         ++playCallCount;
         return 0;
     }
@@ -31,4 +32,9 @@ public class MockPolicy implements IPolicy {
     public Integer nextAction(Observation observation) {
         return nextAction(observation.getData());
     }
+
+    @Override
+    public void reset() {
+
+    }
 }
diff --git a/rl4j/rl4j-doom/src/main/java/org/deeplearning4j/rl4j/mdp/vizdoom/VizDoom.java b/rl4j/rl4j-doom/src/main/java/org/deeplearning4j/rl4j/mdp/vizdoom/VizDoom.java
index 4c331ad78..becce416f 100644
--- a/rl4j/rl4j-doom/src/main/java/org/deeplearning4j/rl4j/mdp/vizdoom/VizDoom.java
+++ b/rl4j/rl4j-doom/src/main/java/org/deeplearning4j/rl4j/mdp/vizdoom/VizDoom.java
@@ -28,6 +28,9 @@ import org.deeplearning4j.rl4j.space.ArrayObservationSpace;
 import org.deeplearning4j.rl4j.space.DiscreteSpace;
 import org.deeplearning4j.rl4j.space.Encodable;
 import org.deeplearning4j.rl4j.space.ObservationSpace;
+import org.nd4j.linalg.api.buffer.DataType;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.factory.Nd4j;
 import vizdoom.*;
 
 import java.util.ArrayList;
@@ -155,7 +158,7 @@ abstract public class VizDoom implements MDP> implements MDP {
+public class GymEnv> implements MDP {
 
     public static final String GYM_MONITOR_DIR = "/tmp/gym-dqn";
 
@@ -82,7 +80,7 @@ public class GymEnv> implements MDP {
     private PyObject locals;
 
     final protected DiscreteSpace actionSpace;
-    final protected ObservationSpace observationSpace;
+    final protected ObservationSpace observationSpace;
     @Getter
     final private String envId;
     @Getter
@@ -119,7 +117,7 @@ public class GymEnv> implements MDP {
             for (int i = 0; i < shape.length; i++) {
                 shape[i] = (int)PyLong_AsLong(PyTuple_GetItem(shapeTuple, i));
             }
-            observationSpace = (ObservationSpace) new ArrayObservationSpace(shape);
+            observationSpace = (ObservationSpace) new ArrayObservationSpace(shape);
             Py_DecRef(shapeTuple);
 
             PyObject n = PyRun_StringFlags("env.action_space.n", Py_eval_input, globals, locals, null);
@@ -140,7 +138,7 @@ public class GymEnv> implements MDP {
     }
 
     @Override
-    public ObservationSpace getObservationSpace() {
+    public ObservationSpace getObservationSpace() {
         return observationSpace;
     }
 
@@ -153,7 +151,7 @@ public class GymEnv> implements MDP {
     }
 
     @Override
-    public StepReply step(A action) {
+    public StepReply step(A action) {
         int gstate = PyGILState_Ensure();
         try {
             if (render) {
@@ -186,7 +184,7 @@ public class GymEnv> implements MDP {
     }
 
     @Override
-    public O reset() {
+    public OBSERVATION reset() {
         int gstate = PyGILState_Ensure();
         try {
             Py_DecRef(PyRun_StringFlags("state = env.reset()", Py_single_input, globals, locals, null));
@@ -201,7 +199,7 @@ public class GymEnv> implements MDP {
 
             double[] data = new double[(int)stateData.capacity()];
             stateData.get(data);
-            return (O) new Box(data);
+            return (OBSERVATION) new Box(data);
         } finally {
             PyGILState_Release(gstate);
         }
@@ -220,7 +218,7 @@ public class GymEnv> implements MDP {
     }
 
     @Override
-    public GymEnv newInstance() {
-        return new GymEnv(envId, render, monitor);
+    public GymEnv newInstance() {
+        return new GymEnv(envId, render, monitor);
     }
 }
diff --git a/rl4j/rl4j-gym/src/test/java/org/deeplearning4j/rl4j/mdp/gym/GymEnvTest.java b/rl4j/rl4j-gym/src/test/java/org/deeplearning4j/rl4j/mdp/gym/GymEnvTest.java
index 2196d7b31..4faf26b2b 100644
--- a/rl4j/rl4j-gym/src/test/java/org/deeplearning4j/rl4j/mdp/gym/GymEnvTest.java
+++ b/rl4j/rl4j-gym/src/test/java/org/deeplearning4j/rl4j/mdp/gym/GymEnvTest.java
@@ -40,8 +40,8 @@ public class GymEnvTest {
         assertEquals(false, mdp.isDone());
         Box o = (Box)mdp.reset();
         StepReply r = mdp.step(0);
-        assertEquals(4, o.toArray().length);
-        assertEquals(4, ((Box)r.getObservation()).toArray().length);
+        assertEquals(4, o.getData().shape()[0]);
+        assertEquals(4, ((Box)r.getObservation()).getData().shape()[0]);
         assertNotEquals(null, mdp.newInstance());
         mdp.close();
     }
diff --git a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoBox.java b/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoBox.java
index 412976b27..91cec3d8b 100644
--- a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoBox.java
+++ b/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoBox.java
@@ -1,5 +1,5 @@
 /*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
+ * Copyright (c) 2020 Konduit K. K.
  *
  * This program and the accompanying materials are made available under the
  * terms of the Apache License, Version 2.0 which is available at
@@ -16,33 +16,13 @@
 
 package org.deeplearning4j.malmo;
 
-import java.util.Arrays;
+import org.deeplearning4j.rl4j.space.Box;
+import org.nd4j.linalg.factory.Nd4j;
 
-import org.deeplearning4j.rl4j.space.Encodable;
+@Deprecated
+public class MalmoBox extends Box {
 
-/**
- * Encodable state as a simple value array similar to Gym Box model, but without a JSON constructor
- * @author howard-abrams (howard.abrams@ca.com) on 1/12/17.
- */
-public class MalmoBox implements Encodable {
-    double[] value;
-
-    /**
-     * Construct state from an array of doubles
-     * @param value state values
-     */
-    //TODO: If this constructor was added to "Box", we wouldn't need this class at all.
-    public MalmoBox(double... value) {
-        this.value = value;
-    }
-
-    @Override
-    public double[] toArray() {
-        return value;
-    }
-
-    @Override
-    public String toString() {
-        return Arrays.toString(value);
+    public MalmoBox(double... arr) {
+        super(arr);
     }
 }
diff --git a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoDescretePositionPolicy.java b/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoDescretePositionPolicy.java
index c853fa362..d68de87ef 100644
--- a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoDescretePositionPolicy.java
+++ b/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoDescretePositionPolicy.java
@@ -19,6 +19,8 @@ package org.deeplearning4j.malmo;
 import java.util.Arrays;
 
 import com.microsoft.msr.malmo.WorldState;
+import org.deeplearning4j.rl4j.space.Box;
+import org.nd4j.linalg.api.ndarray.INDArray;
 
 /**
  * A Malmo consistency policy that ensures the both there is a reward and next observation has a different position that the previous one.
@@ -30,14 +32,14 @@ public class MalmoDescretePositionPolicy implements MalmoObservationPolicy {
 
     @Override
     public boolean isObservationConsistant(WorldState world_state, WorldState original_world_state) {
-        MalmoBox last_observation = observationSpace.getObservation(world_state);
-        MalmoBox old_observation = observationSpace.getObservation(original_world_state);
+        Box last_observation = observationSpace.getObservation(world_state);
+        Box old_observation = observationSpace.getObservation(original_world_state);
 
-        double[] newvalues = old_observation == null ? null : old_observation.toArray();
-        double[] oldvalues = last_observation == null ? null : last_observation.toArray();
+        INDArray newvalues = old_observation == null ? null : old_observation.getData();
+        INDArray oldvalues = last_observation == null ? null : last_observation.getData();
 
         return !(world_state.getObservations().isEmpty() || world_state.getRewards().isEmpty()
-                        || Arrays.equals(oldvalues, newvalues));
+                        || oldvalues.eq(newvalues).all());
     }
 
 }
diff --git a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoEnv.java b/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoEnv.java
index b27412b99..b98db3650 100644
--- a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoEnv.java
+++ b/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoEnv.java
@@ -21,6 +21,7 @@ import java.nio.file.Paths;
 
 import org.deeplearning4j.gym.StepReply;
 import org.deeplearning4j.rl4j.mdp.MDP;
+import org.deeplearning4j.rl4j.space.Box;
 import org.deeplearning4j.rl4j.space.DiscreteSpace;
 
 import com.microsoft.msr.malmo.AgentHost;
@@ -34,6 +35,7 @@ import com.microsoft.msr.malmo.WorldState;
 import lombok.Setter;
 import lombok.Getter;
 
+import org.deeplearning4j.rl4j.space.Encodable;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -233,7 +235,7 @@ public class MalmoEnv implements MDP {
             logger.info("Mission ended");
         }
 
-        return new StepReply(last_observation, getRewards(last_world_state), isDone(), null);
+        return new StepReply<>(last_observation, getRewards(last_world_state), isDone(), null);
     }
 
     private double getRewards(WorldState world_state) {
diff --git a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpace.java b/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpace.java
index 61a0dddc7..cc140bee2 100644
--- a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpace.java
+++ b/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpace.java
@@ -16,6 +16,8 @@
 
 package org.deeplearning4j.malmo;
 
+import org.deeplearning4j.rl4j.space.Box;
+import org.deeplearning4j.rl4j.space.Encodable;
 import org.deeplearning4j.rl4j.space.ObservationSpace;
 
 import com.microsoft.msr.malmo.WorldState;
diff --git a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpaceGrid.java b/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpaceGrid.java
index 00b7c4f7a..1595def55 100644
--- a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpaceGrid.java
+++ b/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpaceGrid.java
@@ -19,6 +19,7 @@ package org.deeplearning4j.malmo;
 
 import com.microsoft.msr.malmo.TimestampedStringVector;
 import com.microsoft.msr.malmo.WorldState;
+import org.deeplearning4j.rl4j.space.Box;
 import org.json.JSONArray;
 import org.json.JSONObject;
 import org.nd4j.linalg.api.ndarray.INDArray;
diff --git a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpacePixels.java b/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpacePixels.java
index 52dc02918..4fbbb6cc2 100644
--- a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpacePixels.java
+++ b/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpacePixels.java
@@ -18,6 +18,7 @@ package org.deeplearning4j.malmo;
 
 import java.util.HashMap;
 
+import org.deeplearning4j.rl4j.space.Box;
 import org.nd4j.linalg.api.ndarray.INDArray;
 import org.nd4j.linalg.factory.Nd4j;
 
diff --git a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpacePosition.java b/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpacePosition.java
index 50f710bf5..cf85059d8 100644
--- a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpacePosition.java
+++ b/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpacePosition.java
@@ -16,6 +16,7 @@
 
 package org.deeplearning4j.malmo;
 
+import org.deeplearning4j.rl4j.space.Box;
 import org.json.JSONObject;
 
 import org.nd4j.linalg.api.ndarray.INDArray;